diff --git a/vcr/cassette.py b/vcr/cassette.py index e5e58ad..f03c01b 100644 --- a/vcr/cassette.py +++ b/vcr/cassette.py @@ -7,7 +7,7 @@ except ImportError: from .compat.counter import Counter # Internal imports -from .patch import build_patchers +from .patch import PatcherBuilder from .persist import load_cassette, save_cassette from .filters import filter_request from .serializers import yamlserializer @@ -36,7 +36,7 @@ class CassetteContextDecorator(object): def _patch_generator(self, cassette): with contextlib2.ExitStack() as exit_stack: - for patcher in build_patchers(cassette): + for patcher in PatcherBuilder(cassette).build_patchers(): exit_stack.enter_context(patcher) yield cassette # TODO(@IvanMalison): Hmmm. it kind of feels like this should be somewhere else. diff --git a/vcr/patch.py b/vcr/patch.py index 35f9b23..e38cea0 100644 --- a/vcr/patch.py +++ b/vcr/patch.py @@ -1,4 +1,6 @@ '''Utilities for patching in cassettes''' +import itertools + import contextlib2 import mock @@ -51,89 +53,96 @@ else: _CertValidatingHTTPSConnection = boto.https_connection.CertValidatingHTTPSConnection +class PatcherBuilder(object): -def cassette_subclass(base_class, cassette): - bases = (base_class,) - if not issubclass(base_class, object): # Check for old style class - bases += (object,) - return type('{0}{1}'.format(base_class.__name__, cassette._path), bases, dict(cassette=cassette)) + def __init__(self, cassette): + self._cassette = cassette + self._class_to_cassette_subclass = {} + def build_patchers(self): + patcher_args = itertools.chain(self._httplib(), self._requests(), self._urllib3(), + self._httplib2(), self._boto()) + for args in patcher_args: + patcher = self._build_patcher(*args) + if patcher: + yield patcher -def build_patchers(cassette): - """ - Build patches for all the HTTPConnections references we can find! - This replaces the actual HTTPConnection with a VCRHTTPConnection - object which knows how to save to / read from cassettes - """ - _VCRHTTPConnection = cassette_subclass(VCRHTTPConnection, cassette) - _VCRHTTPSConnection = cassette_subclass(VCRHTTPSConnection, cassette) + def _build_patcher(self, obj, patched_attribute, replacement_class): + if not hasattr(obj, patched_attribute): + return + if isinstance(replacement_class, dict): + for key in replacement_class: + replacement_class[key] = self._get_cassette_subclass(replacement_class[key]) + else: + replacement_class = self._get_cassette_subclass(replacement_class) + return mock.patch.object(obj, patched_attribute, replacement_class) - yield mock.patch.object(httplib, 'HTTPConnection', _VCRHTTPConnection) - yield mock.patch.object(httplib, 'HTTPSConnection', _VCRHTTPSConnection) + def _get_cassette_subclass(self, klass): + if klass not in self._class_to_cassette_subclass: + self._class_to_cassette_subclass[klass] = self._cassette_subclass(klass) + return self._class_to_cassette_subclass[klass] - # requests - try: - import requests.packages.urllib3.connectionpool as cpool - except ImportError: # pragma: no cover - pass - else: - from .stubs.requests_stubs import VCRRequestsHTTPConnection, VCRRequestsHTTPSConnection + def _cassette_subclass(self, base_class): + bases = (base_class,) + if not issubclass(base_class, object): # Check for old style class + bases += (object,) + return type('{0}{1}'.format(base_class.__name__, self._cassette._path), + bases, dict(cassette=self._cassette)) - # patch requests v1.x - yield mock.patch.object(cpool, 'VerifiedHTTPSConnection', - cassette_subclass(VCRRequestsHTTPSConnection, cassette)) - yield mock.patch.object(cpool, 'HTTPConnection', - cassette_subclass(VCRRequestsHTTPConnection, cassette)) - yield mock.patch.object(cpool, 'HTTPConnection', _VCRHTTPConnection) + def _httplib(self): + yield httplib, 'HTTPConnection', VCRHTTPConnection + yield httplib, 'HTTPSConnection', VCRHTTPSConnection - # patch requests v2.x - if hasattr(cpool.HTTPConnectionPool, 'ConnectionCls'): - yield mock.patch.object(cpool.HTTPConnectionPool, 'ConnectionCls', - cassette_subclass(VCRRequestsHTTPConnection, cassette)) - yield mock.patch.object(cpool.HTTPSConnectionPool, 'ConnectionCls', - cassette_subclass(VCRRequestsHTTPSConnection, cassette)) + def _requests(self): + try: + import requests.packages.urllib3.connectionpool as cpool + except ImportError: # pragma: no cover + pass + else: + from .stubs.requests_stubs import VCRRequestsHTTPConnection, VCRRequestsHTTPSConnection - # patch urllib3 - try: - import urllib3.connectionpool as cpool - except ImportError: # pragma: no cover - pass - else: - from .stubs.urllib3_stubs import VCRVerifiedHTTPSConnection - yield mock.patch.object(cpool, 'VerifiedHTTPSConnection', - cassette_subclass(VCRVerifiedHTTPSConnection, cassette)) - yield mock.patch.object(cpool, 'HTTPConnection', _VCRHTTPConnection) + yield cpool, 'VerifiedHTTPSConnection', VCRRequestsHTTPSConnection + yield cpool, 'VerifiedHTTPSConnection', VCRRequestsHTTPSConnection + yield cpool, 'HTTPConnection', VCRRequestsHTTPConnection + yield cpool, 'HTTPConnection', VCRHTTPConnection + yield cpool.HTTPConnectionPool, 'ConnectionCls', VCRRequestsHTTPConnection + yield cpool.HTTPSConnectionPool, 'ConnectionCls', VCRRequestsHTTPSConnection - # patch httplib2 - try: - import httplib2 as cpool - except ImportError: # pragma: no cover - pass - else: - from .stubs.httplib2_stubs import VCRHTTPConnectionWithTimeout - from .stubs.httplib2_stubs import VCRHTTPSConnectionWithTimeout - _VCRHTTPConnectionWithTimeout = cassette_subclass(VCRHTTPConnectionWithTimeout, - cassette) - _VCRHTTPSConnectionWithTimeout = cassette_subclass(VCRHTTPSConnectionWithTimeout, - cassette) - yield mock.patch.object(cpool, 'HTTPConnectionWithTimeout', - _VCRHTTPConnectionWithTimeout) - yield mock.patch.object(cpool, 'HTTPSConnectionWithTimeout', - _VCRHTTPSConnectionWithTimeout) - yield mock.patch.object(cpool, 'SCHEME_TO_CONNECTION', - {'http': _VCRHTTPConnectionWithTimeout, - 'https': _VCRHTTPSConnectionWithTimeout}) + def _urllib3(self): + try: + import urllib3.connectionpool as cpool + except ImportError: # pragma: no cover + pass + else: + from .stubs.urllib3_stubs import VCRVerifiedHTTPSConnection + + yield cpool, 'VerifiedHTTPSConnection', VCRVerifiedHTTPSConnection + yield cpool, 'HTTPConnection', VCRHTTPConnection + + def _httplib2(self): + try: + import httplib2 as cpool + except ImportError: # pragma: no cover + pass + else: + from .stubs.httplib2_stubs import VCRHTTPConnectionWithTimeout + from .stubs.httplib2_stubs import VCRHTTPSConnectionWithTimeout + + yield cpool, 'HTTPConnectionWithTimeout', VCRHTTPConnectionWithTimeout + yield cpool, 'HTTPSConnectionWithTimeout', VCRHTTPSConnectionWithTimeout + yield cpool, 'SCHEME_TO_CONNECTION', {'http': VCRHTTPConnectionWithTimeout, + 'https': VCRHTTPSConnectionWithTimeout} + + def _boto(self): + try: + import boto.https_connection as cpool + except ImportError: # pragma: no cover + pass + else: + from .stubs.boto_stubs import VCRCertValidatingHTTPSConnection + yield cpool, 'CertValidatingHTTPSConnection', VCRCertValidatingHTTPSConnection - # patch boto - try: - import boto.https_connection as cpool - except ImportError: # pragma: no cover - pass - else: - from .stubs.boto_stubs import VCRCertValidatingHTTPSConnection - yield mock.patch.object(cpool, 'CertValidatingHTTPSConnection', - cassette_subclass(VCRCertValidatingHTTPSConnection, cassette)) def reset_patchers(): @@ -148,11 +157,14 @@ def reset_patchers(): yield mock.patch.object(cpool, 'VerifiedHTTPSConnection', _VerifiedHTTPSConnection) yield mock.patch.object(cpool, 'HTTPConnection', _cpoolHTTPConnection) # unpatch requests v2.x - yield mock.patch.object(cpool.HTTPConnectionPool, 'ConnectionCls', - _cpoolHTTPConnection) - yield mock.patch.object(cpool, 'HTTPSConnection', _cpoolHTTPSConnection) - yield mock.patch.object(cpool.HTTPSConnectionPool, 'ConnectionCls', - _cpoolHTTPSConnection) + if hasattr(cpool.HTTPConnectionPool, 'ConnectionCls'): + yield mock.patch.object(cpool.HTTPConnectionPool, 'ConnectionCls', + _cpoolHTTPConnection) + yield mock.patch.object(cpool.HTTPSConnectionPool, 'ConnectionCls', + _cpoolHTTPSConnection) + + if hasattr(cpool, 'HTTPSConnection'): + yield mock.patch.object(cpool, 'HTTPSConnection', _cpoolHTTPSConnection) try: import urllib3.connectionpool as cpool