1
0
mirror of https://github.com/kevin1024/vcrpy.git synced 2025-12-11 18:06:10 +00:00

Compare commits

..

14 Commits

Author SHA1 Message Date
Ivan Malison
6e049ba7a1 version bump to v1.1.2 2014-10-08 12:11:53 -07:00
Ivan Malison
916e7839e5 Actually use pytest.raises in test. 2014-10-07 13:45:09 -07:00
Ivan Malison
99692a92d2 Handle unicode error in json serialize properly. 2014-10-07 13:21:47 -07:00
Ivan Malison
a9a68ba44b Random tweaks. 2014-10-05 18:37:01 -07:00
Ivan Malison
e9f35db405 Remove .travis.yml changes. 2014-10-05 16:42:46 -07:00
Ivan Malison
7193407a07 Remove ipdb because it causes python below 2.6 to blow up. 2014-10-03 01:40:02 -07:00
Ivan Malison
c3427ae3a2 Fix pip install of tox in travis. 2014-10-02 15:48:29 -07:00
Ivan Malison
3a46a6f210 travis through tox. 2014-10-02 15:26:22 -07:00
Ivan Malison
163181844b Refactor tox.ini using new 1.8 features. 2014-10-02 14:57:53 -07:00
Ivan Malison
2c6f072d11 better logging when matches aren't working. 2014-09-25 04:49:00 -07:00
Ivan Malison
361ed82a10 Bump version to 1.1.1 2014-09-22 19:22:52 -07:00
Ivan Malison
0871c3b87c Remove instance variables for filter_headers, filter_query_params, ignore_localhost and ignore_hosts. These still exist on the VCR object, but they are automatically translated into a filter function when passed to the cassette. 2014-09-22 17:57:22 -07:00
Ivan 'Goat' Malison
d484dee50f Merge pull request #110 from IvanMalison/use_cassette_decorator_pytest_compatibility
Fix use_cassette decorator in python 2 by using wrapt.decorator
2014-09-22 17:02:23 -07:00
Ivan Malison
b046ee4bb1 Fix use_cassette decorator in python 2 by using wrapt.decorator. Add wrapt as dependency. 2014-09-22 16:40:09 -07:00
15 changed files with 261 additions and 321 deletions

View File

@@ -457,6 +457,12 @@ API in version 1.0.x
## Changelog ## Changelog
* 1.1.2 Add urllib==1.7.1 support. Make json serialize error handling correct
Improve logging of match failures.
* 1.1.1 Use function signature preserving `wrapt.decorator` to write the
decorator version of use_cassette in order to ensure compatibility with
py.test fixtures and python 2. Move all request filtering into the
`before_record_callable`.
* 1.1.0 Add `before_record_response`. Fix several bugs related to the context * 1.1.0 Add `before_record_response`. Fix several bugs related to the context
management of cassettes. management of cassettes.
* 1.0.3: Fix an issue with requests 2.4 and make sure case sensitivity is * 1.0.3: Fix an issue with requests 2.4 and make sure case sensitivity is

View File

@@ -20,7 +20,7 @@ class PyTest(TestCommand):
setup( setup(
name='vcrpy', name='vcrpy',
version='1.1.0', version='1.1.2',
description=( description=(
"Automatically mock your HTTP interactions to simplify and " "Automatically mock your HTTP interactions to simplify and "
"speed up testing" "speed up testing"
@@ -41,7 +41,7 @@ setup(
'vcr.compat': 'vcr/compat', 'vcr.compat': 'vcr/compat',
'vcr.persisters': 'vcr/persisters', 'vcr.persisters': 'vcr/persisters',
}, },
install_requires=['PyYAML', 'mock', 'six', 'contextlib2'], install_requires=['PyYAML', 'mock', 'six', 'contextlib2', 'wrapt'],
license='MIT', license='MIT',
tests_require=['pytest', 'mock', 'pytest-localserver'], tests_require=['pytest', 'mock', 'pytest-localserver'],
cmdclass={'test': PyTest}, cmdclass={'test': PyTest},

View File

@@ -1,4 +1,4 @@
'''Basic tests about cassettes''' '''Basic tests for cassettes'''
# coding=utf-8 # coding=utf-8
# External imports # External imports

View File

@@ -54,15 +54,20 @@ def test_filter_querystring(tmpdir):
urlopen(url) urlopen(url)
assert 'foo' not in cass.requests[0].url assert 'foo' not in cass.requests[0].url
def test_filter_callback(tmpdir): def test_filter_callback(tmpdir):
url = 'http://httpbin.org/get' url = 'http://httpbin.org/get'
cass_file = str(tmpdir.join('basic_auth_filter.yaml')) cass_file = str(tmpdir.join('basic_auth_filter.yaml'))
def before_record_cb(request): def before_record_cb(request):
if request.path != '/get': if request.path != '/get':
return request return request
my_vcr = vcr.VCR( # Test the legacy keyword.
before_record = before_record_cb, my_vcr = vcr.VCR(before_record=before_record_cb)
) with my_vcr.use_cassette(cass_file, filter_headers=['authorization']) as cass:
urlopen(url)
assert len(cass) == 0
my_vcr = vcr.VCR(before_record_request=before_record_cb)
with my_vcr.use_cassette(cass_file, filter_headers=['authorization']) as cass: with my_vcr.use_cassette(cass_file, filter_headers=['authorization']) as cass:
urlopen(url) urlopen(url)
assert len(cass) == 0 assert len(cass) == 0

View File

@@ -1,37 +1,37 @@
from vcr.filters import _remove_headers, _remove_query_parameters from vcr.filters import remove_headers, remove_query_parameters
from vcr.request import Request from vcr.request import Request
def test_remove_headers(): def test_remove_headers():
headers = {'hello': ['goodbye'], 'secret': ['header']} headers = {'hello': ['goodbye'], 'secret': ['header']}
request = Request('GET', 'http://google.com', '', headers) request = Request('GET', 'http://google.com', '', headers)
_remove_headers(request, ['secret']) remove_headers(request, ['secret'])
assert request.headers == {'hello': 'goodbye'} assert request.headers == {'hello': 'goodbye'}
def test_remove_headers_empty(): def test_remove_headers_empty():
headers = {'hello': 'goodbye', 'secret': 'header'} headers = {'hello': 'goodbye', 'secret': 'header'}
request = Request('GET', 'http://google.com', '', headers) request = Request('GET', 'http://google.com', '', headers)
_remove_headers(request, []) remove_headers(request, [])
assert request.headers == headers assert request.headers == headers
def test_remove_query_parameters(): def test_remove_query_parameters():
uri = 'http://g.com/?q=cowboys&w=1' uri = 'http://g.com/?q=cowboys&w=1'
request = Request('GET', uri, '', {}) request = Request('GET', uri, '', {})
_remove_query_parameters(request, ['w']) remove_query_parameters(request, ['w'])
assert request.uri == 'http://g.com/?q=cowboys' assert request.uri == 'http://g.com/?q=cowboys'
def test_remove_all_query_parameters(): def test_remove_all_query_parameters():
uri = 'http://g.com/?q=cowboys&w=1' uri = 'http://g.com/?q=cowboys&w=1'
request = Request('GET', uri, '', {}) request = Request('GET', uri, '', {})
_remove_query_parameters(request, ['w', 'q']) remove_query_parameters(request, ['w', 'q'])
assert request.uri == 'http://g.com/' assert request.uri == 'http://g.com/'
def test_remove_nonexistent_query_parameters(): def test_remove_nonexistent_query_parameters():
uri = 'http://g.com/' uri = 'http://g.com/'
request = Request('GET', uri, '', {}) request = Request('GET', uri, '', {})
_remove_query_parameters(request, ['w', 'q']) remove_query_parameters(request, ['w', 'q'])
assert request.uri == 'http://g.com/' assert request.uri == 'http://g.com/'

View File

@@ -1,21 +1,35 @@
import mock
import pytest import pytest
from vcr.serialize import deserialize from vcr.serialize import deserialize
from vcr.serializers import yamlserializer, jsonserializer from vcr.serializers import yamlserializer, jsonserializer
def test_deserialize_old_yaml_cassette(): def test_deserialize_old_yaml_cassette():
with open('tests/fixtures/migration/old_cassette.yaml', 'r') as f: with open('tests/fixtures/migration/old_cassette.yaml', 'r') as f:
with pytest.raises(ValueError): with pytest.raises(ValueError):
deserialize(f.read(), yamlserializer) deserialize(f.read(), yamlserializer)
def test_deserialize_old_json_cassette(): def test_deserialize_old_json_cassette():
with open('tests/fixtures/migration/old_cassette.json', 'r') as f: with open('tests/fixtures/migration/old_cassette.json', 'r') as f:
with pytest.raises(ValueError): with pytest.raises(ValueError):
deserialize(f.read(), jsonserializer) deserialize(f.read(), jsonserializer)
def test_deserialize_new_yaml_cassette(): def test_deserialize_new_yaml_cassette():
with open('tests/fixtures/migration/new_cassette.yaml', 'r') as f: with open('tests/fixtures/migration/new_cassette.yaml', 'r') as f:
deserialize(f.read(), yamlserializer) deserialize(f.read(), yamlserializer)
def test_deserialize_new_json_cassette(): def test_deserialize_new_json_cassette():
with open('tests/fixtures/migration/new_cassette.json', 'r') as f: with open('tests/fixtures/migration/new_cassette.json', 'r') as f:
deserialize(f.read(), jsonserializer) deserialize(f.read(), jsonserializer)
@mock.patch.object(jsonserializer.json, 'dumps',
side_effect=UnicodeDecodeError('utf-8', b'unicode error in serialization',
0, 10, 'blew up'))
def test_serialize_constructs_UnicodeDecodeError(mock_dumps):
with pytest.raises(UnicodeDecodeError):
jsonserializer.serialize({})

View File

@@ -1,28 +1,76 @@
import mock import mock
import pytest
from vcr import VCR from vcr import VCR, use_cassette
from vcr.request import Request
def test_vcr_use_cassette(): def test_vcr_use_cassette():
filter_headers = mock.Mock() record_mode = mock.Mock()
test_vcr = VCR(filter_headers=filter_headers) test_vcr = VCR(record_mode=record_mode)
with mock.patch('vcr.cassette.Cassette.load') as mock_cassette_load: with mock.patch('vcr.cassette.Cassette.load') as mock_cassette_load:
@test_vcr.use_cassette('test') @test_vcr.use_cassette('test')
def function(): def function():
pass pass
assert mock_cassette_load.call_count == 0 assert mock_cassette_load.call_count == 0
function() function()
assert mock_cassette_load.call_args[1]['filter_headers'] is filter_headers assert mock_cassette_load.call_args[1]['record_mode'] is record_mode
# Make sure that calls to function now use cassettes with the # Make sure that calls to function now use cassettes with the
# new filter_header_settings # new filter_header_settings
test_vcr.filter_headers = ('a',) test_vcr.record_mode = mock.Mock()
function() function()
assert mock_cassette_load.call_args[1]['filter_headers'] == test_vcr.filter_headers assert mock_cassette_load.call_args[1]['record_mode'] == test_vcr.record_mode
# Ensure that explicitly provided arguments still supercede # Ensure that explicitly provided arguments still supercede
# those on the vcr. # those on the vcr.
new_filter_headers = mock.Mock() new_record_mode = mock.Mock()
with test_vcr.use_cassette('test', filter_headers=new_filter_headers) as cassette: with test_vcr.use_cassette('test', record_mode=new_record_mode) as cassette:
assert cassette._filter_headers == new_filter_headers assert cassette.record_mode == new_record_mode
def test_vcr_before_record_request_params():
base_path = 'http://httpbin.org/'
def before_record_cb(request):
if request.path != '/get':
return request
test_vcr = VCR(filter_headers=('cookie',), before_record_request=before_record_cb,
ignore_hosts=('www.test.com',), ignore_localhost=True,
filter_query_parameters=('foo',))
with test_vcr.use_cassette('test') as cassette:
assert cassette.filter_request(Request('GET', base_path + 'get', '', {})) is None
assert cassette.filter_request(Request('GET', base_path + 'get2', '', {})) is not None
assert cassette.filter_request(Request('GET', base_path + '?foo=bar', '', {})).query == []
assert cassette.filter_request(
Request('GET', base_path + '?foo=bar', '',
{'cookie': 'test', 'other': 'fun'})).headers == {'other': 'fun'}
assert cassette.filter_request(Request('GET', base_path + '?foo=bar', '',
{'cookie': 'test', 'other': 'fun'})).headers == {'other': 'fun'}
assert cassette.filter_request(Request('GET', 'http://www.test.com' + '?foo=bar', '',
{'cookie': 'test', 'other': 'fun'})) is None
with test_vcr.use_cassette('test', before_record_request=None) as cassette:
# Test that before_record can be overwritten with
assert cassette.filter_request(Request('GET', base_path + 'get', '', {})) is not None
@pytest.fixture
def random_fixture():
return 1
@use_cassette('test')
def test_fixtures_with_use_cassette(random_fixture):
# Applying a decorator to a test function that requests features can cause
# problems if the decorator does not preserve the signature of the original
# test function.
# This test ensures that use_cassette preserves the signature of the original
# test function, and thus that use_cassette is compatible with py.test
# fixtures. It is admittedly a bit strange because the test would never even
# run if the relevant feature were broken.
pass

193
tox.ini
View File

@@ -1,189 +1,24 @@
# Tox (http://tox.testrun.org/) is a tool for running tests
# in multiple virtualenvs. This configuration file will run the
# test suite on all supported python versions. To use it, "pip install tox"
# and then run "tox" from this directory.
[tox] [tox]
envlist = envlist = {py26,py27,py33,py34,pypy}-{requests24,requests23,requests22,requests1,httplib2,urllib3,boto}
py26,
py27,
py33,
py34,
pypy,
py26requests24,
py27requests24,
py34requests24,
pypyrequests24,
py26requests23,
py27requests23,
py34requests23,
pypyrequests23,
py26requests22,
py27requests22,
py34requests22,
pypyrequests22,
py26requests1,
py27requests1,
py33requests1,
pypyrequests1,
py26httplib2,
py27httplib2,
py33httplib2,
py34httplib2,
pypyhttplib2,
[testenv] [testenv]
commands = commands =
py.test {posargs} py.test {posargs}
basepython =
py26: python2.6
py27: python2.7
py33: python3.3
py34: python3.4
pypy: pypy
deps = deps =
mock mock
pytest pytest
pytest-localserver pytest-localserver
PyYAML PyYAML
ipdb requests1: requests==1.2.3
requests24: requests==2.4.0
[testenv:py26requests1] requests23: requests==2.3.0
basepython = python2.6 requests22: requests==2.2.1
deps = httplib2: httplib2
{[testenv]deps} urllib3: urllib3==1.7.1
requests==1.2.3 boto: boto
[testenv:py27requests1]
basepython = python2.7
deps =
{[testenv]deps}
requests==1.2.3
[testenv:py33requests1]
basepython = python3.3
deps =
{[testenv]deps}
requests==1.2.3
[testenv:pypyrequests1]
basepython = pypy
deps =
{[testenv]deps}
requests==1.2.3
[testenv:py26requests24]
basepython = python2.6
deps =
{[testenv]deps}
requests==2.4.0
[testenv:py27requests24]
basepython = python2.7
deps =
{[testenv]deps}
requests==2.4.0
[testenv:py33requests24]
basepython = python3.4
deps =
{[testenv]deps}
requests==2.4.0
[testenv:py34requests24]
basepython = python3.4
deps =
{[testenv]deps}
requests==2.4.0
[testenv:pypyrequests24]
basepython = pypy
deps =
{[testenv]deps}
requests==2.4.0
[testenv:py26requests23]
basepython = python2.6
deps =
{[testenv]deps}
requests==2.3.0
[testenv:py27requests23]
basepython = python2.7
deps =
{[testenv]deps}
requests==2.3.0
[testenv:py33requests23]
basepython = python3.4
deps =
{[testenv]deps}
requests==2.3.0
[testenv:py34requests23]
basepython = python3.4
deps =
{[testenv]deps}
requests==2.3.0
[testenv:pypyrequests23]
basepython = pypy
deps =
{[testenv]deps}
requests==2.3.0
[testenv:py26requests22]
basepython = python2.6
deps =
{[testenv]deps}
requests==2.2.1
[testenv:py27requests22]
basepython = python2.7
deps =
{[testenv]deps}
requests==2.2.1
[testenv:py33requests22]
basepython = python3.4
deps =
{[testenv]deps}
requests==2.2.1
[testenv:py34requests22]
basepython = python3.4
deps =
{[testenv]deps}
requests==2.2.1
[testenv:pypyrequests22]
basepython = pypy
deps =
{[testenv]deps}
requests==2.2.1
[testenv:py26httplib2]
basepython = python2.6
deps =
{[testenv]deps}
httplib2
[testenv:py27httplib2]
basepython = python2.7
deps =
{[testenv]deps}
httplib2
[testenv:py33httplib2]
basepython = python3.4
deps =
{[testenv]deps}
httplib2
[testenv:py34httplib2]
basepython = python3.4
deps =
{[testenv]deps}
httplib2
[testenv:pypyhttplib2]
basepython = pypy
deps =
{[testenv]deps}
httplib2

View File

@@ -1,7 +1,8 @@
'''The container for recorded requests and responses''' """The container for recorded requests and responses"""
import logging import logging
import contextlib2 import contextlib2
import wrapt
try: try:
from collections import Counter from collections import Counter
except ImportError: except ImportError:
@@ -10,7 +11,6 @@ except ImportError:
# Internal imports # Internal imports
from .patch import CassettePatcherBuilder from .patch import CassettePatcherBuilder
from .persist import load_cassette, save_cassette from .persist import load_cassette, save_cassette
from .filters import filter_request
from .serializers import yamlserializer from .serializers import yamlserializer
from .matchers import requests_match, uri, method from .matchers import requests_match, uri, method
from .errors import UnhandledHTTPRequestError from .errors import UnhandledHTTPRequestError
@@ -19,7 +19,7 @@ from .errors import UnhandledHTTPRequestError
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
class CassetteContextDecorator(contextlib2.ContextDecorator): class CassetteContextDecorator(object):
"""Context manager/decorator that handles installing the cassette and """Context manager/decorator that handles installing the cassette and
removing cassettes. removing cassettes.
@@ -45,11 +45,12 @@ class CassetteContextDecorator(contextlib2.ContextDecorator):
log.debug('Entered context for cassette at {0}.'.format(cassette._path)) log.debug('Entered context for cassette at {0}.'.format(cassette._path))
yield cassette yield cassette
log.debug('Exiting context for cassette at {0}.'.format(cassette._path)) log.debug('Exiting context for cassette at {0}.'.format(cassette._path))
# TODO(@IvanMalison): Hmmm. it kind of feels like this should be somewhere else. # TODO(@IvanMalison): Hmmm. it kind of feels like this should be
# somewhere else.
cassette._save() cassette._save()
def __enter__(self): def __enter__(self):
assert self.__finish is None assert self.__finish is None, "Cassette already open."
path, kwargs = self._args_getter() path, kwargs = self._args_getter()
self.__finish = self._patch_generator(self.cls.load(path, **kwargs)) self.__finish = self._patch_generator(self.cls.load(path, **kwargs))
return next(self.__finish) return next(self.__finish)
@@ -58,13 +59,18 @@ class CassetteContextDecorator(contextlib2.ContextDecorator):
next(self.__finish, None) next(self.__finish, None)
self.__finish = None self.__finish = None
@wrapt.decorator
def __call__(self, function, instance, args, kwargs):
with self:
return function(*args, **kwargs)
class Cassette(object): class Cassette(object):
'''A container for recorded requests and responses''' """A container for recorded requests and responses"""
@classmethod @classmethod
def load(cls, path, **kwargs): def load(cls, path, **kwargs):
'''Load in the cassette stored at the provided path''' """Instantiate and load the cassette stored at the specified path."""
new_cassette = cls(path, **kwargs) new_cassette = cls(path, **kwargs)
new_cassette._load() new_cassette._load()
return new_cassette return new_cassette
@@ -79,20 +85,14 @@ class Cassette(object):
def __init__(self, path, serializer=yamlserializer, record_mode='once', def __init__(self, path, serializer=yamlserializer, record_mode='once',
match_on=(uri, method), filter_headers=(), match_on=(uri, method), filter_headers=(),
filter_query_parameters=(), before_record=None, before_record_response=None, filter_query_parameters=(), before_record_request=None,
ignore_hosts=(), ignore_localhost=()): before_record_response=None, ignore_hosts=(),
ignore_localhost=()):
self._path = path self._path = path
self._serializer = serializer self._serializer = serializer
self._match_on = match_on self._match_on = match_on
self._filter_headers = filter_headers self._before_record_request = before_record_request or (lambda x: x)
self._filter_query_parameters = filter_query_parameters self._before_record_response = before_record_response or (lambda x: x)
self._before_record = before_record
self._before_record_response = before_record_response
self._ignore_hosts = ignore_hosts
if ignore_localhost:
self._ignore_hosts = list(set(
list(self._ignore_hosts) + ['localhost', '0.0.0.0', '127.0.0.1']
))
# self.data is the list of (req, resp) tuples # self.data is the list of (req, resp) tuples
self.data = [] self.data = []
@@ -107,9 +107,7 @@ class Cassette(object):
@property @property
def all_played(self): def all_played(self):
""" """Returns True if all responses have been played, False otherwise."""
Returns True if all responses have been played, False otherwise.
"""
return self.play_count == len(self) return self.play_count == len(self)
@property @property
@@ -125,18 +123,9 @@ class Cassette(object):
return self.rewound and self.record_mode == 'once' or \ return self.rewound and self.record_mode == 'once' or \
self.record_mode == 'none' self.record_mode == 'none'
def _filter_request(self, request):
return filter_request(
request=request,
filter_headers=self._filter_headers,
filter_query_parameters=self._filter_query_parameters,
before_record=self._before_record,
ignore_hosts=self._ignore_hosts
)
def append(self, request, response): def append(self, request, response):
'''Add a request, response pair to this cassette''' """Add a request, response pair to this cassette"""
request = self._filter_request(request) request = self._before_record_request(request)
if not request: if not request:
return return
if self._before_record_response: if self._before_record_response:
@@ -144,29 +133,30 @@ class Cassette(object):
self.data.append((request, response)) self.data.append((request, response))
self.dirty = True self.dirty = True
def filter_request(self, request):
return self._before_record_request(request)
def _responses(self, request): def _responses(self, request):
""" """
internal API, returns an iterator with all responses matching internal API, returns an iterator with all responses matching
the request. the request.
""" """
request = self._filter_request(request) request = self._before_record_request(request)
if not request:
return
for index, (stored_request, response) in enumerate(self.data): for index, (stored_request, response) in enumerate(self.data):
if requests_match(request, stored_request, self._match_on): if requests_match(request, stored_request, self._match_on):
yield index, response yield index, response
def can_play_response_for(self, request): def can_play_response_for(self, request):
request = self._filter_request(request) request = self._before_record_request(request)
return request and request in self and \ return request and request in self and \
self.record_mode != 'all' and \ self.record_mode != 'all' and \
self.rewound self.rewound
def play_response(self, request): def play_response(self, request):
''' """
Get the response corresponding to a request, but only if it Get the response corresponding to a request, but only if it
hasn't been played back before, and mark it as played hasn't been played back before, and mark it as played
''' """
for index, response in self._responses(request): for index, response in self._responses(request):
if self.play_counts[index] == 0: if self.play_counts[index] == 0:
self.play_counts[index] += 1 self.play_counts[index] += 1
@@ -178,11 +168,11 @@ class Cassette(object):
) )
def responses_of(self, request): def responses_of(self, request):
''' """
Find the responses corresponding to a request. Find the responses corresponding to a request.
This function isn't actually used by VCR internally, but is This function isn't actually used by VCR internally, but is
provided as an external API. provided as an external API.
''' """
responses = [response for index, response in self._responses(request)] responses = [response for index, response in self._responses(request)]
if responses: if responses:
@@ -224,11 +214,11 @@ class Cassette(object):
) )
def __len__(self): def __len__(self):
'''Return the number of request,response pairs stored in here''' """Return the number of request,response pairs stored in here"""
return len(self.data) return len(self.data)
def __contains__(self, request): def __contains__(self, request):
'''Return whether or not a request has been stored''' """Return whether or not a request has been stored"""
for response in self._responses(request): for response in self._responses(request):
return True return True
return False return False

View File

@@ -1,30 +1,22 @@
import collections
import copy
import functools import functools
import os import os
from .cassette import Cassette from .cassette import Cassette
from .serializers import yamlserializer, jsonserializer from .serializers import yamlserializer, jsonserializer
from . import matchers from . import matchers
from . import filters
class VCR(object): class VCR(object):
def __init__(self,
serializer='yaml', def __init__(self, serializer='yaml', cassette_library_dir=None,
cassette_library_dir=None, record_mode="once", filter_headers=(),
record_mode="once", filter_query_parameters=(), before_record_request=None,
filter_headers=(), before_record_response=None, ignore_hosts=(),
filter_query_parameters=(), match_on=('method', 'scheme', 'host', 'port', 'path', 'query',),
before_record=None, ignore_localhost=False, before_record=None):
before_record_response=None,
match_on=(
'method',
'scheme',
'host',
'port',
'path',
'query',
),
ignore_hosts=(),
ignore_localhost=False,
):
self.serializer = serializer self.serializer = serializer
self.match_on = match_on self.match_on = match_on
self.cassette_library_dir = cassette_library_dir self.cassette_library_dir = cassette_library_dir
@@ -47,7 +39,7 @@ class VCR(object):
self.record_mode = record_mode self.record_mode = record_mode
self.filter_headers = filter_headers self.filter_headers = filter_headers
self.filter_query_parameters = filter_query_parameters self.filter_query_parameters = filter_query_parameters
self.before_record = before_record self.before_record_request = before_record_request or before_record
self.before_record_response = before_record_response self.before_record_response = before_record_response
self.ignore_hosts = ignore_hosts self.ignore_hosts = ignore_hosts
self.ignore_localhost = ignore_localhost self.ignore_localhost = ignore_localhost
@@ -69,12 +61,13 @@ class VCR(object):
matchers.append(self.matchers[m]) matchers.append(self.matchers[m])
except KeyError: except KeyError:
raise KeyError( raise KeyError(
"Matcher {0} doesn't exist or isn't registered".format( "Matcher {0} doesn't exist or isn't registered".format(m)
m)
) )
return matchers return matchers
def use_cassette(self, path, **kwargs): def use_cassette(self, path, with_current_defaults=False, **kwargs):
if with_current_defaults:
return Cassette.use(path, self.get_path_and_merged_config(path, **kwargs))
args_getter = functools.partial(self.get_path_and_merged_config, path, **kwargs) args_getter = functools.partial(self.get_path_and_merged_config, path, **kwargs)
return Cassette.use_arg_getter(args_getter) return Cassette.use_arg_getter(args_getter)
@@ -89,30 +82,87 @@ class VCR(object):
path = os.path.join(cassette_library_dir, path) path = os.path.join(cassette_library_dir, path)
merged_config = { merged_config = {
"serializer": self._get_serializer(serializer_name), 'serializer': self._get_serializer(serializer_name),
"match_on": self._get_matchers(matcher_names), 'match_on': self._get_matchers(matcher_names),
"record_mode": kwargs.get('record_mode', self.record_mode), 'record_mode': kwargs.get('record_mode', self.record_mode),
"filter_headers": kwargs.get( 'before_record_request': self._build_before_record_request(kwargs),
'filter_headers', self.filter_headers 'before_record_response': self._build_before_record_response(kwargs)
),
"filter_query_parameters": kwargs.get(
'filter_query_parameters', self.filter_query_parameters
),
"before_record": kwargs.get(
"before_record", self.before_record
),
"before_record_response": kwargs.get(
"before_record_response", self.before_record_response
),
"ignore_hosts": kwargs.get(
'ignore_hosts', self.ignore_hosts
),
"ignore_localhost": kwargs.get(
'ignore_localhost', self.ignore_localhost
),
} }
return path, merged_config return path, merged_config
def _build_before_record_response(self, options):
before_record_response = options.get(
'before_record_response', self.before_record_response
)
filter_functions = []
if before_record_response and not isinstance(before_record_response,
collections.Iterable):
before_record_response = (before_record_response,)
for function in before_record_response:
filter_functions.append(function)
def before_record_response(response):
for function in filter_functions:
if response is None:
break
response = function(response)
return response
return before_record_response
def _build_before_record_request(self, options):
filter_functions = []
filter_headers = options.get(
'filter_headers', self.filter_headers
)
filter_query_parameters = options.get(
'filter_query_parameters', self.filter_query_parameters
)
before_record_request = options.get(
"before_record_request", options.get("before_record", self.before_record_request)
)
ignore_hosts = options.get(
'ignore_hosts', self.ignore_hosts
)
ignore_localhost = options.get(
'ignore_localhost', self.ignore_localhost
)
if filter_headers:
filter_functions.append(functools.partial(filters.remove_headers,
headers_to_remove=filter_headers))
if filter_query_parameters:
filter_functions.append(functools.partial(filters.remove_query_parameters,
query_parameters_to_remove=filter_query_parameters))
hosts_to_ignore = list(ignore_hosts)
if ignore_localhost:
hosts_to_ignore.extend(('localhost', '0.0.0.0', '127.0.0.1'))
if hosts_to_ignore:
hosts_to_ignore = set(hosts_to_ignore)
filter_functions.append(self._build_ignore_hosts(hosts_to_ignore))
if before_record_request:
if not isinstance(before_record_request, collections.Iterable):
before_record_request = (before_record_request,)
for function in before_record_request:
filter_functions.append(function)
def before_record_request(request):
request = copy.copy(request)
for function in filter_functions:
if request is None:
break
request = function(request)
return request
return before_record_request
@staticmethod
def _build_ignore_hosts(hosts_to_ignore):
def filter_ignored_hosts(request):
if hasattr(request, 'host') and request.host in hosts_to_ignore:
return
return request
return filter_ignored_hosts
def register_serializer(self, name, serializer): def register_serializer(self, name, serializer):
self.serializers[name] = serializer self.serializers[name] = serializer

View File

@@ -2,7 +2,7 @@ from six.moves.urllib.parse import urlparse, urlencode, urlunparse
import copy import copy
def _remove_headers(request, headers_to_remove): def remove_headers(request, headers_to_remove):
headers = copy.copy(request.headers) headers = copy.copy(request.headers)
headers_to_remove = [h.lower() for h in headers_to_remove] headers_to_remove = [h.lower() for h in headers_to_remove]
keys = [k for k in headers if k.lower() in headers_to_remove] keys = [k for k in headers if k.lower() in headers_to_remove]
@@ -13,7 +13,7 @@ def _remove_headers(request, headers_to_remove):
return request return request
def _remove_query_parameters(request, query_parameters_to_remove): def remove_query_parameters(request, query_parameters_to_remove):
query = request.query query = request.query
new_query = [(k, v) for (k, v) in query new_query = [(k, v) for (k, v) in query
if k not in query_parameters_to_remove] if k not in query_parameters_to_remove]
@@ -22,22 +22,3 @@ def _remove_query_parameters(request, query_parameters_to_remove):
uri_parts[4] = urlencode(new_query) uri_parts[4] = urlencode(new_query)
request.uri = urlunparse(uri_parts) request.uri = urlunparse(uri_parts)
return request return request
def filter_request(
request,
filter_headers,
filter_query_parameters,
before_record,
ignore_hosts
):
request = copy.copy(request) # don't mutate request object
if hasattr(request, 'headers') and filter_headers:
request = _remove_headers(request, filter_headers)
if hasattr(request, 'host') and request.host in ignore_hosts:
return None
if filter_query_parameters:
request = _remove_query_parameters(request, filter_query_parameters)
if before_record:
request = before_record(request)
return request

View File

@@ -38,16 +38,16 @@ def headers(r1, r2):
return r1.headers == r2.headers return r1.headers == r2.headers
def _log_matches(matches): def _log_matches(r1, r2, matches):
differences = [m for m in matches if not m[0]] differences = [m for m in matches if not m[0]]
if differences: if differences:
log.debug( log.debug(
'Requests differ according to the following matchers: ' + "Requests {0} and {1} differ according to "
str(differences) "the following matchers: {2}".format(r1, r2, differences)
) )
def requests_match(r1, r2, matchers): def requests_match(r1, r2, matchers):
matches = [(m(r1, r2), m) for m in matchers] matches = [(m(r1, r2), m) for m in matchers]
_log_matches(matches) _log_matches(r1, r2, matches)
return all([m[0] for m in matches]) return all([m[0] for m in matches])

View File

@@ -59,7 +59,9 @@ class CassettePatcherBuilder(object):
def _build_patchers_from_mock_triples_decorator(function): def _build_patchers_from_mock_triples_decorator(function):
@functools.wraps(function) @functools.wraps(function)
def wrapped(self, *args, **kwargs): def wrapped(self, *args, **kwargs):
return self._build_patchers_from_mock_triples(function(self, *args, **kwargs)) return self._build_patchers_from_mock_triples(
function(self, *args, **kwargs)
)
return wrapped return wrapped
def __init__(self, cassette): def __init__(self, cassette):
@@ -273,8 +275,9 @@ def reset_patchers():
yield mock.patch.object(cpool, 'VerifiedHTTPSConnection', _VerifiedHTTPSConnection) yield mock.patch.object(cpool, 'VerifiedHTTPSConnection', _VerifiedHTTPSConnection)
yield mock.patch.object(cpool, 'HTTPConnection', _HTTPConnection) yield mock.patch.object(cpool, 'HTTPConnection', _HTTPConnection)
yield mock.patch.object(cpool, 'HTTPSConnection', _HTTPSConnection) yield mock.patch.object(cpool, 'HTTPSConnection', _HTTPSConnection)
yield mock.patch.object(cpool.HTTPConnectionPool, 'ConnectionCls', _HTTPConnection) if hasattr(cpool.HTTPConnectionPool, 'ConnectionCls'):
yield mock.patch.object(cpool.HTTPSConnectionPool, 'ConnectionCls', _HTTPSConnection) yield mock.patch.object(cpool.HTTPConnectionPool, 'ConnectionCls', _HTTPConnection)
yield mock.patch.object(cpool.HTTPSConnectionPool, 'ConnectionCls', _HTTPSConnection)
try: try:
import httplib2 as cpool import httplib2 as cpool

View File

@@ -11,10 +11,14 @@ def deserialize(cassette_string):
def serialize(cassette_dict): def serialize(cassette_dict):
try: try:
return json.dumps(cassette_dict, indent=4) return json.dumps(cassette_dict, indent=4)
except UnicodeDecodeError: except UnicodeDecodeError as original:
raise UnicodeDecodeError( raise UnicodeDecodeError(
"Error serializing cassette to JSON. ", original.encoding,
"Does this HTTP interaction contain binary data? ", b"Error serializing cassette to JSON",
"If so, use a different serializer (like the yaml serializer) ", original.start,
"for this request" original.end,
original.args[-1] +
("Does this HTTP interaction contain binary data? "
"If so, use a different serializer (like the yaml serializer) "
"for this request?")
) )

View File

@@ -217,11 +217,15 @@ class VCRConnection(object):
response = self.cassette.play_response(self._vcr_request) response = self.cassette.play_response(self._vcr_request)
return VCRHTTPResponse(response) return VCRHTTPResponse(response)
else: else:
if self.cassette.write_protected and self.cassette._filter_request(self._vcr_request): if self.cassette.write_protected and self.cassette.filter_request(
self._vcr_request
):
raise CannotOverwriteExistingCassetteException( raise CannotOverwriteExistingCassetteException(
"No match for the request (%r) was found. "
"Can't overwrite existing cassette (%r) in " "Can't overwrite existing cassette (%r) in "
"your current record mode (%r)." "your current record mode (%r)."
% (self.cassette._path, self.cassette.record_mode) % (self._vcr_request, self.cassette._path,
self.cassette.record_mode)
) )
# Otherwise, we should send the request, then get the response # Otherwise, we should send the request, then get the response