diff --git a/tests/unit/test_vcr.py b/tests/unit/test_vcr.py index cbdcde4..a17535f 100644 --- a/tests/unit/test_vcr.py +++ b/tests/unit/test_vcr.py @@ -1,11 +1,13 @@ import os import pytest +from six.moves import http_client as httplib from vcr import VCR, use_cassette from vcr.compat import mock from vcr.request import Request from vcr.stubs import VCRHTTPSConnection +from vcr.patch import _HTTPConnection, force_reset def test_vcr_use_cassette(): @@ -243,6 +245,7 @@ def test_path_transformer(): def test_cassette_name_generator_defaults_to_using_module_function_defined_in(): vcr = VCR(inject_cassette=True) + @vcr.use_cassette def function_name(cassette): assert cassette._path == os.path.join(os.path.dirname(__file__), @@ -274,3 +277,19 @@ def test_additional_matchers(): function_defaults() function_additional() + + +class TestVCRClass(VCR().test_case()): + + def no_decoration(self): + assert httplib.HTTPConnection == _HTTPConnection + + def test_one(self): + with force_reset(): + self.no_decoration() + with force_reset(): + self.test_two() + assert httplib.HTTPConnection != _HTTPConnection + + def test_two(self): + assert httplib.HTTPConnection != _HTTPConnection diff --git a/vcr/config.py b/vcr/config.py index e2389db..455d074 100644 --- a/vcr/config.py +++ b/vcr/config.py @@ -2,19 +2,25 @@ import copy import functools import inspect import os +import types import six from .compat import collections from .cassette import Cassette from .serializers import yamlserializer, jsonserializer -from .util import compose +from .util import compose, auto_decorate from . import matchers from . import filters class VCR(object): + @staticmethod + def is_test_method(method_name, function): + return method_name.startswith('test') and \ + isinstance(function, types.FunctionType) + @staticmethod def ensure_suffix(suffix): def ensure(path): @@ -202,7 +208,7 @@ class VCR(object): if filter_query_parameters: filter_functions.append(functools.partial( filters.remove_query_parameters, - query_parameters_to_remove=filter_query_parameters + query_parameters_to_remove=filter_query_parameters )) if filter_post_data_parameters: filter_functions.append( @@ -250,3 +256,7 @@ class VCR(object): def register_matcher(self, name, matcher): self.matchers[name] = matcher + + def test_case(self, predicate=None): + predicate = predicate or self.is_test_method + return six.with_metaclass(auto_decorate(self.use_cassette, predicate)) diff --git a/vcr/util.py b/vcr/util.py index 8c5bd94..9a44f9b 100644 --- a/vcr/util.py +++ b/vcr/util.py @@ -1,4 +1,6 @@ import collections +import types + # Shamelessly stolen from https://github.com/kennethreitz/requests/blob/master/requests/structures.py class CaseInsensitiveDict(collections.MutableMapping): @@ -90,3 +92,19 @@ def read_body(request): if hasattr(request.body, 'read'): return request.body.read() return request.body + + +def auto_decorate( + decorator, + predicate=lambda name, value: isinstance(value, types.FunctionType) +): + class DecorateAll(type): + + def __new__(cls, name, bases, attributes_dict): + for attribute, value in attributes_dict.items(): + if predicate(attribute, value): + attributes_dict[attribute] = decorator(value) + return super(DecorateAll, cls).__new__( + cls, name, bases, attributes_dict + ) + return DecorateAll