diff --git a/README.md b/README.md index 06723be..2c8255e 100644 --- a/README.md +++ b/README.md @@ -389,6 +389,43 @@ my_vcr = config.VCR(custom_patches=((where_the_custom_https_connection_lives, 'C @my_vcr.use_cassette(...) ``` +## Automatic Cassette Naming + +VCR.py now allows the omission of the path argument to the +use_cassette function. Both of the following are now legal/should work + +``` python +@my_vcr.use_cassette +def my_test_function(): + ... +``` + +``` python +@my_vcr.use_cassette() +def my_test_function(): + ... +``` + +In both cases, VCR.py will use a path that is generated from the +provided test function's name. If no `cassette_library_dir` has been +set, the cassette will be in a file with the name of the test function +in directory of the file in which the test function is declared. If a +`cassette_library_dir` is set, has been set, the cassette will appear +in that directory in a file with the name of the decorated function. + +It is possible to control the path produced by the automatic naming +machinery by customizing the `path_transformer` and +`func_path_generator` vcr variables. To add an extension to all +cassette names, use `VCR.ensure_suffix` as follows: + +``` python +my_vcr = VCR(path_transformer=VCR.ensure_suffix('.yaml')) + +@my_vcr.use_cassette +def my_test_function(): + +``` + ## Installation VCR.py is a package on PyPI, so you can `pip install vcrpy` (first you may need diff --git a/tests/unit/test_cassettes.py b/tests/unit/test_cassettes.py index 9b573bf..947ffa4 100644 --- a/tests/unit/test_cassettes.py +++ b/tests/unit/test_cassettes.py @@ -1,4 +1,6 @@ import copy +import inspect +import os from six.moves import http_client as httplib import contextlib2 @@ -12,14 +14,13 @@ from vcr.patch import force_reset from vcr.stubs import VCRHTTPSConnection - def test_cassette_load(tmpdir): a_file = tmpdir.join('test_cassette.yml') a_file.write(yaml.dump({'interactions': [ {'request': {'body': '', 'uri': 'foo', 'method': 'GET', 'headers': {}}, 'response': 'bar'} ]})) - a_cassette = Cassette.load(str(a_file)) + a_cassette = Cassette.load(path=str(a_file)) assert len(a_cassette) == 1 @@ -87,33 +88,35 @@ def make_get_request(): @mock.patch('vcr.cassette.Cassette.can_play_response_for', return_value=True) @mock.patch('vcr.stubs.VCRHTTPResponse') def test_function_decorated_with_use_cassette_can_be_invoked_multiple_times(*args): - decorated_function = Cassette.use('test')(make_get_request) - for i in range(2): + decorated_function = Cassette.use(path='test')(make_get_request) + for i in range(4): decorated_function() def test_arg_getter_functionality(): - arg_getter = mock.Mock(return_value=('test', {})) + arg_getter = mock.Mock(return_value={'path': 'test'}) context_decorator = Cassette.use_arg_getter(arg_getter) with context_decorator as cassette: assert cassette._path == 'test' - arg_getter.return_value = ('other', {}) + arg_getter.return_value = {'path': 'other'} with context_decorator as cassette: assert cassette._path == 'other' - arg_getter.return_value = ('', {'filter_headers': ('header_name',)}) + arg_getter.return_value = {'path': 'other', 'filter_headers': ('header_name',)} @context_decorator def function(): pass - with mock.patch.object(Cassette, 'load', return_value=mock.MagicMock(inject=False)) as cassette_load: + with mock.patch.object( + Cassette, 'load', + return_value=mock.MagicMock(inject=False) + ) as cassette_load: function() - cassette_load.assert_called_once_with(arg_getter.return_value[0], - **arg_getter.return_value[1]) + cassette_load.assert_called_once_with(**arg_getter.return_value) def test_cassette_not_all_played(): @@ -156,13 +159,13 @@ def test_nesting_cassette_context_managers(*args): second_response['body']['string'] = b'second_response' with contextlib2.ExitStack() as exit_stack: - first_cassette = exit_stack.enter_context(Cassette.use('test')) + first_cassette = exit_stack.enter_context(Cassette.use(path='test')) exit_stack.enter_context(mock.patch.object(first_cassette, 'play_response', return_value=first_response)) assert_get_response_body_is('first_response') # Make sure a second cassette can supercede the first - with Cassette.use('test') as second_cassette: + with Cassette.use(path='test') as second_cassette: with mock.patch.object(second_cassette, 'play_response', return_value=second_response): assert_get_response_body_is('second_response') @@ -172,12 +175,12 @@ def test_nesting_cassette_context_managers(*args): def test_nesting_context_managers_by_checking_references_of_http_connection(): original = httplib.HTTPConnection - with Cassette.use('test'): + with Cassette.use(path='test'): first_cassette_HTTPConnection = httplib.HTTPConnection - with Cassette.use('test'): + with Cassette.use(path='test'): second_cassette_HTTPConnection = httplib.HTTPConnection assert second_cassette_HTTPConnection is not first_cassette_HTTPConnection - with Cassette.use('test'): + with Cassette.use(path='test'): assert httplib.HTTPConnection is not second_cassette_HTTPConnection with force_reset(): assert httplib.HTTPConnection is original @@ -188,12 +191,14 @@ def test_nesting_context_managers_by_checking_references_of_http_connection(): def test_custom_patchers(): class Test(object): attribute = None - with Cassette.use('custom_patches', custom_patches=((Test, 'attribute', VCRHTTPSConnection),)): + with Cassette.use(path='custom_patches', + custom_patches=((Test, 'attribute', VCRHTTPSConnection),)): assert issubclass(Test.attribute, VCRHTTPSConnection) assert VCRHTTPSConnection is not Test.attribute old_attribute = Test.attribute - with Cassette.use('custom_patches', custom_patches=((Test, 'attribute', VCRHTTPSConnection),)): + with Cassette.use(path='custom_patches', + custom_patches=((Test, 'attribute', VCRHTTPSConnection),)): assert issubclass(Test.attribute, VCRHTTPSConnection) assert VCRHTTPSConnection is not Test.attribute assert Test.attribute is not old_attribute @@ -203,10 +208,10 @@ def test_custom_patchers(): assert Test.attribute is old_attribute -def test_use_cassette_decorated_functions_are_reentrant(): +def test_decorated_functions_are_reentrant(): info = {"second": False} original_conn = httplib.HTTPConnection - @Cassette.use('whatever', inject=True) + @Cassette.use(path='whatever', inject=True) def test_function(cassette): if info['second']: assert httplib.HTTPConnection is not info['first_conn'] @@ -217,3 +222,35 @@ def test_use_cassette_decorated_functions_are_reentrant(): assert httplib.HTTPConnection is info['first_conn'] test_function() assert httplib.HTTPConnection is original_conn + + +def test_cassette_use_called_without_path_uses_function_to_generate_path(): + @Cassette.use(inject=True) + def function_name(cassette): + assert cassette._path == 'function_name' + function_name() + + +def test_path_transformer_with_function_path(): + path_transformer = lambda path: os.path.join('a', path) + @Cassette.use(inject=True, path_transformer=path_transformer) + def function_name(cassette): + assert cassette._path == os.path.join('a', 'function_name') + function_name() + + +def test_path_transformer_with_context_manager(): + with Cassette.use( + path='b', path_transformer=lambda *args: 'a' + ) as cassette: + assert cassette._path == 'a' + + +def test_func_path_generator(): + def generator(function): + return os.path.join(os.path.dirname(inspect.getfile(function)), + function.__name__) + @Cassette.use(inject=True, func_path_generator=generator) + def function_name(cassette): + assert cassette._path == os.path.join(os.path.dirname(__file__), 'function_name') + function_name() diff --git a/tests/unit/test_vcr.py b/tests/unit/test_vcr.py index 94e7a9e..d6a88fc 100644 --- a/tests/unit/test_vcr.py +++ b/tests/unit/test_vcr.py @@ -1,3 +1,5 @@ +import os + import mock import pytest @@ -9,7 +11,10 @@ from vcr.stubs import VCRHTTPSConnection def test_vcr_use_cassette(): record_mode = mock.Mock() test_vcr = VCR(record_mode=record_mode) - with mock.patch('vcr.cassette.Cassette.load', return_value=mock.MagicMock(inject=False)) as mock_cassette_load: + with mock.patch( + 'vcr.cassette.Cassette.load', + return_value=mock.MagicMock(inject=False) + ) as mock_cassette_load: @test_vcr.use_cassette('test') def function(): pass @@ -87,7 +92,10 @@ def test_custom_patchers(): assert issubclass(Test.attribute, VCRHTTPSConnection) assert VCRHTTPSConnection is not Test.attribute - with test_vcr.use_cassette('custom_patches', custom_patches=((Test, 'attribute2', VCRHTTPSConnection),)): + with test_vcr.use_cassette( + 'custom_patches', + custom_patches=((Test, 'attribute2', VCRHTTPSConnection),) + ): assert issubclass(Test.attribute, VCRHTTPSConnection) assert VCRHTTPSConnection is not Test.attribute assert Test.attribute is Test.attribute2 @@ -128,3 +136,57 @@ def test_with_current_defaults(): vcr.record_mode = 'all' changing_defaults(assert_record_mode_all) current_defaults(assert_record_mode_once) + + +def test_cassette_library_dir_with_decoration_and_no_explicit_path(): + library_dir = '/libary_dir' + vcr = VCR(inject_cassette=True, cassette_library_dir=library_dir) + @vcr.use_cassette() + def function_name(cassette): + assert cassette._path == os.path.join(library_dir, 'function_name') + function_name() + + +def test_cassette_library_dir_with_path_transformer(): + library_dir = '/libary_dir' + vcr = VCR(inject_cassette=True, cassette_library_dir=library_dir, + path_transformer=lambda path: path + '.json') + @vcr.use_cassette() + def function_name(cassette): + assert cassette._path == os.path.join(library_dir, 'function_name.json') + function_name() + + +def test_use_cassette_with_no_extra_invocation(): + vcr = VCR(inject_cassette=True, cassette_library_dir='/') + @vcr.use_cassette + def function_name(cassette): + assert cassette._path == os.path.join('/', 'function_name') + function_name() + + +def test_path_transformer(): + vcr = VCR(inject_cassette=True, cassette_library_dir='/', + path_transformer=lambda x: x + '_test') + @vcr.use_cassette + def function_name(cassette): + assert cassette._path == os.path.join('/', 'function_name_test') + function_name() + + +def test_cassette_name_generator_defaults_to_using_module_function_defined_in(): + vcr = VCR(inject_cassette=True) + @vcr.use_cassette + def function_name(cassette): + assert cassette._path == os.path.join(os.path.dirname(__file__), + 'function_name') + function_name() + + +def test_ensure_suffix(): + vcr = VCR(inject_cassette=True, path_transformer=VCR.ensure_suffix('.yaml')) + @vcr.use_cassette + def function_name(cassette): + assert cassette._path == os.path.join(os.path.dirname(__file__), + 'function_name.yaml') + function_name() diff --git a/vcr/cassette.py b/vcr/cassette.py index c01a763..77b5395 100644 --- a/vcr/cassette.py +++ b/vcr/cassette.py @@ -1,4 +1,5 @@ """The container for recorded requests and responses""" +import functools import logging import contextlib2 @@ -9,11 +10,12 @@ except ImportError: from backport_collections import Counter # Internal imports +from .errors import UnhandledHTTPRequestError +from .matchers import requests_match, uri, method from .patch import CassettePatcherBuilder from .persist import load_cassette, save_cassette from .serializers import yamlserializer -from .matchers import requests_match, uri, method -from .errors import UnhandledHTTPRequestError +from .util import partition_dict log = logging.getLogger(__name__) @@ -29,9 +31,11 @@ class CassetteContextDecorator(object): from interfering with another. """ + _non_cassette_arguments = ('path_transformer', 'func_path_generator') + @classmethod - def from_args(cls, cassette_class, path, **kwargs): - return cls(cassette_class, lambda: (path, kwargs)) + def from_args(cls, cassette_class, **kwargs): + return cls(cassette_class, lambda: dict(kwargs)) def __init__(self, cls, args_getter): self.cls = cls @@ -49,6 +53,14 @@ class CassetteContextDecorator(object): # somewhere else. cassette._save() + @classmethod + def key_predicate(cls, key, value): + return key in cls._non_cassette_arguments + + @classmethod + def _split_keys(cls, kwargs): + return partition_dict(cls.key_predicate, kwargs) + def __enter__(self): # This assertion is here to prevent the dangerous behavior # that would result from forgetting about a __finish before @@ -59,8 +71,11 @@ class CassetteContextDecorator(object): # with context_decorator: # pass assert self.__finish is None, "Cassette already open." - path, kwargs = self._args_getter() - self.__finish = self._patch_generator(self.cls.load(path, **kwargs)) + other_kwargs, cassette_kwargs = self._split_keys(self._args_getter()) + if 'path_transformer' in other_kwargs: + transformer = other_kwargs['path_transformer'] + cassette_kwargs['path'] = transformer(cassette_kwargs['path']) + self.__finish = self._patch_generator(self.cls.load(**cassette_kwargs)) return next(self.__finish) def __exit__(self, *args): @@ -70,23 +85,42 @@ class CassetteContextDecorator(object): @wrapt.decorator def __call__(self, function, instance, args, kwargs): # This awkward cloning thing is done to ensure that decorated - # functions are reentrant. Reentrancy is required for thread + # functions are reentrant. This is required for thread # safety and the correct operation of recursive functions. - clone = type(self)(self.cls, self._args_getter) + args_getter = self._build_args_getter_for_decorator( + function, self._args_getter + ) + clone = type(self)(self.cls, args_getter) with clone as cassette: if cassette.inject: return function(cassette, *args, **kwargs) else: return function(*args, **kwargs) + @staticmethod + def get_function_name(function): + return function.__name__ + + @classmethod + def _build_args_getter_for_decorator(cls, function, args_getter): + def new_args_getter(): + kwargs = args_getter() + if 'path' not in kwargs: + name_generator = (kwargs.get('func_path_generator') or + cls.get_function_name) + path = name_generator(function) + kwargs['path'] = path + return kwargs + return new_args_getter + class Cassette(object): """A container for recorded requests and responses""" @classmethod - def load(cls, path, **kwargs): + def load(cls, **kwargs): """Instantiate and load the cassette stored at the specified path.""" - new_cassette = cls(path, **kwargs) + new_cassette = cls(**kwargs) new_cassette._load() return new_cassette @@ -95,8 +129,8 @@ class Cassette(object): return CassetteContextDecorator(cls, arg_getter) @classmethod - def use(cls, *args, **kwargs): - return CassetteContextDecorator.from_args(cls, *args, **kwargs) + def use(cls, **kwargs): + return CassetteContextDecorator.from_args(cls, **kwargs) def __init__(self, path, serializer=yamlserializer, record_mode='once', match_on=(uri, method), before_record_request=None, diff --git a/vcr/config.py b/vcr/config.py index 6a66356..6a5de5b 100644 --- a/vcr/config.py +++ b/vcr/config.py @@ -1,23 +1,35 @@ import collections import copy import functools +import inspect import os +import six + from .cassette import Cassette from .serializers import yamlserializer, jsonserializer +from .util import compose from . import matchers from . import filters class VCR(object): - def __init__(self, serializer='yaml', cassette_library_dir=None, - record_mode="once", filter_headers=(), ignore_localhost=False, - custom_patches=(), filter_query_parameters=(), - filter_post_data_parameters=(), before_record_request=None, - before_record_response=None, ignore_hosts=(), + @staticmethod + def ensure_suffix(suffix): + def ensure(path): + if not path.endswith(suffix): + return path + suffix + return path + return ensure + + def __init__(self, path_transformer=lambda x: x, before_record_request=None, + custom_patches=(), filter_query_parameters=(), ignore_hosts=(), + record_mode="once", ignore_localhost=False, filter_headers=(), + before_record_response=None, filter_post_data_parameters=(), match_on=('method', 'scheme', 'host', 'port', 'path', 'query'), - before_record=None, inject_cassette=False): + before_record=None, inject_cassette=False, serializer='yaml', + cassette_library_dir=None, func_path_generator=None): self.serializer = serializer self.match_on = match_on self.cassette_library_dir = cassette_library_dir @@ -46,6 +58,8 @@ class VCR(object): self.ignore_hosts = ignore_hosts self.ignore_localhost = ignore_localhost self.inject_cassette = inject_cassette + self.path_transformer = path_transformer + self.func_path_generator = func_path_generator self._custom_patches = tuple(custom_patches) def _get_serializer(self, serializer_name): @@ -69,27 +83,48 @@ class VCR(object): ) return matchers - def use_cassette(self, path, with_current_defaults=False, **kwargs): + def use_cassette(self, path=None, **kwargs): + if path is not None and not isinstance(path, six.string_types): + function = path + # Assume this is an attempt to decorate a function + return self._use_cassette(**kwargs)(function) + return self._use_cassette(path=path, **kwargs) + + def _use_cassette(self, with_current_defaults=False, **kwargs): if with_current_defaults: - path, config = self.get_path_and_merged_config(path, **kwargs) - return Cassette.use(path, **config) + config = self.get_merged_config(**kwargs) + return Cassette.use(**config) # This is made a function that evaluates every time a cassette # is made so that changes that are made to this VCR instance # that occur AFTER the `use_cassette` decorator is applied # still affect subsequent calls to the decorated function. - args_getter = functools.partial(self.get_path_and_merged_config, - path, **kwargs) + args_getter = functools.partial(self.get_merged_config, **kwargs) return Cassette.use_arg_getter(args_getter) - def get_path_and_merged_config(self, path, **kwargs): + def get_merged_config(self, **kwargs): serializer_name = kwargs.get('serializer', self.serializer) matcher_names = kwargs.get('match_on', self.match_on) + path_transformer = kwargs.get( + 'path_transformer', + self.path_transformer + ) + func_path_generator = kwargs.get( + 'func_path_generator', + self.func_path_generator + ) cassette_library_dir = kwargs.get( 'cassette_library_dir', self.cassette_library_dir ) if cassette_library_dir: - path = os.path.join(cassette_library_dir, path) + def add_cassette_library_dir(path): + if not path.startswith(cassette_library_dir): + return os.path.join(cassette_library_dir, path) + path_transformer = compose(add_cassette_library_dir, path_transformer) + elif not func_path_generator: + # If we don't have a library dir, use the functions + # location to build a full path for cassettes. + func_path_generator = self._build_path_from_func_using_module merged_config = { 'serializer': self._get_serializer(serializer_name), @@ -102,9 +137,14 @@ class VCR(object): 'custom_patches': self._custom_patches + kwargs.get( 'custom_patches', () ), - 'inject': kwargs.get('inject_cassette', self.inject_cassette) + 'inject': kwargs.get('inject_cassette', self.inject_cassette), + 'path_transformer': path_transformer, + 'func_path_generator': func_path_generator } - return path, merged_config + path = kwargs.get('path') + if path: + merged_config['path'] = path + return merged_config def _build_before_record_response(self, options): before_record_response = options.get( @@ -185,6 +225,11 @@ class VCR(object): return request return filter_ignored_hosts + @staticmethod + def _build_path_from_func_using_module(function): + return os.path.join(os.path.dirname(inspect.getfile(function)), + function.__name__) + def register_serializer(self, name, serializer): self.serializers[name] = serializer diff --git a/vcr/util.py b/vcr/util.py new file mode 100644 index 0000000..57f72b1 --- /dev/null +++ b/vcr/util.py @@ -0,0 +1,16 @@ +def partition_dict(predicate, dictionary): + true_dict = {} + false_dict = {} + for key, value in dictionary.items(): + this_dict = true_dict if predicate(key, value) else false_dict + this_dict[key] = value + return true_dict, false_dict + + +def compose(*functions): + def composed(incoming): + res = incoming + for function in functions[::-1]: + res = function(res) + return res + return composed