diff --git a/vcr/cassette.py b/vcr/cassette.py index ddbb0f1..d5c33b8 100644 --- a/vcr/cassette.py +++ b/vcr/cassette.py @@ -8,7 +8,7 @@ except ImportError: from .compat.counter import Counter # Internal imports -from .patch import PatcherBuilder +from .patch import CassettePatcherBuilder from .persist import load_cassette, save_cassette from .filters import filter_request from .serializers import yamlserializer @@ -40,7 +40,7 @@ class CassetteContextDecorator(contextlib2.ContextDecorator): def _patch_generator(self, cassette): with contextlib2.ExitStack() as exit_stack: - for patcher in PatcherBuilder(cassette).build_patchers(): + for patcher in CassettePatcherBuilder(cassette).build(): exit_stack.enter_context(patcher) log.debug('Entered context for cassette at {0}.'.format(cassette._path)) yield cassette diff --git a/vcr/patch.py b/vcr/patch.py index dc4e72e..d033b55 100644 --- a/vcr/patch.py +++ b/vcr/patch.py @@ -1,4 +1,5 @@ '''Utilities for patching in cassettes''' +import functools import itertools import contextlib2 @@ -53,16 +54,25 @@ else: _CertValidatingHTTPSConnection = boto.https_connection.CertValidatingHTTPSConnection -class PatcherBuilder(object): +class CassettePatcherBuilder(object): + + def _build_patchers_from_mock_triples_decorator(function): + @functools.wraps(function) + def wrapped(self, *args, **kwargs): + return self._build_patchers_from_mock_triples(function(self, *args, **kwargs)) + return wrapped def __init__(self, cassette): self._cassette = cassette self._class_to_cassette_subclass = {} - def build_patchers(self): - patcher_args = itertools.chain(self._httplib(), self._requests(), self._urllib3(), - self._httplib2(), self._boto()) - for args in patcher_args: + def build(self): + return itertools.chain(self._httplib(), self._requests(), + self._urllib3(), self._httplib2(), + self._boto()) + + def _build_patchers_from_mock_triples(self, mock_triples): + for args in mock_triples: patcher = self._build_patcher(*args) if patcher: yield patcher @@ -71,18 +81,28 @@ class PatcherBuilder(object): if not hasattr(obj, patched_attribute): return - if isinstance(replacement_class, dict): - for key in replacement_class: - replacement_class[key] = self._get_cassette_subclass(replacement_class[key]) - else: - replacement_class = self._get_cassette_subclass(replacement_class) - return mock.patch.object(obj, patched_attribute, replacement_class) + return mock.patch.object(obj, patched_attribute, + self._recursively_apply_get_cassette_subclass( + replacement_class)) + + def _recursively_apply_get_cassette_subclass(self, replacement_dict_or_obj): + if isinstance(replacement_dict_or_obj, dict): + for key, replacement_obj in replacement_dict_or_obj: + replacement_obj = self._recursively_apply_get_cassette_subclass( + replacement_obj) + replacement_dict_or_obj[key] = replacement_obj + return replacement_dict_or_obj + if hasattr(replacement_dict_or_obj, 'cassette'): + replacement_dict_or_obj = self._get_cassette_subclass( + replacement_dict_or_obj) + return replacement_dict_or_obj def _get_cassette_subclass(self, klass): if klass.cassette is not None: return klass if klass not in self._class_to_cassette_subclass: - self._class_to_cassette_subclass[klass] = self._build_cassette_subclass(klass) + subclass = self._build_cassette_subclass(klass) + self._class_to_cassette_subclass[klass] = subclass return self._class_to_cassette_subclass[klass] def _build_cassette_subclass(self, base_class): @@ -92,6 +112,7 @@ class PatcherBuilder(object): return type('{0}{1}'.format(base_class.__name__, self._cassette._path), bases, dict(cassette=self._cassette)) + @_build_patchers_from_mock_triples_decorator def _httplib(self): yield httplib, 'HTTPConnection', VCRHTTPConnection yield httplib, 'HTTPSConnection', VCRHTTPSConnection @@ -100,17 +121,51 @@ class PatcherBuilder(object): try: import requests.packages.urllib3.connectionpool as cpool except ImportError: # pragma: no cover - pass - else: - from .stubs.requests_stubs import VCRRequestsHTTPConnection, VCRRequestsHTTPSConnection + return + from .stubs.requests_stubs import VCRRequestsHTTPConnection, VCRRequestsHTTPSConnection + http_connection_remover = ConnectionRemover( + self._get_cassette_subclass(VCRHTTPConnection) + ) + https_connection_remover = ConnectionRemover( + self._get_cassette_subclass(VCRHTTPSConnection) + ) + mock_triples = ( + (cpool, 'VerifiedHTTPSConnection', VCRRequestsHTTPSConnection), + (cpool, 'VerifiedHTTPSConnection', VCRRequestsHTTPSConnection), + (cpool, 'HTTPConnection', VCRRequestsHTTPConnection), + (cpool, 'HTTPConnection', VCRHTTPConnection), + (cpool.HTTPConnectionPool, 'ConnectionCls', VCRRequestsHTTPConnection), + (cpool.HTTPSConnectionPool, 'ConnectionCls', VCRRequestsHTTPSConnection), + # These handle making sure that sessions only use the + # connections of the appropriate type. + (cpool.HTTPConnectionPool, '_get_conn', self._patched_get_conn(cpool.HTTPConnectionPool)), + (cpool.HTTPSConnectionPool, '_get_conn', self._patched_get_conn(cpool.HTTPSConnectionPool)), + (cpool.HTTPConnectionPool, '_new_conn', self._patched_new_conn(cpool.HTTPConnectionPool, http_connection_remover)), + (cpool.HTTPSConnectionPool, '_new_conn', self._patched_new_conn(cpool.HTTPConnectionPool, https_connection_remover)) + ) + return itertools.chain(self._build_patchers_from_mock_triples(mock_triples), + (http_connection_remover, https_connection_remover)) - yield cpool, 'VerifiedHTTPSConnection', VCRRequestsHTTPSConnection - yield cpool, 'VerifiedHTTPSConnection', VCRRequestsHTTPSConnection - yield cpool, 'HTTPConnection', VCRRequestsHTTPConnection - yield cpool, 'HTTPConnection', VCRHTTPConnection - yield cpool.HTTPConnectionPool, 'ConnectionCls', VCRRequestsHTTPConnection - yield cpool.HTTPSConnectionPool, 'ConnectionCls', VCRRequestsHTTPSConnection + def _patched_get_conn(self, connection_pool_class): + get_conn = connection_pool_class._get_conn + @functools.wraps(get_conn) + def patched_get_conn(pool, timeout=None): + connection = get_conn(pool, timeout) + while not isinstance(connection, pool.ConnectionCls): + connection = get_conn(pool, timeout) + return connection + return patched_get_conn + def _patched_new_conn(self, connection_pool_class, connection_remover): + new_conn = connection_pool_class._new_conn + @functools.wraps(new_conn) + def patched_new_conn(pool): + new_connection = new_conn(pool) + connection_remover.add_connection_to_pool_entry(pool, new_connection) + return new_connection + return patched_new_conn + + @_build_patchers_from_mock_triples_decorator def _urllib3(self): try: import urllib3.connectionpool as cpool @@ -122,6 +177,7 @@ class PatcherBuilder(object): yield cpool, 'VerifiedHTTPSConnection', VCRVerifiedHTTPSConnection yield cpool, 'HTTPConnection', VCRHTTPConnection + @_build_patchers_from_mock_triples_decorator def _httplib2(self): try: import httplib2 as cpool @@ -136,6 +192,7 @@ class PatcherBuilder(object): yield cpool, 'SCHEME_TO_CONNECTION', {'http': VCRHTTPConnectionWithTimeout, 'https': VCRHTTPSConnectionWithTimeout} + @_build_patchers_from_mock_triples_decorator def _boto(self): try: import boto.https_connection as cpool @@ -146,6 +203,36 @@ class PatcherBuilder(object): yield cpool, 'CertValidatingHTTPSConnection', VCRCertValidatingHTTPSConnection +class ConnectionRemover(object): + + def __init__(self, connection_class): + self._connection_class = connection_class + self._connection_pool_to_connections = {} + + def add_connection_to_pool_entry(self, pool, connection): + if isinstance(connection, self._connection_class): + self._connection_pool_to_connection.setdefault(pool, set()).add(connection) + + def remove_connection_to_pool_entry(self, pool, connection): + if isinstance(connection, self._connection_class): + self._connection_pool_to_connection[self._connection_class].remove(connection) + + def __enter__(self): + return self + + def __exit__(self, *args): + for pool, connections in self._connection_pool_to_connections.items(): + readd_connections = [] + while pool.not_empty() and connections: + connection = pool.get() + if isinstance(connection, self._connection_class): + connections.remove(connection) + else: + readd_connections.append(connection) + for connection in readd_connections: + self.pool._put_conn(connection) + + def reset_patchers(): yield mock.patch.object(httplib, 'HTTPConnection', _HTTPConnection) yield mock.patch.object(httplib, 'HTTPSConnection', _HTTPSConnection)