diff --git a/tests/unit/test_vcr.py b/tests/unit/test_vcr.py index 1029de6..cbdcde4 100644 --- a/tests/unit/test_vcr.py +++ b/tests/unit/test_vcr.py @@ -72,6 +72,32 @@ def test_vcr_before_record_request_params(): assert cassette.filter_request(Request('GET', base_path + 'get', '', {})) is not None +def test_vcr_before_record_response_iterable(): + # Regression test for #191 + + request = Request('GET', '/', '', {}) + response = object() # just can't be None + + # Prevent actually saving the cassette + with mock.patch('vcr.cassette.save_cassette'): + + # Baseline: non-iterable before_record_response should work + mock_filter = mock.Mock() + vcr = VCR(before_record_response=mock_filter) + with vcr.use_cassette('test') as cassette: + assert mock_filter.call_count == 0 + cassette.append(request, response) + assert mock_filter.call_count == 1 + + # Regression test: iterable before_record_response should work too + mock_filter = mock.Mock() + vcr = VCR(before_record_response=(mock_filter,)) + with vcr.use_cassette('test') as cassette: + assert mock_filter.call_count == 0 + cassette.append(request, response) + assert mock_filter.call_count == 1 + + @pytest.fixture def random_fixture(): return 1 diff --git a/vcr/cassette.py b/vcr/cassette.py index ff63492..bc7410f 100644 --- a/vcr/cassette.py +++ b/vcr/cassette.py @@ -164,7 +164,7 @@ class Cassette(object): return CassetteContextDecorator.from_args(cls, **kwargs) def __init__(self, path, serializer=yamlserializer, record_mode='once', - match_on=(uri, method), before_record_request=None, + match_on=(uri, method), before_record_request=None, before_record_response=None, custom_patches=(), inject=False): @@ -210,8 +210,7 @@ class Cassette(object): request = self._before_record_request(request) if not request: return - if self._before_record_response: - response = self._before_record_response(response) + response = self._before_record_response(response) self.data.append((request, response)) self.dirty = True diff --git a/vcr/config.py b/vcr/config.py index 0c45844..a94520f 100644 --- a/vcr/config.py +++ b/vcr/config.py @@ -67,10 +67,11 @@ class VCR(object): try: serializer = self.serializers[serializer_name] except KeyError: - print("Serializer {0} doesn't exist or isn't registered".format( - serializer_name - )) - raise KeyError + raise KeyError( + "Serializer {0} doesn't exist or isn't registered".format( + serializer_name + ) + ) return serializer def _get_matchers(self, matcher_names): @@ -157,12 +158,10 @@ class VCR(object): 'before_record_response', self.before_record_response ) filter_functions = [] - if before_record_response and not isinstance(before_record_response, - collections.Iterable): - before_record_response = (before_record_response,) - for function in before_record_response: - filter_functions.append(function) - + if before_record_response: + if not isinstance(before_record_response, collections.Iterable): + before_record_response = (before_record_response,) + filter_functions.extend(before_record_response) def before_record_response(response): for function in filter_functions: if response is None: @@ -212,20 +211,16 @@ class VCR(object): ) ) - hosts_to_ignore = list(ignore_hosts) + hosts_to_ignore = set(ignore_hosts) if ignore_localhost: - hosts_to_ignore.extend(('localhost', '0.0.0.0', '127.0.0.1')) - + hosts_to_ignore.update(('localhost', '0.0.0.0', '127.0.0.1')) if hosts_to_ignore: - hosts_to_ignore = set(hosts_to_ignore) filter_functions.append(self._build_ignore_hosts(hosts_to_ignore)) if before_record_request: if not isinstance(before_record_request, collections.Iterable): before_record_request = (before_record_request,) - for function in before_record_request: - filter_functions.append(function) - + filter_functions.extend(before_record_request) def before_record_request(request): request = copy.copy(request) for function in filter_functions: @@ -233,7 +228,6 @@ class VCR(object): break request = function(request) return request - return before_record_request @staticmethod diff --git a/vcr/errors.py b/vcr/errors.py index 1b40f91..bdc9701 100644 --- a/vcr/errors.py +++ b/vcr/errors.py @@ -3,8 +3,5 @@ class CannotOverwriteExistingCassetteException(Exception): class UnhandledHTTPRequestError(KeyError): - ''' - Raised when a cassette does not c - ontain the request we want - ''' + """Raised when a cassette does not contain the request we want.""" pass