diff --git a/tests/integration/test_tornado.py b/tests/integration/test_tornado.py index 1230077..7116256 100644 --- a/tests/integration/test_tornado.py +++ b/tests/integration/test_tornado.py @@ -275,3 +275,12 @@ def test_cannot_overwrite_cassette_raise_error_disabled(get_client, tmpdir): ) assert isinstance(response.error, CannotOverwriteExistingCassetteException) + + +@pytest.mark.gen_test +@vcr.use_cassette +def test_tornado_with_decorator_use_cassette(get_client): + response = yield get_client().fetch( + http.HTTPRequest('http://www.google.com/', method='GET') + ) + assert response.body == "not actually google" diff --git a/tests/integration/test_tornado_with_decorator_use_cassette.yaml b/tests/integration/test_tornado_with_decorator_use_cassette.yaml new file mode 100644 index 0000000..ae05aca --- /dev/null +++ b/tests/integration/test_tornado_with_decorator_use_cassette.yaml @@ -0,0 +1,53 @@ +interactions: +- request: + body: null + headers: {} + method: GET + uri: http://www.google.com/ + response: + body: {string: !!python/unicode 'not actually google'} + headers: + - !!python/tuple + - Expires + - ['-1'] + - !!python/tuple + - Connection + - [close] + - !!python/tuple + - P3p + - ['CP="This is not a P3P policy! See http://www.google.com/support/accounts/bin/answer.py?hl=en&answer=151657 + for more info."'] + - !!python/tuple + - Alternate-Protocol + - ['80:quic,p=0'] + - !!python/tuple + - Accept-Ranges + - [none] + - !!python/tuple + - X-Xss-Protection + - [1; mode=block] + - !!python/tuple + - Vary + - [Accept-Encoding] + - !!python/tuple + - Date + - ['Thu, 30 Jul 2015 08:41:40 GMT'] + - !!python/tuple + - Cache-Control + - ['private, max-age=0'] + - !!python/tuple + - Content-Type + - [text/html; charset=ISO-8859-1] + - !!python/tuple + - Set-Cookie + - ['PREF=ID=1111111111111111:FF=0:TM=1438245700:LM=1438245700:V=1:S=GAzVO0ALebSpC_cJ; + expires=Sat, 29-Jul-2017 08:41:40 GMT; path=/; domain=.google.com', 'NID=69=Br7oRAwgmKoK__HC6FEnuxglTFDmFxqP6Md63lKhzW1w6WkDbp3U90CDxnUKvDP6wJH8yxY5Lk5ZnFf66Q1B0d4OsYoKgq0vjfBAYXuCIAWtOuGZEOsFXanXs7pt2Mjx; + expires=Fri, 29-Jan-2016 08:41:40 GMT; path=/; domain=.google.com; HttpOnly'] + - !!python/tuple + - X-Frame-Options + - [SAMEORIGIN] + - !!python/tuple + - Server + - [gws] + status: {code: 200, message: OK} +version: 1 diff --git a/tests/unit/test_cassettes.py b/tests/unit/test_cassettes.py index e7d8b7c..00916fb 100644 --- a/tests/unit/test_cassettes.py +++ b/tests/unit/test_cassettes.py @@ -253,3 +253,37 @@ def test_func_path_generator(): def function_name(cassette): assert cassette._path == os.path.join(os.path.dirname(__file__), 'function_name') function_name() + + +def test_use_as_decorator_on_coroutine(): + original_http_connetion = httplib.HTTPConnection + @Cassette.use(inject=True) + def test_function(cassette): + assert httplib.HTTPConnection.cassette is cassette + assert httplib.HTTPConnection is not original_http_connetion + value = yield 1 + assert value == 1 + assert httplib.HTTPConnection.cassette is cassette + assert httplib.HTTPConnection is not original_http_connetion + value = yield 2 + assert value == 2 + coroutine = test_function() + value = coroutine.next() + while True: + try: + value = coroutine.send(value) + except StopIteration: + break + + +def test_use_as_decorator_on_generator(): + original_http_connetion = httplib.HTTPConnection + @Cassette.use(inject=True) + def test_function(cassette): + assert httplib.HTTPConnection.cassette is cassette + assert httplib.HTTPConnection is not original_http_connetion + yield 1 + assert httplib.HTTPConnection.cassette is cassette + assert httplib.HTTPConnection is not original_http_connetion + yield 2 + assert list(test_function()) == [1, 2] diff --git a/vcr/cassette.py b/vcr/cassette.py index 87d2598..1a290cd 100644 --- a/vcr/cassette.py +++ b/vcr/cassette.py @@ -1,11 +1,8 @@ -"""The container for recorded requests and responses""" -import functools +import inspect import logging - import wrapt -# Internal imports from .compat import contextlib, collections from .errors import UnhandledHTTPRequestError from .matchers import requests_match, uri, method @@ -50,14 +47,6 @@ class CassetteContextDecorator(object): # somewhere else. cassette._save() - @classmethod - def key_predicate(cls, key, value): - return key in cls._non_cassette_arguments - - @classmethod - def _split_keys(cls, kwargs): - return partition_dict(cls.key_predicate, kwargs) - def __enter__(self): # This assertion is here to prevent the dangerous behavior # that would result from forgetting about a __finish before @@ -68,7 +57,10 @@ class CassetteContextDecorator(object): # with context_decorator: # pass assert self.__finish is None, "Cassette already open." - other_kwargs, cassette_kwargs = self._split_keys(self._args_getter()) + other_kwargs, cassette_kwargs = partition_dict( + lambda key, _: key in self._non_cassette_arguments, + self._args_getter() + ) if 'path_transformer' in other_kwargs: transformer = other_kwargs['path_transformer'] cassette_kwargs['path'] = transformer(cassette_kwargs['path']) @@ -84,27 +76,48 @@ class CassetteContextDecorator(object): # This awkward cloning thing is done to ensure that decorated # functions are reentrant. This is required for thread # safety and the correct operation of recursive functions. - args_getter = self._build_args_getter_for_decorator( - function, self._args_getter - ) - clone = type(self)(self.cls, args_getter) - with clone as cassette: - if cassette.inject: - return function(cassette, *args, **kwargs) - else: - return function(*args, **kwargs) + args_getter = self._build_args_getter_for_decorator(function) + return type(self)(self.cls, args_getter)._execute_function(function, args, kwargs) + + def _execute_function(self, function, args, kwargs): + if inspect.isgeneratorfunction(function): + handler = self._handle_coroutine + else: + handler = self._handle_function + return handler(function, args, kwargs) + + def _handle_coroutine(self, function, args, kwargs): + with self as cassette: + coroutine = self.__handle_function(cassette, function, args, kwargs) + to_send = None + while True: + try: + to_yield = coroutine.send(to_send) + except StopIteration: + break + else: + to_send = yield to_yield + + def __handle_function(self, cassette, function, args, kwargs): + if cassette.inject: + return function(cassette, *args, **kwargs) + else: + return function(*args, **kwargs) + + def _handle_function(self, function, args, kwargs): + with self as cassette: + self.__handle_function(cassette, function, args, kwargs) @staticmethod def get_function_name(function): return function.__name__ - @classmethod - def _build_args_getter_for_decorator(cls, function, args_getter): + def _build_args_getter_for_decorator(self, function): def new_args_getter(): - kwargs = args_getter() + kwargs = self._args_getter() if 'path' not in kwargs: name_generator = (kwargs.get('func_path_generator') or - cls.get_function_name) + self.get_function_name) path = name_generator(function) kwargs['path'] = path return kwargs diff --git a/vcr/config.py b/vcr/config.py index 7655a3a..09fbdd6 100644 --- a/vcr/config.py +++ b/vcr/config.py @@ -23,7 +23,7 @@ class VCR(object): return path return ensure - def __init__(self, path_transformer=lambda x: x, before_record_request=None, + def __init__(self, path_transformer=None, before_record_request=None, custom_patches=(), filter_query_parameters=(), ignore_hosts=(), record_mode="once", ignore_localhost=False, filter_headers=(), before_record_response=None, filter_post_data_parameters=(), @@ -59,7 +59,7 @@ class VCR(object): self.ignore_hosts = ignore_hosts self.ignore_localhost = ignore_localhost self.inject_cassette = inject_cassette - self.path_transformer = path_transformer + self.path_transformer = path_transformer or self.ensure_suffix('.yaml') self.func_path_generator = func_path_generator self._custom_patches = tuple(custom_patches) @@ -107,7 +107,7 @@ class VCR(object): matcher_names = kwargs.get('match_on', self.match_on) path_transformer = kwargs.get( 'path_transformer', - self.path_transformer + self.path_transformer or self.ensure_suffix('.yaml') ) func_path_generator = kwargs.get( 'func_path_generator',