1
0
mirror of https://github.com/kevin1024/vcrpy.git synced 2025-12-09 09:13:23 +00:00

Compare commits

..

8 Commits

Author SHA1 Message Date
Ivan Malison
e324a9677d version 1.5.0 2015-05-14 14:05:50 -07:00
Ivan Malison
28640beb7d README updates. 2015-05-14 14:03:49 -07:00
Ivan 'Goat' Malison
c338d5d32c Merge pull request #154 from marco-santamaria/master
Filter parameters from 'application/json' content-type POST requests
2015-05-14 14:03:14 -07:00
marco.santamaria
59aa351ca8 Added support for json post data in filter_post_data_parameters. 2015-05-14 14:13:14 +02:00
Ivan Malison
2323b9da5f Automatically generate cassette names from function names. Add
`path_transformer` and `func_path_generator`. Closes #151.
2015-05-10 03:22:43 -07:00
Ivan Malison
0bbbc694b0 Make CassetteContextDecorator decorator produce reentrant functions.
Closes #150.
2015-05-09 23:14:00 -07:00
Kevin McCarthy
d293020617 Merge pull request #153 from addgene/mw/specify-six-version
Fix version of `six` dependency.
2015-05-07 20:34:38 -10:00
Morgan Wahl
daac863f0b Fixed version of six dependency.
`from six.moves.http_client import HTTPConnection` fails before version 1.5.0 of six. (on Python 2.7, at least.)
2015-05-07 15:35:26 -04:00
10 changed files with 398 additions and 67 deletions

View File

@@ -7,16 +7,28 @@ This is a Python version of [Ruby's VCR library](https://github.com/vcr/vcr).
[![Build Status](https://secure.travis-ci.org/kevin1024/vcrpy.png?branch=master)](http://travis-ci.org/kevin1024/vcrpy) [![Build Status](https://secure.travis-ci.org/kevin1024/vcrpy.png?branch=master)](http://travis-ci.org/kevin1024/vcrpy)
[![Stories in Ready](https://badge.waffle.io/kevin1024/vcrpy.png?label=ready&title=Ready)](https://waffle.io/kevin1024/vcrpy) [![Stories in Ready](https://badge.waffle.io/kevin1024/vcrpy.png?label=ready&title=Ready)](https://waffle.io/kevin1024/vcrpy)
## What it does ## What it does VCR.py simplifies and speeds up tests that make HTTP
Simplify and speed up testing HTTP by recording all HTTP interactions and requests. The first time you run code that is inside a VCR.py context
saving them to "cassette" files, which are yaml files containing the contents manager or decorated function, VCR.py records all HTTP interactions
of your requests and responses. Then when you run your tests again, they all that take place through the libraries it supports and serializes and
just hit the text files instead of the internet. This speeds up your tests and writes them to a flat file (in yaml format by default). This flat file
lets you work offline. is called a cassette. When the relevant peice of code is executed
again, VCR.py will read the serialized requests and responses from the
aforementioned cassette file, and intercept any HTTP requests that it
recognizes from the original test run and return responses that
corresponded to those requests. This means that the requests will not
actually result in HTTP traffic, which confers several benefits
including:
If the server you are testing against ever changes its API, all you need to do - The ability to work offline
is delete your existing cassette files, and run your tests again. All of the - Completely deterministic tests
mocked responses will be updated with the new API. - Increased test execution speed
If the server you are testing against ever changes its API, all you
need to do is delete your existing cassette files, and run your tests
again. VCR.py will detect the absence of a cassette file and once
again record all HTTP interactions, which will update them to
correspond to the new API.
## Compatibility Notes ## Compatibility Notes
VCR.py supports Python 2.6 and 2.7, 3.3, 3.4, and [pypy](http://pypy.org). VCR.py supports Python 2.6 and 2.7, 3.3, 3.4, and [pypy](http://pypy.org).
@@ -58,8 +70,17 @@ def test_iana():
assert 'Example domains' in response assert 'Example domains' in response
``` ```
All of the parameters and configuration works the same for the decorator When using the decorator version of `use_cassette`, it is possible to
version. omit the path to the cassette file.
```python
@vcr.use_cassette()
def test_iana():
response = urllib2.urlopen('http://www.iana.org/domains/reserved').read()
assert 'Example domains' in response
```
In this case, the cassette file will be given the same name as the test function, and it will be placed in the same directory as the file in which the test is defined. See the Automatic Test Naming section below for more details.
## Configuration ## Configuration
@@ -389,6 +410,43 @@ my_vcr = config.VCR(custom_patches=((where_the_custom_https_connection_lives, 'C
@my_vcr.use_cassette(...) @my_vcr.use_cassette(...)
``` ```
## Automatic Cassette Naming
VCR.py now allows the omission of the path argument to the
use_cassette function. Both of the following are now legal/should work
``` python
@my_vcr.use_cassette
def my_test_function():
...
```
``` python
@my_vcr.use_cassette()
def my_test_function():
...
```
In both cases, VCR.py will use a path that is generated from the
provided test function's name. If no `cassette_library_dir` has been
set, the cassette will be in a file with the name of the test function
in directory of the file in which the test function is declared. If a
`cassette_library_dir` has been set, the cassette will appear
in that directory in a file with the name of the decorated function.
It is possible to control the path produced by the automatic naming
machinery by customizing the `path_transformer` and
`func_path_generator` vcr variables. To add an extension to all
cassette names, use `VCR.ensure_suffix` as follows:
``` python
my_vcr = VCR(path_transformer=VCR.ensure_suffix('.yaml'))
@my_vcr.use_cassette
def my_test_function():
```
## Installation ## Installation
VCR.py is a package on PyPI, so you can `pip install vcrpy` (first you may need VCR.py is a package on PyPI, so you can `pip install vcrpy` (first you may need
@@ -484,6 +542,8 @@ API in version 1.0.x
## Changelog ## Changelog
* 1.5.0 Automatic cassette naming and 'application/json' post data
filtering (thanks @marco-santamaria).
* 1.4.2 Fix a bug caused by requests 2.7 and chunked transfer encoding * 1.4.2 Fix a bug caused by requests 2.7 and chunked transfer encoding
* 1.4.1 Include README, tests, LICENSE in package. Thanks @ralphbean. * 1.4.1 Include README, tests, LICENSE in package. Thanks @ralphbean.
* 1.4.0 Filter post data parameters (thanks @eadmundo), support for * 1.4.0 Filter post data parameters (thanks @eadmundo), support for

View File

@@ -20,7 +20,7 @@ class PyTest(TestCommand):
setup( setup(
name='vcrpy', name='vcrpy',
version='1.4.2', version='1.5.0',
description=( description=(
"Automatically mock your HTTP interactions to simplify and " "Automatically mock your HTTP interactions to simplify and "
"speed up testing" "speed up testing"
@@ -29,7 +29,7 @@ setup(
author_email='me@kevinmccarthy.org', author_email='me@kevinmccarthy.org',
url='https://github.com/kevin1024/vcrpy', url='https://github.com/kevin1024/vcrpy',
packages=find_packages(exclude=("tests*",)), packages=find_packages(exclude=("tests*",)),
install_requires=['PyYAML', 'mock', 'six', 'contextlib2', install_requires=['PyYAML', 'mock', 'six>=1.5', 'contextlib2',
'wrapt', 'backport_collections'], 'wrapt', 'backport_collections'],
license='MIT', license='MIT',
tests_require=['pytest', 'mock', 'pytest-localserver'], tests_require=['pytest', 'mock', 'pytest-localserver'],

View File

@@ -4,6 +4,7 @@ from six.moves.urllib.request import urlopen, Request
from six.moves.urllib.parse import urlencode from six.moves.urllib.parse import urlencode
from six.moves.urllib.error import HTTPError from six.moves.urllib.error import HTTPError
import vcr import vcr
import json
def _request_with_auth(url, username, password): def _request_with_auth(url, username, password):
@@ -66,6 +67,18 @@ def test_filter_post_data(tmpdir):
assert b'id=secret' not in cass.requests[0].body assert b'id=secret' not in cass.requests[0].body
def test_filter_json_post_data(tmpdir):
data = json.dumps({'id': 'secret', 'foo': 'bar'}).encode('utf-8')
request = Request('http://httpbin.org/post', data=data)
request.add_header('Content-Type', 'application/json')
cass_file = str(tmpdir.join('filter_jpd.yaml'))
with vcr.use_cassette(cass_file, filter_post_data_parameters=['id']):
urlopen(request)
with vcr.use_cassette(cass_file, filter_post_data_parameters=['id']) as cass:
assert b'"id": "secret"' not in cass.requests[0].body
def test_filter_callback(tmpdir): def test_filter_callback(tmpdir):
url = 'http://httpbin.org/get' url = 'http://httpbin.org/get'
cass_file = str(tmpdir.join('basic_auth_filter.yaml')) cass_file = str(tmpdir.join('basic_auth_filter.yaml'))

View File

@@ -1,4 +1,6 @@
import copy import copy
import inspect
import os
from six.moves import http_client as httplib from six.moves import http_client as httplib
import contextlib2 import contextlib2
@@ -12,14 +14,13 @@ from vcr.patch import force_reset
from vcr.stubs import VCRHTTPSConnection from vcr.stubs import VCRHTTPSConnection
def test_cassette_load(tmpdir): def test_cassette_load(tmpdir):
a_file = tmpdir.join('test_cassette.yml') a_file = tmpdir.join('test_cassette.yml')
a_file.write(yaml.dump({'interactions': [ a_file.write(yaml.dump({'interactions': [
{'request': {'body': '', 'uri': 'foo', 'method': 'GET', 'headers': {}}, {'request': {'body': '', 'uri': 'foo', 'method': 'GET', 'headers': {}},
'response': 'bar'} 'response': 'bar'}
]})) ]}))
a_cassette = Cassette.load(str(a_file)) a_cassette = Cassette.load(path=str(a_file))
assert len(a_cassette) == 1 assert len(a_cassette) == 1
@@ -87,33 +88,35 @@ def make_get_request():
@mock.patch('vcr.cassette.Cassette.can_play_response_for', return_value=True) @mock.patch('vcr.cassette.Cassette.can_play_response_for', return_value=True)
@mock.patch('vcr.stubs.VCRHTTPResponse') @mock.patch('vcr.stubs.VCRHTTPResponse')
def test_function_decorated_with_use_cassette_can_be_invoked_multiple_times(*args): def test_function_decorated_with_use_cassette_can_be_invoked_multiple_times(*args):
decorated_function = Cassette.use('test')(make_get_request) decorated_function = Cassette.use(path='test')(make_get_request)
for i in range(2): for i in range(4):
decorated_function() decorated_function()
def test_arg_getter_functionality(): def test_arg_getter_functionality():
arg_getter = mock.Mock(return_value=('test', {})) arg_getter = mock.Mock(return_value={'path': 'test'})
context_decorator = Cassette.use_arg_getter(arg_getter) context_decorator = Cassette.use_arg_getter(arg_getter)
with context_decorator as cassette: with context_decorator as cassette:
assert cassette._path == 'test' assert cassette._path == 'test'
arg_getter.return_value = ('other', {}) arg_getter.return_value = {'path': 'other'}
with context_decorator as cassette: with context_decorator as cassette:
assert cassette._path == 'other' assert cassette._path == 'other'
arg_getter.return_value = ('', {'filter_headers': ('header_name',)}) arg_getter.return_value = {'path': 'other', 'filter_headers': ('header_name',)}
@context_decorator @context_decorator
def function(): def function():
pass pass
with mock.patch.object(Cassette, 'load', return_value=mock.MagicMock(inject=False)) as cassette_load: with mock.patch.object(
Cassette, 'load',
return_value=mock.MagicMock(inject=False)
) as cassette_load:
function() function()
cassette_load.assert_called_once_with(arg_getter.return_value[0], cassette_load.assert_called_once_with(**arg_getter.return_value)
**arg_getter.return_value[1])
def test_cassette_not_all_played(): def test_cassette_not_all_played():
@@ -156,13 +159,13 @@ def test_nesting_cassette_context_managers(*args):
second_response['body']['string'] = b'second_response' second_response['body']['string'] = b'second_response'
with contextlib2.ExitStack() as exit_stack: with contextlib2.ExitStack() as exit_stack:
first_cassette = exit_stack.enter_context(Cassette.use('test')) first_cassette = exit_stack.enter_context(Cassette.use(path='test'))
exit_stack.enter_context(mock.patch.object(first_cassette, 'play_response', exit_stack.enter_context(mock.patch.object(first_cassette, 'play_response',
return_value=first_response)) return_value=first_response))
assert_get_response_body_is('first_response') assert_get_response_body_is('first_response')
# Make sure a second cassette can supercede the first # Make sure a second cassette can supercede the first
with Cassette.use('test') as second_cassette: with Cassette.use(path='test') as second_cassette:
with mock.patch.object(second_cassette, 'play_response', return_value=second_response): with mock.patch.object(second_cassette, 'play_response', return_value=second_response):
assert_get_response_body_is('second_response') assert_get_response_body_is('second_response')
@@ -172,12 +175,12 @@ def test_nesting_cassette_context_managers(*args):
def test_nesting_context_managers_by_checking_references_of_http_connection(): def test_nesting_context_managers_by_checking_references_of_http_connection():
original = httplib.HTTPConnection original = httplib.HTTPConnection
with Cassette.use('test'): with Cassette.use(path='test'):
first_cassette_HTTPConnection = httplib.HTTPConnection first_cassette_HTTPConnection = httplib.HTTPConnection
with Cassette.use('test'): with Cassette.use(path='test'):
second_cassette_HTTPConnection = httplib.HTTPConnection second_cassette_HTTPConnection = httplib.HTTPConnection
assert second_cassette_HTTPConnection is not first_cassette_HTTPConnection assert second_cassette_HTTPConnection is not first_cassette_HTTPConnection
with Cassette.use('test'): with Cassette.use(path='test'):
assert httplib.HTTPConnection is not second_cassette_HTTPConnection assert httplib.HTTPConnection is not second_cassette_HTTPConnection
with force_reset(): with force_reset():
assert httplib.HTTPConnection is original assert httplib.HTTPConnection is original
@@ -188,12 +191,14 @@ def test_nesting_context_managers_by_checking_references_of_http_connection():
def test_custom_patchers(): def test_custom_patchers():
class Test(object): class Test(object):
attribute = None attribute = None
with Cassette.use('custom_patches', custom_patches=((Test, 'attribute', VCRHTTPSConnection),)): with Cassette.use(path='custom_patches',
custom_patches=((Test, 'attribute', VCRHTTPSConnection),)):
assert issubclass(Test.attribute, VCRHTTPSConnection) assert issubclass(Test.attribute, VCRHTTPSConnection)
assert VCRHTTPSConnection is not Test.attribute assert VCRHTTPSConnection is not Test.attribute
old_attribute = Test.attribute old_attribute = Test.attribute
with Cassette.use('custom_patches', custom_patches=((Test, 'attribute', VCRHTTPSConnection),)): with Cassette.use(path='custom_patches',
custom_patches=((Test, 'attribute', VCRHTTPSConnection),)):
assert issubclass(Test.attribute, VCRHTTPSConnection) assert issubclass(Test.attribute, VCRHTTPSConnection)
assert VCRHTTPSConnection is not Test.attribute assert VCRHTTPSConnection is not Test.attribute
assert Test.attribute is not old_attribute assert Test.attribute is not old_attribute
@@ -201,3 +206,51 @@ def test_custom_patchers():
assert issubclass(Test.attribute, VCRHTTPSConnection) assert issubclass(Test.attribute, VCRHTTPSConnection)
assert VCRHTTPSConnection is not Test.attribute assert VCRHTTPSConnection is not Test.attribute
assert Test.attribute is old_attribute assert Test.attribute is old_attribute
def test_decorated_functions_are_reentrant():
info = {"second": False}
original_conn = httplib.HTTPConnection
@Cassette.use(path='whatever', inject=True)
def test_function(cassette):
if info['second']:
assert httplib.HTTPConnection is not info['first_conn']
else:
info['first_conn'] = httplib.HTTPConnection
info['second'] = True
test_function()
assert httplib.HTTPConnection is info['first_conn']
test_function()
assert httplib.HTTPConnection is original_conn
def test_cassette_use_called_without_path_uses_function_to_generate_path():
@Cassette.use(inject=True)
def function_name(cassette):
assert cassette._path == 'function_name'
function_name()
def test_path_transformer_with_function_path():
path_transformer = lambda path: os.path.join('a', path)
@Cassette.use(inject=True, path_transformer=path_transformer)
def function_name(cassette):
assert cassette._path == os.path.join('a', 'function_name')
function_name()
def test_path_transformer_with_context_manager():
with Cassette.use(
path='b', path_transformer=lambda *args: 'a'
) as cassette:
assert cassette._path == 'a'
def test_func_path_generator():
def generator(function):
return os.path.join(os.path.dirname(inspect.getfile(function)),
function.__name__)
@Cassette.use(inject=True, func_path_generator=generator)
def function_name(cassette):
assert cassette._path == os.path.join(os.path.dirname(__file__), 'function_name')
function_name()

View File

@@ -4,6 +4,7 @@ from vcr.filters import (
remove_post_data_parameters remove_post_data_parameters
) )
from vcr.request import Request from vcr.request import Request
import json
def test_remove_headers(): def test_remove_headers():
@@ -67,3 +68,29 @@ def test_remove_nonexistent_post_data_parameters():
request = Request('POST', 'http://google.com', body, {}) request = Request('POST', 'http://google.com', body, {})
remove_post_data_parameters(request, ['id']) remove_post_data_parameters(request, ['id'])
assert request.body == b'' assert request.body == b''
def test_remove_json_post_data_parameters():
body = b'{"id": "secret", "foo": "bar", "baz": "qux"}'
request = Request('POST', 'http://google.com', body, {})
request.add_header('Content-Type', 'application/json')
remove_post_data_parameters(request, ['id'])
request_body_json = json.loads(request.body.decode('utf-8'))
expected_json = json.loads(b'{"foo": "bar", "baz": "qux"}'.decode('utf-8'))
assert request_body_json == expected_json
def test_remove_all_json_post_data_parameters():
body = b'{"id": "secret", "foo": "bar"}'
request = Request('POST', 'http://google.com', body, {})
request.add_header('Content-Type', 'application/json')
remove_post_data_parameters(request, ['id', 'foo'])
assert request.body == b'{}'
def test_remove_nonexistent_json_post_data_parameters():
body = b'{}'
request = Request('POST', 'http://google.com', body, {})
request.add_header('Content-Type', 'application/json')
remove_post_data_parameters(request, ['id'])
assert request.body == b'{}'

View File

@@ -1,3 +1,5 @@
import os
import mock import mock
import pytest import pytest
@@ -9,7 +11,10 @@ from vcr.stubs import VCRHTTPSConnection
def test_vcr_use_cassette(): def test_vcr_use_cassette():
record_mode = mock.Mock() record_mode = mock.Mock()
test_vcr = VCR(record_mode=record_mode) test_vcr = VCR(record_mode=record_mode)
with mock.patch('vcr.cassette.Cassette.load', return_value=mock.MagicMock(inject=False)) as mock_cassette_load: with mock.patch(
'vcr.cassette.Cassette.load',
return_value=mock.MagicMock(inject=False)
) as mock_cassette_load:
@test_vcr.use_cassette('test') @test_vcr.use_cassette('test')
def function(): def function():
pass pass
@@ -87,7 +92,10 @@ def test_custom_patchers():
assert issubclass(Test.attribute, VCRHTTPSConnection) assert issubclass(Test.attribute, VCRHTTPSConnection)
assert VCRHTTPSConnection is not Test.attribute assert VCRHTTPSConnection is not Test.attribute
with test_vcr.use_cassette('custom_patches', custom_patches=((Test, 'attribute2', VCRHTTPSConnection),)): with test_vcr.use_cassette(
'custom_patches',
custom_patches=((Test, 'attribute2', VCRHTTPSConnection),)
):
assert issubclass(Test.attribute, VCRHTTPSConnection) assert issubclass(Test.attribute, VCRHTTPSConnection)
assert VCRHTTPSConnection is not Test.attribute assert VCRHTTPSConnection is not Test.attribute
assert Test.attribute is Test.attribute2 assert Test.attribute is Test.attribute2
@@ -128,3 +136,57 @@ def test_with_current_defaults():
vcr.record_mode = 'all' vcr.record_mode = 'all'
changing_defaults(assert_record_mode_all) changing_defaults(assert_record_mode_all)
current_defaults(assert_record_mode_once) current_defaults(assert_record_mode_once)
def test_cassette_library_dir_with_decoration_and_no_explicit_path():
library_dir = '/libary_dir'
vcr = VCR(inject_cassette=True, cassette_library_dir=library_dir)
@vcr.use_cassette()
def function_name(cassette):
assert cassette._path == os.path.join(library_dir, 'function_name')
function_name()
def test_cassette_library_dir_with_path_transformer():
library_dir = '/libary_dir'
vcr = VCR(inject_cassette=True, cassette_library_dir=library_dir,
path_transformer=lambda path: path + '.json')
@vcr.use_cassette()
def function_name(cassette):
assert cassette._path == os.path.join(library_dir, 'function_name.json')
function_name()
def test_use_cassette_with_no_extra_invocation():
vcr = VCR(inject_cassette=True, cassette_library_dir='/')
@vcr.use_cassette
def function_name(cassette):
assert cassette._path == os.path.join('/', 'function_name')
function_name()
def test_path_transformer():
vcr = VCR(inject_cassette=True, cassette_library_dir='/',
path_transformer=lambda x: x + '_test')
@vcr.use_cassette
def function_name(cassette):
assert cassette._path == os.path.join('/', 'function_name_test')
function_name()
def test_cassette_name_generator_defaults_to_using_module_function_defined_in():
vcr = VCR(inject_cassette=True)
@vcr.use_cassette
def function_name(cassette):
assert cassette._path == os.path.join(os.path.dirname(__file__),
'function_name')
function_name()
def test_ensure_suffix():
vcr = VCR(inject_cassette=True, path_transformer=VCR.ensure_suffix('.yaml'))
@vcr.use_cassette
def function_name(cassette):
assert cassette._path == os.path.join(os.path.dirname(__file__),
'function_name.yaml')
function_name()

View File

@@ -1,4 +1,5 @@
"""The container for recorded requests and responses""" """The container for recorded requests and responses"""
import functools
import logging import logging
import contextlib2 import contextlib2
@@ -9,11 +10,12 @@ except ImportError:
from backport_collections import Counter from backport_collections import Counter
# Internal imports # Internal imports
from .errors import UnhandledHTTPRequestError
from .matchers import requests_match, uri, method
from .patch import CassettePatcherBuilder from .patch import CassettePatcherBuilder
from .persist import load_cassette, save_cassette from .persist import load_cassette, save_cassette
from .serializers import yamlserializer from .serializers import yamlserializer
from .matchers import requests_match, uri, method from .util import partition_dict
from .errors import UnhandledHTTPRequestError
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
@@ -29,9 +31,11 @@ class CassetteContextDecorator(object):
from interfering with another. from interfering with another.
""" """
_non_cassette_arguments = ('path_transformer', 'func_path_generator')
@classmethod @classmethod
def from_args(cls, cassette_class, path, **kwargs): def from_args(cls, cassette_class, **kwargs):
return cls(cassette_class, lambda: (path, kwargs)) return cls(cassette_class, lambda: dict(kwargs))
def __init__(self, cls, args_getter): def __init__(self, cls, args_getter):
self.cls = cls self.cls = cls
@@ -49,10 +53,29 @@ class CassetteContextDecorator(object):
# somewhere else. # somewhere else.
cassette._save() cassette._save()
@classmethod
def key_predicate(cls, key, value):
return key in cls._non_cassette_arguments
@classmethod
def _split_keys(cls, kwargs):
return partition_dict(cls.key_predicate, kwargs)
def __enter__(self): def __enter__(self):
# This assertion is here to prevent the dangerous behavior
# that would result from forgetting about a __finish before
# completing it.
# How might this condition be met? Here is an example:
# context_decorator = Cassette.use('whatever')
# with context_decorator:
# with context_decorator:
# pass
assert self.__finish is None, "Cassette already open." assert self.__finish is None, "Cassette already open."
path, kwargs = self._args_getter() other_kwargs, cassette_kwargs = self._split_keys(self._args_getter())
self.__finish = self._patch_generator(self.cls.load(path, **kwargs)) if 'path_transformer' in other_kwargs:
transformer = other_kwargs['path_transformer']
cassette_kwargs['path'] = transformer(cassette_kwargs['path'])
self.__finish = self._patch_generator(self.cls.load(**cassette_kwargs))
return next(self.__finish) return next(self.__finish)
def __exit__(self, *args): def __exit__(self, *args):
@@ -61,20 +84,43 @@ class CassetteContextDecorator(object):
@wrapt.decorator @wrapt.decorator
def __call__(self, function, instance, args, kwargs): def __call__(self, function, instance, args, kwargs):
with self as cassette: # This awkward cloning thing is done to ensure that decorated
# functions are reentrant. This is required for thread
# safety and the correct operation of recursive functions.
args_getter = self._build_args_getter_for_decorator(
function, self._args_getter
)
clone = type(self)(self.cls, args_getter)
with clone as cassette:
if cassette.inject: if cassette.inject:
return function(cassette, *args, **kwargs) return function(cassette, *args, **kwargs)
else: else:
return function(*args, **kwargs) return function(*args, **kwargs)
@staticmethod
def get_function_name(function):
return function.__name__
@classmethod
def _build_args_getter_for_decorator(cls, function, args_getter):
def new_args_getter():
kwargs = args_getter()
if 'path' not in kwargs:
name_generator = (kwargs.get('func_path_generator') or
cls.get_function_name)
path = name_generator(function)
kwargs['path'] = path
return kwargs
return new_args_getter
class Cassette(object): class Cassette(object):
"""A container for recorded requests and responses""" """A container for recorded requests and responses"""
@classmethod @classmethod
def load(cls, path, **kwargs): def load(cls, **kwargs):
"""Instantiate and load the cassette stored at the specified path.""" """Instantiate and load the cassette stored at the specified path."""
new_cassette = cls(path, **kwargs) new_cassette = cls(**kwargs)
new_cassette._load() new_cassette._load()
return new_cassette return new_cassette
@@ -83,8 +129,8 @@ class Cassette(object):
return CassetteContextDecorator(cls, arg_getter) return CassetteContextDecorator(cls, arg_getter)
@classmethod @classmethod
def use(cls, *args, **kwargs): def use(cls, **kwargs):
return CassetteContextDecorator.from_args(cls, *args, **kwargs) return CassetteContextDecorator.from_args(cls, **kwargs)
def __init__(self, path, serializer=yamlserializer, record_mode='once', def __init__(self, path, serializer=yamlserializer, record_mode='once',
match_on=(uri, method), before_record_request=None, match_on=(uri, method), before_record_request=None,

View File

@@ -1,23 +1,35 @@
import collections import collections
import copy import copy
import functools import functools
import inspect
import os import os
import six
from .cassette import Cassette from .cassette import Cassette
from .serializers import yamlserializer, jsonserializer from .serializers import yamlserializer, jsonserializer
from .util import compose
from . import matchers from . import matchers
from . import filters from . import filters
class VCR(object): class VCR(object):
def __init__(self, serializer='yaml', cassette_library_dir=None, @staticmethod
record_mode="once", filter_headers=(), ignore_localhost=False, def ensure_suffix(suffix):
custom_patches=(), filter_query_parameters=(), def ensure(path):
filter_post_data_parameters=(), before_record_request=None, if not path.endswith(suffix):
before_record_response=None, ignore_hosts=(), return path + suffix
return path
return ensure
def __init__(self, path_transformer=lambda x: x, before_record_request=None,
custom_patches=(), filter_query_parameters=(), ignore_hosts=(),
record_mode="once", ignore_localhost=False, filter_headers=(),
before_record_response=None, filter_post_data_parameters=(),
match_on=('method', 'scheme', 'host', 'port', 'path', 'query'), match_on=('method', 'scheme', 'host', 'port', 'path', 'query'),
before_record=None, inject_cassette=False): before_record=None, inject_cassette=False, serializer='yaml',
cassette_library_dir=None, func_path_generator=None):
self.serializer = serializer self.serializer = serializer
self.match_on = match_on self.match_on = match_on
self.cassette_library_dir = cassette_library_dir self.cassette_library_dir = cassette_library_dir
@@ -46,6 +58,8 @@ class VCR(object):
self.ignore_hosts = ignore_hosts self.ignore_hosts = ignore_hosts
self.ignore_localhost = ignore_localhost self.ignore_localhost = ignore_localhost
self.inject_cassette = inject_cassette self.inject_cassette = inject_cassette
self.path_transformer = path_transformer
self.func_path_generator = func_path_generator
self._custom_patches = tuple(custom_patches) self._custom_patches = tuple(custom_patches)
def _get_serializer(self, serializer_name): def _get_serializer(self, serializer_name):
@@ -69,27 +83,48 @@ class VCR(object):
) )
return matchers return matchers
def use_cassette(self, path, with_current_defaults=False, **kwargs): def use_cassette(self, path=None, **kwargs):
if path is not None and not isinstance(path, six.string_types):
function = path
# Assume this is an attempt to decorate a function
return self._use_cassette(**kwargs)(function)
return self._use_cassette(path=path, **kwargs)
def _use_cassette(self, with_current_defaults=False, **kwargs):
if with_current_defaults: if with_current_defaults:
path, config = self.get_path_and_merged_config(path, **kwargs) config = self.get_merged_config(**kwargs)
return Cassette.use(path, **config) return Cassette.use(**config)
# This is made a function that evaluates every time a cassette # This is made a function that evaluates every time a cassette
# is made so that changes that are made to this VCR instance # is made so that changes that are made to this VCR instance
# that occur AFTER the `use_cassette` decorator is applied # that occur AFTER the `use_cassette` decorator is applied
# still affect subsequent calls to the decorated function. # still affect subsequent calls to the decorated function.
args_getter = functools.partial(self.get_path_and_merged_config, args_getter = functools.partial(self.get_merged_config, **kwargs)
path, **kwargs)
return Cassette.use_arg_getter(args_getter) return Cassette.use_arg_getter(args_getter)
def get_path_and_merged_config(self, path, **kwargs): def get_merged_config(self, **kwargs):
serializer_name = kwargs.get('serializer', self.serializer) serializer_name = kwargs.get('serializer', self.serializer)
matcher_names = kwargs.get('match_on', self.match_on) matcher_names = kwargs.get('match_on', self.match_on)
path_transformer = kwargs.get(
'path_transformer',
self.path_transformer
)
func_path_generator = kwargs.get(
'func_path_generator',
self.func_path_generator
)
cassette_library_dir = kwargs.get( cassette_library_dir = kwargs.get(
'cassette_library_dir', 'cassette_library_dir',
self.cassette_library_dir self.cassette_library_dir
) )
if cassette_library_dir: if cassette_library_dir:
path = os.path.join(cassette_library_dir, path) def add_cassette_library_dir(path):
if not path.startswith(cassette_library_dir):
return os.path.join(cassette_library_dir, path)
path_transformer = compose(add_cassette_library_dir, path_transformer)
elif not func_path_generator:
# If we don't have a library dir, use the functions
# location to build a full path for cassettes.
func_path_generator = self._build_path_from_func_using_module
merged_config = { merged_config = {
'serializer': self._get_serializer(serializer_name), 'serializer': self._get_serializer(serializer_name),
@@ -102,9 +137,14 @@ class VCR(object):
'custom_patches': self._custom_patches + kwargs.get( 'custom_patches': self._custom_patches + kwargs.get(
'custom_patches', () 'custom_patches', ()
), ),
'inject': kwargs.get('inject_cassette', self.inject_cassette) 'inject': kwargs.get('inject_cassette', self.inject_cassette),
'path_transformer': path_transformer,
'func_path_generator': func_path_generator
} }
return path, merged_config path = kwargs.get('path')
if path:
merged_config['path'] = path
return merged_config
def _build_before_record_response(self, options): def _build_before_record_response(self, options):
before_record_response = options.get( before_record_response = options.get(
@@ -185,6 +225,11 @@ class VCR(object):
return request return request
return filter_ignored_hosts return filter_ignored_hosts
@staticmethod
def _build_path_from_func_using_module(function):
return os.path.join(os.path.dirname(inspect.getfile(function)),
function.__name__)
def register_serializer(self, name, serializer): def register_serializer(self, name, serializer):
self.serializers[name] = serializer self.serializers[name] = serializer

View File

@@ -5,6 +5,7 @@ try:
except ImportError: except ImportError:
from backport_collections import OrderedDict from backport_collections import OrderedDict
import copy import copy
import json
def remove_headers(request, headers_to_remove): def remove_headers(request, headers_to_remove):
@@ -31,13 +32,21 @@ def remove_query_parameters(request, query_parameters_to_remove):
def remove_post_data_parameters(request, post_data_parameters_to_remove): def remove_post_data_parameters(request, post_data_parameters_to_remove):
if request.method == 'POST' and not isinstance(request.body, BytesIO): if request.method == 'POST' and not isinstance(request.body, BytesIO):
post_data = OrderedDict() if ('Content-Type' in request.headers and
for k, sep, v in [p.partition(b'=') for p in request.body.split(b'&')]: request.headers['Content-Type'] == 'application/json'):
if k in post_data: json_data = json.loads(request.body.decode('utf-8'))
post_data[k].append(v) for k in list(json_data.keys()):
elif len(k) > 0 and k.decode('utf-8') not in post_data_parameters_to_remove: if k in post_data_parameters_to_remove:
post_data[k] = [v] del json_data[k]
request.body = b'&'.join( request.body = json.dumps(json_data).encode('utf-8')
b'='.join([k, v]) else:
for k, vals in post_data.items() for v in vals) post_data = OrderedDict()
for k, sep, v in [p.partition(b'=') for p in request.body.split(b'&')]:
if k in post_data:
post_data[k].append(v)
elif len(k) > 0 and k.decode('utf-8') not in post_data_parameters_to_remove:
post_data[k] = [v]
request.body = b'&'.join(
b'='.join([k, v])
for k, vals in post_data.items() for v in vals)
return request return request

16
vcr/util.py Normal file
View File

@@ -0,0 +1,16 @@
def partition_dict(predicate, dictionary):
true_dict = {}
false_dict = {}
for key, value in dictionary.items():
this_dict = true_dict if predicate(key, value) else false_dict
this_dict[key] = value
return true_dict, false_dict
def compose(*functions):
def composed(incoming):
res = incoming
for function in functions[::-1]:
res = function(res)
return res
return composed