1
0
mirror of https://github.com/kevin1024/vcrpy.git synced 2025-12-09 01:03:24 +00:00

Fix some of the issues from #109

This commit is contained in:
Ivan Malison
2014-09-19 17:06:53 -07:00
parent 1018867838
commit b1cdd50e9b
2 changed files with 110 additions and 23 deletions

View File

@@ -8,7 +8,7 @@ except ImportError:
from .compat.counter import Counter from .compat.counter import Counter
# Internal imports # Internal imports
from .patch import PatcherBuilder from .patch import CassettePatcherBuilder
from .persist import load_cassette, save_cassette from .persist import load_cassette, save_cassette
from .filters import filter_request from .filters import filter_request
from .serializers import yamlserializer from .serializers import yamlserializer
@@ -40,7 +40,7 @@ class CassetteContextDecorator(contextlib2.ContextDecorator):
def _patch_generator(self, cassette): def _patch_generator(self, cassette):
with contextlib2.ExitStack() as exit_stack: with contextlib2.ExitStack() as exit_stack:
for patcher in PatcherBuilder(cassette).build_patchers(): for patcher in CassettePatcherBuilder(cassette).build():
exit_stack.enter_context(patcher) exit_stack.enter_context(patcher)
log.debug('Entered context for cassette at {0}.'.format(cassette._path)) log.debug('Entered context for cassette at {0}.'.format(cassette._path))
yield cassette yield cassette

View File

@@ -1,4 +1,5 @@
'''Utilities for patching in cassettes''' '''Utilities for patching in cassettes'''
import functools
import itertools import itertools
import contextlib2 import contextlib2
@@ -53,16 +54,25 @@ else:
_CertValidatingHTTPSConnection = boto.https_connection.CertValidatingHTTPSConnection _CertValidatingHTTPSConnection = boto.https_connection.CertValidatingHTTPSConnection
class PatcherBuilder(object): class CassettePatcherBuilder(object):
def _build_patchers_from_mock_triples_decorator(function):
@functools.wraps(function)
def wrapped(self, *args, **kwargs):
return self._build_patchers_from_mock_triples(function(self, *args, **kwargs))
return wrapped
def __init__(self, cassette): def __init__(self, cassette):
self._cassette = cassette self._cassette = cassette
self._class_to_cassette_subclass = {} self._class_to_cassette_subclass = {}
def build_patchers(self): def build(self):
patcher_args = itertools.chain(self._httplib(), self._requests(), self._urllib3(), return itertools.chain(self._httplib(), self._requests(),
self._httplib2(), self._boto()) self._urllib3(), self._httplib2(),
for args in patcher_args: self._boto())
def _build_patchers_from_mock_triples(self, mock_triples):
for args in mock_triples:
patcher = self._build_patcher(*args) patcher = self._build_patcher(*args)
if patcher: if patcher:
yield patcher yield patcher
@@ -71,18 +81,28 @@ class PatcherBuilder(object):
if not hasattr(obj, patched_attribute): if not hasattr(obj, patched_attribute):
return return
if isinstance(replacement_class, dict): return mock.patch.object(obj, patched_attribute,
for key in replacement_class: self._recursively_apply_get_cassette_subclass(
replacement_class[key] = self._get_cassette_subclass(replacement_class[key]) replacement_class))
else:
replacement_class = self._get_cassette_subclass(replacement_class) def _recursively_apply_get_cassette_subclass(self, replacement_dict_or_obj):
return mock.patch.object(obj, patched_attribute, replacement_class) if isinstance(replacement_dict_or_obj, dict):
for key, replacement_obj in replacement_dict_or_obj:
replacement_obj = self._recursively_apply_get_cassette_subclass(
replacement_obj)
replacement_dict_or_obj[key] = replacement_obj
return replacement_dict_or_obj
if hasattr(replacement_dict_or_obj, 'cassette'):
replacement_dict_or_obj = self._get_cassette_subclass(
replacement_dict_or_obj)
return replacement_dict_or_obj
def _get_cassette_subclass(self, klass): def _get_cassette_subclass(self, klass):
if klass.cassette is not None: if klass.cassette is not None:
return klass return klass
if klass not in self._class_to_cassette_subclass: if klass not in self._class_to_cassette_subclass:
self._class_to_cassette_subclass[klass] = self._build_cassette_subclass(klass) subclass = self._build_cassette_subclass(klass)
self._class_to_cassette_subclass[klass] = subclass
return self._class_to_cassette_subclass[klass] return self._class_to_cassette_subclass[klass]
def _build_cassette_subclass(self, base_class): def _build_cassette_subclass(self, base_class):
@@ -92,6 +112,7 @@ class PatcherBuilder(object):
return type('{0}{1}'.format(base_class.__name__, self._cassette._path), return type('{0}{1}'.format(base_class.__name__, self._cassette._path),
bases, dict(cassette=self._cassette)) bases, dict(cassette=self._cassette))
@_build_patchers_from_mock_triples_decorator
def _httplib(self): def _httplib(self):
yield httplib, 'HTTPConnection', VCRHTTPConnection yield httplib, 'HTTPConnection', VCRHTTPConnection
yield httplib, 'HTTPSConnection', VCRHTTPSConnection yield httplib, 'HTTPSConnection', VCRHTTPSConnection
@@ -100,17 +121,51 @@ class PatcherBuilder(object):
try: try:
import requests.packages.urllib3.connectionpool as cpool import requests.packages.urllib3.connectionpool as cpool
except ImportError: # pragma: no cover except ImportError: # pragma: no cover
pass return
else: from .stubs.requests_stubs import VCRRequestsHTTPConnection, VCRRequestsHTTPSConnection
from .stubs.requests_stubs import VCRRequestsHTTPConnection, VCRRequestsHTTPSConnection http_connection_remover = ConnectionRemover(
self._get_cassette_subclass(VCRHTTPConnection)
)
https_connection_remover = ConnectionRemover(
self._get_cassette_subclass(VCRHTTPSConnection)
)
mock_triples = (
(cpool, 'VerifiedHTTPSConnection', VCRRequestsHTTPSConnection),
(cpool, 'VerifiedHTTPSConnection', VCRRequestsHTTPSConnection),
(cpool, 'HTTPConnection', VCRRequestsHTTPConnection),
(cpool, 'HTTPConnection', VCRHTTPConnection),
(cpool.HTTPConnectionPool, 'ConnectionCls', VCRRequestsHTTPConnection),
(cpool.HTTPSConnectionPool, 'ConnectionCls', VCRRequestsHTTPSConnection),
# These handle making sure that sessions only use the
# connections of the appropriate type.
(cpool.HTTPConnectionPool, '_get_conn', self._patched_get_conn(cpool.HTTPConnectionPool)),
(cpool.HTTPSConnectionPool, '_get_conn', self._patched_get_conn(cpool.HTTPSConnectionPool)),
(cpool.HTTPConnectionPool, '_new_conn', self._patched_new_conn(cpool.HTTPConnectionPool, http_connection_remover)),
(cpool.HTTPSConnectionPool, '_new_conn', self._patched_new_conn(cpool.HTTPConnectionPool, https_connection_remover))
)
return itertools.chain(self._build_patchers_from_mock_triples(mock_triples),
(http_connection_remover, https_connection_remover))
yield cpool, 'VerifiedHTTPSConnection', VCRRequestsHTTPSConnection def _patched_get_conn(self, connection_pool_class):
yield cpool, 'VerifiedHTTPSConnection', VCRRequestsHTTPSConnection get_conn = connection_pool_class._get_conn
yield cpool, 'HTTPConnection', VCRRequestsHTTPConnection @functools.wraps(get_conn)
yield cpool, 'HTTPConnection', VCRHTTPConnection def patched_get_conn(pool, timeout=None):
yield cpool.HTTPConnectionPool, 'ConnectionCls', VCRRequestsHTTPConnection connection = get_conn(pool, timeout)
yield cpool.HTTPSConnectionPool, 'ConnectionCls', VCRRequestsHTTPSConnection while not isinstance(connection, pool.ConnectionCls):
connection = get_conn(pool, timeout)
return connection
return patched_get_conn
def _patched_new_conn(self, connection_pool_class, connection_remover):
new_conn = connection_pool_class._new_conn
@functools.wraps(new_conn)
def patched_new_conn(pool):
new_connection = new_conn(pool)
connection_remover.add_connection_to_pool_entry(pool, new_connection)
return new_connection
return patched_new_conn
@_build_patchers_from_mock_triples_decorator
def _urllib3(self): def _urllib3(self):
try: try:
import urllib3.connectionpool as cpool import urllib3.connectionpool as cpool
@@ -122,6 +177,7 @@ class PatcherBuilder(object):
yield cpool, 'VerifiedHTTPSConnection', VCRVerifiedHTTPSConnection yield cpool, 'VerifiedHTTPSConnection', VCRVerifiedHTTPSConnection
yield cpool, 'HTTPConnection', VCRHTTPConnection yield cpool, 'HTTPConnection', VCRHTTPConnection
@_build_patchers_from_mock_triples_decorator
def _httplib2(self): def _httplib2(self):
try: try:
import httplib2 as cpool import httplib2 as cpool
@@ -136,6 +192,7 @@ class PatcherBuilder(object):
yield cpool, 'SCHEME_TO_CONNECTION', {'http': VCRHTTPConnectionWithTimeout, yield cpool, 'SCHEME_TO_CONNECTION', {'http': VCRHTTPConnectionWithTimeout,
'https': VCRHTTPSConnectionWithTimeout} 'https': VCRHTTPSConnectionWithTimeout}
@_build_patchers_from_mock_triples_decorator
def _boto(self): def _boto(self):
try: try:
import boto.https_connection as cpool import boto.https_connection as cpool
@@ -146,6 +203,36 @@ class PatcherBuilder(object):
yield cpool, 'CertValidatingHTTPSConnection', VCRCertValidatingHTTPSConnection yield cpool, 'CertValidatingHTTPSConnection', VCRCertValidatingHTTPSConnection
class ConnectionRemover(object):
def __init__(self, connection_class):
self._connection_class = connection_class
self._connection_pool_to_connections = {}
def add_connection_to_pool_entry(self, pool, connection):
if isinstance(connection, self._connection_class):
self._connection_pool_to_connection.setdefault(pool, set()).add(connection)
def remove_connection_to_pool_entry(self, pool, connection):
if isinstance(connection, self._connection_class):
self._connection_pool_to_connection[self._connection_class].remove(connection)
def __enter__(self):
return self
def __exit__(self, *args):
for pool, connections in self._connection_pool_to_connections.items():
readd_connections = []
while pool.not_empty() and connections:
connection = pool.get()
if isinstance(connection, self._connection_class):
connections.remove(connection)
else:
readd_connections.append(connection)
for connection in readd_connections:
self.pool._put_conn(connection)
def reset_patchers(): def reset_patchers():
yield mock.patch.object(httplib, 'HTTPConnection', _HTTPConnection) yield mock.patch.object(httplib, 'HTTPConnection', _HTTPConnection)
yield mock.patch.object(httplib, 'HTTPSConnection', _HTTPSConnection) yield mock.patch.object(httplib, 'HTTPSConnection', _HTTPSConnection)