1
0
mirror of https://github.com/kevin1024/vcrpy.git synced 2025-12-09 01:03:24 +00:00

Don't drop passed-in before_record_response

This is an actual bugfix. If before_record_response is passed into VCR as
an iterable then it won't be included in filter_functions. This commit
repairs the logic to separate the tests (just as it's already done for
before_record_request).

Also use .extend() rather than looping on .append()
This commit is contained in:
Aron Griffis
2015-08-23 10:25:23 -04:00
parent a569dd4dc8
commit 3e5553c56a
2 changed files with 31 additions and 10 deletions

View File

@@ -72,6 +72,32 @@ def test_vcr_before_record_request_params():
assert cassette.filter_request(Request('GET', base_path + 'get', '', {})) is not None assert cassette.filter_request(Request('GET', base_path + 'get', '', {})) is not None
def test_vcr_before_record_response_iterable():
# Regression test for #191
request = Request('GET', '/', '', {})
response = object() # just can't be None
# Prevent actually saving the cassette
with mock.patch('vcr.cassette.save_cassette'):
# Baseline: non-iterable before_record_response should work
mock_filter = mock.Mock()
vcr = VCR(before_record_response=mock_filter)
with vcr.use_cassette('test') as cassette:
assert mock_filter.call_count == 0
cassette.append(request, response)
assert mock_filter.call_count == 1
# Regression test: iterable before_record_response should work too
mock_filter = mock.Mock()
vcr = VCR(before_record_response=(mock_filter,))
with vcr.use_cassette('test') as cassette:
assert mock_filter.call_count == 0
cassette.append(request, response)
assert mock_filter.call_count == 1
@pytest.fixture @pytest.fixture
def random_fixture(): def random_fixture():
return 1 return 1

View File

@@ -158,12 +158,10 @@ class VCR(object):
'before_record_response', self.before_record_response 'before_record_response', self.before_record_response
) )
filter_functions = [] filter_functions = []
if before_record_response and not isinstance(before_record_response, if before_record_response:
collections.Iterable): if not isinstance(before_record_response, collections.Iterable):
before_record_response = (before_record_response,) before_record_response = (before_record_response,)
for function in before_record_response: filter_functions.extend(before_record_response)
filter_functions.append(function)
def before_record_response(response): def before_record_response(response):
for function in filter_functions: for function in filter_functions:
if response is None: if response is None:
@@ -224,9 +222,7 @@ class VCR(object):
if before_record_request: if before_record_request:
if not isinstance(before_record_request, collections.Iterable): if not isinstance(before_record_request, collections.Iterable):
before_record_request = (before_record_request,) before_record_request = (before_record_request,)
for function in before_record_request: filter_functions.extend(before_record_request)
filter_functions.append(function)
def before_record_request(request): def before_record_request(request):
request = copy.copy(request) request = copy.copy(request)
for function in filter_functions: for function in filter_functions:
@@ -234,7 +230,6 @@ class VCR(object):
break break
request = function(request) request = function(request)
return request return request
return before_record_request return before_record_request
@staticmethod @staticmethod