1
0
mirror of https://github.com/kevin1024/vcrpy.git synced 2025-12-10 17:45:35 +00:00

Merge pull request #191 from agriffis/trivial-fixes

Trivial cleanups and one bugfix
This commit is contained in:
Ivan 'Goat' Malison
2015-08-23 13:33:34 -07:00
4 changed files with 41 additions and 25 deletions

View File

@@ -72,6 +72,32 @@ def test_vcr_before_record_request_params():
assert cassette.filter_request(Request('GET', base_path + 'get', '', {})) is not None assert cassette.filter_request(Request('GET', base_path + 'get', '', {})) is not None
def test_vcr_before_record_response_iterable():
# Regression test for #191
request = Request('GET', '/', '', {})
response = object() # just can't be None
# Prevent actually saving the cassette
with mock.patch('vcr.cassette.save_cassette'):
# Baseline: non-iterable before_record_response should work
mock_filter = mock.Mock()
vcr = VCR(before_record_response=mock_filter)
with vcr.use_cassette('test') as cassette:
assert mock_filter.call_count == 0
cassette.append(request, response)
assert mock_filter.call_count == 1
# Regression test: iterable before_record_response should work too
mock_filter = mock.Mock()
vcr = VCR(before_record_response=(mock_filter,))
with vcr.use_cassette('test') as cassette:
assert mock_filter.call_count == 0
cassette.append(request, response)
assert mock_filter.call_count == 1
@pytest.fixture @pytest.fixture
def random_fixture(): def random_fixture():
return 1 return 1

View File

@@ -164,7 +164,7 @@ class Cassette(object):
return CassetteContextDecorator.from_args(cls, **kwargs) return CassetteContextDecorator.from_args(cls, **kwargs)
def __init__(self, path, serializer=yamlserializer, record_mode='once', def __init__(self, path, serializer=yamlserializer, record_mode='once',
match_on=(uri, method), before_record_request=None, match_on=(uri, method), before_record_request=None,
before_record_response=None, custom_patches=(), before_record_response=None, custom_patches=(),
inject=False): inject=False):
@@ -210,8 +210,7 @@ class Cassette(object):
request = self._before_record_request(request) request = self._before_record_request(request)
if not request: if not request:
return return
if self._before_record_response: response = self._before_record_response(response)
response = self._before_record_response(response)
self.data.append((request, response)) self.data.append((request, response))
self.dirty = True self.dirty = True

View File

@@ -67,10 +67,11 @@ class VCR(object):
try: try:
serializer = self.serializers[serializer_name] serializer = self.serializers[serializer_name]
except KeyError: except KeyError:
print("Serializer {0} doesn't exist or isn't registered".format( raise KeyError(
serializer_name "Serializer {0} doesn't exist or isn't registered".format(
)) serializer_name
raise KeyError )
)
return serializer return serializer
def _get_matchers(self, matcher_names): def _get_matchers(self, matcher_names):
@@ -157,12 +158,10 @@ class VCR(object):
'before_record_response', self.before_record_response 'before_record_response', self.before_record_response
) )
filter_functions = [] filter_functions = []
if before_record_response and not isinstance(before_record_response, if before_record_response:
collections.Iterable): if not isinstance(before_record_response, collections.Iterable):
before_record_response = (before_record_response,) before_record_response = (before_record_response,)
for function in before_record_response: filter_functions.extend(before_record_response)
filter_functions.append(function)
def before_record_response(response): def before_record_response(response):
for function in filter_functions: for function in filter_functions:
if response is None: if response is None:
@@ -212,20 +211,16 @@ class VCR(object):
) )
) )
hosts_to_ignore = list(ignore_hosts) hosts_to_ignore = set(ignore_hosts)
if ignore_localhost: if ignore_localhost:
hosts_to_ignore.extend(('localhost', '0.0.0.0', '127.0.0.1')) hosts_to_ignore.update(('localhost', '0.0.0.0', '127.0.0.1'))
if hosts_to_ignore: if hosts_to_ignore:
hosts_to_ignore = set(hosts_to_ignore)
filter_functions.append(self._build_ignore_hosts(hosts_to_ignore)) filter_functions.append(self._build_ignore_hosts(hosts_to_ignore))
if before_record_request: if before_record_request:
if not isinstance(before_record_request, collections.Iterable): if not isinstance(before_record_request, collections.Iterable):
before_record_request = (before_record_request,) before_record_request = (before_record_request,)
for function in before_record_request: filter_functions.extend(before_record_request)
filter_functions.append(function)
def before_record_request(request): def before_record_request(request):
request = copy.copy(request) request = copy.copy(request)
for function in filter_functions: for function in filter_functions:
@@ -233,7 +228,6 @@ class VCR(object):
break break
request = function(request) request = function(request)
return request return request
return before_record_request return before_record_request
@staticmethod @staticmethod

View File

@@ -3,8 +3,5 @@ class CannotOverwriteExistingCassetteException(Exception):
class UnhandledHTTPRequestError(KeyError): class UnhandledHTTPRequestError(KeyError):
''' """Raised when a cassette does not contain the request we want."""
Raised when a cassette does not c
ontain the request we want
'''
pass pass