mirror of
https://github.com/kevin1024/vcrpy.git
synced 2025-12-09 01:03:24 +00:00
Fix some of the issues from #109
This commit is contained in:
@@ -8,7 +8,7 @@ except ImportError:
|
|||||||
from .compat.counter import Counter
|
from .compat.counter import Counter
|
||||||
|
|
||||||
# Internal imports
|
# Internal imports
|
||||||
from .patch import PatcherBuilder
|
from .patch import CassettePatcherBuilder
|
||||||
from .persist import load_cassette, save_cassette
|
from .persist import load_cassette, save_cassette
|
||||||
from .filters import filter_request
|
from .filters import filter_request
|
||||||
from .serializers import yamlserializer
|
from .serializers import yamlserializer
|
||||||
@@ -40,7 +40,7 @@ class CassetteContextDecorator(contextlib2.ContextDecorator):
|
|||||||
|
|
||||||
def _patch_generator(self, cassette):
|
def _patch_generator(self, cassette):
|
||||||
with contextlib2.ExitStack() as exit_stack:
|
with contextlib2.ExitStack() as exit_stack:
|
||||||
for patcher in PatcherBuilder(cassette).build_patchers():
|
for patcher in CassettePatcherBuilder(cassette).build():
|
||||||
exit_stack.enter_context(patcher)
|
exit_stack.enter_context(patcher)
|
||||||
log.debug('Entered context for cassette at {0}.'.format(cassette._path))
|
log.debug('Entered context for cassette at {0}.'.format(cassette._path))
|
||||||
yield cassette
|
yield cassette
|
||||||
|
|||||||
129
vcr/patch.py
129
vcr/patch.py
@@ -1,4 +1,5 @@
|
|||||||
'''Utilities for patching in cassettes'''
|
'''Utilities for patching in cassettes'''
|
||||||
|
import functools
|
||||||
import itertools
|
import itertools
|
||||||
|
|
||||||
import contextlib2
|
import contextlib2
|
||||||
@@ -53,16 +54,25 @@ else:
|
|||||||
_CertValidatingHTTPSConnection = boto.https_connection.CertValidatingHTTPSConnection
|
_CertValidatingHTTPSConnection = boto.https_connection.CertValidatingHTTPSConnection
|
||||||
|
|
||||||
|
|
||||||
class PatcherBuilder(object):
|
class CassettePatcherBuilder(object):
|
||||||
|
|
||||||
|
def _build_patchers_from_mock_triples_decorator(function):
|
||||||
|
@functools.wraps(function)
|
||||||
|
def wrapped(self, *args, **kwargs):
|
||||||
|
return self._build_patchers_from_mock_triples(function(self, *args, **kwargs))
|
||||||
|
return wrapped
|
||||||
|
|
||||||
def __init__(self, cassette):
|
def __init__(self, cassette):
|
||||||
self._cassette = cassette
|
self._cassette = cassette
|
||||||
self._class_to_cassette_subclass = {}
|
self._class_to_cassette_subclass = {}
|
||||||
|
|
||||||
def build_patchers(self):
|
def build(self):
|
||||||
patcher_args = itertools.chain(self._httplib(), self._requests(), self._urllib3(),
|
return itertools.chain(self._httplib(), self._requests(),
|
||||||
self._httplib2(), self._boto())
|
self._urllib3(), self._httplib2(),
|
||||||
for args in patcher_args:
|
self._boto())
|
||||||
|
|
||||||
|
def _build_patchers_from_mock_triples(self, mock_triples):
|
||||||
|
for args in mock_triples:
|
||||||
patcher = self._build_patcher(*args)
|
patcher = self._build_patcher(*args)
|
||||||
if patcher:
|
if patcher:
|
||||||
yield patcher
|
yield patcher
|
||||||
@@ -71,18 +81,28 @@ class PatcherBuilder(object):
|
|||||||
if not hasattr(obj, patched_attribute):
|
if not hasattr(obj, patched_attribute):
|
||||||
return
|
return
|
||||||
|
|
||||||
if isinstance(replacement_class, dict):
|
return mock.patch.object(obj, patched_attribute,
|
||||||
for key in replacement_class:
|
self._recursively_apply_get_cassette_subclass(
|
||||||
replacement_class[key] = self._get_cassette_subclass(replacement_class[key])
|
replacement_class))
|
||||||
else:
|
|
||||||
replacement_class = self._get_cassette_subclass(replacement_class)
|
def _recursively_apply_get_cassette_subclass(self, replacement_dict_or_obj):
|
||||||
return mock.patch.object(obj, patched_attribute, replacement_class)
|
if isinstance(replacement_dict_or_obj, dict):
|
||||||
|
for key, replacement_obj in replacement_dict_or_obj:
|
||||||
|
replacement_obj = self._recursively_apply_get_cassette_subclass(
|
||||||
|
replacement_obj)
|
||||||
|
replacement_dict_or_obj[key] = replacement_obj
|
||||||
|
return replacement_dict_or_obj
|
||||||
|
if hasattr(replacement_dict_or_obj, 'cassette'):
|
||||||
|
replacement_dict_or_obj = self._get_cassette_subclass(
|
||||||
|
replacement_dict_or_obj)
|
||||||
|
return replacement_dict_or_obj
|
||||||
|
|
||||||
def _get_cassette_subclass(self, klass):
|
def _get_cassette_subclass(self, klass):
|
||||||
if klass.cassette is not None:
|
if klass.cassette is not None:
|
||||||
return klass
|
return klass
|
||||||
if klass not in self._class_to_cassette_subclass:
|
if klass not in self._class_to_cassette_subclass:
|
||||||
self._class_to_cassette_subclass[klass] = self._build_cassette_subclass(klass)
|
subclass = self._build_cassette_subclass(klass)
|
||||||
|
self._class_to_cassette_subclass[klass] = subclass
|
||||||
return self._class_to_cassette_subclass[klass]
|
return self._class_to_cassette_subclass[klass]
|
||||||
|
|
||||||
def _build_cassette_subclass(self, base_class):
|
def _build_cassette_subclass(self, base_class):
|
||||||
@@ -92,6 +112,7 @@ class PatcherBuilder(object):
|
|||||||
return type('{0}{1}'.format(base_class.__name__, self._cassette._path),
|
return type('{0}{1}'.format(base_class.__name__, self._cassette._path),
|
||||||
bases, dict(cassette=self._cassette))
|
bases, dict(cassette=self._cassette))
|
||||||
|
|
||||||
|
@_build_patchers_from_mock_triples_decorator
|
||||||
def _httplib(self):
|
def _httplib(self):
|
||||||
yield httplib, 'HTTPConnection', VCRHTTPConnection
|
yield httplib, 'HTTPConnection', VCRHTTPConnection
|
||||||
yield httplib, 'HTTPSConnection', VCRHTTPSConnection
|
yield httplib, 'HTTPSConnection', VCRHTTPSConnection
|
||||||
@@ -100,17 +121,51 @@ class PatcherBuilder(object):
|
|||||||
try:
|
try:
|
||||||
import requests.packages.urllib3.connectionpool as cpool
|
import requests.packages.urllib3.connectionpool as cpool
|
||||||
except ImportError: # pragma: no cover
|
except ImportError: # pragma: no cover
|
||||||
pass
|
return
|
||||||
else:
|
from .stubs.requests_stubs import VCRRequestsHTTPConnection, VCRRequestsHTTPSConnection
|
||||||
from .stubs.requests_stubs import VCRRequestsHTTPConnection, VCRRequestsHTTPSConnection
|
http_connection_remover = ConnectionRemover(
|
||||||
|
self._get_cassette_subclass(VCRHTTPConnection)
|
||||||
|
)
|
||||||
|
https_connection_remover = ConnectionRemover(
|
||||||
|
self._get_cassette_subclass(VCRHTTPSConnection)
|
||||||
|
)
|
||||||
|
mock_triples = (
|
||||||
|
(cpool, 'VerifiedHTTPSConnection', VCRRequestsHTTPSConnection),
|
||||||
|
(cpool, 'VerifiedHTTPSConnection', VCRRequestsHTTPSConnection),
|
||||||
|
(cpool, 'HTTPConnection', VCRRequestsHTTPConnection),
|
||||||
|
(cpool, 'HTTPConnection', VCRHTTPConnection),
|
||||||
|
(cpool.HTTPConnectionPool, 'ConnectionCls', VCRRequestsHTTPConnection),
|
||||||
|
(cpool.HTTPSConnectionPool, 'ConnectionCls', VCRRequestsHTTPSConnection),
|
||||||
|
# These handle making sure that sessions only use the
|
||||||
|
# connections of the appropriate type.
|
||||||
|
(cpool.HTTPConnectionPool, '_get_conn', self._patched_get_conn(cpool.HTTPConnectionPool)),
|
||||||
|
(cpool.HTTPSConnectionPool, '_get_conn', self._patched_get_conn(cpool.HTTPSConnectionPool)),
|
||||||
|
(cpool.HTTPConnectionPool, '_new_conn', self._patched_new_conn(cpool.HTTPConnectionPool, http_connection_remover)),
|
||||||
|
(cpool.HTTPSConnectionPool, '_new_conn', self._patched_new_conn(cpool.HTTPConnectionPool, https_connection_remover))
|
||||||
|
)
|
||||||
|
return itertools.chain(self._build_patchers_from_mock_triples(mock_triples),
|
||||||
|
(http_connection_remover, https_connection_remover))
|
||||||
|
|
||||||
yield cpool, 'VerifiedHTTPSConnection', VCRRequestsHTTPSConnection
|
def _patched_get_conn(self, connection_pool_class):
|
||||||
yield cpool, 'VerifiedHTTPSConnection', VCRRequestsHTTPSConnection
|
get_conn = connection_pool_class._get_conn
|
||||||
yield cpool, 'HTTPConnection', VCRRequestsHTTPConnection
|
@functools.wraps(get_conn)
|
||||||
yield cpool, 'HTTPConnection', VCRHTTPConnection
|
def patched_get_conn(pool, timeout=None):
|
||||||
yield cpool.HTTPConnectionPool, 'ConnectionCls', VCRRequestsHTTPConnection
|
connection = get_conn(pool, timeout)
|
||||||
yield cpool.HTTPSConnectionPool, 'ConnectionCls', VCRRequestsHTTPSConnection
|
while not isinstance(connection, pool.ConnectionCls):
|
||||||
|
connection = get_conn(pool, timeout)
|
||||||
|
return connection
|
||||||
|
return patched_get_conn
|
||||||
|
|
||||||
|
def _patched_new_conn(self, connection_pool_class, connection_remover):
|
||||||
|
new_conn = connection_pool_class._new_conn
|
||||||
|
@functools.wraps(new_conn)
|
||||||
|
def patched_new_conn(pool):
|
||||||
|
new_connection = new_conn(pool)
|
||||||
|
connection_remover.add_connection_to_pool_entry(pool, new_connection)
|
||||||
|
return new_connection
|
||||||
|
return patched_new_conn
|
||||||
|
|
||||||
|
@_build_patchers_from_mock_triples_decorator
|
||||||
def _urllib3(self):
|
def _urllib3(self):
|
||||||
try:
|
try:
|
||||||
import urllib3.connectionpool as cpool
|
import urllib3.connectionpool as cpool
|
||||||
@@ -122,6 +177,7 @@ class PatcherBuilder(object):
|
|||||||
yield cpool, 'VerifiedHTTPSConnection', VCRVerifiedHTTPSConnection
|
yield cpool, 'VerifiedHTTPSConnection', VCRVerifiedHTTPSConnection
|
||||||
yield cpool, 'HTTPConnection', VCRHTTPConnection
|
yield cpool, 'HTTPConnection', VCRHTTPConnection
|
||||||
|
|
||||||
|
@_build_patchers_from_mock_triples_decorator
|
||||||
def _httplib2(self):
|
def _httplib2(self):
|
||||||
try:
|
try:
|
||||||
import httplib2 as cpool
|
import httplib2 as cpool
|
||||||
@@ -136,6 +192,7 @@ class PatcherBuilder(object):
|
|||||||
yield cpool, 'SCHEME_TO_CONNECTION', {'http': VCRHTTPConnectionWithTimeout,
|
yield cpool, 'SCHEME_TO_CONNECTION', {'http': VCRHTTPConnectionWithTimeout,
|
||||||
'https': VCRHTTPSConnectionWithTimeout}
|
'https': VCRHTTPSConnectionWithTimeout}
|
||||||
|
|
||||||
|
@_build_patchers_from_mock_triples_decorator
|
||||||
def _boto(self):
|
def _boto(self):
|
||||||
try:
|
try:
|
||||||
import boto.https_connection as cpool
|
import boto.https_connection as cpool
|
||||||
@@ -146,6 +203,36 @@ class PatcherBuilder(object):
|
|||||||
yield cpool, 'CertValidatingHTTPSConnection', VCRCertValidatingHTTPSConnection
|
yield cpool, 'CertValidatingHTTPSConnection', VCRCertValidatingHTTPSConnection
|
||||||
|
|
||||||
|
|
||||||
|
class ConnectionRemover(object):
|
||||||
|
|
||||||
|
def __init__(self, connection_class):
|
||||||
|
self._connection_class = connection_class
|
||||||
|
self._connection_pool_to_connections = {}
|
||||||
|
|
||||||
|
def add_connection_to_pool_entry(self, pool, connection):
|
||||||
|
if isinstance(connection, self._connection_class):
|
||||||
|
self._connection_pool_to_connection.setdefault(pool, set()).add(connection)
|
||||||
|
|
||||||
|
def remove_connection_to_pool_entry(self, pool, connection):
|
||||||
|
if isinstance(connection, self._connection_class):
|
||||||
|
self._connection_pool_to_connection[self._connection_class].remove(connection)
|
||||||
|
|
||||||
|
def __enter__(self):
|
||||||
|
return self
|
||||||
|
|
||||||
|
def __exit__(self, *args):
|
||||||
|
for pool, connections in self._connection_pool_to_connections.items():
|
||||||
|
readd_connections = []
|
||||||
|
while pool.not_empty() and connections:
|
||||||
|
connection = pool.get()
|
||||||
|
if isinstance(connection, self._connection_class):
|
||||||
|
connections.remove(connection)
|
||||||
|
else:
|
||||||
|
readd_connections.append(connection)
|
||||||
|
for connection in readd_connections:
|
||||||
|
self.pool._put_conn(connection)
|
||||||
|
|
||||||
|
|
||||||
def reset_patchers():
|
def reset_patchers():
|
||||||
yield mock.patch.object(httplib, 'HTTPConnection', _HTTPConnection)
|
yield mock.patch.object(httplib, 'HTTPConnection', _HTTPConnection)
|
||||||
yield mock.patch.object(httplib, 'HTTPSConnection', _HTTPSConnection)
|
yield mock.patch.object(httplib, 'HTTPSConnection', _HTTPSConnection)
|
||||||
|
|||||||
Reference in New Issue
Block a user