From 757ad9c8362aa233ce753de604218cba5902feea Mon Sep 17 00:00:00 2001 From: Ivan Malison Date: Sat, 20 Sep 2014 11:59:25 -0700 Subject: [PATCH] Revert "Remove ConnectionRemover class that tried to get rid of vcr connections in ConnectionPools." This reverts commit dc249b09656acb27651e04071592e276ecf4dca8. Conflicts: vcr/patch.py --- vcr/patch.py | 61 ++++++++++++++++++++++++++++++++++++++++++++++++---- 1 file changed, 57 insertions(+), 4 deletions(-) diff --git a/vcr/patch.py b/vcr/patch.py index 74bb355..1a09801 100644 --- a/vcr/patch.py +++ b/vcr/patch.py @@ -123,6 +123,12 @@ class CassettePatcherBuilder(object): except ImportError: # pragma: no cover return () from .stubs.requests_stubs import VCRRequestsHTTPConnection, VCRRequestsHTTPSConnection + http_connection_remover = ConnectionRemover( + self._get_cassette_subclass(VCRHTTPConnection) + ) + https_connection_remover = ConnectionRemover( + self._get_cassette_subclass(VCRHTTPSConnection) + ) mock_triples = ( (cpool, 'VerifiedHTTPSConnection', VCRRequestsHTTPSConnection), (cpool, 'VerifiedHTTPSConnection', VCRRequestsHTTPSConnection), @@ -130,16 +136,24 @@ class CassettePatcherBuilder(object): (cpool, 'HTTPSConnection', VCRRequestsHTTPSConnection), (cpool.HTTPConnectionPool, 'ConnectionCls', VCRRequestsHTTPConnection), (cpool.HTTPSConnectionPool, 'ConnectionCls', VCRRequestsHTTPSConnection), - # These handle making sure that sessions only use the - # connections of the appropriate type. ) + # These handle making sure that sessions only use the + # connections of the appropriate type. mock_triples += ((cpool.HTTPConnectionPool, '_get_conn', self._patched_get_conn(cpool.HTTPConnectionPool, lambda : cpool.HTTPConnection)), (cpool.HTTPSConnectionPool, '_get_conn', self._patched_get_conn(cpool.HTTPSConnectionPool, - lambda : cpool.HTTPSConnection))) - return self._build_patchers_from_mock_triples(mock_triples) + lambda : cpool.HTTPSConnection)), + (cpool.HTTPConnectionPool, '_new_conn', + self._patched_new_conn(cpool.HTTPConnectionPool, + http_connection_remover)), + (cpool.HTTPSConnectionPool, '_new_conn', + self._patched_new_conn(cpool.HTTPConnectionPool, + https_connection_remover))) + + return itertools.chain(self._build_patchers_from_mock_triples(mock_triples), + (http_connection_remover, https_connection_remover)) def _patched_get_conn(self, connection_pool_class, connection_class_getter): get_conn = connection_pool_class._get_conn @@ -153,6 +167,15 @@ class CassettePatcherBuilder(object): return connection return patched_get_conn + def _patched_new_conn(self, connection_pool_class, connection_remover): + new_conn = connection_pool_class._new_conn + @functools.wraps(new_conn) + def patched_new_conn(pool): + new_connection = new_conn(pool) + connection_remover.add_connection_to_pool_entry(pool, new_connection) + return new_connection + return patched_new_conn + @_build_patchers_from_mock_triples_decorator def _urllib3(self): try: @@ -191,6 +214,36 @@ class CassettePatcherBuilder(object): yield cpool, 'CertValidatingHTTPSConnection', VCRCertValidatingHTTPSConnection +class ConnectionRemover(object): + + def __init__(self, connection_class): + self._connection_class = connection_class + self._connection_pool_to_connections = {} + + def add_connection_to_pool_entry(self, pool, connection): + if isinstance(connection, self._connection_class): + self._connection_pool_to_connection.setdefault(pool, set()).add(connection) + + def remove_connection_to_pool_entry(self, pool, connection): + if isinstance(connection, self._connection_class): + self._connection_pool_to_connection[self._connection_class].remove(connection) + + def __enter__(self): + return self + + def __exit__(self, *args): + for pool, connections in self._connection_pool_to_connections.items(): + readd_connections = [] + while pool.not_empty() and connections: + connection = pool.get() + if isinstance(connection, self._connection_class): + connections.remove(connection) + else: + readd_connections.append(connection) + for connection in readd_connections: + self.pool._put_conn(connection) + + def reset_patchers(): yield mock.patch.object(httplib, 'HTTPConnection', _HTTPConnection) yield mock.patch.object(httplib, 'HTTPSConnection', _HTTPSConnection)