From 0871c3b87c82a3a4167b68fac8fcf6fde82e8409 Mon Sep 17 00:00:00 2001 From: Ivan Malison Date: Mon, 22 Sep 2014 17:57:22 -0700 Subject: [PATCH] Remove instance variables for filter_headers, filter_query_params, ignore_localhost and ignore_hosts. These still exist on the VCR object, but they are automatically translated into a filter function when passed to the cassette. --- tests/integration/test_filter.py | 11 ++- tests/unit/test_filters.py | 12 +-- tests/unit/test_vcr.py | 45 ++++++++-- tox.ini | 1 - vcr/cassette.py | 42 +++------- vcr/config.py | 138 +++++++++++++++++++++---------- vcr/filters.py | 23 +----- vcr/stubs/__init__.py | 2 +- 8 files changed, 161 insertions(+), 113 deletions(-) diff --git a/tests/integration/test_filter.py b/tests/integration/test_filter.py index 11dea94..c3e6582 100644 --- a/tests/integration/test_filter.py +++ b/tests/integration/test_filter.py @@ -54,15 +54,20 @@ def test_filter_querystring(tmpdir): urlopen(url) assert 'foo' not in cass.requests[0].url + def test_filter_callback(tmpdir): url = 'http://httpbin.org/get' cass_file = str(tmpdir.join('basic_auth_filter.yaml')) def before_record_cb(request): if request.path != '/get': return request - my_vcr = vcr.VCR( - before_record = before_record_cb, - ) + # Test the legacy keyword. + my_vcr = vcr.VCR(before_record=before_record_cb) + with my_vcr.use_cassette(cass_file, filter_headers=['authorization']) as cass: + urlopen(url) + assert len(cass) == 0 + + my_vcr = vcr.VCR(before_record_request=before_record_cb) with my_vcr.use_cassette(cass_file, filter_headers=['authorization']) as cass: urlopen(url) assert len(cass) == 0 diff --git a/tests/unit/test_filters.py b/tests/unit/test_filters.py index ae1f7a4..2629c5e 100644 --- a/tests/unit/test_filters.py +++ b/tests/unit/test_filters.py @@ -1,37 +1,37 @@ -from vcr.filters import _remove_headers, _remove_query_parameters +from vcr.filters import remove_headers, remove_query_parameters from vcr.request import Request def test_remove_headers(): headers = {'hello': ['goodbye'], 'secret': ['header']} request = Request('GET', 'http://google.com', '', headers) - _remove_headers(request, ['secret']) + remove_headers(request, ['secret']) assert request.headers == {'hello': 'goodbye'} def test_remove_headers_empty(): headers = {'hello': 'goodbye', 'secret': 'header'} request = Request('GET', 'http://google.com', '', headers) - _remove_headers(request, []) + remove_headers(request, []) assert request.headers == headers def test_remove_query_parameters(): uri = 'http://g.com/?q=cowboys&w=1' request = Request('GET', uri, '', {}) - _remove_query_parameters(request, ['w']) + remove_query_parameters(request, ['w']) assert request.uri == 'http://g.com/?q=cowboys' def test_remove_all_query_parameters(): uri = 'http://g.com/?q=cowboys&w=1' request = Request('GET', uri, '', {}) - _remove_query_parameters(request, ['w', 'q']) + remove_query_parameters(request, ['w', 'q']) assert request.uri == 'http://g.com/' def test_remove_nonexistent_query_parameters(): uri = 'http://g.com/' request = Request('GET', uri, '', {}) - _remove_query_parameters(request, ['w', 'q']) + remove_query_parameters(request, ['w', 'q']) assert request.uri == 'http://g.com/' diff --git a/tests/unit/test_vcr.py b/tests/unit/test_vcr.py index a2143aa..33e17a0 100644 --- a/tests/unit/test_vcr.py +++ b/tests/unit/test_vcr.py @@ -2,31 +2,60 @@ import mock import pytest from vcr import VCR, use_cassette +from vcr.request import Request def test_vcr_use_cassette(): - filter_headers = mock.Mock() - test_vcr = VCR(filter_headers=filter_headers) + record_mode = mock.Mock() + test_vcr = VCR(record_mode=record_mode) with mock.patch('vcr.cassette.Cassette.load') as mock_cassette_load: @test_vcr.use_cassette('test') def function(): pass assert mock_cassette_load.call_count == 0 function() - assert mock_cassette_load.call_args[1]['filter_headers'] is filter_headers + assert mock_cassette_load.call_args[1]['record_mode'] is record_mode # Make sure that calls to function now use cassettes with the # new filter_header_settings - test_vcr.filter_headers = ('a',) + test_vcr.record_mode = mock.Mock() function() - assert mock_cassette_load.call_args[1]['filter_headers'] == test_vcr.filter_headers + assert mock_cassette_load.call_args[1]['record_mode'] == test_vcr.record_mode # Ensure that explicitly provided arguments still supercede # those on the vcr. - new_filter_headers = mock.Mock() + new_record_mode = mock.Mock() - with test_vcr.use_cassette('test', filter_headers=new_filter_headers) as cassette: - assert cassette._filter_headers == new_filter_headers + with test_vcr.use_cassette('test', record_mode=new_record_mode) as cassette: + assert cassette.record_mode == new_record_mode + + +def test_vcr_before_record_request_params(): + base_path = 'http://httpbin.org/' + def before_record_cb(request): + if request.path != '/get': + return request + test_vcr = VCR(filter_headers=('cookie',), before_record_request=before_record_cb, + ignore_hosts=('www.test.com',), ignore_localhost=True, + filter_query_parameters=('foo',)) + + with test_vcr.use_cassette('test') as cassette: + assert cassette.filter_request(Request('GET', base_path + 'get', '', {})) is None + assert cassette.filter_request(Request('GET', base_path + 'get2', '', {})) is not None + + assert cassette.filter_request(Request('GET', base_path + '?foo=bar', '', {})).query == [] + assert cassette.filter_request( + Request('GET', base_path + '?foo=bar', '', + {'cookie': 'test', 'other': 'fun'})).headers == {'other': 'fun'} + assert cassette.filter_request(Request('GET', base_path + '?foo=bar', '', + {'cookie': 'test', 'other': 'fun'})).headers == {'other': 'fun'} + + assert cassette.filter_request(Request('GET', 'http://www.test.com' + '?foo=bar', '', + {'cookie': 'test', 'other': 'fun'})) is None + + with test_vcr.use_cassette('test', before_record_request=None) as cassette: + # Test that before_record can be overwritten with + assert cassette.filter_request(Request('GET', base_path + 'get', '', {})) is not None @pytest.fixture diff --git a/tox.ini b/tox.ini index 5797c48..bf95a3c 100644 --- a/tox.ini +++ b/tox.ini @@ -96,7 +96,6 @@ deps = {[testenv]deps} requests==2.4.0 - [testenv:py26requests23] basepython = python2.6 deps = diff --git a/vcr/cassette.py b/vcr/cassette.py index bc52340..e2b1d91 100644 --- a/vcr/cassette.py +++ b/vcr/cassette.py @@ -11,7 +11,6 @@ except ImportError: # Internal imports from .patch import CassettePatcherBuilder from .persist import load_cassette, save_cassette -from .filters import filter_request from .serializers import yamlserializer from .matchers import requests_match, uri, method from .errors import UnhandledHTTPRequestError @@ -25,7 +24,7 @@ class CassetteContextDecorator(object): removing cassettes. This class defers the creation of a new cassette instance until the point at - which it is installed by context manager or decorator. The fact that a new + which it is installned by context manager or decorator. The fact that a new cassette is used with each application prevents the state of any cassette from interfering with another. """ @@ -50,7 +49,7 @@ class CassetteContextDecorator(object): cassette._save() def __enter__(self): - assert self.__finish is None + assert self.__finish is None, "Cassette already open." path, kwargs = self._args_getter() self.__finish = self._patch_generator(self.cls.load(path, **kwargs)) return next(self.__finish) @@ -70,7 +69,7 @@ class Cassette(object): @classmethod def load(cls, path, **kwargs): - '''Load in the cassette stored at the provided path''' + '''Instantiate and load the cassette stored at the specified path.''' new_cassette = cls(path, **kwargs) new_cassette._load() return new_cassette @@ -85,20 +84,13 @@ class Cassette(object): def __init__(self, path, serializer=yamlserializer, record_mode='once', match_on=(uri, method), filter_headers=(), - filter_query_parameters=(), before_record=None, before_record_response=None, - ignore_hosts=(), ignore_localhost=()): + filter_query_parameters=(), before_record_request=None, + before_record_response=None, ignore_hosts=(), ignore_localhost=()): self._path = path self._serializer = serializer self._match_on = match_on - self._filter_headers = filter_headers - self._filter_query_parameters = filter_query_parameters - self._before_record = before_record - self._before_record_response = before_record_response - self._ignore_hosts = ignore_hosts - if ignore_localhost: - self._ignore_hosts = list(set( - list(self._ignore_hosts) + ['localhost', '0.0.0.0', '127.0.0.1'] - )) + self._before_record_request = before_record_request or (lambda x: x) + self._before_record_response = before_record_response or (lambda x: x) # self.data is the list of (req, resp) tuples self.data = [] @@ -131,18 +123,9 @@ class Cassette(object): return self.rewound and self.record_mode == 'once' or \ self.record_mode == 'none' - def _filter_request(self, request): - return filter_request( - request=request, - filter_headers=self._filter_headers, - filter_query_parameters=self._filter_query_parameters, - before_record=self._before_record, - ignore_hosts=self._ignore_hosts - ) - def append(self, request, response): '''Add a request, response pair to this cassette''' - request = self._filter_request(request) + request = self._before_record_request(request) if not request: return if self._before_record_response: @@ -150,20 +133,21 @@ class Cassette(object): self.data.append((request, response)) self.dirty = True + def filter_request(self, request): + return self._before_record_request(request) + def _responses(self, request): """ internal API, returns an iterator with all responses matching the request. """ - request = self._filter_request(request) - if not request: - return + request = self._before_record_request(request) for index, (stored_request, response) in enumerate(self.data): if requests_match(request, stored_request, self._match_on): yield index, response def can_play_response_for(self, request): - request = self._filter_request(request) + request = self._before_record_request(request) return request and request in self and \ self.record_mode != 'all' and \ self.rewound diff --git a/vcr/config.py b/vcr/config.py index 8308b48..e8ff7a8 100644 --- a/vcr/config.py +++ b/vcr/config.py @@ -1,30 +1,22 @@ +import collections +import copy import functools import os + from .cassette import Cassette from .serializers import yamlserializer, jsonserializer from . import matchers +from . import filters class VCR(object): - def __init__(self, - serializer='yaml', - cassette_library_dir=None, - record_mode="once", - filter_headers=(), - filter_query_parameters=(), - before_record=None, - before_record_response=None, - match_on=( - 'method', - 'scheme', - 'host', - 'port', - 'path', - 'query', - ), - ignore_hosts=(), - ignore_localhost=False, - ): + + def __init__(self, serializer='yaml', cassette_library_dir=None, + record_mode="once", filter_headers=(), + filter_query_parameters=(), before_record_request=None, + before_record_response=None, ignore_hosts=(), + match_on=('method', 'scheme', 'host', 'port', 'path', 'query',), + ignore_localhost=False, before_record=None): self.serializer = serializer self.match_on = match_on self.cassette_library_dir = cassette_library_dir @@ -47,7 +39,7 @@ class VCR(object): self.record_mode = record_mode self.filter_headers = filter_headers self.filter_query_parameters = filter_query_parameters - self.before_record = before_record + self.before_record_request = before_record_request or before_record self.before_record_response = before_record_response self.ignore_hosts = ignore_hosts self.ignore_localhost = ignore_localhost @@ -69,12 +61,13 @@ class VCR(object): matchers.append(self.matchers[m]) except KeyError: raise KeyError( - "Matcher {0} doesn't exist or isn't registered".format( - m) + "Matcher {0} doesn't exist or isn't registered".format(m) ) return matchers - def use_cassette(self, path, **kwargs): + def use_cassette(self, path, with_current_defaults=False, **kwargs): + if with_current_defaults: + return Cassette.use(path, self.get_path_and_merged_config(path, **kwargs)) args_getter = functools.partial(self.get_path_and_merged_config, path, **kwargs) return Cassette.use_arg_getter(args_getter) @@ -89,30 +82,87 @@ class VCR(object): path = os.path.join(cassette_library_dir, path) merged_config = { - "serializer": self._get_serializer(serializer_name), - "match_on": self._get_matchers(matcher_names), - "record_mode": kwargs.get('record_mode', self.record_mode), - "filter_headers": kwargs.get( - 'filter_headers', self.filter_headers - ), - "filter_query_parameters": kwargs.get( - 'filter_query_parameters', self.filter_query_parameters - ), - "before_record": kwargs.get( - "before_record", self.before_record - ), - "before_record_response": kwargs.get( - "before_record_response", self.before_record_response - ), - "ignore_hosts": kwargs.get( - 'ignore_hosts', self.ignore_hosts - ), - "ignore_localhost": kwargs.get( - 'ignore_localhost', self.ignore_localhost - ), + 'serializer': self._get_serializer(serializer_name), + 'match_on': self._get_matchers(matcher_names), + 'record_mode': kwargs.get('record_mode', self.record_mode), + 'before_record_request': self._build_before_record_request(kwargs), + 'before_record_response': self._build_before_record_response(kwargs) } return path, merged_config + def _build_before_record_response(self, options): + before_record_response = options.get( + 'before_record_response', self.before_record_response + ) + filter_functions = [] + if before_record_response and not isinstance(before_record_response, + collections.Iterable): + before_record_response = (before_record_response,) + for function in before_record_response: + filter_functions.append(function) + def before_record_response(response): + for function in filter_functions: + if response is None: + break + response = function(response) + return response + return before_record_response + + def _build_before_record_request(self, options): + filter_functions = [] + filter_headers = options.get( + 'filter_headers', self.filter_headers + ) + filter_query_parameters = options.get( + 'filter_query_parameters', self.filter_query_parameters + ) + before_record_request = options.get( + "before_record_request", options.get("before_record", self.before_record_request) + ) + ignore_hosts = options.get( + 'ignore_hosts', self.ignore_hosts + ) + ignore_localhost = options.get( + 'ignore_localhost', self.ignore_localhost + ) + if filter_headers: + filter_functions.append(functools.partial(filters.remove_headers, + headers_to_remove=filter_headers)) + if filter_query_parameters: + filter_functions.append(functools.partial(filters.remove_query_parameters, + query_parameters_to_remove=filter_query_parameters)) + + hosts_to_ignore = list(ignore_hosts) + if ignore_localhost: + hosts_to_ignore.extend(('localhost', '0.0.0.0', '127.0.0.1')) + + if hosts_to_ignore: + hosts_to_ignore = set(hosts_to_ignore) + filter_functions.append(self._build_ignore_hosts(hosts_to_ignore)) + + if before_record_request: + if not isinstance(before_record_request, collections.Iterable): + before_record_request = (before_record_request,) + for function in before_record_request: + filter_functions.append(function) + def before_record_request(request): + request = copy.copy(request) + for function in filter_functions: + if request is None: + break + request = function(request) + return request + + return before_record_request + + @staticmethod + def _build_ignore_hosts(hosts_to_ignore): + def filter_ignored_hosts(request): + if hasattr(request, 'host') and request.host in hosts_to_ignore: + return + return request + return filter_ignored_hosts + def register_serializer(self, name, serializer): self.serializers[name] = serializer diff --git a/vcr/filters.py b/vcr/filters.py index 562f840..7e79451 100644 --- a/vcr/filters.py +++ b/vcr/filters.py @@ -2,7 +2,7 @@ from six.moves.urllib.parse import urlparse, urlencode, urlunparse import copy -def _remove_headers(request, headers_to_remove): +def remove_headers(request, headers_to_remove): headers = copy.copy(request.headers) headers_to_remove = [h.lower() for h in headers_to_remove] keys = [k for k in headers if k.lower() in headers_to_remove] @@ -13,7 +13,7 @@ def _remove_headers(request, headers_to_remove): return request -def _remove_query_parameters(request, query_parameters_to_remove): +def remove_query_parameters(request, query_parameters_to_remove): query = request.query new_query = [(k, v) for (k, v) in query if k not in query_parameters_to_remove] @@ -22,22 +22,3 @@ def _remove_query_parameters(request, query_parameters_to_remove): uri_parts[4] = urlencode(new_query) request.uri = urlunparse(uri_parts) return request - - -def filter_request( - request, - filter_headers, - filter_query_parameters, - before_record, - ignore_hosts - ): - request = copy.copy(request) # don't mutate request object - if hasattr(request, 'headers') and filter_headers: - request = _remove_headers(request, filter_headers) - if hasattr(request, 'host') and request.host in ignore_hosts: - return None - if filter_query_parameters: - request = _remove_query_parameters(request, filter_query_parameters) - if before_record: - request = before_record(request) - return request diff --git a/vcr/stubs/__init__.py b/vcr/stubs/__init__.py index df67b1d..6ffe4b4 100644 --- a/vcr/stubs/__init__.py +++ b/vcr/stubs/__init__.py @@ -217,7 +217,7 @@ class VCRConnection(object): response = self.cassette.play_response(self._vcr_request) return VCRHTTPResponse(response) else: - if self.cassette.write_protected and self.cassette._filter_request(self._vcr_request): + if self.cassette.write_protected and self.cassette.filter_request(self._vcr_request): raise CannotOverwriteExistingCassetteException( "Can't overwrite existing cassette (%r) in " "your current record mode (%r)."