1
0
mirror of https://github.com/kevin1024/vcrpy.git synced 2025-12-09 17:15:35 +00:00

Format project with black (#467)

Format with line length 110 to match flake8

make black part of linting check

Update travis spec for updated black requirements

Add diff output for black on failure

update changelog
This commit is contained in:
Josh Peak
2019-08-24 11:36:35 +10:00
committed by GitHub
parent 75969de601
commit 7caf29735a
70 changed files with 2000 additions and 2217 deletions

View File

@@ -7,14 +7,16 @@ from .config import VCR
try: # Python 2.7+
from logging import NullHandler
except ImportError:
class NullHandler(logging.Handler):
def emit(self, record):
pass
if sys.version_info[0] == 2:
warnings.warn(
"Python 2.x support of vcrpy is deprecated and will be removed in an upcoming major release.",
DeprecationWarning
DeprecationWarning,
)
logging.getLogger(__name__).addHandler(NullHandler())

View File

@@ -1,3 +1,3 @@
async def handle_coroutine(vcr, fn): # noqa: E999
with vcr as cassette:
return (await fn(cassette)) # noqa: E999
return await fn(cassette) # noqa: E999

View File

@@ -17,14 +17,17 @@ from .util import partition_dict
try:
from asyncio import iscoroutinefunction
except ImportError:
def iscoroutinefunction(*args, **kwargs):
return False
if sys.version_info[:2] >= (3, 5):
from ._handle_coroutine import handle_coroutine
else:
def handle_coroutine(*args, **kwags):
raise NotImplementedError('Not implemented on Python 2')
raise NotImplementedError("Not implemented on Python 2")
log = logging.getLogger(__name__)
@@ -48,7 +51,7 @@ class CassetteContextDecorator(object):
this class as a context manager in ``__exit__``.
"""
_non_cassette_arguments = ('path_transformer', 'func_path_generator')
_non_cassette_arguments = ("path_transformer", "func_path_generator")
@classmethod
def from_args(cls, cassette_class, **kwargs):
@@ -63,14 +66,10 @@ class CassetteContextDecorator(object):
with contextlib.ExitStack() as exit_stack:
for patcher in CassettePatcherBuilder(cassette).build():
exit_stack.enter_context(patcher)
log_format = '{action} context for cassette at {path}.'
log.debug(log_format.format(
action="Entering", path=cassette._path
))
log_format = "{action} context for cassette at {path}."
log.debug(log_format.format(action="Entering", path=cassette._path))
yield cassette
log.debug(log_format.format(
action="Exiting", path=cassette._path
))
log.debug(log_format.format(action="Exiting", path=cassette._path))
# TODO(@IvanMalison): Hmmm. it kind of feels like this should be
# somewhere else.
cassette._save()
@@ -86,12 +85,11 @@ class CassetteContextDecorator(object):
# pass
assert self.__finish is None, "Cassette already open."
other_kwargs, cassette_kwargs = partition_dict(
lambda key, _: key in self._non_cassette_arguments,
self._args_getter()
lambda key, _: key in self._non_cassette_arguments, self._args_getter()
)
if other_kwargs.get('path_transformer'):
transformer = other_kwargs['path_transformer']
cassette_kwargs['path'] = transformer(cassette_kwargs['path'])
if other_kwargs.get("path_transformer"):
transformer = other_kwargs["path_transformer"]
cassette_kwargs["path"] = transformer(cassette_kwargs["path"])
self.__finish = self._patch_generator(self.cls.load(**cassette_kwargs))
return next(self.__finish)
@@ -105,9 +103,7 @@ class CassetteContextDecorator(object):
# functions are reentrant. This is required for thread
# safety and the correct operation of recursive functions.
args_getter = self._build_args_getter_for_decorator(function)
return type(self)(self.cls, args_getter)._execute_function(
function, args, kwargs
)
return type(self)(self.cls, args_getter)._execute_function(function, args, kwargs)
def _execute_function(self, function, args, kwargs):
def handle_function(cassette):
@@ -154,12 +150,12 @@ class CassetteContextDecorator(object):
def _build_args_getter_for_decorator(self, function):
def new_args_getter():
kwargs = self._args_getter()
if 'path' not in kwargs:
name_generator = (kwargs.get('func_path_generator') or
self.get_function_name)
if "path" not in kwargs:
name_generator = kwargs.get("func_path_generator") or self.get_function_name
path = name_generator(function)
kwargs['path'] = path
kwargs["path"] = path
return kwargs
return new_args_getter
@@ -181,10 +177,18 @@ class Cassette(object):
def use(cls, **kwargs):
return CassetteContextDecorator.from_args(cls, **kwargs)
def __init__(self, path, serializer=None, persister=None, record_mode='once',
match_on=(uri, method), before_record_request=None,
before_record_response=None, custom_patches=(),
inject=False):
def __init__(
self,
path,
serializer=None,
persister=None,
record_mode="once",
match_on=(uri, method),
before_record_request=None,
before_record_response=None,
custom_patches=(),
inject=False,
):
self._persister = persister or FilesystemPersister
self._path = path
self._serializer = serializer or yamlserializer
@@ -221,8 +225,7 @@ class Cassette(object):
@property
def write_protected(self):
return self.rewound and self.record_mode == 'once' or \
self.record_mode == 'none'
return self.rewound and self.record_mode == "once" or self.record_mode == "none"
def append(self, request, response):
"""Add a request, response pair to this cassette"""
@@ -254,9 +257,7 @@ class Cassette(object):
def can_play_response_for(self, request):
request = self._before_record_request(request)
return request and request in self and \
self.record_mode != 'all' and \
self.rewound
return request and request in self and self.record_mode != "all" and self.rewound
def play_response(self, request):
"""
@@ -269,8 +270,7 @@ class Cassette(object):
return response
# The cassette doesn't contain the request asked for.
raise UnhandledHTTPRequestError(
"The cassette (%r) doesn't contain the request (%r) asked for"
% (self._path, request)
"The cassette (%r) doesn't contain the request (%r) asked for" % (self._path, request)
)
def responses_of(self, request):
@@ -285,8 +285,7 @@ class Cassette(object):
return responses
# The cassette doesn't contain the request asked for.
raise UnhandledHTTPRequestError(
"The cassette (%r) doesn't contain the request (%r) asked for"
% (self._path, request)
"The cassette (%r) doesn't contain the request (%r) asked for" % (self._path, request)
)
def rewind(self):
@@ -333,19 +332,12 @@ class Cassette(object):
def _save(self, force=False):
if force or self.dirty:
self._persister.save_cassette(
self._path,
self._as_dict(),
serializer=self._serializer,
)
self._persister.save_cassette(self._path, self._as_dict(), serializer=self._serializer)
self.dirty = False
def _load(self):
try:
requests, responses = self._persister.load_cassette(
self._path,
serializer=self._serializer,
)
requests, responses = self._persister.load_cassette(self._path, serializer=self._serializer)
for request, response in zip(requests, responses):
self.append(request, response)
self.dirty = False
@@ -354,9 +346,7 @@ class Cassette(object):
pass
def __str__(self):
return "<Cassette containing {} recorded response(s)>".format(
len(self)
)
return "<Cassette containing {} recorded response(s)>".format(len(self))
def __len__(self):
"""Return the number of request,response pairs stored in here"""

View File

@@ -8,7 +8,7 @@ try:
except ImportError:
import contextlib2 as contextlib
else:
if not hasattr(contextlib, 'ExitStack'):
if not hasattr(contextlib, "ExitStack"):
import contextlib2 as contextlib
__all__ = ['mock', 'contextlib']
__all__ = ["mock", "contextlib"]

View File

@@ -1,4 +1,5 @@
import copy
try:
from collections import abc as collections_abc # only works on python 3.3+
except ImportError:
@@ -19,11 +20,9 @@ from . import filters
class VCR(object):
@staticmethod
def is_test_method(method_name, function):
return method_name.startswith('test') and \
isinstance(function, types.FunctionType)
return method_name.startswith("test") and isinstance(function, types.FunctionType)
@staticmethod
def ensure_suffix(suffix):
@@ -31,35 +30,45 @@ class VCR(object):
if not path.endswith(suffix):
return path + suffix
return path
return ensure
def __init__(self, path_transformer=None, before_record_request=None,
custom_patches=(), filter_query_parameters=(), ignore_hosts=(),
record_mode="once", ignore_localhost=False, filter_headers=(),
before_record_response=None, filter_post_data_parameters=(),
match_on=('method', 'scheme', 'host', 'port', 'path', 'query'),
before_record=None, inject_cassette=False, serializer='yaml',
cassette_library_dir=None, func_path_generator=None,
decode_compressed_response=False):
def __init__(
self,
path_transformer=None,
before_record_request=None,
custom_patches=(),
filter_query_parameters=(),
ignore_hosts=(),
record_mode="once",
ignore_localhost=False,
filter_headers=(),
before_record_response=None,
filter_post_data_parameters=(),
match_on=("method", "scheme", "host", "port", "path", "query"),
before_record=None,
inject_cassette=False,
serializer="yaml",
cassette_library_dir=None,
func_path_generator=None,
decode_compressed_response=False,
):
self.serializer = serializer
self.match_on = match_on
self.cassette_library_dir = cassette_library_dir
self.serializers = {
'yaml': yamlserializer,
'json': jsonserializer,
}
self.serializers = {"yaml": yamlserializer, "json": jsonserializer}
self.matchers = {
'method': matchers.method,
'uri': matchers.uri,
'url': matchers.uri, # matcher for backwards compatibility
'scheme': matchers.scheme,
'host': matchers.host,
'port': matchers.port,
'path': matchers.path,
'query': matchers.query,
'headers': matchers.headers,
'raw_body': matchers.raw_body,
'body': matchers.body,
"method": matchers.method,
"uri": matchers.uri,
"url": matchers.uri, # matcher for backwards compatibility
"scheme": matchers.scheme,
"host": matchers.host,
"port": matchers.port,
"path": matchers.path,
"query": matchers.query,
"headers": matchers.headers,
"raw_body": matchers.raw_body,
"body": matchers.body,
}
self.persister = FilesystemPersister
self.record_mode = record_mode
@@ -80,11 +89,7 @@ class VCR(object):
try:
serializer = self.serializers[serializer_name]
except KeyError:
raise KeyError(
"Serializer {} doesn't exist or isn't registered".format(
serializer_name
)
)
raise KeyError("Serializer {} doesn't exist or isn't registered".format(serializer_name))
return serializer
def _get_matchers(self, matcher_names):
@@ -93,9 +98,7 @@ class VCR(object):
for m in matcher_names:
matchers.append(self.matchers[m])
except KeyError:
raise KeyError(
"Matcher {} doesn't exist or isn't registered".format(m)
)
raise KeyError("Matcher {} doesn't exist or isn't registered".format(m))
return matchers
def use_cassette(self, path=None, **kwargs):
@@ -117,62 +120,47 @@ class VCR(object):
return Cassette.use_arg_getter(args_getter)
def get_merged_config(self, **kwargs):
serializer_name = kwargs.get('serializer', self.serializer)
matcher_names = kwargs.get('match_on', self.match_on)
path_transformer = kwargs.get(
'path_transformer',
self.path_transformer
)
func_path_generator = kwargs.get(
'func_path_generator',
self.func_path_generator
)
cassette_library_dir = kwargs.get(
'cassette_library_dir',
self.cassette_library_dir
)
additional_matchers = kwargs.get('additional_matchers', ())
serializer_name = kwargs.get("serializer", self.serializer)
matcher_names = kwargs.get("match_on", self.match_on)
path_transformer = kwargs.get("path_transformer", self.path_transformer)
func_path_generator = kwargs.get("func_path_generator", self.func_path_generator)
cassette_library_dir = kwargs.get("cassette_library_dir", self.cassette_library_dir)
additional_matchers = kwargs.get("additional_matchers", ())
if cassette_library_dir:
def add_cassette_library_dir(path):
if not path.startswith(cassette_library_dir):
return os.path.join(cassette_library_dir, path)
return path
path_transformer = compose(
add_cassette_library_dir, path_transformer
)
path_transformer = compose(add_cassette_library_dir, path_transformer)
elif not func_path_generator:
# If we don't have a library dir, use the functions
# location to build a full path for cassettes.
func_path_generator = self._build_path_from_func_using_module
merged_config = {
'serializer': self._get_serializer(serializer_name),
'persister': self.persister,
'match_on': self._get_matchers(
tuple(matcher_names) + tuple(additional_matchers)
),
'record_mode': kwargs.get('record_mode', self.record_mode),
'before_record_request': self._build_before_record_request(kwargs),
'before_record_response': self._build_before_record_response(kwargs),
'custom_patches': self._custom_patches + kwargs.get(
'custom_patches', ()
),
'inject': kwargs.get('inject_cassette', self.inject_cassette),
'path_transformer': path_transformer,
'func_path_generator': func_path_generator
"serializer": self._get_serializer(serializer_name),
"persister": self.persister,
"match_on": self._get_matchers(tuple(matcher_names) + tuple(additional_matchers)),
"record_mode": kwargs.get("record_mode", self.record_mode),
"before_record_request": self._build_before_record_request(kwargs),
"before_record_response": self._build_before_record_response(kwargs),
"custom_patches": self._custom_patches + kwargs.get("custom_patches", ()),
"inject": kwargs.get("inject_cassette", self.inject_cassette),
"path_transformer": path_transformer,
"func_path_generator": func_path_generator,
}
path = kwargs.get('path')
path = kwargs.get("path")
if path:
merged_config['path'] = path
merged_config["path"] = path
return merged_config
def _build_before_record_response(self, options):
before_record_response = options.get(
'before_record_response', self.before_record_response
)
before_record_response = options.get("before_record_response", self.before_record_response)
decode_compressed_response = options.get(
'decode_compressed_response', self.decode_compressed_response
"decode_compressed_response", self.decode_compressed_response
)
filter_functions = []
if decode_compressed_response:
@@ -188,58 +176,38 @@ class VCR(object):
break
response = function(response)
return response
return before_record_response
def _build_before_record_request(self, options):
filter_functions = []
filter_headers = options.get(
'filter_headers', self.filter_headers
)
filter_query_parameters = options.get(
'filter_query_parameters', self.filter_query_parameters
)
filter_headers = options.get("filter_headers", self.filter_headers)
filter_query_parameters = options.get("filter_query_parameters", self.filter_query_parameters)
filter_post_data_parameters = options.get(
'filter_post_data_parameters', self.filter_post_data_parameters
"filter_post_data_parameters", self.filter_post_data_parameters
)
before_record_request = options.get(
"before_record_request",
options.get("before_record", self.before_record_request)
)
ignore_hosts = options.get(
'ignore_hosts', self.ignore_hosts
)
ignore_localhost = options.get(
'ignore_localhost', self.ignore_localhost
"before_record_request", options.get("before_record", self.before_record_request)
)
ignore_hosts = options.get("ignore_hosts", self.ignore_hosts)
ignore_localhost = options.get("ignore_localhost", self.ignore_localhost)
if filter_headers:
replacements = [h if isinstance(h, tuple) else (h, None)
for h in filter_headers]
filter_functions.append(
functools.partial(
filters.replace_headers,
replacements=replacements,
)
)
replacements = [h if isinstance(h, tuple) else (h, None) for h in filter_headers]
filter_functions.append(functools.partial(filters.replace_headers, replacements=replacements))
if filter_query_parameters:
replacements = [p if isinstance(p, tuple) else (p, None)
for p in filter_query_parameters]
filter_functions.append(functools.partial(
filters.replace_query_parameters,
replacements=replacements,
))
if filter_post_data_parameters:
replacements = [p if isinstance(p, tuple) else (p, None)
for p in filter_post_data_parameters]
replacements = [p if isinstance(p, tuple) else (p, None) for p in filter_query_parameters]
filter_functions.append(
functools.partial(
filters.replace_post_data_parameters,
replacements=replacements,
)
functools.partial(filters.replace_query_parameters, replacements=replacements)
)
if filter_post_data_parameters:
replacements = [p if isinstance(p, tuple) else (p, None) for p in filter_post_data_parameters]
filter_functions.append(
functools.partial(filters.replace_post_data_parameters, replacements=replacements)
)
hosts_to_ignore = set(ignore_hosts)
if ignore_localhost:
hosts_to_ignore.update(('localhost', '0.0.0.0', '127.0.0.1'))
hosts_to_ignore.update(("localhost", "0.0.0.0", "127.0.0.1"))
if hosts_to_ignore:
filter_functions.append(self._build_ignore_hosts(hosts_to_ignore))
@@ -255,20 +223,21 @@ class VCR(object):
break
request = function(request)
return request
return before_record_request
@staticmethod
def _build_ignore_hosts(hosts_to_ignore):
def filter_ignored_hosts(request):
if hasattr(request, 'host') and request.host in hosts_to_ignore:
if hasattr(request, "host") and request.host in hosts_to_ignore:
return
return request
return filter_ignored_hosts
@staticmethod
def _build_path_from_func_using_module(function):
return os.path.join(os.path.dirname(inspect.getfile(function)),
function.__name__)
return os.path.join(os.path.dirname(inspect.getfile(function)), function.__name__)
def register_serializer(self, name, serializer):
self.serializers[name] = serializer

View File

@@ -14,27 +14,29 @@ class CannotOverwriteExistingCassetteException(Exception):
if best_matches:
# Build a comprehensible message to put in the exception.
best_matches_msg = "Found {} similar requests with {} different matcher(s) :\n".format(
len(best_matches), len(best_matches[0][2]))
len(best_matches), len(best_matches[0][2])
)
for idx, best_match in enumerate(best_matches, start=1):
request, succeeded_matchers, failed_matchers_assertion_msgs = best_match
best_matches_msg += "\n%s - (%r).\n" \
"Matchers succeeded : %s\n" \
"Matchers failed :\n" % (idx, request, succeeded_matchers)
best_matches_msg += (
"\n%s - (%r).\n"
"Matchers succeeded : %s\n"
"Matchers failed :\n" % (idx, request, succeeded_matchers)
)
for failed_matcher, assertion_msg in failed_matchers_assertion_msgs:
best_matches_msg += "%s - assertion failure :\n" \
"%s\n" % (failed_matcher, assertion_msg)
best_matches_msg += "%s - assertion failure :\n" "%s\n" % (failed_matcher, assertion_msg)
else:
best_matches_msg = "No similar requests, that have not been played, found."
return (
"Can't overwrite existing cassette (%r) in "
"your current record mode (%r).\n"
"No match for the request (%r) was found.\n"
"%s"
% (cassette._path, cassette.record_mode, failed_request, best_matches_msg)
"%s" % (cassette._path, cassette.record_mode, failed_request, best_matches_msg)
)
class UnhandledHTTPRequestError(KeyError):
"""Raised when a cassette does not contain the request we want."""
pass

View File

@@ -83,9 +83,9 @@ def replace_post_data_parameters(request, replacements):
value or None.
"""
replacements = dict(replacements)
if request.method == 'POST' and not isinstance(request.body, BytesIO):
if request.headers.get('Content-Type') == 'application/json':
json_data = json.loads(request.body.decode('utf-8'))
if request.method == "POST" and not isinstance(request.body, BytesIO):
if request.headers.get("Content-Type") == "application/json":
json_data = json.loads(request.body.decode("utf-8"))
for k, rv in replacements.items():
if k in json_data:
ov = json_data.pop(k)
@@ -93,28 +93,26 @@ def replace_post_data_parameters(request, replacements):
rv = rv(key=k, value=ov, request=request)
if rv is not None:
json_data[k] = rv
request.body = json.dumps(json_data).encode('utf-8')
request.body = json.dumps(json_data).encode("utf-8")
else:
if isinstance(request.body, text_type):
request.body = request.body.encode('utf-8')
splits = [p.partition(b'=') for p in request.body.split(b'&')]
request.body = request.body.encode("utf-8")
splits = [p.partition(b"=") for p in request.body.split(b"&")]
new_splits = []
for k, sep, ov in splits:
if sep is None:
new_splits.append((k, sep, ov))
else:
rk = k.decode('utf-8')
rk = k.decode("utf-8")
if rk not in replacements:
new_splits.append((k, sep, ov))
else:
rv = replacements[rk]
if callable(rv):
rv = rv(key=rk, value=ov.decode('utf-8'),
request=request)
rv = rv(key=rk, value=ov.decode("utf-8"), request=request)
if rv is not None:
new_splits.append((k, sep, rv.encode('utf-8')))
request.body = b'&'.join(k if sep is None else b''.join([k, sep, v])
for k, sep, v in new_splits)
new_splits.append((k, sep, rv.encode("utf-8")))
request.body = b"&".join(k if sep is None else b"".join([k, sep, v]) for k, sep, v in new_splits)
return request
@@ -133,15 +131,16 @@ def decode_response(response):
2. delete the content-encoding header
3. update content-length header to decompressed length
"""
def is_compressed(headers):
encoding = headers.get('content-encoding', [])
return encoding and encoding[0] in ('gzip', 'deflate')
encoding = headers.get("content-encoding", [])
return encoding and encoding[0] in ("gzip", "deflate")
def decompress_body(body, encoding):
"""Returns decompressed body according to encoding using zlib.
to (de-)compress gzip format, use wbits = zlib.MAX_WBITS | 16
"""
if encoding == 'gzip':
if encoding == "gzip":
return zlib.decompress(body, zlib.MAX_WBITS | 16)
else: # encoding == 'deflate'
return zlib.decompress(body)
@@ -149,15 +148,15 @@ def decode_response(response):
# Deepcopy here in case `headers` contain objects that could
# be mutated by a shallow copy and corrupt the real response.
response = copy.deepcopy(response)
headers = CaseInsensitiveDict(response['headers'])
headers = CaseInsensitiveDict(response["headers"])
if is_compressed(headers):
encoding = headers['content-encoding'][0]
headers['content-encoding'].remove(encoding)
if not headers['content-encoding']:
del headers['content-encoding']
encoding = headers["content-encoding"][0]
headers["content-encoding"].remove(encoding)
if not headers["content-encoding"]:
del headers["content-encoding"]
new_body = decompress_body(response['body']['string'], encoding)
response['body']['string'] = new_body
headers['content-length'] = [str(len(new_body))]
response['headers'] = dict(headers)
new_body = decompress_body(response["body"]["string"], encoding)
response["body"]["string"] = new_body
headers["content-length"] = [str(len(new_body))]
response["headers"] = dict(headers)
return response

View File

@@ -51,9 +51,10 @@ def headers(r1, r2):
assert r1.headers == r2.headers, "{} != {}".format(r1.headers, r2.headers)
def _header_checker(value, header='Content-Type'):
def _header_checker(value, header="Content-Type"):
def checker(headers):
return value in headers.get(header, '').lower()
return value in headers.get(header, "").lower()
return checker
@@ -62,18 +63,18 @@ def _transform_json(body):
# string. RFC 7159 says the default encoding is UTF-8 (although UTF-16
# and UTF-32 are also allowed: hmmmmm).
if body:
return json.loads(body.decode('utf-8'))
return json.loads(body.decode("utf-8"))
_xml_header_checker = _header_checker('text/xml')
_xmlrpc_header_checker = _header_checker('xmlrpc', header='User-Agent')
_xml_header_checker = _header_checker("text/xml")
_xmlrpc_header_checker = _header_checker("xmlrpc", header="User-Agent")
_checker_transformer_pairs = (
(_header_checker('application/x-www-form-urlencoded'),
lambda body: urllib.parse.parse_qs(body.decode('ascii'))),
(_header_checker('application/json'),
_transform_json),
(lambda request: _xml_header_checker(request) and _xmlrpc_header_checker(request),
xmlrpc_client.loads),
(
_header_checker("application/x-www-form-urlencoded"),
lambda body: urllib.parse.parse_qs(body.decode("ascii")),
),
(_header_checker("application/json"), _transform_json),
(lambda request: _xml_header_checker(request) and _xmlrpc_header_checker(request), xmlrpc_client.loads),
)
@@ -92,11 +93,7 @@ def _get_transformer(request):
def requests_match(r1, r2, matchers):
successes, failures = get_matchers_results(r1, r2, matchers)
if failures:
log.debug(
"Requests {} and {} differ.\n"
"Failure details:\n"
"{}".format(r1, r2, failures)
)
log.debug("Requests {} and {} differ.\n" "Failure details:\n" "{}".format(r1, r2, failures))
return len(failures) == 0

View File

@@ -38,55 +38,46 @@ def preprocess_yaml(cassette):
# versions. So this just strips the tags before deserializing.
STRINGS_TO_NUKE = [
'!!python/object:vcr.request.Request',
'!!python/object/apply:__builtin__.frozenset',
'!!python/object/apply:builtins.frozenset',
"!!python/object:vcr.request.Request",
"!!python/object/apply:__builtin__.frozenset",
"!!python/object/apply:builtins.frozenset",
]
for s in STRINGS_TO_NUKE:
cassette = cassette.replace(s, '')
cassette = cassette.replace(s, "")
return cassette
PARTS = [
'protocol',
'host',
'port',
'path',
]
PARTS = ["protocol", "host", "port", "path"]
def build_uri(**parts):
port = parts['port']
scheme = parts['protocol']
default_port = {'https': 443, 'http': 80}[scheme]
parts['port'] = ':{}'.format(port) if port != default_port else ''
port = parts["port"]
scheme = parts["protocol"]
default_port = {"https": 443, "http": 80}[scheme]
parts["port"] = ":{}".format(port) if port != default_port else ""
return "{protocol}://{host}{port}{path}".format(**parts)
def _migrate(data):
interactions = []
for item in data:
req = item['request']
res = item['response']
req = item["request"]
res = item["response"]
uri = {k: req.pop(k) for k in PARTS}
req['uri'] = build_uri(**uri)
req["uri"] = build_uri(**uri)
# convert headers to dict of lists
headers = req['headers']
headers = req["headers"]
for k in headers:
headers[k] = [headers[k]]
response_headers = {}
for k, v in get_httpmessage(
b"".join(h.encode('utf-8') for h in res['headers'])
).items():
for k, v in get_httpmessage(b"".join(h.encode("utf-8") for h in res["headers"])).items():
response_headers.setdefault(k, [])
response_headers[k].append(v)
res['headers'] = response_headers
interactions.append({'request': req, 'response': res})
res["headers"] = response_headers
interactions.append({"request": req, "response": res})
return {
'requests': [
request.Request._from_dict(i['request']) for i in interactions
],
'responses': [i['response'] for i in interactions],
"requests": [request.Request._from_dict(i["request"]) for i in interactions],
"responses": [i["response"] for i in interactions],
}
@@ -105,7 +96,7 @@ def _list_of_tuples_to_dict(fs):
def _already_migrated(data):
try:
if data.get('version') == 1:
if data.get("version") == 1:
return True
except AttributeError:
return False
@@ -116,9 +107,7 @@ def migrate_yml(in_fp, out_fp):
if _already_migrated(data):
return False
for i in range(len(data)):
data[i]['request']['headers'] = _list_of_tuples_to_dict(
data[i]['request']['headers']
)
data[i]["request"]["headers"] = _list_of_tuples_to_dict(data[i]["request"]["headers"])
interactions = _migrate(data)
out_fp.write(serialize(interactions, yamlserializer))
return True
@@ -127,43 +116,42 @@ def migrate_yml(in_fp, out_fp):
def migrate(file_path, migration_fn):
# because we assume that original files can be reverted
# we will try to copy the content. (os.rename not needed)
with tempfile.TemporaryFile(mode='w+') as out_fp:
with open(file_path, 'r') as in_fp:
with tempfile.TemporaryFile(mode="w+") as out_fp:
with open(file_path, "r") as in_fp:
if not migration_fn(in_fp, out_fp):
return False
with open(file_path, 'w') as in_fp:
with open(file_path, "w") as in_fp:
out_fp.seek(0)
shutil.copyfileobj(out_fp, in_fp)
return True
def try_migrate(path):
if path.endswith('.json'):
if path.endswith(".json"):
return migrate(path, migrate_json)
elif path.endswith('.yaml') or path.endswith('.yml'):
elif path.endswith(".yaml") or path.endswith(".yml"):
return migrate(path, migrate_yml)
return False
def main():
if len(sys.argv) != 2:
raise SystemExit("Please provide path to cassettes directory or file. "
"Usage: python -m vcr.migration PATH")
raise SystemExit(
"Please provide path to cassettes directory or file. " "Usage: python -m vcr.migration PATH"
)
path = sys.argv[1]
if not os.path.isabs(path):
path = os.path.abspath(path)
files = [path]
if os.path.isdir(path):
files = (os.path.join(root, name)
for (root, dirs, files) in os.walk(path)
for name in files)
files = (os.path.join(root, name) for (root, dirs, files) in os.walk(path) for name in files)
for file_path in files:
migrated = try_migrate(file_path)
status = 'OK' if migrated else 'FAIL'
status = "OK" if migrated else "FAIL"
sys.stderr.write("[{}] {}\n".format(status, file_path))
sys.stderr.write("Done.\n")
if __name__ == '__main__':
if __name__ == "__main__":
main()

View File

@@ -1,4 +1,4 @@
'''Utilities for patching in cassettes'''
"""Utilities for patching in cassettes"""
import functools
import itertools
@@ -76,16 +76,14 @@ try:
except ImportError: # pragma: no cover
pass
else:
_SimpleAsyncHTTPClient_fetch_impl = \
tornado.simple_httpclient.SimpleAsyncHTTPClient.fetch_impl
_SimpleAsyncHTTPClient_fetch_impl = tornado.simple_httpclient.SimpleAsyncHTTPClient.fetch_impl
try:
import tornado.curl_httpclient
except ImportError: # pragma: no cover
pass
else:
_CurlAsyncHTTPClient_fetch_impl = \
tornado.curl_httpclient.CurlAsyncHTTPClient.fetch_impl
_CurlAsyncHTTPClient_fetch_impl = tornado.curl_httpclient.CurlAsyncHTTPClient.fetch_impl
try:
import aiohttp.client
@@ -96,13 +94,10 @@ else:
class CassettePatcherBuilder(object):
def _build_patchers_from_mock_triples_decorator(function):
@functools.wraps(function)
def wrapped(self, *args, **kwargs):
return self._build_patchers_from_mock_triples(
function(self, *args, **kwargs)
)
return self._build_patchers_from_mock_triples(function(self, *args, **kwargs))
return wrapped
@@ -112,11 +107,15 @@ class CassettePatcherBuilder(object):
def build(self):
return itertools.chain(
self._httplib(), self._requests(), self._boto3(), self._urllib3(),
self._httplib2(), self._boto(), self._tornado(), self._aiohttp(),
self._build_patchers_from_mock_triples(
self._cassette.custom_patches
),
self._httplib(),
self._requests(),
self._boto3(),
self._urllib3(),
self._httplib2(),
self._boto(),
self._tornado(),
self._aiohttp(),
self._build_patchers_from_mock_triples(self._cassette.custom_patches),
)
def _build_patchers_from_mock_triples(self, mock_triples):
@@ -129,9 +128,9 @@ class CassettePatcherBuilder(object):
if not hasattr(obj, patched_attribute):
return
return mock.patch.object(obj, patched_attribute,
self._recursively_apply_get_cassette_subclass(
replacement_class))
return mock.patch.object(
obj, patched_attribute, self._recursively_apply_get_cassette_subclass(replacement_class)
)
def _recursively_apply_get_cassette_subclass(self, replacement_dict_or_obj):
"""One of the subtleties of this class is that it does not directly
@@ -153,13 +152,11 @@ class CassettePatcherBuilder(object):
"""
if isinstance(replacement_dict_or_obj, dict):
for key, replacement_obj in replacement_dict_or_obj.items():
replacement_obj = self._recursively_apply_get_cassette_subclass(
replacement_obj)
replacement_obj = self._recursively_apply_get_cassette_subclass(replacement_obj)
replacement_dict_or_obj[key] = replacement_obj
return replacement_dict_or_obj
if hasattr(replacement_dict_or_obj, 'cassette'):
replacement_dict_or_obj = self._get_cassette_subclass(
replacement_dict_or_obj)
if hasattr(replacement_dict_or_obj, "cassette"):
replacement_dict_or_obj = self._get_cassette_subclass(replacement_dict_or_obj)
return replacement_dict_or_obj
def _get_cassette_subclass(self, klass):
@@ -174,13 +171,14 @@ class CassettePatcherBuilder(object):
bases = (base_class,)
if not issubclass(base_class, object): # Check for old style class
bases += (object,)
return type('{}{}'.format(base_class.__name__, self._cassette._path),
bases, dict(cassette=self._cassette))
return type(
"{}{}".format(base_class.__name__, self._cassette._path), bases, dict(cassette=self._cassette)
)
@_build_patchers_from_mock_triples_decorator
def _httplib(self):
yield httplib, 'HTTPConnection', VCRHTTPConnection
yield httplib, 'HTTPSConnection', VCRHTTPSConnection
yield httplib, "HTTPConnection", VCRHTTPConnection
yield httplib, "HTTPSConnection", VCRHTTPSConnection
def _requests(self):
try:
@@ -203,12 +201,14 @@ class CassettePatcherBuilder(object):
pass
else:
from .stubs import boto3_stubs
yield self._urllib3_patchers(cpool, boto3_stubs)
else:
from .stubs import boto3_stubs
log.debug("Patching boto3 cpool with %s", cpool)
yield cpool.AWSHTTPConnectionPool, 'ConnectionCls', boto3_stubs.VCRRequestsHTTPConnection
yield cpool.AWSHTTPSConnectionPool, 'ConnectionCls', boto3_stubs.VCRRequestsHTTPSConnection
yield cpool.AWSHTTPConnectionPool, "ConnectionCls", boto3_stubs.VCRRequestsHTTPConnection
yield cpool.AWSHTTPSConnectionPool, "ConnectionCls", boto3_stubs.VCRRequestsHTTPSConnection
def _patched_get_conn(self, connection_pool_class, connection_class_getter):
get_conn = connection_pool_class._get_conn
@@ -217,8 +217,8 @@ class CassettePatcherBuilder(object):
def patched_get_conn(pool, timeout=None):
connection = get_conn(pool, timeout)
connection_class = (
pool.ConnectionCls if hasattr(pool, 'ConnectionCls')
else connection_class_getter())
pool.ConnectionCls if hasattr(pool, "ConnectionCls") else connection_class_getter()
)
# We need to make sure that we are actually providing a
# patched version of the connection class. This might not
# always be the case because the pool keeps previously
@@ -248,6 +248,7 @@ class CassettePatcherBuilder(object):
except ImportError: # pragma: no cover
return ()
from .stubs import urllib3_stubs
return self._urllib3_patchers(cpool, urllib3_stubs)
@_build_patchers_from_mock_triples_decorator
@@ -260,10 +261,12 @@ class CassettePatcherBuilder(object):
from .stubs.httplib2_stubs import VCRHTTPConnectionWithTimeout
from .stubs.httplib2_stubs import VCRHTTPSConnectionWithTimeout
yield cpool, 'HTTPConnectionWithTimeout', VCRHTTPConnectionWithTimeout
yield cpool, 'HTTPSConnectionWithTimeout', VCRHTTPSConnectionWithTimeout
yield cpool, 'SCHEME_TO_CONNECTION', {'http': VCRHTTPConnectionWithTimeout,
'https': VCRHTTPSConnectionWithTimeout}
yield cpool, "HTTPConnectionWithTimeout", VCRHTTPConnectionWithTimeout
yield cpool, "HTTPSConnectionWithTimeout", VCRHTTPSConnectionWithTimeout
yield cpool, "SCHEME_TO_CONNECTION", {
"http": VCRHTTPConnectionWithTimeout,
"https": VCRHTTPSConnectionWithTimeout,
}
@_build_patchers_from_mock_triples_decorator
def _boto(self):
@@ -273,7 +276,8 @@ class CassettePatcherBuilder(object):
pass
else:
from .stubs.boto_stubs import VCRCertValidatingHTTPSConnection
yield cpool, 'CertValidatingHTTPSConnection', VCRCertValidatingHTTPSConnection
yield cpool, "CertValidatingHTTPSConnection", VCRCertValidatingHTTPSConnection
@_build_patchers_from_mock_triples_decorator
def _tornado(self):
@@ -284,10 +288,8 @@ class CassettePatcherBuilder(object):
else:
from .stubs.tornado_stubs import vcr_fetch_impl
new_fetch_impl = vcr_fetch_impl(
self._cassette, _SimpleAsyncHTTPClient_fetch_impl
)
yield simple.SimpleAsyncHTTPClient, 'fetch_impl', new_fetch_impl
new_fetch_impl = vcr_fetch_impl(self._cassette, _SimpleAsyncHTTPClient_fetch_impl)
yield simple.SimpleAsyncHTTPClient, "fetch_impl", new_fetch_impl
try:
import tornado.curl_httpclient as curl
except ImportError: # pragma: no cover
@@ -295,10 +297,8 @@ class CassettePatcherBuilder(object):
else:
from .stubs.tornado_stubs import vcr_fetch_impl
new_fetch_impl = vcr_fetch_impl(
self._cassette, _CurlAsyncHTTPClient_fetch_impl
)
yield curl.CurlAsyncHTTPClient, 'fetch_impl', new_fetch_impl
new_fetch_impl = vcr_fetch_impl(self._cassette, _CurlAsyncHTTPClient_fetch_impl)
yield curl.CurlAsyncHTTPClient, "fetch_impl", new_fetch_impl
@_build_patchers_from_mock_triples_decorator
def _aiohttp(self):
@@ -308,10 +308,9 @@ class CassettePatcherBuilder(object):
pass
else:
from .stubs.aiohttp_stubs import vcr_request
new_request = vcr_request(
self._cassette, _AiohttpClientSessionRequest
)
yield client.ClientSession, '_request', new_request
new_request = vcr_request(self._cassette, _AiohttpClientSessionRequest)
yield client.ClientSession, "_request", new_request
def _urllib3_patchers(self, cpool, stubs):
http_connection_remover = ConnectionRemover(
@@ -321,34 +320,45 @@ class CassettePatcherBuilder(object):
self._get_cassette_subclass(stubs.VCRRequestsHTTPSConnection)
)
mock_triples = (
(cpool, 'VerifiedHTTPSConnection', stubs.VCRRequestsHTTPSConnection),
(cpool, 'HTTPConnection', stubs.VCRRequestsHTTPConnection),
(cpool, 'HTTPSConnection', stubs.VCRRequestsHTTPSConnection),
(cpool, 'is_connection_dropped', mock.Mock(return_value=False)), # Needed on Windows only
(cpool.HTTPConnectionPool, 'ConnectionCls', stubs.VCRRequestsHTTPConnection),
(cpool.HTTPSConnectionPool, 'ConnectionCls', stubs.VCRRequestsHTTPSConnection),
(cpool, "VerifiedHTTPSConnection", stubs.VCRRequestsHTTPSConnection),
(cpool, "HTTPConnection", stubs.VCRRequestsHTTPConnection),
(cpool, "HTTPSConnection", stubs.VCRRequestsHTTPSConnection),
(cpool, "is_connection_dropped", mock.Mock(return_value=False)), # Needed on Windows only
(cpool.HTTPConnectionPool, "ConnectionCls", stubs.VCRRequestsHTTPConnection),
(cpool.HTTPSConnectionPool, "ConnectionCls", stubs.VCRRequestsHTTPSConnection),
)
# These handle making sure that sessions only use the
# connections of the appropriate type.
mock_triples += ((cpool.HTTPConnectionPool, '_get_conn',
self._patched_get_conn(cpool.HTTPConnectionPool,
lambda: cpool.HTTPConnection)),
(cpool.HTTPSConnectionPool, '_get_conn',
self._patched_get_conn(cpool.HTTPSConnectionPool,
lambda: cpool.HTTPSConnection)),
(cpool.HTTPConnectionPool, '_new_conn',
self._patched_new_conn(cpool.HTTPConnectionPool,
http_connection_remover)),
(cpool.HTTPSConnectionPool, '_new_conn',
self._patched_new_conn(cpool.HTTPSConnectionPool,
https_connection_remover)))
mock_triples += (
(
cpool.HTTPConnectionPool,
"_get_conn",
self._patched_get_conn(cpool.HTTPConnectionPool, lambda: cpool.HTTPConnection),
),
(
cpool.HTTPSConnectionPool,
"_get_conn",
self._patched_get_conn(cpool.HTTPSConnectionPool, lambda: cpool.HTTPSConnection),
),
(
cpool.HTTPConnectionPool,
"_new_conn",
self._patched_new_conn(cpool.HTTPConnectionPool, http_connection_remover),
),
(
cpool.HTTPSConnectionPool,
"_new_conn",
self._patched_new_conn(cpool.HTTPSConnectionPool, https_connection_remover),
),
)
return itertools.chain(self._build_patchers_from_mock_triples(mock_triples),
(http_connection_remover, https_connection_remover))
return itertools.chain(
self._build_patchers_from_mock_triples(mock_triples),
(http_connection_remover, https_connection_remover),
)
class ConnectionRemover(object):
def __init__(self, connection_class):
self._connection_class = connection_class
self._connection_pool_to_connections = {}
@@ -378,11 +388,12 @@ class ConnectionRemover(object):
def reset_patchers():
yield mock.patch.object(httplib, 'HTTPConnection', _HTTPConnection)
yield mock.patch.object(httplib, 'HTTPSConnection', _HTTPSConnection)
yield mock.patch.object(httplib, "HTTPConnection", _HTTPConnection)
yield mock.patch.object(httplib, "HTTPSConnection", _HTTPSConnection)
try:
import requests
if requests.__build__ < 0x021603:
# Avoid double unmock if requests 2.16.3
# First, this is pointless, requests.packages.urllib3 *IS* urllib3 (see packages.py)
@@ -400,29 +411,27 @@ def reset_patchers():
pass
else:
# unpatch requests v1.x
yield mock.patch.object(cpool, 'VerifiedHTTPSConnection', _VerifiedHTTPSConnection)
yield mock.patch.object(cpool, 'HTTPConnection', _cpoolHTTPConnection)
yield mock.patch.object(cpool, "VerifiedHTTPSConnection", _VerifiedHTTPSConnection)
yield mock.patch.object(cpool, "HTTPConnection", _cpoolHTTPConnection)
# unpatch requests v2.x
if hasattr(cpool.HTTPConnectionPool, 'ConnectionCls'):
yield mock.patch.object(cpool.HTTPConnectionPool, 'ConnectionCls',
_cpoolHTTPConnection)
yield mock.patch.object(cpool.HTTPSConnectionPool, 'ConnectionCls',
_cpoolHTTPSConnection)
if hasattr(cpool.HTTPConnectionPool, "ConnectionCls"):
yield mock.patch.object(cpool.HTTPConnectionPool, "ConnectionCls", _cpoolHTTPConnection)
yield mock.patch.object(cpool.HTTPSConnectionPool, "ConnectionCls", _cpoolHTTPSConnection)
if hasattr(cpool, 'HTTPSConnection'):
yield mock.patch.object(cpool, 'HTTPSConnection', _cpoolHTTPSConnection)
if hasattr(cpool, "HTTPSConnection"):
yield mock.patch.object(cpool, "HTTPSConnection", _cpoolHTTPSConnection)
try:
import urllib3.connectionpool as cpool
except ImportError: # pragma: no cover
pass
else:
yield mock.patch.object(cpool, 'VerifiedHTTPSConnection', _VerifiedHTTPSConnection)
yield mock.patch.object(cpool, 'HTTPConnection', _cpoolHTTPConnection)
yield mock.patch.object(cpool, 'HTTPSConnection', _cpoolHTTPSConnection)
if hasattr(cpool.HTTPConnectionPool, 'ConnectionCls'):
yield mock.patch.object(cpool.HTTPConnectionPool, 'ConnectionCls', _cpoolHTTPConnection)
yield mock.patch.object(cpool.HTTPSConnectionPool, 'ConnectionCls', _cpoolHTTPSConnection)
yield mock.patch.object(cpool, "VerifiedHTTPSConnection", _VerifiedHTTPSConnection)
yield mock.patch.object(cpool, "HTTPConnection", _cpoolHTTPConnection)
yield mock.patch.object(cpool, "HTTPSConnection", _cpoolHTTPSConnection)
if hasattr(cpool.HTTPConnectionPool, "ConnectionCls"):
yield mock.patch.object(cpool.HTTPConnectionPool, "ConnectionCls", _cpoolHTTPConnection)
yield mock.patch.object(cpool.HTTPSConnectionPool, "ConnectionCls", _cpoolHTTPSConnection)
try:
# unpatch botocore with awsrequest
@@ -435,64 +444,53 @@ def reset_patchers():
pass
else:
# unpatch requests v1.x
yield mock.patch.object(cpool, 'VerifiedHTTPSConnection', _Boto3VerifiedHTTPSConnection)
yield mock.patch.object(cpool, 'HTTPConnection', _cpoolBoto3HTTPConnection)
yield mock.patch.object(cpool, "VerifiedHTTPSConnection", _Boto3VerifiedHTTPSConnection)
yield mock.patch.object(cpool, "HTTPConnection", _cpoolBoto3HTTPConnection)
# unpatch requests v2.x
if hasattr(cpool.HTTPConnectionPool, 'ConnectionCls'):
yield mock.patch.object(cpool.HTTPConnectionPool, 'ConnectionCls',
_cpoolBoto3HTTPConnection)
yield mock.patch.object(cpool.HTTPSConnectionPool, 'ConnectionCls',
_cpoolBoto3HTTPSConnection)
if hasattr(cpool.HTTPConnectionPool, "ConnectionCls"):
yield mock.patch.object(cpool.HTTPConnectionPool, "ConnectionCls", _cpoolBoto3HTTPConnection)
yield mock.patch.object(
cpool.HTTPSConnectionPool, "ConnectionCls", _cpoolBoto3HTTPSConnection
)
if hasattr(cpool, 'HTTPSConnection'):
yield mock.patch.object(cpool, 'HTTPSConnection', _cpoolBoto3HTTPSConnection)
if hasattr(cpool, "HTTPSConnection"):
yield mock.patch.object(cpool, "HTTPSConnection", _cpoolBoto3HTTPSConnection)
else:
if hasattr(cpool.AWSHTTPConnectionPool, 'ConnectionCls'):
yield mock.patch.object(cpool.AWSHTTPConnectionPool, 'ConnectionCls',
_cpoolBoto3HTTPConnection)
yield mock.patch.object(cpool.AWSHTTPSConnectionPool, 'ConnectionCls',
_cpoolBoto3HTTPSConnection)
if hasattr(cpool.AWSHTTPConnectionPool, "ConnectionCls"):
yield mock.patch.object(cpool.AWSHTTPConnectionPool, "ConnectionCls", _cpoolBoto3HTTPConnection)
yield mock.patch.object(cpool.AWSHTTPSConnectionPool, "ConnectionCls", _cpoolBoto3HTTPSConnection)
if hasattr(cpool, 'AWSHTTPSConnection'):
yield mock.patch.object(cpool, 'AWSHTTPSConnection', _cpoolBoto3HTTPSConnection)
if hasattr(cpool, "AWSHTTPSConnection"):
yield mock.patch.object(cpool, "AWSHTTPSConnection", _cpoolBoto3HTTPSConnection)
try:
import httplib2 as cpool
except ImportError: # pragma: no cover
pass
else:
yield mock.patch.object(cpool, 'HTTPConnectionWithTimeout', _HTTPConnectionWithTimeout)
yield mock.patch.object(cpool, 'HTTPSConnectionWithTimeout', _HTTPSConnectionWithTimeout)
yield mock.patch.object(cpool, 'SCHEME_TO_CONNECTION', _SCHEME_TO_CONNECTION)
yield mock.patch.object(cpool, "HTTPConnectionWithTimeout", _HTTPConnectionWithTimeout)
yield mock.patch.object(cpool, "HTTPSConnectionWithTimeout", _HTTPSConnectionWithTimeout)
yield mock.patch.object(cpool, "SCHEME_TO_CONNECTION", _SCHEME_TO_CONNECTION)
try:
import boto.https_connection as cpool
except ImportError: # pragma: no cover
pass
else:
yield mock.patch.object(cpool, 'CertValidatingHTTPSConnection',
_CertValidatingHTTPSConnection)
yield mock.patch.object(cpool, "CertValidatingHTTPSConnection", _CertValidatingHTTPSConnection)
try:
import tornado.simple_httpclient as simple
except ImportError: # pragma: no cover
pass
else:
yield mock.patch.object(
simple.SimpleAsyncHTTPClient,
'fetch_impl',
_SimpleAsyncHTTPClient_fetch_impl,
)
yield mock.patch.object(simple.SimpleAsyncHTTPClient, "fetch_impl", _SimpleAsyncHTTPClient_fetch_impl)
try:
import tornado.curl_httpclient as curl
except ImportError: # pragma: no cover
pass
else:
yield mock.patch.object(
curl.CurlAsyncHTTPClient,
'fetch_impl',
_CurlAsyncHTTPClient_fetch_impl,
)
yield mock.patch.object(curl.CurlAsyncHTTPClient, "fetch_impl", _CurlAsyncHTTPClient_fetch_impl)
@contextlib.contextmanager

View File

@@ -5,14 +5,13 @@ from ..serialize import serialize, deserialize
class FilesystemPersister(object):
@classmethod
def load_cassette(cls, cassette_path, serializer):
try:
with open(cassette_path) as f:
cassette_content = f.read()
except IOError:
raise ValueError('Cassette not found.')
raise ValueError("Cassette not found.")
cassette = deserialize(cassette_content, serializer)
return cassette
@@ -22,5 +21,5 @@ class FilesystemPersister(object):
dirname, filename = os.path.split(cassette_path)
if dirname and not os.path.exists(dirname):
os.makedirs(dirname)
with open(cassette_path, 'w') as f:
with open(cassette_path, "w") as f:
f.write(data)

View File

@@ -15,7 +15,7 @@ class Request(object):
def __init__(self, method, uri, body, headers):
self.method = method
self.uri = uri
self._was_file = hasattr(body, 'read')
self._was_file = hasattr(body, "read")
if self._was_file:
self.body = body.read()
else:
@@ -40,13 +40,14 @@ class Request(object):
@body.setter
def body(self, value):
if isinstance(value, text_type):
value = value.encode('utf-8')
value = value.encode("utf-8")
self._body = value
def add_header(self, key, value):
warnings.warn("Request.add_header is deprecated. "
"Please assign to request.headers instead.",
DeprecationWarning)
warnings.warn(
"Request.add_header is deprecated. " "Please assign to request.headers instead.",
DeprecationWarning,
)
self.headers[key] = value
@property
@@ -63,7 +64,7 @@ class Request(object):
port = parse_uri.port
if port is None:
try:
port = {'https': 443, 'http': 80}[parse_uri.scheme]
port = {"https": 443, "http": 80}[parse_uri.scheme]
except KeyError:
pass
return port
@@ -95,10 +96,10 @@ class Request(object):
def _to_dict(self):
return {
'method': self.method,
'uri': self.uri,
'body': self.body,
'headers': {k: [v] for k, v in self.headers.items()},
"method": self.method,
"uri": self.uri,
"body": self.body,
"headers": {k: [v] for k, v in self.headers.items()},
}
@classmethod

View File

@@ -20,7 +20,7 @@ Deserializing: string (yaml converts from utf-8) -> bytestring
def _looks_like_an_old_cassette(data):
return isinstance(data, list) and len(data) and 'request' in data[0]
return isinstance(data, list) and len(data) and "request" in data[0]
def _warn_about_old_cassette_format():
@@ -41,23 +41,18 @@ def deserialize(cassette_string, serializer):
if _looks_like_an_old_cassette(data):
_warn_about_old_cassette_format()
requests = [Request._from_dict(r['request']) for r in data['interactions']]
responses = [
compat.convert_to_bytes(r['response']) for r in data['interactions']
]
requests = [Request._from_dict(r["request"]) for r in data["interactions"]]
responses = [compat.convert_to_bytes(r["response"]) for r in data["interactions"]]
return requests, responses
def serialize(cassette_dict, serializer):
interactions = ([{
'request': compat.convert_to_unicode(request._to_dict()),
'response': compat.convert_to_unicode(response),
} for request, response in zip(
cassette_dict['requests'],
cassette_dict['responses'],
)])
data = {
'version': CASSETTE_FORMAT_VERSION,
'interactions': interactions,
}
interactions = [
{
"request": compat.convert_to_unicode(request._to_dict()),
"response": compat.convert_to_unicode(response),
}
for request, response in zip(cassette_dict["requests"], cassette_dict["responses"])
]
data = {"version": CASSETTE_FORMAT_VERSION, "interactions": interactions}
return serializer.serialize(data)

View File

@@ -24,8 +24,8 @@ def convert_body_to_bytes(resp):
http://pyyaml.org/wiki/PyYAMLDocumentation#Python3support
"""
try:
if resp['body']['string'] is not None and not isinstance(resp['body']['string'], six.binary_type):
resp['body']['string'] = resp['body']['string'].encode('utf-8')
if resp["body"]["string"] is not None and not isinstance(resp["body"]["string"], six.binary_type):
resp["body"]["string"] = resp["body"]["string"].encode("utf-8")
except (KeyError, TypeError, UnicodeEncodeError):
# The thing we were converting either wasn't a dictionary or didn't
# have the keys we were expecting. Some of the tests just serialize
@@ -45,7 +45,7 @@ def _convert_string_to_unicode(string):
try:
if string is not None and not isinstance(string, six.text_type):
result = string.decode('utf-8')
result = string.decode("utf-8")
except (TypeError, UnicodeDecodeError, AttributeError):
# Sometimes the string actually is binary or StringIO object,
# so if you can't decode it, just give up.
@@ -63,17 +63,15 @@ def convert_body_to_unicode(resp):
# Some of the tests just serialize and deserialize a string.
return _convert_string_to_unicode(resp)
else:
body = resp.get('body')
body = resp.get("body")
if body is not None:
try:
body['string'] = _convert_string_to_unicode(
body['string']
)
body["string"] = _convert_string_to_unicode(body["string"])
except (KeyError, TypeError, AttributeError):
# The thing we were converting either wasn't a dictionary or
# didn't have the keys we were expecting.
# For example request object has no 'string' key.
resp['body'] = _convert_string_to_unicode(body)
resp["body"] = _convert_string_to_unicode(body)
return resp

View File

@@ -23,7 +23,7 @@ def serialize(cassette_dict):
b"Error serializing cassette to JSON",
original.start,
original.end,
original.args[-1] + error_message
original.args[-1] + error_message,
)
except TypeError: # py3
raise TypeError(error_message)

View File

@@ -1,12 +1,8 @@
'''Stubs for patching HTTP and HTTPS requests'''
"""Stubs for patching HTTP and HTTPS requests"""
import logging
import six
from six.moves.http_client import (
HTTPConnection,
HTTPSConnection,
HTTPResponse,
)
from six.moves.http_client import HTTPConnection, HTTPSConnection, HTTPResponse
from six import BytesIO
from vcr.request import Request
from vcr.errors import CannotOverwriteExistingCassetteException
@@ -45,8 +41,7 @@ def parse_headers(header_list):
header_string = b""
for key, values in header_list.items():
for v in values:
header_string += \
key.encode('utf-8') + b":" + v.encode('utf-8') + b"\r\n"
header_string += key.encode("utf-8") + b":" + v.encode("utf-8") + b"\r\n"
return compat.get_httpmessage(header_string)
@@ -62,27 +57,28 @@ class VCRHTTPResponse(HTTPResponse):
"""
Stub response class that gets returned instead of a HTTPResponse
"""
def __init__(self, recorded_response):
self.fp = None
self.recorded_response = recorded_response
self.reason = recorded_response['status']['message']
self.status = self.code = recorded_response['status']['code']
self.reason = recorded_response["status"]["message"]
self.status = self.code = recorded_response["status"]["code"]
self.version = None
self._content = BytesIO(self.recorded_response['body']['string'])
self._content = BytesIO(self.recorded_response["body"]["string"])
self._closed = False
headers = self.recorded_response['headers']
headers = self.recorded_response["headers"]
# Since we are loading a response that has already been serialized, our
# response is no longer chunked. That means we don't want any
# libraries trying to process a chunked response. By removing the
# transfer-encoding: chunked header, this should cause the downstream
# libraries to process this as a non-chunked response.
te_key = [h for h in headers.keys() if h.upper() == 'TRANSFER-ENCODING']
te_key = [h for h in headers.keys() if h.upper() == "TRANSFER-ENCODING"]
if te_key:
del headers[te_key[0]]
self.headers = self.msg = parse_headers(headers)
self.length = compat.get_header(self.msg, 'content-length') or None
self.length = compat.get_header(self.msg, "content-length") or None
@property
def closed(self):
@@ -129,17 +125,17 @@ class VCRHTTPResponse(HTTPResponse):
return self.closed
def info(self):
return parse_headers(self.recorded_response['headers'])
return parse_headers(self.recorded_response["headers"])
def getheaders(self):
message = parse_headers(self.recorded_response['headers'])
message = parse_headers(self.recorded_response["headers"])
return list(compat.get_header_items(message))
def getheader(self, header, default=None):
values = [v for (k, v) in self.getheaders() if k.lower() == header.lower()]
if values:
return ', '.join(values)
return ", ".join(values)
else:
return default
@@ -156,41 +152,27 @@ class VCRConnection(object):
Returns empty string for the default port and ':port' otherwise
"""
port = self.real_connection.port
default_port = {'https': 443, 'http': 80}[self._protocol]
return ':{}'.format(port) if port != default_port else ''
default_port = {"https": 443, "http": 80}[self._protocol]
return ":{}".format(port) if port != default_port else ""
def _uri(self, url):
"""Returns request absolute URI"""
if url and not url.startswith('/'):
if url and not url.startswith("/"):
# Then this must be a proxy request.
return url
uri = "{}://{}{}{}".format(
self._protocol,
self.real_connection.host,
self._port_postfix(),
url,
)
uri = "{}://{}{}{}".format(self._protocol, self.real_connection.host, self._port_postfix(), url)
log.debug("Absolute URI: %s", uri)
return uri
def _url(self, uri):
"""Returns request selector url from absolute URI"""
prefix = "{}://{}{}".format(
self._protocol,
self.real_connection.host,
self._port_postfix(),
)
return uri.replace(prefix, '', 1)
prefix = "{}://{}{}".format(self._protocol, self.real_connection.host, self._port_postfix())
return uri.replace(prefix, "", 1)
def request(self, method, url, body=None, headers=None, *args, **kwargs):
'''Persist the request metadata in self._vcr_request'''
self._vcr_request = Request(
method=method,
uri=self._uri(url),
body=body,
headers=headers or {}
)
log.debug('Got {}'.format(self._vcr_request))
"""Persist the request metadata in self._vcr_request"""
self._vcr_request = Request(method=method, uri=self._uri(url), body=body, headers=headers or {})
log.debug("Got {}".format(self._vcr_request))
# Note: The request may not actually be finished at this point, so
# I'm not sending the actual request until getresponse(). This
@@ -205,25 +187,19 @@ class VCRConnection(object):
to start building up a request. Usually followed by a bunch
of putheader() calls.
"""
self._vcr_request = Request(
method=method,
uri=self._uri(url),
body="",
headers={}
)
log.debug('Got {}'.format(self._vcr_request))
self._vcr_request = Request(method=method, uri=self._uri(url), body="", headers={})
log.debug("Got {}".format(self._vcr_request))
def putheader(self, header, *values):
self._vcr_request.headers[header] = values
def send(self, data):
'''
"""
This method is called after request(), to add additional data to the
body of the request. So if that happens, let's just append the data
onto the most recent request in the cassette.
'''
self._vcr_request.body = self._vcr_request.body + data \
if self._vcr_request.body else data
"""
self._vcr_request.body = self._vcr_request.body + data if self._vcr_request.body else data
def close(self):
# Note: the real connection will only close if it's open, so
@@ -240,37 +216,27 @@ class VCRConnection(object):
self._vcr_request.body = message_body
def getresponse(self, _=False, **kwargs):
'''Retrieve the response'''
"""Retrieve the response"""
# Check to see if the cassette has a response for this request. If so,
# then return it
if self.cassette.can_play_response_for(self._vcr_request):
log.info(
"Playing response for {} from cassette".format(
self._vcr_request
)
)
log.info("Playing response for {} from cassette".format(self._vcr_request))
response = self.cassette.play_response(self._vcr_request)
return VCRHTTPResponse(response)
else:
if self.cassette.write_protected and self.cassette.filter_request(
self._vcr_request
):
if self.cassette.write_protected and self.cassette.filter_request(self._vcr_request):
raise CannotOverwriteExistingCassetteException(
cassette=self.cassette,
failed_request=self._vcr_request
cassette=self.cassette, failed_request=self._vcr_request
)
# Otherwise, we should send the request, then get the response
# and return it.
log.info(
"{} not in cassette, sending to real server".format(
self._vcr_request
)
)
log.info("{} not in cassette, sending to real server".format(self._vcr_request))
# This is imported here to avoid circular import.
# TODO(@IvanMalison): Refactor to allow normal import.
from vcr.patch import force_reset
with force_reset():
self.real_connection.request(
method=self._vcr_request.method,
@@ -284,12 +250,9 @@ class VCRConnection(object):
# put the response into the cassette
response = {
'status': {
'code': response.status,
'message': response.reason
},
'headers': serialize_headers(response),
'body': {'string': response.read()},
"status": {"code": response.status, "message": response.reason},
"headers": serialize_headers(response),
"body": {"string": response.read()},
}
self.cassette.append(self._vcr_request, response)
return VCRHTTPResponse(response)
@@ -305,8 +268,7 @@ class VCRConnection(object):
and are not write-protected.
"""
if hasattr(self, '_vcr_request') and \
self.cassette.can_play_response_for(self._vcr_request):
if hasattr(self, "_vcr_request") and self.cassette.can_play_response_for(self._vcr_request):
# We already have a response we are going to play, don't
# actually connect
return
@@ -316,6 +278,7 @@ class VCRConnection(object):
return
from vcr.patch import force_reset
with force_reset():
return self.real_connection.connect(*args, **kwargs)
@@ -334,12 +297,13 @@ class VCRConnection(object):
def __init__(self, *args, **kwargs):
if six.PY3:
kwargs.pop('strict', None) # apparently this is gone in py3
kwargs.pop("strict", None) # apparently this is gone in py3
# need to temporarily reset here because the real connection
# inherits from the thing that we are mocking out. Take out
# the reset if you want to see what I mean :)
from vcr.patch import force_reset
with force_reset():
self.real_connection = self._baseclass(*args, **kwargs)
@@ -371,7 +335,7 @@ class VCRConnection(object):
Send requests for weird attributes up to the real connection
(counterpart to __setattr above)
"""
if self.__dict__.get('real_connection'):
if self.__dict__.get("real_connection"):
# check in case real_connection has not been set yet, such as when
# we're setting the real_connection itself for the first time
return getattr(self.real_connection, name)
@@ -385,13 +349,15 @@ for k, v in HTTPConnection.__dict__.items():
class VCRHTTPConnection(VCRConnection):
'''A Mocked class for HTTP requests'''
"""A Mocked class for HTTP requests"""
_baseclass = HTTPConnection
_protocol = 'http'
_protocol = "http"
class VCRHTTPSConnection(VCRConnection):
'''A Mocked class for HTTPS requests'''
"""A Mocked class for HTTPS requests"""
_baseclass = HTTPSConnection
_protocol = 'https'
_protocol = "https"
is_verified = True

View File

@@ -1,4 +1,4 @@
'''Stubs for aiohttp HTTP clients'''
"""Stubs for aiohttp HTTP clients"""
from __future__ import absolute_import
import asyncio
@@ -33,14 +33,14 @@ class MockClientResponse(ClientResponse):
session=None,
)
async def json(self, *, encoding='utf-8', loads=json.loads, **kwargs): # NOQA: E999
async def json(self, *, encoding="utf-8", loads=json.loads, **kwargs): # NOQA: E999
stripped = self._body.strip()
if not stripped:
return None
return loads(stripped.decode(encoding))
async def text(self, encoding='utf-8', errors='strict'):
async def text(self, encoding="utf-8", errors="strict"):
return self._body.decode(encoding, errors=errors)
async def read(self):
@@ -58,11 +58,11 @@ class MockClientResponse(ClientResponse):
def build_response(vcr_request, vcr_response, history):
response = MockClientResponse(vcr_request.method, URL(vcr_response.get('url')))
response.status = vcr_response['status']['code']
response._body = vcr_response['body'].get('string', b'')
response.reason = vcr_response['status']['message']
response._headers = CIMultiDictProxy(CIMultiDict(vcr_response['headers']))
response = MockClientResponse(vcr_request.method, URL(vcr_response.get("url")))
response.status = vcr_response["status"]["code"]
response._body = vcr_response["body"].get("string", b"")
response.reason = vcr_response["status"]["message"]
response._headers = CIMultiDictProxy(CIMultiDict(vcr_response["headers"]))
response._history = tuple(history)
response.close()
@@ -83,17 +83,14 @@ def play_responses(cassette, vcr_request):
async def record_response(cassette, vcr_request, response, past=False):
body = {} if past else {'string': (await response.read())}
body = {} if past else {"string": (await response.read())}
headers = {str(key): value for key, value in response.headers.items()}
vcr_response = {
'status': {
'code': response.status,
'message': response.reason,
},
'headers': headers,
'body': body, # NOQA: E999
'url': str(response.url),
"status": {"code": response.status, "message": response.reason},
"headers": headers,
"body": body, # NOQA: E999
"url": str(response.url),
}
cassette.append(vcr_request, vcr_response)
@@ -108,14 +105,14 @@ async def record_responses(cassette, vcr_request, response):
def vcr_request(cassette, real_request):
@functools.wraps(real_request)
async def new_request(self, method, url, **kwargs):
headers = kwargs.get('headers')
auth = kwargs.get('auth')
headers = kwargs.get("headers")
auth = kwargs.get("auth")
headers = self._prepare_headers(headers)
data = kwargs.get('data', kwargs.get('json'))
params = kwargs.get('params')
data = kwargs.get("data", kwargs.get("json"))
params = kwargs.get("params")
if auth is not None:
headers['AUTHORIZATION'] = auth.encode()
headers["AUTHORIZATION"] = auth.encode()
request_url = URL(url)
if params:
@@ -131,14 +128,16 @@ def vcr_request(cassette, real_request):
if cassette.write_protected and cassette.filter_request(vcr_request):
response = MockClientResponse(method, URL(url))
response.status = 599
msg = ("No match for the request {!r} was found. Can't overwrite "
"existing cassette {!r} in your current record mode {!r}.")
msg = (
"No match for the request {!r} was found. Can't overwrite "
"existing cassette {!r} in your current record mode {!r}."
)
msg = msg.format(vcr_request, cassette._path, cassette.record_mode)
response._body = msg.encode()
response.close()
return response
log.info('%s not in cassette, sending to real server', vcr_request)
log.info("%s not in cassette, sending to real server", vcr_request)
response = await real_request(self, method, url, **kwargs) # NOQA: E999
await record_responses(cassette, vcr_request, response)

View File

@@ -27,17 +27,18 @@ class VCRRequestsHTTPSConnection(VCRHTTPSConnection, VerifiedHTTPSConnection):
def __init__(self, *args, **kwargs):
if six.PY3:
kwargs.pop('strict', None) # apparently this is gone in py3
kwargs.pop("strict", None) # apparently this is gone in py3
# need to temporarily reset here because the real connection
# inherits from the thing that we are mocking out. Take out
# the reset if you want to see what I mean :)
from vcr.patch import force_reset
with force_reset():
self.real_connection = self._baseclass(*args, **kwargs)
# Make sure to set those attributes as it seems `AWSHTTPConnection` does not
# set them, making the connection to fail !
self.real_connection.assert_hostname = kwargs.get("assert_hostname", False)
self.real_connection.cert_reqs = kwargs.get("cert_reqs", 'CERT_NONE')
self.real_connection.cert_reqs = kwargs.get("cert_reqs", "CERT_NONE")
self._sock = None

View File

@@ -1,4 +1,4 @@
'''Stubs for boto'''
"""Stubs for boto"""
from boto.https_connection import CertValidatingHTTPSConnection
from ..stubs import VCRHTTPSConnection

View File

@@ -1,6 +1,7 @@
import six
from six import BytesIO
from six.moves.http_client import HTTPMessage
try:
import http.client
except ImportError:

View File

@@ -1,62 +1,60 @@
'''Stubs for httplib2'''
"""Stubs for httplib2"""
from httplib2 import HTTPConnectionWithTimeout, HTTPSConnectionWithTimeout
from ..stubs import VCRHTTPConnection, VCRHTTPSConnection
class VCRHTTPConnectionWithTimeout(VCRHTTPConnection,
HTTPConnectionWithTimeout):
class VCRHTTPConnectionWithTimeout(VCRHTTPConnection, HTTPConnectionWithTimeout):
_baseclass = HTTPConnectionWithTimeout
def __init__(self, *args, **kwargs):
'''I overrode the init because I need to clean kwargs before calling
HTTPConnection.__init__.'''
"""I overrode the init because I need to clean kwargs before calling
HTTPConnection.__init__."""
# Delete the keyword arguments that HTTPConnection would not recognize
safe_keys = {'host', 'port', 'strict', 'timeout', 'source_address'}
safe_keys = {"host", "port", "strict", "timeout", "source_address"}
unknown_keys = set(kwargs.keys()) - safe_keys
safe_kwargs = kwargs.copy()
for kw in unknown_keys:
del safe_kwargs[kw]
self.proxy_info = kwargs.pop('proxy_info', None)
self.proxy_info = kwargs.pop("proxy_info", None)
VCRHTTPConnection.__init__(self, *args, **safe_kwargs)
self.sock = self.real_connection.sock
class VCRHTTPSConnectionWithTimeout(VCRHTTPSConnection,
HTTPSConnectionWithTimeout):
class VCRHTTPSConnectionWithTimeout(VCRHTTPSConnection, HTTPSConnectionWithTimeout):
_baseclass = HTTPSConnectionWithTimeout
def __init__(self, *args, **kwargs):
# Delete the keyword arguments that HTTPSConnection would not recognize
safe_keys = {
'host',
'port',
'key_file',
'cert_file',
'strict',
'timeout',
'source_address',
'ca_certs',
'disable_ssl_certificate_validation',
"host",
"port",
"key_file",
"cert_file",
"strict",
"timeout",
"source_address",
"ca_certs",
"disable_ssl_certificate_validation",
}
unknown_keys = set(kwargs.keys()) - safe_keys
safe_kwargs = kwargs.copy()
for kw in unknown_keys:
del safe_kwargs[kw]
self.proxy_info = kwargs.pop('proxy_info', None)
if 'ca_certs' not in kwargs or kwargs['ca_certs'] is None:
self.proxy_info = kwargs.pop("proxy_info", None)
if "ca_certs" not in kwargs or kwargs["ca_certs"] is None:
try:
import httplib2
self.ca_certs = httplib2.CA_CERTS
except ImportError:
self.ca_certs = None
else:
self.ca_certs = kwargs['ca_certs']
self.ca_certs = kwargs["ca_certs"]
self.disable_ssl_certificate_validation = kwargs.pop(
'disable_ssl_certificate_validation', None)
self.disable_ssl_certificate_validation = kwargs.pop("disable_ssl_certificate_validation", None)
VCRHTTPSConnection.__init__(self, *args, **safe_kwargs)
self.sock = self.real_connection.sock

View File

@@ -1,4 +1,4 @@
'''Stubs for requests'''
"""Stubs for requests"""
try:
from urllib3.connectionpool import HTTPConnection, VerifiedHTTPSConnection

View File

@@ -1,4 +1,4 @@
'''Stubs for tornado HTTP clients'''
"""Stubs for tornado HTTP clients"""
from __future__ import absolute_import
import functools
@@ -12,20 +12,19 @@ from vcr.request import Request
def vcr_fetch_impl(cassette, real_fetch_impl):
@functools.wraps(real_fetch_impl)
def new_fetch_impl(self, request, callback):
headers = request.headers.copy()
if request.user_agent:
headers.setdefault('User-Agent', request.user_agent)
headers.setdefault("User-Agent", request.user_agent)
# TODO body_producer, header_callback, and streaming_callback are not
# yet supported.
unsupported_call = (
getattr(request, 'body_producer', None) is not None or
request.header_callback is not None or
request.streaming_callback is not None
getattr(request, "body_producer", None) is not None
or request.header_callback is not None
or request.streaming_callback is not None
)
if unsupported_call:
response = HTTPResponse(
@@ -40,18 +39,13 @@ def vcr_fetch_impl(cassette, real_fetch_impl):
)
return callback(response)
vcr_request = Request(
request.method,
request.url,
request.body,
headers,
)
vcr_request = Request(request.method, request.url, request.body, headers)
if cassette.can_play_response_for(vcr_request):
vcr_response = cassette.play_response(vcr_request)
headers = httputil.HTTPHeaders()
recorded_headers = vcr_response['headers']
recorded_headers = vcr_response["headers"]
if isinstance(recorded_headers, dict):
recorded_headers = recorded_headers.items()
for k, vs in recorded_headers:
@@ -59,43 +53,34 @@ def vcr_fetch_impl(cassette, real_fetch_impl):
headers.add(k, v)
response = HTTPResponse(
request,
code=vcr_response['status']['code'],
reason=vcr_response['status']['message'],
code=vcr_response["status"]["code"],
reason=vcr_response["status"]["message"],
headers=headers,
buffer=BytesIO(vcr_response['body']['string']),
effective_url=vcr_response.get('url'),
buffer=BytesIO(vcr_response["body"]["string"]),
effective_url=vcr_response.get("url"),
request_time=self.io_loop.time() - request.start_time,
)
return callback(response)
else:
if cassette.write_protected and cassette.filter_request(
vcr_request
):
if cassette.write_protected and cassette.filter_request(vcr_request):
response = HTTPResponse(
request,
599,
error=CannotOverwriteExistingCassetteException(
cassette=cassette,
failed_request=vcr_request
cassette=cassette, failed_request=vcr_request
),
request_time=self.io_loop.time() - request.start_time,
)
return callback(response)
def new_callback(response):
headers = [
(k, response.headers.get_list(k))
for k in response.headers.keys()
]
headers = [(k, response.headers.get_list(k)) for k in response.headers.keys()]
vcr_response = {
'status': {
'code': response.code,
'message': response.reason,
},
'headers': headers,
'body': {'string': response.body},
'url': response.effective_url,
"status": {"code": response.code, "message": response.reason},
"headers": headers,
"body": {"string": response.body},
"url": response.effective_url,
}
cassette.append(vcr_request, vcr_response)
return callback(response)

View File

@@ -1,4 +1,4 @@
'''Stubs for urllib3'''
"""Stubs for urllib3"""
from urllib3.connectionpool import HTTPConnection, VerifiedHTTPSConnection
from ..stubs import VCRHTTPConnection, VCRHTTPSConnection

View File

@@ -29,6 +29,7 @@ class CaseInsensitiveDict(MutableMapping):
operations are given keys that have equal ``.lower()``s, the
behavior is undefined.
"""
def __init__(self, data=None, **kwargs):
self._store = dict()
if data is None:
@@ -54,11 +55,7 @@ class CaseInsensitiveDict(MutableMapping):
def lower_items(self):
"""Like iteritems(), but with all lowercase keys."""
return (
(lowerkey, keyval[1])
for (lowerkey, keyval)
in self._store.items()
)
return ((lowerkey, keyval[1]) for (lowerkey, keyval) in self._store.items())
def __eq__(self, other):
if isinstance(other, Mapping):
@@ -92,37 +89,30 @@ def compose(*functions):
if function:
res = function(res)
return res
return composed
def read_body(request):
if hasattr(request.body, 'read'):
if hasattr(request.body, "read"):
return request.body.read()
return request.body
def auto_decorate(
decorator,
predicate=lambda name, value: isinstance(value, types.FunctionType)
):
def auto_decorate(decorator, predicate=lambda name, value: isinstance(value, types.FunctionType)):
def maybe_decorate(attribute, value):
if predicate(attribute, value):
value = decorator(value)
return value
class DecorateAll(type):
def __setattr__(cls, attribute, value):
return super(DecorateAll, cls).__setattr__(
attribute, maybe_decorate(attribute, value)
)
return super(DecorateAll, cls).__setattr__(attribute, maybe_decorate(attribute, value))
def __new__(cls, name, bases, attributes_dict):
new_attributes_dict = {
attribute: maybe_decorate(attribute, value)
for attribute, value in attributes_dict.items()
attribute: maybe_decorate(attribute, value) for attribute, value in attributes_dict.items()
}
return super(DecorateAll, cls).__new__(
cls, name, bases, new_attributes_dict
)
return super(DecorateAll, cls).__new__(cls, name, bases, new_attributes_dict)
return DecorateAll