1
0
mirror of https://github.com/kevin1024/vcrpy.git synced 2025-12-08 16:53:23 +00:00

Fix some of the issues from #109

This commit is contained in:
Ivan Malison
2014-09-19 17:06:53 -07:00
parent 1018867838
commit b1cdd50e9b
2 changed files with 110 additions and 23 deletions

View File

@@ -8,7 +8,7 @@ except ImportError:
from .compat.counter import Counter
# Internal imports
from .patch import PatcherBuilder
from .patch import CassettePatcherBuilder
from .persist import load_cassette, save_cassette
from .filters import filter_request
from .serializers import yamlserializer
@@ -40,7 +40,7 @@ class CassetteContextDecorator(contextlib2.ContextDecorator):
def _patch_generator(self, cassette):
with contextlib2.ExitStack() as exit_stack:
for patcher in PatcherBuilder(cassette).build_patchers():
for patcher in CassettePatcherBuilder(cassette).build():
exit_stack.enter_context(patcher)
log.debug('Entered context for cassette at {0}.'.format(cassette._path))
yield cassette

View File

@@ -1,4 +1,5 @@
'''Utilities for patching in cassettes'''
import functools
import itertools
import contextlib2
@@ -53,16 +54,25 @@ else:
_CertValidatingHTTPSConnection = boto.https_connection.CertValidatingHTTPSConnection
class PatcherBuilder(object):
class CassettePatcherBuilder(object):
def _build_patchers_from_mock_triples_decorator(function):
@functools.wraps(function)
def wrapped(self, *args, **kwargs):
return self._build_patchers_from_mock_triples(function(self, *args, **kwargs))
return wrapped
def __init__(self, cassette):
self._cassette = cassette
self._class_to_cassette_subclass = {}
def build_patchers(self):
patcher_args = itertools.chain(self._httplib(), self._requests(), self._urllib3(),
self._httplib2(), self._boto())
for args in patcher_args:
def build(self):
return itertools.chain(self._httplib(), self._requests(),
self._urllib3(), self._httplib2(),
self._boto())
def _build_patchers_from_mock_triples(self, mock_triples):
for args in mock_triples:
patcher = self._build_patcher(*args)
if patcher:
yield patcher
@@ -71,18 +81,28 @@ class PatcherBuilder(object):
if not hasattr(obj, patched_attribute):
return
if isinstance(replacement_class, dict):
for key in replacement_class:
replacement_class[key] = self._get_cassette_subclass(replacement_class[key])
else:
replacement_class = self._get_cassette_subclass(replacement_class)
return mock.patch.object(obj, patched_attribute, replacement_class)
return mock.patch.object(obj, patched_attribute,
self._recursively_apply_get_cassette_subclass(
replacement_class))
def _recursively_apply_get_cassette_subclass(self, replacement_dict_or_obj):
if isinstance(replacement_dict_or_obj, dict):
for key, replacement_obj in replacement_dict_or_obj:
replacement_obj = self._recursively_apply_get_cassette_subclass(
replacement_obj)
replacement_dict_or_obj[key] = replacement_obj
return replacement_dict_or_obj
if hasattr(replacement_dict_or_obj, 'cassette'):
replacement_dict_or_obj = self._get_cassette_subclass(
replacement_dict_or_obj)
return replacement_dict_or_obj
def _get_cassette_subclass(self, klass):
if klass.cassette is not None:
return klass
if klass not in self._class_to_cassette_subclass:
self._class_to_cassette_subclass[klass] = self._build_cassette_subclass(klass)
subclass = self._build_cassette_subclass(klass)
self._class_to_cassette_subclass[klass] = subclass
return self._class_to_cassette_subclass[klass]
def _build_cassette_subclass(self, base_class):
@@ -92,6 +112,7 @@ class PatcherBuilder(object):
return type('{0}{1}'.format(base_class.__name__, self._cassette._path),
bases, dict(cassette=self._cassette))
@_build_patchers_from_mock_triples_decorator
def _httplib(self):
yield httplib, 'HTTPConnection', VCRHTTPConnection
yield httplib, 'HTTPSConnection', VCRHTTPSConnection
@@ -100,17 +121,51 @@ class PatcherBuilder(object):
try:
import requests.packages.urllib3.connectionpool as cpool
except ImportError: # pragma: no cover
pass
else:
from .stubs.requests_stubs import VCRRequestsHTTPConnection, VCRRequestsHTTPSConnection
return
from .stubs.requests_stubs import VCRRequestsHTTPConnection, VCRRequestsHTTPSConnection
http_connection_remover = ConnectionRemover(
self._get_cassette_subclass(VCRHTTPConnection)
)
https_connection_remover = ConnectionRemover(
self._get_cassette_subclass(VCRHTTPSConnection)
)
mock_triples = (
(cpool, 'VerifiedHTTPSConnection', VCRRequestsHTTPSConnection),
(cpool, 'VerifiedHTTPSConnection', VCRRequestsHTTPSConnection),
(cpool, 'HTTPConnection', VCRRequestsHTTPConnection),
(cpool, 'HTTPConnection', VCRHTTPConnection),
(cpool.HTTPConnectionPool, 'ConnectionCls', VCRRequestsHTTPConnection),
(cpool.HTTPSConnectionPool, 'ConnectionCls', VCRRequestsHTTPSConnection),
# These handle making sure that sessions only use the
# connections of the appropriate type.
(cpool.HTTPConnectionPool, '_get_conn', self._patched_get_conn(cpool.HTTPConnectionPool)),
(cpool.HTTPSConnectionPool, '_get_conn', self._patched_get_conn(cpool.HTTPSConnectionPool)),
(cpool.HTTPConnectionPool, '_new_conn', self._patched_new_conn(cpool.HTTPConnectionPool, http_connection_remover)),
(cpool.HTTPSConnectionPool, '_new_conn', self._patched_new_conn(cpool.HTTPConnectionPool, https_connection_remover))
)
return itertools.chain(self._build_patchers_from_mock_triples(mock_triples),
(http_connection_remover, https_connection_remover))
yield cpool, 'VerifiedHTTPSConnection', VCRRequestsHTTPSConnection
yield cpool, 'VerifiedHTTPSConnection', VCRRequestsHTTPSConnection
yield cpool, 'HTTPConnection', VCRRequestsHTTPConnection
yield cpool, 'HTTPConnection', VCRHTTPConnection
yield cpool.HTTPConnectionPool, 'ConnectionCls', VCRRequestsHTTPConnection
yield cpool.HTTPSConnectionPool, 'ConnectionCls', VCRRequestsHTTPSConnection
def _patched_get_conn(self, connection_pool_class):
get_conn = connection_pool_class._get_conn
@functools.wraps(get_conn)
def patched_get_conn(pool, timeout=None):
connection = get_conn(pool, timeout)
while not isinstance(connection, pool.ConnectionCls):
connection = get_conn(pool, timeout)
return connection
return patched_get_conn
def _patched_new_conn(self, connection_pool_class, connection_remover):
new_conn = connection_pool_class._new_conn
@functools.wraps(new_conn)
def patched_new_conn(pool):
new_connection = new_conn(pool)
connection_remover.add_connection_to_pool_entry(pool, new_connection)
return new_connection
return patched_new_conn
@_build_patchers_from_mock_triples_decorator
def _urllib3(self):
try:
import urllib3.connectionpool as cpool
@@ -122,6 +177,7 @@ class PatcherBuilder(object):
yield cpool, 'VerifiedHTTPSConnection', VCRVerifiedHTTPSConnection
yield cpool, 'HTTPConnection', VCRHTTPConnection
@_build_patchers_from_mock_triples_decorator
def _httplib2(self):
try:
import httplib2 as cpool
@@ -136,6 +192,7 @@ class PatcherBuilder(object):
yield cpool, 'SCHEME_TO_CONNECTION', {'http': VCRHTTPConnectionWithTimeout,
'https': VCRHTTPSConnectionWithTimeout}
@_build_patchers_from_mock_triples_decorator
def _boto(self):
try:
import boto.https_connection as cpool
@@ -146,6 +203,36 @@ class PatcherBuilder(object):
yield cpool, 'CertValidatingHTTPSConnection', VCRCertValidatingHTTPSConnection
class ConnectionRemover(object):
def __init__(self, connection_class):
self._connection_class = connection_class
self._connection_pool_to_connections = {}
def add_connection_to_pool_entry(self, pool, connection):
if isinstance(connection, self._connection_class):
self._connection_pool_to_connection.setdefault(pool, set()).add(connection)
def remove_connection_to_pool_entry(self, pool, connection):
if isinstance(connection, self._connection_class):
self._connection_pool_to_connection[self._connection_class].remove(connection)
def __enter__(self):
return self
def __exit__(self, *args):
for pool, connections in self._connection_pool_to_connections.items():
readd_connections = []
while pool.not_empty() and connections:
connection = pool.get()
if isinstance(connection, self._connection_class):
connections.remove(connection)
else:
readd_connections.append(connection)
for connection in readd_connections:
self.pool._put_conn(connection)
def reset_patchers():
yield mock.patch.object(httplib, 'HTTPConnection', _HTTPConnection)
yield mock.patch.object(httplib, 'HTTPSConnection', _HTTPSConnection)