1
0
mirror of https://github.com/kevin1024/vcrpy.git synced 2025-12-08 16:53:23 +00:00
Files
vcrpy/vcr/patch.py

320 lines
13 KiB
Python

'''Utilities for patching in cassettes'''
import functools
import itertools
import contextlib2
import mock
from .stubs import VCRHTTPConnection, VCRHTTPSConnection
from six.moves import http_client as httplib
# Save some of the original types for the purposes of unpatching
_HTTPConnection = httplib.HTTPConnection
_HTTPSConnection = httplib.HTTPSConnection
# Try to save the original types for requests
try:
import requests.packages.urllib3.connectionpool as cpool
except ImportError: # pragma: no cover
pass
else:
_VerifiedHTTPSConnection = cpool.VerifiedHTTPSConnection
_cpoolHTTPConnection = cpool.HTTPConnection
_cpoolHTTPSConnection = cpool.HTTPSConnection
# Try to save the original types for urllib3
try:
import urllib3
except ImportError: # pragma: no cover
pass
else:
_VerifiedHTTPSConnection = urllib3.connectionpool.VerifiedHTTPSConnection
# Try to save the original types for httplib2
try:
import httplib2
except ImportError: # pragma: no cover
pass
else:
_HTTPConnectionWithTimeout = httplib2.HTTPConnectionWithTimeout
_HTTPSConnectionWithTimeout = httplib2.HTTPSConnectionWithTimeout
_SCHEME_TO_CONNECTION = httplib2.SCHEME_TO_CONNECTION
# Try to save the original types for boto
try:
import boto.https_connection
except ImportError: # pragma: no cover
pass
else:
_CertValidatingHTTPSConnection = boto.https_connection.CertValidatingHTTPSConnection
class CassettePatcherBuilder(object):
def _build_patchers_from_mock_triples_decorator(function):
@functools.wraps(function)
def wrapped(self, *args, **kwargs):
return self._build_patchers_from_mock_triples(
function(self, *args, **kwargs)
)
return wrapped
def __init__(self, cassette):
self._cassette = cassette
self._class_to_cassette_subclass = {}
def build(self):
return itertools.chain(
self._httplib(), self._requests(), self._urllib3(), self._httplib2(),
self._boto(), self._build_patchers_from_mock_triples(
self._cassette.custom_patches
)
)
def _build_patchers_from_mock_triples(self, mock_triples):
for args in mock_triples:
patcher = self._build_patcher(*args)
if patcher:
yield patcher
def _build_patcher(self, obj, patched_attribute, replacement_class):
if not hasattr(obj, patched_attribute):
return
return mock.patch.object(obj, patched_attribute,
self._recursively_apply_get_cassette_subclass(
replacement_class))
def _recursively_apply_get_cassette_subclass(self, replacement_dict_or_obj):
"""One of the subtleties of this class is that it does not directly
replace HTTPSConnection with VCRRequestsHTTPSConnection, but a
subclass of this class that has cassette assigned to the
appropriate value. This behavior is necessary to properly
support nested cassette contexts
This function exists to ensure that we use the same class
object (reference) to patch everything that replaces
VCRRequestHTTP[S]Connection, but that we can talk about
patching them with the raw references instead.
"""
if isinstance(replacement_dict_or_obj, dict):
for key, replacement_obj in replacement_dict_or_obj.items():
replacement_obj = self._recursively_apply_get_cassette_subclass(
replacement_obj)
replacement_dict_or_obj[key] = replacement_obj
return replacement_dict_or_obj
if hasattr(replacement_dict_or_obj, 'cassette'):
replacement_dict_or_obj = self._get_cassette_subclass(
replacement_dict_or_obj)
return replacement_dict_or_obj
def _get_cassette_subclass(self, klass):
if klass.cassette is not None:
return klass
if klass not in self._class_to_cassette_subclass:
subclass = self._build_cassette_subclass(klass)
self._class_to_cassette_subclass[klass] = subclass
return self._class_to_cassette_subclass[klass]
def _build_cassette_subclass(self, base_class):
bases = (base_class,)
if not issubclass(base_class, object): # Check for old style class
bases += (object,)
return type('{0}{1}'.format(base_class.__name__, self._cassette._path),
bases, dict(cassette=self._cassette))
@_build_patchers_from_mock_triples_decorator
def _httplib(self):
yield httplib, 'HTTPConnection', VCRHTTPConnection
yield httplib, 'HTTPSConnection', VCRHTTPSConnection
def _requests(self):
try:
import requests.packages.urllib3.connectionpool as cpool
except ImportError: # pragma: no cover
return ()
from .stubs import requests_stubs
return self._urllib3_patchers(cpool, requests_stubs)
def _patched_get_conn(self, connection_pool_class, connection_class_getter):
get_conn = connection_pool_class._get_conn
@functools.wraps(get_conn)
def patched_get_conn(pool, timeout=None):
connection = get_conn(pool, timeout)
connection_class = pool.ConnectionCls if hasattr(pool, 'ConnectionCls') \
else connection_class_getter()
while not isinstance(connection, connection_class):
connection = get_conn(pool, timeout)
return connection
return patched_get_conn
def _patched_new_conn(self, connection_pool_class, connection_remover):
new_conn = connection_pool_class._new_conn
@functools.wraps(new_conn)
def patched_new_conn(pool):
new_connection = new_conn(pool)
connection_remover.add_connection_to_pool_entry(pool, new_connection)
return new_connection
return patched_new_conn
def _urllib3(self):
try:
import urllib3.connectionpool as cpool
except ImportError: # pragma: no cover
return ()
from .stubs import urllib3_stubs
return self._urllib3_patchers(cpool, urllib3_stubs)
@_build_patchers_from_mock_triples_decorator
def _httplib2(self):
try:
import httplib2 as cpool
except ImportError: # pragma: no cover
pass
else:
from .stubs.httplib2_stubs import VCRHTTPConnectionWithTimeout
from .stubs.httplib2_stubs import VCRHTTPSConnectionWithTimeout
yield cpool, 'HTTPConnectionWithTimeout', VCRHTTPConnectionWithTimeout
yield cpool, 'HTTPSConnectionWithTimeout', VCRHTTPSConnectionWithTimeout
yield cpool, 'SCHEME_TO_CONNECTION', {'http': VCRHTTPConnectionWithTimeout,
'https': VCRHTTPSConnectionWithTimeout}
@_build_patchers_from_mock_triples_decorator
def _boto(self):
try:
import boto.https_connection as cpool
except ImportError: # pragma: no cover
pass
else:
from .stubs.boto_stubs import VCRCertValidatingHTTPSConnection
yield cpool, 'CertValidatingHTTPSConnection', VCRCertValidatingHTTPSConnection
def _urllib3_patchers(self, cpool, stubs):
http_connection_remover = ConnectionRemover(
self._get_cassette_subclass(stubs.VCRRequestsHTTPConnection)
)
https_connection_remover = ConnectionRemover(
self._get_cassette_subclass(stubs.VCRRequestsHTTPSConnection)
)
mock_triples = (
(cpool, 'VerifiedHTTPSConnection', stubs.VCRRequestsHTTPSConnection),
(cpool, 'VerifiedHTTPSConnection', stubs.VCRRequestsHTTPSConnection),
(cpool, 'HTTPConnection', stubs.VCRRequestsHTTPConnection),
(cpool, 'HTTPSConnection', stubs.VCRRequestsHTTPSConnection),
(cpool, 'is_connection_dropped', mock.Mock(return_value=False)), # Needed on Windows only
(cpool.HTTPConnectionPool, 'ConnectionCls', stubs.VCRRequestsHTTPConnection),
(cpool.HTTPSConnectionPool, 'ConnectionCls', stubs.VCRRequestsHTTPSConnection),
)
# These handle making sure that sessions only use the
# connections of the appropriate type.
mock_triples += ((cpool.HTTPConnectionPool, '_get_conn',
self._patched_get_conn(cpool.HTTPConnectionPool,
lambda : cpool.HTTPConnection)),
(cpool.HTTPSConnectionPool, '_get_conn',
self._patched_get_conn(cpool.HTTPSConnectionPool,
lambda : cpool.HTTPSConnection)),
(cpool.HTTPConnectionPool, '_new_conn',
self._patched_new_conn(cpool.HTTPConnectionPool,
http_connection_remover)),
(cpool.HTTPSConnectionPool, '_new_conn',
self._patched_new_conn(cpool.HTTPSConnectionPool,
https_connection_remover)))
return itertools.chain(self._build_patchers_from_mock_triples(mock_triples),
(http_connection_remover, https_connection_remover))
class ConnectionRemover(object):
def __init__(self, connection_class):
self._connection_class = connection_class
self._connection_pool_to_connections = {}
def add_connection_to_pool_entry(self, pool, connection):
if isinstance(connection, self._connection_class):
self._connection_pool_to_connections.setdefault(pool, set()).add(connection)
def remove_connection_to_pool_entry(self, pool, connection):
if isinstance(connection, self._connection_class):
self._connection_pool_to_connections[self._connection_class].remove(connection)
def __enter__(self):
return self
def __exit__(self, *args):
for pool, connections in self._connection_pool_to_connections.items():
readd_connections = []
while pool.pool and not pool.pool.empty() and connections:
connection = pool.pool.get()
if isinstance(connection, self._connection_class):
connections.remove(connection)
else:
readd_connections.append(connection)
for connection in readd_connections:
pool._put_conn(connection)
def reset_patchers():
yield mock.patch.object(httplib, 'HTTPConnection', _HTTPConnection)
yield mock.patch.object(httplib, 'HTTPSConnection', _HTTPSConnection)
try:
import requests.packages.urllib3.connectionpool as cpool
except ImportError: # pragma: no cover
pass
else:
# unpatch requests v1.x
yield mock.patch.object(cpool, 'VerifiedHTTPSConnection', _VerifiedHTTPSConnection)
yield mock.patch.object(cpool, 'HTTPConnection', _cpoolHTTPConnection)
# unpatch requests v2.x
if hasattr(cpool.HTTPConnectionPool, 'ConnectionCls'):
yield mock.patch.object(cpool.HTTPConnectionPool, 'ConnectionCls',
_cpoolHTTPConnection)
yield mock.patch.object(cpool.HTTPSConnectionPool, 'ConnectionCls',
_cpoolHTTPSConnection)
if hasattr(cpool, 'HTTPSConnection'):
yield mock.patch.object(cpool, 'HTTPSConnection', _cpoolHTTPSConnection)
try:
import urllib3.connectionpool as cpool
except ImportError: # pragma: no cover
pass
else:
yield mock.patch.object(cpool, 'VerifiedHTTPSConnection', _VerifiedHTTPSConnection)
yield mock.patch.object(cpool, 'HTTPConnection', _HTTPConnection)
yield mock.patch.object(cpool, 'HTTPSConnection', _HTTPSConnection)
if hasattr(cpool.HTTPConnectionPool, 'ConnectionCls'):
yield mock.patch.object(cpool.HTTPConnectionPool, 'ConnectionCls', _HTTPConnection)
yield mock.patch.object(cpool.HTTPSConnectionPool, 'ConnectionCls', _HTTPSConnection)
try:
import httplib2 as cpool
except ImportError: # pragma: no cover
pass
else:
yield mock.patch.object(cpool, 'HTTPConnectionWithTimeout', _HTTPConnectionWithTimeout)
yield mock.patch.object(cpool, 'HTTPSConnectionWithTimeout', _HTTPSConnectionWithTimeout)
yield mock.patch.object(cpool, 'SCHEME_TO_CONNECTION', _SCHEME_TO_CONNECTION)
try:
import boto.https_connection as cpool
except ImportError: # pragma: no cover
pass
else:
yield mock.patch.object(cpool, 'CertValidatingHTTPSConnection',
_CertValidatingHTTPSConnection)
@contextlib2.contextmanager
def force_reset():
with contextlib2.ExitStack() as exit_stack:
for patcher in reset_patchers():
exit_stack.enter_context(patcher)
yield