1
0
mirror of https://github.com/kevin1024/vcrpy.git synced 2025-12-09 01:03:24 +00:00

Automatically generate cassette names from function names. Add

`path_transformer` and `func_path_generator`. Closes #151.
This commit is contained in:
Ivan Malison
2015-05-10 03:22:43 -07:00
parent 0bbbc694b0
commit 2323b9da5f
6 changed files with 279 additions and 48 deletions

View File

@@ -389,6 +389,43 @@ my_vcr = config.VCR(custom_patches=((where_the_custom_https_connection_lives, 'C
@my_vcr.use_cassette(...)
```
## Automatic Cassette Naming
VCR.py now allows the omission of the path argument to the
use_cassette function. Both of the following are now legal/should work
``` python
@my_vcr.use_cassette
def my_test_function():
...
```
``` python
@my_vcr.use_cassette()
def my_test_function():
...
```
In both cases, VCR.py will use a path that is generated from the
provided test function's name. If no `cassette_library_dir` has been
set, the cassette will be in a file with the name of the test function
in directory of the file in which the test function is declared. If a
`cassette_library_dir` is set, has been set, the cassette will appear
in that directory in a file with the name of the decorated function.
It is possible to control the path produced by the automatic naming
machinery by customizing the `path_transformer` and
`func_path_generator` vcr variables. To add an extension to all
cassette names, use `VCR.ensure_suffix` as follows:
``` python
my_vcr = VCR(path_transformer=VCR.ensure_suffix('.yaml'))
@my_vcr.use_cassette
def my_test_function():
```
## Installation
VCR.py is a package on PyPI, so you can `pip install vcrpy` (first you may need

View File

@@ -1,4 +1,6 @@
import copy
import inspect
import os
from six.moves import http_client as httplib
import contextlib2
@@ -12,14 +14,13 @@ from vcr.patch import force_reset
from vcr.stubs import VCRHTTPSConnection
def test_cassette_load(tmpdir):
a_file = tmpdir.join('test_cassette.yml')
a_file.write(yaml.dump({'interactions': [
{'request': {'body': '', 'uri': 'foo', 'method': 'GET', 'headers': {}},
'response': 'bar'}
]}))
a_cassette = Cassette.load(str(a_file))
a_cassette = Cassette.load(path=str(a_file))
assert len(a_cassette) == 1
@@ -87,33 +88,35 @@ def make_get_request():
@mock.patch('vcr.cassette.Cassette.can_play_response_for', return_value=True)
@mock.patch('vcr.stubs.VCRHTTPResponse')
def test_function_decorated_with_use_cassette_can_be_invoked_multiple_times(*args):
decorated_function = Cassette.use('test')(make_get_request)
for i in range(2):
decorated_function = Cassette.use(path='test')(make_get_request)
for i in range(4):
decorated_function()
def test_arg_getter_functionality():
arg_getter = mock.Mock(return_value=('test', {}))
arg_getter = mock.Mock(return_value={'path': 'test'})
context_decorator = Cassette.use_arg_getter(arg_getter)
with context_decorator as cassette:
assert cassette._path == 'test'
arg_getter.return_value = ('other', {})
arg_getter.return_value = {'path': 'other'}
with context_decorator as cassette:
assert cassette._path == 'other'
arg_getter.return_value = ('', {'filter_headers': ('header_name',)})
arg_getter.return_value = {'path': 'other', 'filter_headers': ('header_name',)}
@context_decorator
def function():
pass
with mock.patch.object(Cassette, 'load', return_value=mock.MagicMock(inject=False)) as cassette_load:
with mock.patch.object(
Cassette, 'load',
return_value=mock.MagicMock(inject=False)
) as cassette_load:
function()
cassette_load.assert_called_once_with(arg_getter.return_value[0],
**arg_getter.return_value[1])
cassette_load.assert_called_once_with(**arg_getter.return_value)
def test_cassette_not_all_played():
@@ -156,13 +159,13 @@ def test_nesting_cassette_context_managers(*args):
second_response['body']['string'] = b'second_response'
with contextlib2.ExitStack() as exit_stack:
first_cassette = exit_stack.enter_context(Cassette.use('test'))
first_cassette = exit_stack.enter_context(Cassette.use(path='test'))
exit_stack.enter_context(mock.patch.object(first_cassette, 'play_response',
return_value=first_response))
assert_get_response_body_is('first_response')
# Make sure a second cassette can supercede the first
with Cassette.use('test') as second_cassette:
with Cassette.use(path='test') as second_cassette:
with mock.patch.object(second_cassette, 'play_response', return_value=second_response):
assert_get_response_body_is('second_response')
@@ -172,12 +175,12 @@ def test_nesting_cassette_context_managers(*args):
def test_nesting_context_managers_by_checking_references_of_http_connection():
original = httplib.HTTPConnection
with Cassette.use('test'):
with Cassette.use(path='test'):
first_cassette_HTTPConnection = httplib.HTTPConnection
with Cassette.use('test'):
with Cassette.use(path='test'):
second_cassette_HTTPConnection = httplib.HTTPConnection
assert second_cassette_HTTPConnection is not first_cassette_HTTPConnection
with Cassette.use('test'):
with Cassette.use(path='test'):
assert httplib.HTTPConnection is not second_cassette_HTTPConnection
with force_reset():
assert httplib.HTTPConnection is original
@@ -188,12 +191,14 @@ def test_nesting_context_managers_by_checking_references_of_http_connection():
def test_custom_patchers():
class Test(object):
attribute = None
with Cassette.use('custom_patches', custom_patches=((Test, 'attribute', VCRHTTPSConnection),)):
with Cassette.use(path='custom_patches',
custom_patches=((Test, 'attribute', VCRHTTPSConnection),)):
assert issubclass(Test.attribute, VCRHTTPSConnection)
assert VCRHTTPSConnection is not Test.attribute
old_attribute = Test.attribute
with Cassette.use('custom_patches', custom_patches=((Test, 'attribute', VCRHTTPSConnection),)):
with Cassette.use(path='custom_patches',
custom_patches=((Test, 'attribute', VCRHTTPSConnection),)):
assert issubclass(Test.attribute, VCRHTTPSConnection)
assert VCRHTTPSConnection is not Test.attribute
assert Test.attribute is not old_attribute
@@ -203,10 +208,10 @@ def test_custom_patchers():
assert Test.attribute is old_attribute
def test_use_cassette_decorated_functions_are_reentrant():
def test_decorated_functions_are_reentrant():
info = {"second": False}
original_conn = httplib.HTTPConnection
@Cassette.use('whatever', inject=True)
@Cassette.use(path='whatever', inject=True)
def test_function(cassette):
if info['second']:
assert httplib.HTTPConnection is not info['first_conn']
@@ -217,3 +222,35 @@ def test_use_cassette_decorated_functions_are_reentrant():
assert httplib.HTTPConnection is info['first_conn']
test_function()
assert httplib.HTTPConnection is original_conn
def test_cassette_use_called_without_path_uses_function_to_generate_path():
@Cassette.use(inject=True)
def function_name(cassette):
assert cassette._path == 'function_name'
function_name()
def test_path_transformer_with_function_path():
path_transformer = lambda path: os.path.join('a', path)
@Cassette.use(inject=True, path_transformer=path_transformer)
def function_name(cassette):
assert cassette._path == os.path.join('a', 'function_name')
function_name()
def test_path_transformer_with_context_manager():
with Cassette.use(
path='b', path_transformer=lambda *args: 'a'
) as cassette:
assert cassette._path == 'a'
def test_func_path_generator():
def generator(function):
return os.path.join(os.path.dirname(inspect.getfile(function)),
function.__name__)
@Cassette.use(inject=True, func_path_generator=generator)
def function_name(cassette):
assert cassette._path == os.path.join(os.path.dirname(__file__), 'function_name')
function_name()

View File

@@ -1,3 +1,5 @@
import os
import mock
import pytest
@@ -9,7 +11,10 @@ from vcr.stubs import VCRHTTPSConnection
def test_vcr_use_cassette():
record_mode = mock.Mock()
test_vcr = VCR(record_mode=record_mode)
with mock.patch('vcr.cassette.Cassette.load', return_value=mock.MagicMock(inject=False)) as mock_cassette_load:
with mock.patch(
'vcr.cassette.Cassette.load',
return_value=mock.MagicMock(inject=False)
) as mock_cassette_load:
@test_vcr.use_cassette('test')
def function():
pass
@@ -87,7 +92,10 @@ def test_custom_patchers():
assert issubclass(Test.attribute, VCRHTTPSConnection)
assert VCRHTTPSConnection is not Test.attribute
with test_vcr.use_cassette('custom_patches', custom_patches=((Test, 'attribute2', VCRHTTPSConnection),)):
with test_vcr.use_cassette(
'custom_patches',
custom_patches=((Test, 'attribute2', VCRHTTPSConnection),)
):
assert issubclass(Test.attribute, VCRHTTPSConnection)
assert VCRHTTPSConnection is not Test.attribute
assert Test.attribute is Test.attribute2
@@ -128,3 +136,57 @@ def test_with_current_defaults():
vcr.record_mode = 'all'
changing_defaults(assert_record_mode_all)
current_defaults(assert_record_mode_once)
def test_cassette_library_dir_with_decoration_and_no_explicit_path():
library_dir = '/libary_dir'
vcr = VCR(inject_cassette=True, cassette_library_dir=library_dir)
@vcr.use_cassette()
def function_name(cassette):
assert cassette._path == os.path.join(library_dir, 'function_name')
function_name()
def test_cassette_library_dir_with_path_transformer():
library_dir = '/libary_dir'
vcr = VCR(inject_cassette=True, cassette_library_dir=library_dir,
path_transformer=lambda path: path + '.json')
@vcr.use_cassette()
def function_name(cassette):
assert cassette._path == os.path.join(library_dir, 'function_name.json')
function_name()
def test_use_cassette_with_no_extra_invocation():
vcr = VCR(inject_cassette=True, cassette_library_dir='/')
@vcr.use_cassette
def function_name(cassette):
assert cassette._path == os.path.join('/', 'function_name')
function_name()
def test_path_transformer():
vcr = VCR(inject_cassette=True, cassette_library_dir='/',
path_transformer=lambda x: x + '_test')
@vcr.use_cassette
def function_name(cassette):
assert cassette._path == os.path.join('/', 'function_name_test')
function_name()
def test_cassette_name_generator_defaults_to_using_module_function_defined_in():
vcr = VCR(inject_cassette=True)
@vcr.use_cassette
def function_name(cassette):
assert cassette._path == os.path.join(os.path.dirname(__file__),
'function_name')
function_name()
def test_ensure_suffix():
vcr = VCR(inject_cassette=True, path_transformer=VCR.ensure_suffix('.yaml'))
@vcr.use_cassette
def function_name(cassette):
assert cassette._path == os.path.join(os.path.dirname(__file__),
'function_name.yaml')
function_name()

View File

@@ -1,4 +1,5 @@
"""The container for recorded requests and responses"""
import functools
import logging
import contextlib2
@@ -9,11 +10,12 @@ except ImportError:
from backport_collections import Counter
# Internal imports
from .errors import UnhandledHTTPRequestError
from .matchers import requests_match, uri, method
from .patch import CassettePatcherBuilder
from .persist import load_cassette, save_cassette
from .serializers import yamlserializer
from .matchers import requests_match, uri, method
from .errors import UnhandledHTTPRequestError
from .util import partition_dict
log = logging.getLogger(__name__)
@@ -29,9 +31,11 @@ class CassetteContextDecorator(object):
from interfering with another.
"""
_non_cassette_arguments = ('path_transformer', 'func_path_generator')
@classmethod
def from_args(cls, cassette_class, path, **kwargs):
return cls(cassette_class, lambda: (path, kwargs))
def from_args(cls, cassette_class, **kwargs):
return cls(cassette_class, lambda: dict(kwargs))
def __init__(self, cls, args_getter):
self.cls = cls
@@ -49,6 +53,14 @@ class CassetteContextDecorator(object):
# somewhere else.
cassette._save()
@classmethod
def key_predicate(cls, key, value):
return key in cls._non_cassette_arguments
@classmethod
def _split_keys(cls, kwargs):
return partition_dict(cls.key_predicate, kwargs)
def __enter__(self):
# This assertion is here to prevent the dangerous behavior
# that would result from forgetting about a __finish before
@@ -59,8 +71,11 @@ class CassetteContextDecorator(object):
# with context_decorator:
# pass
assert self.__finish is None, "Cassette already open."
path, kwargs = self._args_getter()
self.__finish = self._patch_generator(self.cls.load(path, **kwargs))
other_kwargs, cassette_kwargs = self._split_keys(self._args_getter())
if 'path_transformer' in other_kwargs:
transformer = other_kwargs['path_transformer']
cassette_kwargs['path'] = transformer(cassette_kwargs['path'])
self.__finish = self._patch_generator(self.cls.load(**cassette_kwargs))
return next(self.__finish)
def __exit__(self, *args):
@@ -70,23 +85,42 @@ class CassetteContextDecorator(object):
@wrapt.decorator
def __call__(self, function, instance, args, kwargs):
# This awkward cloning thing is done to ensure that decorated
# functions are reentrant. Reentrancy is required for thread
# functions are reentrant. This is required for thread
# safety and the correct operation of recursive functions.
clone = type(self)(self.cls, self._args_getter)
args_getter = self._build_args_getter_for_decorator(
function, self._args_getter
)
clone = type(self)(self.cls, args_getter)
with clone as cassette:
if cassette.inject:
return function(cassette, *args, **kwargs)
else:
return function(*args, **kwargs)
@staticmethod
def get_function_name(function):
return function.__name__
@classmethod
def _build_args_getter_for_decorator(cls, function, args_getter):
def new_args_getter():
kwargs = args_getter()
if 'path' not in kwargs:
name_generator = (kwargs.get('func_path_generator') or
cls.get_function_name)
path = name_generator(function)
kwargs['path'] = path
return kwargs
return new_args_getter
class Cassette(object):
"""A container for recorded requests and responses"""
@classmethod
def load(cls, path, **kwargs):
def load(cls, **kwargs):
"""Instantiate and load the cassette stored at the specified path."""
new_cassette = cls(path, **kwargs)
new_cassette = cls(**kwargs)
new_cassette._load()
return new_cassette
@@ -95,8 +129,8 @@ class Cassette(object):
return CassetteContextDecorator(cls, arg_getter)
@classmethod
def use(cls, *args, **kwargs):
return CassetteContextDecorator.from_args(cls, *args, **kwargs)
def use(cls, **kwargs):
return CassetteContextDecorator.from_args(cls, **kwargs)
def __init__(self, path, serializer=yamlserializer, record_mode='once',
match_on=(uri, method), before_record_request=None,

View File

@@ -1,23 +1,35 @@
import collections
import copy
import functools
import inspect
import os
import six
from .cassette import Cassette
from .serializers import yamlserializer, jsonserializer
from .util import compose
from . import matchers
from . import filters
class VCR(object):
def __init__(self, serializer='yaml', cassette_library_dir=None,
record_mode="once", filter_headers=(), ignore_localhost=False,
custom_patches=(), filter_query_parameters=(),
filter_post_data_parameters=(), before_record_request=None,
before_record_response=None, ignore_hosts=(),
@staticmethod
def ensure_suffix(suffix):
def ensure(path):
if not path.endswith(suffix):
return path + suffix
return path
return ensure
def __init__(self, path_transformer=lambda x: x, before_record_request=None,
custom_patches=(), filter_query_parameters=(), ignore_hosts=(),
record_mode="once", ignore_localhost=False, filter_headers=(),
before_record_response=None, filter_post_data_parameters=(),
match_on=('method', 'scheme', 'host', 'port', 'path', 'query'),
before_record=None, inject_cassette=False):
before_record=None, inject_cassette=False, serializer='yaml',
cassette_library_dir=None, func_path_generator=None):
self.serializer = serializer
self.match_on = match_on
self.cassette_library_dir = cassette_library_dir
@@ -46,6 +58,8 @@ class VCR(object):
self.ignore_hosts = ignore_hosts
self.ignore_localhost = ignore_localhost
self.inject_cassette = inject_cassette
self.path_transformer = path_transformer
self.func_path_generator = func_path_generator
self._custom_patches = tuple(custom_patches)
def _get_serializer(self, serializer_name):
@@ -69,27 +83,48 @@ class VCR(object):
)
return matchers
def use_cassette(self, path, with_current_defaults=False, **kwargs):
def use_cassette(self, path=None, **kwargs):
if path is not None and not isinstance(path, six.string_types):
function = path
# Assume this is an attempt to decorate a function
return self._use_cassette(**kwargs)(function)
return self._use_cassette(path=path, **kwargs)
def _use_cassette(self, with_current_defaults=False, **kwargs):
if with_current_defaults:
path, config = self.get_path_and_merged_config(path, **kwargs)
return Cassette.use(path, **config)
config = self.get_merged_config(**kwargs)
return Cassette.use(**config)
# This is made a function that evaluates every time a cassette
# is made so that changes that are made to this VCR instance
# that occur AFTER the `use_cassette` decorator is applied
# still affect subsequent calls to the decorated function.
args_getter = functools.partial(self.get_path_and_merged_config,
path, **kwargs)
args_getter = functools.partial(self.get_merged_config, **kwargs)
return Cassette.use_arg_getter(args_getter)
def get_path_and_merged_config(self, path, **kwargs):
def get_merged_config(self, **kwargs):
serializer_name = kwargs.get('serializer', self.serializer)
matcher_names = kwargs.get('match_on', self.match_on)
path_transformer = kwargs.get(
'path_transformer',
self.path_transformer
)
func_path_generator = kwargs.get(
'func_path_generator',
self.func_path_generator
)
cassette_library_dir = kwargs.get(
'cassette_library_dir',
self.cassette_library_dir
)
if cassette_library_dir:
path = os.path.join(cassette_library_dir, path)
def add_cassette_library_dir(path):
if not path.startswith(cassette_library_dir):
return os.path.join(cassette_library_dir, path)
path_transformer = compose(add_cassette_library_dir, path_transformer)
elif not func_path_generator:
# If we don't have a library dir, use the functions
# location to build a full path for cassettes.
func_path_generator = self._build_path_from_func_using_module
merged_config = {
'serializer': self._get_serializer(serializer_name),
@@ -102,9 +137,14 @@ class VCR(object):
'custom_patches': self._custom_patches + kwargs.get(
'custom_patches', ()
),
'inject': kwargs.get('inject_cassette', self.inject_cassette)
'inject': kwargs.get('inject_cassette', self.inject_cassette),
'path_transformer': path_transformer,
'func_path_generator': func_path_generator
}
return path, merged_config
path = kwargs.get('path')
if path:
merged_config['path'] = path
return merged_config
def _build_before_record_response(self, options):
before_record_response = options.get(
@@ -185,6 +225,11 @@ class VCR(object):
return request
return filter_ignored_hosts
@staticmethod
def _build_path_from_func_using_module(function):
return os.path.join(os.path.dirname(inspect.getfile(function)),
function.__name__)
def register_serializer(self, name, serializer):
self.serializers[name] = serializer

16
vcr/util.py Normal file
View File

@@ -0,0 +1,16 @@
def partition_dict(predicate, dictionary):
true_dict = {}
false_dict = {}
for key, value in dictionary.items():
this_dict = true_dict if predicate(key, value) else false_dict
this_dict[key] = value
return true_dict, false_dict
def compose(*functions):
def composed(incoming):
res = incoming
for function in functions[::-1]:
res = function(res)
return res
return composed