From b9bdc6401dedd471eb7908afc8edadf849f6aed4 Mon Sep 17 00:00:00 2001 From: Ivan Malison Date: Wed, 1 Apr 2015 17:06:02 -0700 Subject: [PATCH] inject_cassette kwarg. --- tests/unit/test_vcr.py | 23 +++++++++++++++++++---- vcr/cassette.py | 11 ++++++++--- vcr/config.py | 15 +++++++++------ 3 files changed, 36 insertions(+), 13 deletions(-) diff --git a/tests/unit/test_vcr.py b/tests/unit/test_vcr.py index 0c1c6c7..d85b079 100644 --- a/tests/unit/test_vcr.py +++ b/tests/unit/test_vcr.py @@ -70,10 +70,11 @@ def test_fixtures_with_use_cassette(random_fixture): # problems if the decorator does not preserve the signature of the original # test function. - # This test ensures that use_cassette preserves the signature of the original - # test function, and thus that use_cassette is compatible with py.test - # fixtures. It is admittedly a bit strange because the test would never even - # run if the relevant feature were broken. + # This test ensures that use_cassette preserves the signature of + # the original test function, and thus that use_cassette is + # compatible with py.test fixtures. It is admittedly a bit strange + # because the test would never even run if the relevant feature + # were broken. pass @@ -90,3 +91,17 @@ def test_custom_patchers(): assert issubclass(Test.attribute, VCRHTTPSConnection) assert VCRHTTPSConnection is not Test.attribute assert Test.attribute is Test.attribute2 + + +def test_inject_cassette(): + vcr = VCR(inject_cassette=True) + @vcr.use_cassette('test', record_mode='once') + def with_cassette_injected(cassette): + assert cassette.record_mode == 'once' + + @vcr.use_cassette('test', record_mode='once', inject_cassette=False) + def without_cassette_injected(): + pass + + with_cassette_injected() + without_cassette_injected() diff --git a/vcr/cassette.py b/vcr/cassette.py index 5851280..1194fcb 100644 --- a/vcr/cassette.py +++ b/vcr/cassette.py @@ -61,8 +61,11 @@ class CassetteContextDecorator(object): @wrapt.decorator def __call__(self, function, instance, args, kwargs): - with self: - return function(*args, **kwargs) + with self as cassette: + if cassette.inject: + return function(cassette, *args, **kwargs) + else: + return function(*args, **kwargs) class Cassette(object): @@ -85,13 +88,15 @@ class Cassette(object): def __init__(self, path, serializer=yamlserializer, record_mode='once', match_on=(uri, method), before_record_request=None, - before_record_response=None, custom_patches=()): + before_record_response=None, custom_patches=(), + inject=False): self._path = path self._serializer = serializer self._match_on = match_on self._before_record_request = before_record_request or (lambda x: x) self._before_record_response = before_record_response or (lambda x: x) + self.inject = inject self.record_mode = record_mode self.custom_patches = custom_patches diff --git a/vcr/config.py b/vcr/config.py index dcb59a5..6a66356 100644 --- a/vcr/config.py +++ b/vcr/config.py @@ -12,11 +12,12 @@ from . import filters class VCR(object): def __init__(self, serializer='yaml', cassette_library_dir=None, - record_mode="once", filter_headers=(), custom_patches=(), - filter_query_parameters=(), filter_post_data_parameters=(), - before_record_request=None, before_record_response=None, - ignore_hosts=(), ignore_localhost=False, before_record=None, - match_on=('method', 'scheme', 'host', 'port', 'path', 'query')): + record_mode="once", filter_headers=(), ignore_localhost=False, + custom_patches=(), filter_query_parameters=(), + filter_post_data_parameters=(), before_record_request=None, + before_record_response=None, ignore_hosts=(), + match_on=('method', 'scheme', 'host', 'port', 'path', 'query'), + before_record=None, inject_cassette=False): self.serializer = serializer self.match_on = match_on self.cassette_library_dir = cassette_library_dir @@ -44,6 +45,7 @@ class VCR(object): self.before_record_response = before_record_response self.ignore_hosts = ignore_hosts self.ignore_localhost = ignore_localhost + self.inject_cassette = inject_cassette self._custom_patches = tuple(custom_patches) def _get_serializer(self, serializer_name): @@ -99,7 +101,8 @@ class VCR(object): ), 'custom_patches': self._custom_patches + kwargs.get( 'custom_patches', () - ) + ), + 'inject': kwargs.get('inject_cassette', self.inject_cassette) } return path, merged_config