diff --git a/tests/unit/test_vcr.py b/tests/unit/test_vcr.py index 1029de6..cbdcde4 100644 --- a/tests/unit/test_vcr.py +++ b/tests/unit/test_vcr.py @@ -72,6 +72,32 @@ def test_vcr_before_record_request_params(): assert cassette.filter_request(Request('GET', base_path + 'get', '', {})) is not None +def test_vcr_before_record_response_iterable(): + # Regression test for #191 + + request = Request('GET', '/', '', {}) + response = object() # just can't be None + + # Prevent actually saving the cassette + with mock.patch('vcr.cassette.save_cassette'): + + # Baseline: non-iterable before_record_response should work + mock_filter = mock.Mock() + vcr = VCR(before_record_response=mock_filter) + with vcr.use_cassette('test') as cassette: + assert mock_filter.call_count == 0 + cassette.append(request, response) + assert mock_filter.call_count == 1 + + # Regression test: iterable before_record_response should work too + mock_filter = mock.Mock() + vcr = VCR(before_record_response=(mock_filter,)) + with vcr.use_cassette('test') as cassette: + assert mock_filter.call_count == 0 + cassette.append(request, response) + assert mock_filter.call_count == 1 + + @pytest.fixture def random_fixture(): return 1 diff --git a/vcr/config.py b/vcr/config.py index b5319c2..030462a 100644 --- a/vcr/config.py +++ b/vcr/config.py @@ -158,12 +158,10 @@ class VCR(object): 'before_record_response', self.before_record_response ) filter_functions = [] - if before_record_response and not isinstance(before_record_response, - collections.Iterable): - before_record_response = (before_record_response,) - for function in before_record_response: - filter_functions.append(function) - + if before_record_response: + if not isinstance(before_record_response, collections.Iterable): + before_record_response = (before_record_response,) + filter_functions.extend(before_record_response) def before_record_response(response): for function in filter_functions: if response is None: @@ -224,9 +222,7 @@ class VCR(object): if before_record_request: if not isinstance(before_record_request, collections.Iterable): before_record_request = (before_record_request,) - for function in before_record_request: - filter_functions.append(function) - + filter_functions.extend(before_record_request) def before_record_request(request): request = copy.copy(request) for function in filter_functions: @@ -234,7 +230,6 @@ class VCR(object): break request = function(request) return request - return before_record_request @staticmethod