1
0
mirror of https://github.com/kevin1024/vcrpy.git synced 2025-12-08 16:53:23 +00:00

Remove instance variables for filter_headers, filter_query_params, ignore_localhost and ignore_hosts. These still exist on the VCR object, but they are automatically translated into a filter function when passed to the cassette.

This commit is contained in:
Ivan Malison
2014-09-22 17:57:22 -07:00
parent d484dee50f
commit 0871c3b87c
8 changed files with 161 additions and 113 deletions

View File

@@ -54,15 +54,20 @@ def test_filter_querystring(tmpdir):
urlopen(url) urlopen(url)
assert 'foo' not in cass.requests[0].url assert 'foo' not in cass.requests[0].url
def test_filter_callback(tmpdir): def test_filter_callback(tmpdir):
url = 'http://httpbin.org/get' url = 'http://httpbin.org/get'
cass_file = str(tmpdir.join('basic_auth_filter.yaml')) cass_file = str(tmpdir.join('basic_auth_filter.yaml'))
def before_record_cb(request): def before_record_cb(request):
if request.path != '/get': if request.path != '/get':
return request return request
my_vcr = vcr.VCR( # Test the legacy keyword.
before_record = before_record_cb, my_vcr = vcr.VCR(before_record=before_record_cb)
) with my_vcr.use_cassette(cass_file, filter_headers=['authorization']) as cass:
urlopen(url)
assert len(cass) == 0
my_vcr = vcr.VCR(before_record_request=before_record_cb)
with my_vcr.use_cassette(cass_file, filter_headers=['authorization']) as cass: with my_vcr.use_cassette(cass_file, filter_headers=['authorization']) as cass:
urlopen(url) urlopen(url)
assert len(cass) == 0 assert len(cass) == 0

View File

@@ -1,37 +1,37 @@
from vcr.filters import _remove_headers, _remove_query_parameters from vcr.filters import remove_headers, remove_query_parameters
from vcr.request import Request from vcr.request import Request
def test_remove_headers(): def test_remove_headers():
headers = {'hello': ['goodbye'], 'secret': ['header']} headers = {'hello': ['goodbye'], 'secret': ['header']}
request = Request('GET', 'http://google.com', '', headers) request = Request('GET', 'http://google.com', '', headers)
_remove_headers(request, ['secret']) remove_headers(request, ['secret'])
assert request.headers == {'hello': 'goodbye'} assert request.headers == {'hello': 'goodbye'}
def test_remove_headers_empty(): def test_remove_headers_empty():
headers = {'hello': 'goodbye', 'secret': 'header'} headers = {'hello': 'goodbye', 'secret': 'header'}
request = Request('GET', 'http://google.com', '', headers) request = Request('GET', 'http://google.com', '', headers)
_remove_headers(request, []) remove_headers(request, [])
assert request.headers == headers assert request.headers == headers
def test_remove_query_parameters(): def test_remove_query_parameters():
uri = 'http://g.com/?q=cowboys&w=1' uri = 'http://g.com/?q=cowboys&w=1'
request = Request('GET', uri, '', {}) request = Request('GET', uri, '', {})
_remove_query_parameters(request, ['w']) remove_query_parameters(request, ['w'])
assert request.uri == 'http://g.com/?q=cowboys' assert request.uri == 'http://g.com/?q=cowboys'
def test_remove_all_query_parameters(): def test_remove_all_query_parameters():
uri = 'http://g.com/?q=cowboys&w=1' uri = 'http://g.com/?q=cowboys&w=1'
request = Request('GET', uri, '', {}) request = Request('GET', uri, '', {})
_remove_query_parameters(request, ['w', 'q']) remove_query_parameters(request, ['w', 'q'])
assert request.uri == 'http://g.com/' assert request.uri == 'http://g.com/'
def test_remove_nonexistent_query_parameters(): def test_remove_nonexistent_query_parameters():
uri = 'http://g.com/' uri = 'http://g.com/'
request = Request('GET', uri, '', {}) request = Request('GET', uri, '', {})
_remove_query_parameters(request, ['w', 'q']) remove_query_parameters(request, ['w', 'q'])
assert request.uri == 'http://g.com/' assert request.uri == 'http://g.com/'

View File

@@ -2,31 +2,60 @@ import mock
import pytest import pytest
from vcr import VCR, use_cassette from vcr import VCR, use_cassette
from vcr.request import Request
def test_vcr_use_cassette(): def test_vcr_use_cassette():
filter_headers = mock.Mock() record_mode = mock.Mock()
test_vcr = VCR(filter_headers=filter_headers) test_vcr = VCR(record_mode=record_mode)
with mock.patch('vcr.cassette.Cassette.load') as mock_cassette_load: with mock.patch('vcr.cassette.Cassette.load') as mock_cassette_load:
@test_vcr.use_cassette('test') @test_vcr.use_cassette('test')
def function(): def function():
pass pass
assert mock_cassette_load.call_count == 0 assert mock_cassette_load.call_count == 0
function() function()
assert mock_cassette_load.call_args[1]['filter_headers'] is filter_headers assert mock_cassette_load.call_args[1]['record_mode'] is record_mode
# Make sure that calls to function now use cassettes with the # Make sure that calls to function now use cassettes with the
# new filter_header_settings # new filter_header_settings
test_vcr.filter_headers = ('a',) test_vcr.record_mode = mock.Mock()
function() function()
assert mock_cassette_load.call_args[1]['filter_headers'] == test_vcr.filter_headers assert mock_cassette_load.call_args[1]['record_mode'] == test_vcr.record_mode
# Ensure that explicitly provided arguments still supercede # Ensure that explicitly provided arguments still supercede
# those on the vcr. # those on the vcr.
new_filter_headers = mock.Mock() new_record_mode = mock.Mock()
with test_vcr.use_cassette('test', filter_headers=new_filter_headers) as cassette: with test_vcr.use_cassette('test', record_mode=new_record_mode) as cassette:
assert cassette._filter_headers == new_filter_headers assert cassette.record_mode == new_record_mode
def test_vcr_before_record_request_params():
base_path = 'http://httpbin.org/'
def before_record_cb(request):
if request.path != '/get':
return request
test_vcr = VCR(filter_headers=('cookie',), before_record_request=before_record_cb,
ignore_hosts=('www.test.com',), ignore_localhost=True,
filter_query_parameters=('foo',))
with test_vcr.use_cassette('test') as cassette:
assert cassette.filter_request(Request('GET', base_path + 'get', '', {})) is None
assert cassette.filter_request(Request('GET', base_path + 'get2', '', {})) is not None
assert cassette.filter_request(Request('GET', base_path + '?foo=bar', '', {})).query == []
assert cassette.filter_request(
Request('GET', base_path + '?foo=bar', '',
{'cookie': 'test', 'other': 'fun'})).headers == {'other': 'fun'}
assert cassette.filter_request(Request('GET', base_path + '?foo=bar', '',
{'cookie': 'test', 'other': 'fun'})).headers == {'other': 'fun'}
assert cassette.filter_request(Request('GET', 'http://www.test.com' + '?foo=bar', '',
{'cookie': 'test', 'other': 'fun'})) is None
with test_vcr.use_cassette('test', before_record_request=None) as cassette:
# Test that before_record can be overwritten with
assert cassette.filter_request(Request('GET', base_path + 'get', '', {})) is not None
@pytest.fixture @pytest.fixture

View File

@@ -96,7 +96,6 @@ deps =
{[testenv]deps} {[testenv]deps}
requests==2.4.0 requests==2.4.0
[testenv:py26requests23] [testenv:py26requests23]
basepython = python2.6 basepython = python2.6
deps = deps =

View File

@@ -11,7 +11,6 @@ except ImportError:
# Internal imports # Internal imports
from .patch import CassettePatcherBuilder from .patch import CassettePatcherBuilder
from .persist import load_cassette, save_cassette from .persist import load_cassette, save_cassette
from .filters import filter_request
from .serializers import yamlserializer from .serializers import yamlserializer
from .matchers import requests_match, uri, method from .matchers import requests_match, uri, method
from .errors import UnhandledHTTPRequestError from .errors import UnhandledHTTPRequestError
@@ -25,7 +24,7 @@ class CassetteContextDecorator(object):
removing cassettes. removing cassettes.
This class defers the creation of a new cassette instance until the point at This class defers the creation of a new cassette instance until the point at
which it is installed by context manager or decorator. The fact that a new which it is installned by context manager or decorator. The fact that a new
cassette is used with each application prevents the state of any cassette cassette is used with each application prevents the state of any cassette
from interfering with another. from interfering with another.
""" """
@@ -50,7 +49,7 @@ class CassetteContextDecorator(object):
cassette._save() cassette._save()
def __enter__(self): def __enter__(self):
assert self.__finish is None assert self.__finish is None, "Cassette already open."
path, kwargs = self._args_getter() path, kwargs = self._args_getter()
self.__finish = self._patch_generator(self.cls.load(path, **kwargs)) self.__finish = self._patch_generator(self.cls.load(path, **kwargs))
return next(self.__finish) return next(self.__finish)
@@ -70,7 +69,7 @@ class Cassette(object):
@classmethod @classmethod
def load(cls, path, **kwargs): def load(cls, path, **kwargs):
'''Load in the cassette stored at the provided path''' '''Instantiate and load the cassette stored at the specified path.'''
new_cassette = cls(path, **kwargs) new_cassette = cls(path, **kwargs)
new_cassette._load() new_cassette._load()
return new_cassette return new_cassette
@@ -85,20 +84,13 @@ class Cassette(object):
def __init__(self, path, serializer=yamlserializer, record_mode='once', def __init__(self, path, serializer=yamlserializer, record_mode='once',
match_on=(uri, method), filter_headers=(), match_on=(uri, method), filter_headers=(),
filter_query_parameters=(), before_record=None, before_record_response=None, filter_query_parameters=(), before_record_request=None,
ignore_hosts=(), ignore_localhost=()): before_record_response=None, ignore_hosts=(), ignore_localhost=()):
self._path = path self._path = path
self._serializer = serializer self._serializer = serializer
self._match_on = match_on self._match_on = match_on
self._filter_headers = filter_headers self._before_record_request = before_record_request or (lambda x: x)
self._filter_query_parameters = filter_query_parameters self._before_record_response = before_record_response or (lambda x: x)
self._before_record = before_record
self._before_record_response = before_record_response
self._ignore_hosts = ignore_hosts
if ignore_localhost:
self._ignore_hosts = list(set(
list(self._ignore_hosts) + ['localhost', '0.0.0.0', '127.0.0.1']
))
# self.data is the list of (req, resp) tuples # self.data is the list of (req, resp) tuples
self.data = [] self.data = []
@@ -131,18 +123,9 @@ class Cassette(object):
return self.rewound and self.record_mode == 'once' or \ return self.rewound and self.record_mode == 'once' or \
self.record_mode == 'none' self.record_mode == 'none'
def _filter_request(self, request):
return filter_request(
request=request,
filter_headers=self._filter_headers,
filter_query_parameters=self._filter_query_parameters,
before_record=self._before_record,
ignore_hosts=self._ignore_hosts
)
def append(self, request, response): def append(self, request, response):
'''Add a request, response pair to this cassette''' '''Add a request, response pair to this cassette'''
request = self._filter_request(request) request = self._before_record_request(request)
if not request: if not request:
return return
if self._before_record_response: if self._before_record_response:
@@ -150,20 +133,21 @@ class Cassette(object):
self.data.append((request, response)) self.data.append((request, response))
self.dirty = True self.dirty = True
def filter_request(self, request):
return self._before_record_request(request)
def _responses(self, request): def _responses(self, request):
""" """
internal API, returns an iterator with all responses matching internal API, returns an iterator with all responses matching
the request. the request.
""" """
request = self._filter_request(request) request = self._before_record_request(request)
if not request:
return
for index, (stored_request, response) in enumerate(self.data): for index, (stored_request, response) in enumerate(self.data):
if requests_match(request, stored_request, self._match_on): if requests_match(request, stored_request, self._match_on):
yield index, response yield index, response
def can_play_response_for(self, request): def can_play_response_for(self, request):
request = self._filter_request(request) request = self._before_record_request(request)
return request and request in self and \ return request and request in self and \
self.record_mode != 'all' and \ self.record_mode != 'all' and \
self.rewound self.rewound

View File

@@ -1,30 +1,22 @@
import collections
import copy
import functools import functools
import os import os
from .cassette import Cassette from .cassette import Cassette
from .serializers import yamlserializer, jsonserializer from .serializers import yamlserializer, jsonserializer
from . import matchers from . import matchers
from . import filters
class VCR(object): class VCR(object):
def __init__(self,
serializer='yaml', def __init__(self, serializer='yaml', cassette_library_dir=None,
cassette_library_dir=None, record_mode="once", filter_headers=(),
record_mode="once", filter_query_parameters=(), before_record_request=None,
filter_headers=(), before_record_response=None, ignore_hosts=(),
filter_query_parameters=(), match_on=('method', 'scheme', 'host', 'port', 'path', 'query',),
before_record=None, ignore_localhost=False, before_record=None):
before_record_response=None,
match_on=(
'method',
'scheme',
'host',
'port',
'path',
'query',
),
ignore_hosts=(),
ignore_localhost=False,
):
self.serializer = serializer self.serializer = serializer
self.match_on = match_on self.match_on = match_on
self.cassette_library_dir = cassette_library_dir self.cassette_library_dir = cassette_library_dir
@@ -47,7 +39,7 @@ class VCR(object):
self.record_mode = record_mode self.record_mode = record_mode
self.filter_headers = filter_headers self.filter_headers = filter_headers
self.filter_query_parameters = filter_query_parameters self.filter_query_parameters = filter_query_parameters
self.before_record = before_record self.before_record_request = before_record_request or before_record
self.before_record_response = before_record_response self.before_record_response = before_record_response
self.ignore_hosts = ignore_hosts self.ignore_hosts = ignore_hosts
self.ignore_localhost = ignore_localhost self.ignore_localhost = ignore_localhost
@@ -69,12 +61,13 @@ class VCR(object):
matchers.append(self.matchers[m]) matchers.append(self.matchers[m])
except KeyError: except KeyError:
raise KeyError( raise KeyError(
"Matcher {0} doesn't exist or isn't registered".format( "Matcher {0} doesn't exist or isn't registered".format(m)
m)
) )
return matchers return matchers
def use_cassette(self, path, **kwargs): def use_cassette(self, path, with_current_defaults=False, **kwargs):
if with_current_defaults:
return Cassette.use(path, self.get_path_and_merged_config(path, **kwargs))
args_getter = functools.partial(self.get_path_and_merged_config, path, **kwargs) args_getter = functools.partial(self.get_path_and_merged_config, path, **kwargs)
return Cassette.use_arg_getter(args_getter) return Cassette.use_arg_getter(args_getter)
@@ -89,30 +82,87 @@ class VCR(object):
path = os.path.join(cassette_library_dir, path) path = os.path.join(cassette_library_dir, path)
merged_config = { merged_config = {
"serializer": self._get_serializer(serializer_name), 'serializer': self._get_serializer(serializer_name),
"match_on": self._get_matchers(matcher_names), 'match_on': self._get_matchers(matcher_names),
"record_mode": kwargs.get('record_mode', self.record_mode), 'record_mode': kwargs.get('record_mode', self.record_mode),
"filter_headers": kwargs.get( 'before_record_request': self._build_before_record_request(kwargs),
'filter_headers', self.filter_headers 'before_record_response': self._build_before_record_response(kwargs)
),
"filter_query_parameters": kwargs.get(
'filter_query_parameters', self.filter_query_parameters
),
"before_record": kwargs.get(
"before_record", self.before_record
),
"before_record_response": kwargs.get(
"before_record_response", self.before_record_response
),
"ignore_hosts": kwargs.get(
'ignore_hosts', self.ignore_hosts
),
"ignore_localhost": kwargs.get(
'ignore_localhost', self.ignore_localhost
),
} }
return path, merged_config return path, merged_config
def _build_before_record_response(self, options):
before_record_response = options.get(
'before_record_response', self.before_record_response
)
filter_functions = []
if before_record_response and not isinstance(before_record_response,
collections.Iterable):
before_record_response = (before_record_response,)
for function in before_record_response:
filter_functions.append(function)
def before_record_response(response):
for function in filter_functions:
if response is None:
break
response = function(response)
return response
return before_record_response
def _build_before_record_request(self, options):
filter_functions = []
filter_headers = options.get(
'filter_headers', self.filter_headers
)
filter_query_parameters = options.get(
'filter_query_parameters', self.filter_query_parameters
)
before_record_request = options.get(
"before_record_request", options.get("before_record", self.before_record_request)
)
ignore_hosts = options.get(
'ignore_hosts', self.ignore_hosts
)
ignore_localhost = options.get(
'ignore_localhost', self.ignore_localhost
)
if filter_headers:
filter_functions.append(functools.partial(filters.remove_headers,
headers_to_remove=filter_headers))
if filter_query_parameters:
filter_functions.append(functools.partial(filters.remove_query_parameters,
query_parameters_to_remove=filter_query_parameters))
hosts_to_ignore = list(ignore_hosts)
if ignore_localhost:
hosts_to_ignore.extend(('localhost', '0.0.0.0', '127.0.0.1'))
if hosts_to_ignore:
hosts_to_ignore = set(hosts_to_ignore)
filter_functions.append(self._build_ignore_hosts(hosts_to_ignore))
if before_record_request:
if not isinstance(before_record_request, collections.Iterable):
before_record_request = (before_record_request,)
for function in before_record_request:
filter_functions.append(function)
def before_record_request(request):
request = copy.copy(request)
for function in filter_functions:
if request is None:
break
request = function(request)
return request
return before_record_request
@staticmethod
def _build_ignore_hosts(hosts_to_ignore):
def filter_ignored_hosts(request):
if hasattr(request, 'host') and request.host in hosts_to_ignore:
return
return request
return filter_ignored_hosts
def register_serializer(self, name, serializer): def register_serializer(self, name, serializer):
self.serializers[name] = serializer self.serializers[name] = serializer

View File

@@ -2,7 +2,7 @@ from six.moves.urllib.parse import urlparse, urlencode, urlunparse
import copy import copy
def _remove_headers(request, headers_to_remove): def remove_headers(request, headers_to_remove):
headers = copy.copy(request.headers) headers = copy.copy(request.headers)
headers_to_remove = [h.lower() for h in headers_to_remove] headers_to_remove = [h.lower() for h in headers_to_remove]
keys = [k for k in headers if k.lower() in headers_to_remove] keys = [k for k in headers if k.lower() in headers_to_remove]
@@ -13,7 +13,7 @@ def _remove_headers(request, headers_to_remove):
return request return request
def _remove_query_parameters(request, query_parameters_to_remove): def remove_query_parameters(request, query_parameters_to_remove):
query = request.query query = request.query
new_query = [(k, v) for (k, v) in query new_query = [(k, v) for (k, v) in query
if k not in query_parameters_to_remove] if k not in query_parameters_to_remove]
@@ -22,22 +22,3 @@ def _remove_query_parameters(request, query_parameters_to_remove):
uri_parts[4] = urlencode(new_query) uri_parts[4] = urlencode(new_query)
request.uri = urlunparse(uri_parts) request.uri = urlunparse(uri_parts)
return request return request
def filter_request(
request,
filter_headers,
filter_query_parameters,
before_record,
ignore_hosts
):
request = copy.copy(request) # don't mutate request object
if hasattr(request, 'headers') and filter_headers:
request = _remove_headers(request, filter_headers)
if hasattr(request, 'host') and request.host in ignore_hosts:
return None
if filter_query_parameters:
request = _remove_query_parameters(request, filter_query_parameters)
if before_record:
request = before_record(request)
return request

View File

@@ -217,7 +217,7 @@ class VCRConnection(object):
response = self.cassette.play_response(self._vcr_request) response = self.cassette.play_response(self._vcr_request)
return VCRHTTPResponse(response) return VCRHTTPResponse(response)
else: else:
if self.cassette.write_protected and self.cassette._filter_request(self._vcr_request): if self.cassette.write_protected and self.cassette.filter_request(self._vcr_request):
raise CannotOverwriteExistingCassetteException( raise CannotOverwriteExistingCassetteException(
"Can't overwrite existing cassette (%r) in " "Can't overwrite existing cassette (%r) in "
"your current record mode (%r)." "your current record mode (%r)."