1
0
mirror of https://github.com/kevin1024/vcrpy.git synced 2025-12-09 01:03:24 +00:00

Automatically generate cassette names from function names. Add

`path_transformer` and `func_path_generator`. Closes #151.
This commit is contained in:
Ivan Malison
2015-05-10 03:22:43 -07:00
parent 0bbbc694b0
commit 2323b9da5f
6 changed files with 279 additions and 48 deletions

View File

@@ -389,6 +389,43 @@ my_vcr = config.VCR(custom_patches=((where_the_custom_https_connection_lives, 'C
@my_vcr.use_cassette(...) @my_vcr.use_cassette(...)
``` ```
## Automatic Cassette Naming
VCR.py now allows the omission of the path argument to the
use_cassette function. Both of the following are now legal/should work
``` python
@my_vcr.use_cassette
def my_test_function():
...
```
``` python
@my_vcr.use_cassette()
def my_test_function():
...
```
In both cases, VCR.py will use a path that is generated from the
provided test function's name. If no `cassette_library_dir` has been
set, the cassette will be in a file with the name of the test function
in directory of the file in which the test function is declared. If a
`cassette_library_dir` is set, has been set, the cassette will appear
in that directory in a file with the name of the decorated function.
It is possible to control the path produced by the automatic naming
machinery by customizing the `path_transformer` and
`func_path_generator` vcr variables. To add an extension to all
cassette names, use `VCR.ensure_suffix` as follows:
``` python
my_vcr = VCR(path_transformer=VCR.ensure_suffix('.yaml'))
@my_vcr.use_cassette
def my_test_function():
```
## Installation ## Installation
VCR.py is a package on PyPI, so you can `pip install vcrpy` (first you may need VCR.py is a package on PyPI, so you can `pip install vcrpy` (first you may need

View File

@@ -1,4 +1,6 @@
import copy import copy
import inspect
import os
from six.moves import http_client as httplib from six.moves import http_client as httplib
import contextlib2 import contextlib2
@@ -12,14 +14,13 @@ from vcr.patch import force_reset
from vcr.stubs import VCRHTTPSConnection from vcr.stubs import VCRHTTPSConnection
def test_cassette_load(tmpdir): def test_cassette_load(tmpdir):
a_file = tmpdir.join('test_cassette.yml') a_file = tmpdir.join('test_cassette.yml')
a_file.write(yaml.dump({'interactions': [ a_file.write(yaml.dump({'interactions': [
{'request': {'body': '', 'uri': 'foo', 'method': 'GET', 'headers': {}}, {'request': {'body': '', 'uri': 'foo', 'method': 'GET', 'headers': {}},
'response': 'bar'} 'response': 'bar'}
]})) ]}))
a_cassette = Cassette.load(str(a_file)) a_cassette = Cassette.load(path=str(a_file))
assert len(a_cassette) == 1 assert len(a_cassette) == 1
@@ -87,33 +88,35 @@ def make_get_request():
@mock.patch('vcr.cassette.Cassette.can_play_response_for', return_value=True) @mock.patch('vcr.cassette.Cassette.can_play_response_for', return_value=True)
@mock.patch('vcr.stubs.VCRHTTPResponse') @mock.patch('vcr.stubs.VCRHTTPResponse')
def test_function_decorated_with_use_cassette_can_be_invoked_multiple_times(*args): def test_function_decorated_with_use_cassette_can_be_invoked_multiple_times(*args):
decorated_function = Cassette.use('test')(make_get_request) decorated_function = Cassette.use(path='test')(make_get_request)
for i in range(2): for i in range(4):
decorated_function() decorated_function()
def test_arg_getter_functionality(): def test_arg_getter_functionality():
arg_getter = mock.Mock(return_value=('test', {})) arg_getter = mock.Mock(return_value={'path': 'test'})
context_decorator = Cassette.use_arg_getter(arg_getter) context_decorator = Cassette.use_arg_getter(arg_getter)
with context_decorator as cassette: with context_decorator as cassette:
assert cassette._path == 'test' assert cassette._path == 'test'
arg_getter.return_value = ('other', {}) arg_getter.return_value = {'path': 'other'}
with context_decorator as cassette: with context_decorator as cassette:
assert cassette._path == 'other' assert cassette._path == 'other'
arg_getter.return_value = ('', {'filter_headers': ('header_name',)}) arg_getter.return_value = {'path': 'other', 'filter_headers': ('header_name',)}
@context_decorator @context_decorator
def function(): def function():
pass pass
with mock.patch.object(Cassette, 'load', return_value=mock.MagicMock(inject=False)) as cassette_load: with mock.patch.object(
Cassette, 'load',
return_value=mock.MagicMock(inject=False)
) as cassette_load:
function() function()
cassette_load.assert_called_once_with(arg_getter.return_value[0], cassette_load.assert_called_once_with(**arg_getter.return_value)
**arg_getter.return_value[1])
def test_cassette_not_all_played(): def test_cassette_not_all_played():
@@ -156,13 +159,13 @@ def test_nesting_cassette_context_managers(*args):
second_response['body']['string'] = b'second_response' second_response['body']['string'] = b'second_response'
with contextlib2.ExitStack() as exit_stack: with contextlib2.ExitStack() as exit_stack:
first_cassette = exit_stack.enter_context(Cassette.use('test')) first_cassette = exit_stack.enter_context(Cassette.use(path='test'))
exit_stack.enter_context(mock.patch.object(first_cassette, 'play_response', exit_stack.enter_context(mock.patch.object(first_cassette, 'play_response',
return_value=first_response)) return_value=first_response))
assert_get_response_body_is('first_response') assert_get_response_body_is('first_response')
# Make sure a second cassette can supercede the first # Make sure a second cassette can supercede the first
with Cassette.use('test') as second_cassette: with Cassette.use(path='test') as second_cassette:
with mock.patch.object(second_cassette, 'play_response', return_value=second_response): with mock.patch.object(second_cassette, 'play_response', return_value=second_response):
assert_get_response_body_is('second_response') assert_get_response_body_is('second_response')
@@ -172,12 +175,12 @@ def test_nesting_cassette_context_managers(*args):
def test_nesting_context_managers_by_checking_references_of_http_connection(): def test_nesting_context_managers_by_checking_references_of_http_connection():
original = httplib.HTTPConnection original = httplib.HTTPConnection
with Cassette.use('test'): with Cassette.use(path='test'):
first_cassette_HTTPConnection = httplib.HTTPConnection first_cassette_HTTPConnection = httplib.HTTPConnection
with Cassette.use('test'): with Cassette.use(path='test'):
second_cassette_HTTPConnection = httplib.HTTPConnection second_cassette_HTTPConnection = httplib.HTTPConnection
assert second_cassette_HTTPConnection is not first_cassette_HTTPConnection assert second_cassette_HTTPConnection is not first_cassette_HTTPConnection
with Cassette.use('test'): with Cassette.use(path='test'):
assert httplib.HTTPConnection is not second_cassette_HTTPConnection assert httplib.HTTPConnection is not second_cassette_HTTPConnection
with force_reset(): with force_reset():
assert httplib.HTTPConnection is original assert httplib.HTTPConnection is original
@@ -188,12 +191,14 @@ def test_nesting_context_managers_by_checking_references_of_http_connection():
def test_custom_patchers(): def test_custom_patchers():
class Test(object): class Test(object):
attribute = None attribute = None
with Cassette.use('custom_patches', custom_patches=((Test, 'attribute', VCRHTTPSConnection),)): with Cassette.use(path='custom_patches',
custom_patches=((Test, 'attribute', VCRHTTPSConnection),)):
assert issubclass(Test.attribute, VCRHTTPSConnection) assert issubclass(Test.attribute, VCRHTTPSConnection)
assert VCRHTTPSConnection is not Test.attribute assert VCRHTTPSConnection is not Test.attribute
old_attribute = Test.attribute old_attribute = Test.attribute
with Cassette.use('custom_patches', custom_patches=((Test, 'attribute', VCRHTTPSConnection),)): with Cassette.use(path='custom_patches',
custom_patches=((Test, 'attribute', VCRHTTPSConnection),)):
assert issubclass(Test.attribute, VCRHTTPSConnection) assert issubclass(Test.attribute, VCRHTTPSConnection)
assert VCRHTTPSConnection is not Test.attribute assert VCRHTTPSConnection is not Test.attribute
assert Test.attribute is not old_attribute assert Test.attribute is not old_attribute
@@ -203,10 +208,10 @@ def test_custom_patchers():
assert Test.attribute is old_attribute assert Test.attribute is old_attribute
def test_use_cassette_decorated_functions_are_reentrant(): def test_decorated_functions_are_reentrant():
info = {"second": False} info = {"second": False}
original_conn = httplib.HTTPConnection original_conn = httplib.HTTPConnection
@Cassette.use('whatever', inject=True) @Cassette.use(path='whatever', inject=True)
def test_function(cassette): def test_function(cassette):
if info['second']: if info['second']:
assert httplib.HTTPConnection is not info['first_conn'] assert httplib.HTTPConnection is not info['first_conn']
@@ -217,3 +222,35 @@ def test_use_cassette_decorated_functions_are_reentrant():
assert httplib.HTTPConnection is info['first_conn'] assert httplib.HTTPConnection is info['first_conn']
test_function() test_function()
assert httplib.HTTPConnection is original_conn assert httplib.HTTPConnection is original_conn
def test_cassette_use_called_without_path_uses_function_to_generate_path():
@Cassette.use(inject=True)
def function_name(cassette):
assert cassette._path == 'function_name'
function_name()
def test_path_transformer_with_function_path():
path_transformer = lambda path: os.path.join('a', path)
@Cassette.use(inject=True, path_transformer=path_transformer)
def function_name(cassette):
assert cassette._path == os.path.join('a', 'function_name')
function_name()
def test_path_transformer_with_context_manager():
with Cassette.use(
path='b', path_transformer=lambda *args: 'a'
) as cassette:
assert cassette._path == 'a'
def test_func_path_generator():
def generator(function):
return os.path.join(os.path.dirname(inspect.getfile(function)),
function.__name__)
@Cassette.use(inject=True, func_path_generator=generator)
def function_name(cassette):
assert cassette._path == os.path.join(os.path.dirname(__file__), 'function_name')
function_name()

View File

@@ -1,3 +1,5 @@
import os
import mock import mock
import pytest import pytest
@@ -9,7 +11,10 @@ from vcr.stubs import VCRHTTPSConnection
def test_vcr_use_cassette(): def test_vcr_use_cassette():
record_mode = mock.Mock() record_mode = mock.Mock()
test_vcr = VCR(record_mode=record_mode) test_vcr = VCR(record_mode=record_mode)
with mock.patch('vcr.cassette.Cassette.load', return_value=mock.MagicMock(inject=False)) as mock_cassette_load: with mock.patch(
'vcr.cassette.Cassette.load',
return_value=mock.MagicMock(inject=False)
) as mock_cassette_load:
@test_vcr.use_cassette('test') @test_vcr.use_cassette('test')
def function(): def function():
pass pass
@@ -87,7 +92,10 @@ def test_custom_patchers():
assert issubclass(Test.attribute, VCRHTTPSConnection) assert issubclass(Test.attribute, VCRHTTPSConnection)
assert VCRHTTPSConnection is not Test.attribute assert VCRHTTPSConnection is not Test.attribute
with test_vcr.use_cassette('custom_patches', custom_patches=((Test, 'attribute2', VCRHTTPSConnection),)): with test_vcr.use_cassette(
'custom_patches',
custom_patches=((Test, 'attribute2', VCRHTTPSConnection),)
):
assert issubclass(Test.attribute, VCRHTTPSConnection) assert issubclass(Test.attribute, VCRHTTPSConnection)
assert VCRHTTPSConnection is not Test.attribute assert VCRHTTPSConnection is not Test.attribute
assert Test.attribute is Test.attribute2 assert Test.attribute is Test.attribute2
@@ -128,3 +136,57 @@ def test_with_current_defaults():
vcr.record_mode = 'all' vcr.record_mode = 'all'
changing_defaults(assert_record_mode_all) changing_defaults(assert_record_mode_all)
current_defaults(assert_record_mode_once) current_defaults(assert_record_mode_once)
def test_cassette_library_dir_with_decoration_and_no_explicit_path():
library_dir = '/libary_dir'
vcr = VCR(inject_cassette=True, cassette_library_dir=library_dir)
@vcr.use_cassette()
def function_name(cassette):
assert cassette._path == os.path.join(library_dir, 'function_name')
function_name()
def test_cassette_library_dir_with_path_transformer():
library_dir = '/libary_dir'
vcr = VCR(inject_cassette=True, cassette_library_dir=library_dir,
path_transformer=lambda path: path + '.json')
@vcr.use_cassette()
def function_name(cassette):
assert cassette._path == os.path.join(library_dir, 'function_name.json')
function_name()
def test_use_cassette_with_no_extra_invocation():
vcr = VCR(inject_cassette=True, cassette_library_dir='/')
@vcr.use_cassette
def function_name(cassette):
assert cassette._path == os.path.join('/', 'function_name')
function_name()
def test_path_transformer():
vcr = VCR(inject_cassette=True, cassette_library_dir='/',
path_transformer=lambda x: x + '_test')
@vcr.use_cassette
def function_name(cassette):
assert cassette._path == os.path.join('/', 'function_name_test')
function_name()
def test_cassette_name_generator_defaults_to_using_module_function_defined_in():
vcr = VCR(inject_cassette=True)
@vcr.use_cassette
def function_name(cassette):
assert cassette._path == os.path.join(os.path.dirname(__file__),
'function_name')
function_name()
def test_ensure_suffix():
vcr = VCR(inject_cassette=True, path_transformer=VCR.ensure_suffix('.yaml'))
@vcr.use_cassette
def function_name(cassette):
assert cassette._path == os.path.join(os.path.dirname(__file__),
'function_name.yaml')
function_name()

View File

@@ -1,4 +1,5 @@
"""The container for recorded requests and responses""" """The container for recorded requests and responses"""
import functools
import logging import logging
import contextlib2 import contextlib2
@@ -9,11 +10,12 @@ except ImportError:
from backport_collections import Counter from backport_collections import Counter
# Internal imports # Internal imports
from .errors import UnhandledHTTPRequestError
from .matchers import requests_match, uri, method
from .patch import CassettePatcherBuilder from .patch import CassettePatcherBuilder
from .persist import load_cassette, save_cassette from .persist import load_cassette, save_cassette
from .serializers import yamlserializer from .serializers import yamlserializer
from .matchers import requests_match, uri, method from .util import partition_dict
from .errors import UnhandledHTTPRequestError
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
@@ -29,9 +31,11 @@ class CassetteContextDecorator(object):
from interfering with another. from interfering with another.
""" """
_non_cassette_arguments = ('path_transformer', 'func_path_generator')
@classmethod @classmethod
def from_args(cls, cassette_class, path, **kwargs): def from_args(cls, cassette_class, **kwargs):
return cls(cassette_class, lambda: (path, kwargs)) return cls(cassette_class, lambda: dict(kwargs))
def __init__(self, cls, args_getter): def __init__(self, cls, args_getter):
self.cls = cls self.cls = cls
@@ -49,6 +53,14 @@ class CassetteContextDecorator(object):
# somewhere else. # somewhere else.
cassette._save() cassette._save()
@classmethod
def key_predicate(cls, key, value):
return key in cls._non_cassette_arguments
@classmethod
def _split_keys(cls, kwargs):
return partition_dict(cls.key_predicate, kwargs)
def __enter__(self): def __enter__(self):
# This assertion is here to prevent the dangerous behavior # This assertion is here to prevent the dangerous behavior
# that would result from forgetting about a __finish before # that would result from forgetting about a __finish before
@@ -59,8 +71,11 @@ class CassetteContextDecorator(object):
# with context_decorator: # with context_decorator:
# pass # pass
assert self.__finish is None, "Cassette already open." assert self.__finish is None, "Cassette already open."
path, kwargs = self._args_getter() other_kwargs, cassette_kwargs = self._split_keys(self._args_getter())
self.__finish = self._patch_generator(self.cls.load(path, **kwargs)) if 'path_transformer' in other_kwargs:
transformer = other_kwargs['path_transformer']
cassette_kwargs['path'] = transformer(cassette_kwargs['path'])
self.__finish = self._patch_generator(self.cls.load(**cassette_kwargs))
return next(self.__finish) return next(self.__finish)
def __exit__(self, *args): def __exit__(self, *args):
@@ -70,23 +85,42 @@ class CassetteContextDecorator(object):
@wrapt.decorator @wrapt.decorator
def __call__(self, function, instance, args, kwargs): def __call__(self, function, instance, args, kwargs):
# This awkward cloning thing is done to ensure that decorated # This awkward cloning thing is done to ensure that decorated
# functions are reentrant. Reentrancy is required for thread # functions are reentrant. This is required for thread
# safety and the correct operation of recursive functions. # safety and the correct operation of recursive functions.
clone = type(self)(self.cls, self._args_getter) args_getter = self._build_args_getter_for_decorator(
function, self._args_getter
)
clone = type(self)(self.cls, args_getter)
with clone as cassette: with clone as cassette:
if cassette.inject: if cassette.inject:
return function(cassette, *args, **kwargs) return function(cassette, *args, **kwargs)
else: else:
return function(*args, **kwargs) return function(*args, **kwargs)
@staticmethod
def get_function_name(function):
return function.__name__
@classmethod
def _build_args_getter_for_decorator(cls, function, args_getter):
def new_args_getter():
kwargs = args_getter()
if 'path' not in kwargs:
name_generator = (kwargs.get('func_path_generator') or
cls.get_function_name)
path = name_generator(function)
kwargs['path'] = path
return kwargs
return new_args_getter
class Cassette(object): class Cassette(object):
"""A container for recorded requests and responses""" """A container for recorded requests and responses"""
@classmethod @classmethod
def load(cls, path, **kwargs): def load(cls, **kwargs):
"""Instantiate and load the cassette stored at the specified path.""" """Instantiate and load the cassette stored at the specified path."""
new_cassette = cls(path, **kwargs) new_cassette = cls(**kwargs)
new_cassette._load() new_cassette._load()
return new_cassette return new_cassette
@@ -95,8 +129,8 @@ class Cassette(object):
return CassetteContextDecorator(cls, arg_getter) return CassetteContextDecorator(cls, arg_getter)
@classmethod @classmethod
def use(cls, *args, **kwargs): def use(cls, **kwargs):
return CassetteContextDecorator.from_args(cls, *args, **kwargs) return CassetteContextDecorator.from_args(cls, **kwargs)
def __init__(self, path, serializer=yamlserializer, record_mode='once', def __init__(self, path, serializer=yamlserializer, record_mode='once',
match_on=(uri, method), before_record_request=None, match_on=(uri, method), before_record_request=None,

View File

@@ -1,23 +1,35 @@
import collections import collections
import copy import copy
import functools import functools
import inspect
import os import os
import six
from .cassette import Cassette from .cassette import Cassette
from .serializers import yamlserializer, jsonserializer from .serializers import yamlserializer, jsonserializer
from .util import compose
from . import matchers from . import matchers
from . import filters from . import filters
class VCR(object): class VCR(object):
def __init__(self, serializer='yaml', cassette_library_dir=None, @staticmethod
record_mode="once", filter_headers=(), ignore_localhost=False, def ensure_suffix(suffix):
custom_patches=(), filter_query_parameters=(), def ensure(path):
filter_post_data_parameters=(), before_record_request=None, if not path.endswith(suffix):
before_record_response=None, ignore_hosts=(), return path + suffix
return path
return ensure
def __init__(self, path_transformer=lambda x: x, before_record_request=None,
custom_patches=(), filter_query_parameters=(), ignore_hosts=(),
record_mode="once", ignore_localhost=False, filter_headers=(),
before_record_response=None, filter_post_data_parameters=(),
match_on=('method', 'scheme', 'host', 'port', 'path', 'query'), match_on=('method', 'scheme', 'host', 'port', 'path', 'query'),
before_record=None, inject_cassette=False): before_record=None, inject_cassette=False, serializer='yaml',
cassette_library_dir=None, func_path_generator=None):
self.serializer = serializer self.serializer = serializer
self.match_on = match_on self.match_on = match_on
self.cassette_library_dir = cassette_library_dir self.cassette_library_dir = cassette_library_dir
@@ -46,6 +58,8 @@ class VCR(object):
self.ignore_hosts = ignore_hosts self.ignore_hosts = ignore_hosts
self.ignore_localhost = ignore_localhost self.ignore_localhost = ignore_localhost
self.inject_cassette = inject_cassette self.inject_cassette = inject_cassette
self.path_transformer = path_transformer
self.func_path_generator = func_path_generator
self._custom_patches = tuple(custom_patches) self._custom_patches = tuple(custom_patches)
def _get_serializer(self, serializer_name): def _get_serializer(self, serializer_name):
@@ -69,27 +83,48 @@ class VCR(object):
) )
return matchers return matchers
def use_cassette(self, path, with_current_defaults=False, **kwargs): def use_cassette(self, path=None, **kwargs):
if path is not None and not isinstance(path, six.string_types):
function = path
# Assume this is an attempt to decorate a function
return self._use_cassette(**kwargs)(function)
return self._use_cassette(path=path, **kwargs)
def _use_cassette(self, with_current_defaults=False, **kwargs):
if with_current_defaults: if with_current_defaults:
path, config = self.get_path_and_merged_config(path, **kwargs) config = self.get_merged_config(**kwargs)
return Cassette.use(path, **config) return Cassette.use(**config)
# This is made a function that evaluates every time a cassette # This is made a function that evaluates every time a cassette
# is made so that changes that are made to this VCR instance # is made so that changes that are made to this VCR instance
# that occur AFTER the `use_cassette` decorator is applied # that occur AFTER the `use_cassette` decorator is applied
# still affect subsequent calls to the decorated function. # still affect subsequent calls to the decorated function.
args_getter = functools.partial(self.get_path_and_merged_config, args_getter = functools.partial(self.get_merged_config, **kwargs)
path, **kwargs)
return Cassette.use_arg_getter(args_getter) return Cassette.use_arg_getter(args_getter)
def get_path_and_merged_config(self, path, **kwargs): def get_merged_config(self, **kwargs):
serializer_name = kwargs.get('serializer', self.serializer) serializer_name = kwargs.get('serializer', self.serializer)
matcher_names = kwargs.get('match_on', self.match_on) matcher_names = kwargs.get('match_on', self.match_on)
path_transformer = kwargs.get(
'path_transformer',
self.path_transformer
)
func_path_generator = kwargs.get(
'func_path_generator',
self.func_path_generator
)
cassette_library_dir = kwargs.get( cassette_library_dir = kwargs.get(
'cassette_library_dir', 'cassette_library_dir',
self.cassette_library_dir self.cassette_library_dir
) )
if cassette_library_dir: if cassette_library_dir:
path = os.path.join(cassette_library_dir, path) def add_cassette_library_dir(path):
if not path.startswith(cassette_library_dir):
return os.path.join(cassette_library_dir, path)
path_transformer = compose(add_cassette_library_dir, path_transformer)
elif not func_path_generator:
# If we don't have a library dir, use the functions
# location to build a full path for cassettes.
func_path_generator = self._build_path_from_func_using_module
merged_config = { merged_config = {
'serializer': self._get_serializer(serializer_name), 'serializer': self._get_serializer(serializer_name),
@@ -102,9 +137,14 @@ class VCR(object):
'custom_patches': self._custom_patches + kwargs.get( 'custom_patches': self._custom_patches + kwargs.get(
'custom_patches', () 'custom_patches', ()
), ),
'inject': kwargs.get('inject_cassette', self.inject_cassette) 'inject': kwargs.get('inject_cassette', self.inject_cassette),
'path_transformer': path_transformer,
'func_path_generator': func_path_generator
} }
return path, merged_config path = kwargs.get('path')
if path:
merged_config['path'] = path
return merged_config
def _build_before_record_response(self, options): def _build_before_record_response(self, options):
before_record_response = options.get( before_record_response = options.get(
@@ -185,6 +225,11 @@ class VCR(object):
return request return request
return filter_ignored_hosts return filter_ignored_hosts
@staticmethod
def _build_path_from_func_using_module(function):
return os.path.join(os.path.dirname(inspect.getfile(function)),
function.__name__)
def register_serializer(self, name, serializer): def register_serializer(self, name, serializer):
self.serializers[name] = serializer self.serializers[name] = serializer

16
vcr/util.py Normal file
View File

@@ -0,0 +1,16 @@
def partition_dict(predicate, dictionary):
true_dict = {}
false_dict = {}
for key, value in dictionary.items():
this_dict = true_dict if predicate(key, value) else false_dict
this_dict[key] = value
return true_dict, false_dict
def compose(*functions):
def composed(incoming):
res = incoming
for function in functions[::-1]:
res = function(res)
return res
return composed