1
0
mirror of https://github.com/kevin1024/vcrpy.git synced 2025-12-09 01:03:24 +00:00

Merge branch 'master' into unicode-match-on-body

This commit is contained in:
Martin Valgur
2019-07-01 22:09:57 +03:00
committed by GitHub
36 changed files with 804 additions and 201 deletions

View File

@@ -1,7 +1,3 @@
import asyncio
@asyncio.coroutine
def handle_coroutine(vcr, fn):
async def handle_coroutine(vcr, fn): # noqa: E999
with vcr as cassette:
return (yield from fn(cassette)) # noqa: E999
return (await fn(cassette)) # noqa: E999

View File

@@ -8,7 +8,7 @@ import wrapt
from .compat import contextlib
from .errors import UnhandledHTTPRequestError
from .matchers import requests_match, uri, method
from .matchers import requests_match, uri, method, get_matchers_results
from .patch import CassettePatcherBuilder
from .serializers import yamlserializer
from .persisters.filesystem import FilesystemPersister
@@ -16,11 +16,13 @@ from .util import partition_dict
try:
from asyncio import iscoroutinefunction
from ._handle_coroutine import handle_coroutine
except ImportError:
def iscoroutinefunction(*args, **kwargs):
return False
if sys.version_info[:2] >= (3, 5):
from ._handle_coroutine import handle_coroutine
else:
def handle_coroutine(*args, **kwags):
raise NotImplementedError('Not implemented on Python 2')
@@ -136,7 +138,10 @@ class CassetteContextDecorator(object):
except Exception:
to_yield = coroutine.throw(*sys.exc_info())
else:
to_yield = coroutine.send(to_send)
try:
to_yield = coroutine.send(to_send)
except StopIteration:
break
def _handle_function(self, fn):
with self as cassette:
@@ -282,6 +287,41 @@ class Cassette(object):
% (self._path, request)
)
def rewind(self):
self.play_counts = collections.Counter()
def find_requests_with_most_matches(self, request):
"""
Get the most similar request(s) stored in the cassette
of a given request as a list of tuples like this:
- the request object
- the successful matchers as string
- the failed matchers and the related assertion message with the difference details as strings tuple
This is useful when a request failed to be found,
we can get the similar request(s) in order to know what have changed in the request parts.
"""
best_matches = []
request = self._before_record_request(request)
for index, (stored_request, response) in enumerate(self.data):
successes, fails = get_matchers_results(request, stored_request, self._match_on)
best_matches.append((len(successes), stored_request, successes, fails))
best_matches.sort(key=lambda t: t[0], reverse=True)
# Get the first best matches (multiple if equal matches)
final_best_matches = []
previous_nb_success = best_matches[0][0]
for best_match in best_matches:
nb_success = best_match[0]
# Do not keep matches that have 0 successes,
# it means that the request is totally different from
# the ones stored in the cassette
if nb_success < 1 or previous_nb_success != nb_success:
break
previous_nb_success = nb_success
final_best_matches.append(best_match[1:])
return final_best_matches
def _as_dict(self):
return {"requests": self.requests, "responses": self.responses}

View File

@@ -1,5 +1,8 @@
import copy
import collections
try:
from collections import abc as collections_abc # only works on python 3.3+
except ImportError:
import collections as collections_abc
import functools
import inspect
import os
@@ -175,7 +178,7 @@ class VCR(object):
if decode_compressed_response:
filter_functions.append(filters.decode_response)
if before_record_response:
if not isinstance(before_record_response, collections.Iterable):
if not isinstance(before_record_response, collections_abc.Iterable):
before_record_response = (before_record_response,)
filter_functions.extend(before_record_response)
@@ -241,7 +244,7 @@ class VCR(object):
filter_functions.append(self._build_ignore_hosts(hosts_to_ignore))
if before_record_request:
if not isinstance(before_record_request, collections.Iterable):
if not isinstance(before_record_request, collections_abc.Iterable):
before_record_request = (before_record_request,)
filter_functions.extend(before_record_request)

View File

@@ -1,5 +1,32 @@
class CannotOverwriteExistingCassetteException(Exception):
pass
def __init__(self, *args, **kwargs):
self.cassette = kwargs["cassette"]
self.failed_request = kwargs["failed_request"]
message = self._get_message(kwargs["cassette"], kwargs["failed_request"])
super(CannotOverwriteExistingCassetteException, self).__init__(message)
def _get_message(self, cassette, failed_request):
"""Get the final message related to the exception"""
# Get the similar requests in the cassette that
# have match the most with the request.
best_matches = cassette.find_requests_with_most_matches(failed_request)
# Build a comprehensible message to put in the exception.
best_matches_msg = ""
for best_match in best_matches:
request, _, failed_matchers_assertion_msgs = best_match
best_matches_msg += "Similar request found : (%r).\n" % request
for failed_matcher, assertion_msg in failed_matchers_assertion_msgs:
best_matches_msg += "Matcher failed : %s\n" "%s\n" % (
failed_matcher,
assertion_msg,
)
return (
"Can't overwrite existing cassette (%r) in "
"your current record mode (%r).\n"
"No match for the request (%r) was found.\n"
"%s"
% (cassette._path, cassette.record_mode, failed_request, best_matches_msg)
)
class UnhandledHTTPRequestError(KeyError):

View File

@@ -8,35 +8,47 @@ log = logging.getLogger(__name__)
def method(r1, r2):
return r1.method == r2.method
assert r1.method == r2.method, "{} != {}".format(r1.method, r2.method)
def uri(r1, r2):
return r1.uri == r2.uri
assert r1.uri == r2.uri, "{} != {}".format(r1.uri, r2.uri)
def host(r1, r2):
return r1.host == r2.host
assert r1.host == r2.host, "{} != {}".format(r1.host, r2.host)
def scheme(r1, r2):
return r1.scheme == r2.scheme
assert r1.scheme == r2.scheme, "{} != {}".format(r1.scheme, r2.scheme)
def port(r1, r2):
return r1.port == r2.port
assert r1.port == r2.port, "{} != {}".format(r1.port, r2.port)
def path(r1, r2):
return r1.path == r2.path
assert r1.path == r2.path, "{} != {}".format(r1.path, r2.path)
def query(r1, r2):
return r1.query == r2.query
assert r1.query == r2.query, "{} != {}".format(r1.query, r2.query)
def raw_body(r1, r2):
return read_body(r1) == read_body(r2)
assert read_body(r1) == read_body(r2)
def body(r1, r2):
transformer = _get_transformer(r1)
r2_transformer = _get_transformer(r2)
if transformer != r2_transformer:
transformer = _identity
assert transformer(read_body(r1)) == transformer(read_body(r2))
def headers(r1, r2):
assert r1.headers == r2.headers, "{} != {}".format(r1.headers, r2.headers)
def _header_checker(value, header='Content-Type'):
@@ -77,28 +89,67 @@ def _get_transformer(request):
return _identity
def body(r1, r2):
transformer = _get_transformer(r1)
r2_transformer = _get_transformer(r2)
if transformer != r2_transformer:
transformer = _identity
return transformer(read_body(r1)) == transformer(read_body(r2))
def headers(r1, r2):
return r1.headers == r2.headers
def _log_matches(r1, r2, matches):
differences = [m for m in matches if not m[0]]
if differences:
log.debug(
"Requests {} and {} differ according to "
"the following matchers: {}".format(r1, r2, differences)
)
def requests_match(r1, r2, matchers):
matches = [(m(r1, r2), m) for m in matchers]
_log_matches(r1, r2, matches)
return all(m[0] for m in matches)
successes, failures = get_matchers_results(r1, r2, matchers)
if failures:
log.debug(
"Requests {} and {} differ.\n"
"Failure details:\n"
"{}".format(r1, r2, failures)
)
return len(failures) == 0
def _evaluate_matcher(matcher_function, *args):
"""
Evaluate the result of a given matcher as a boolean with an assertion error message if any.
It handles two types of matcher :
- a matcher returning a boolean value.
- a matcher that only makes an assert, returning None or raises an assertion error.
"""
assertion_message = None
try:
match = matcher_function(*args)
match = True if match is None else match
except AssertionError as e:
match = False
assertion_message = str(e)
return match, assertion_message
def get_matchers_results(r1, r2, matchers):
"""
Get the comparison results of two requests as two list.
The first returned list represents the matchers names that passed.
The second list is the failed matchers as a string with failed assertion details if any.
"""
matches_success, matches_fails = [], []
for m in matchers:
matcher_name = m.__name__
match, assertion_message = _evaluate_matcher(m, r1, r2)
if match:
matches_success.append(matcher_name)
else:
assertion_message = get_assertion_message(assertion_message)
matches_fails.append((matcher_name, assertion_message))
return matches_success, matches_fails
def get_assertion_message(assertion_details, **format_options):
"""
Get a detailed message about the failing matcher.
"""
msg = ""
if assertion_details:
separator = format_options.get("separator", "-")
title = format_options.get("title", " DETAILS ")
nb_separator = format_options.get("nb_separator", 40)
first_title_line = (
separator * ((nb_separator - len(title)) // 2)
+ title
+ separator * ((nb_separator - len(title)) // 2)
)
msg += "{}\n{}\n{}\n".format(
first_title_line, str(assertion_details), separator * nb_separator
)
return msg

View File

@@ -68,7 +68,7 @@ def _migrate(data):
for item in data:
req = item['request']
res = item['response']
uri = dict((k, req.pop(k)) for k in PARTS)
uri = {k: req.pop(k) for k in PARTS}
req['uri'] = build_uri(**uri)
# convert headers to dict of lists
headers = req['headers']
@@ -100,7 +100,7 @@ def migrate_json(in_fp, out_fp):
def _list_of_tuples_to_dict(fs):
return dict((k, v) for k, v in fs[0])
return {k: v for k, v in fs[0]}
def _already_migrated(data):
@@ -159,9 +159,9 @@ def main():
for (root, dirs, files) in os.walk(path)
for name in files)
for file_path in files:
migrated = try_migrate(file_path)
status = 'OK' if migrated else 'FAIL'
sys.stderr.write("[{}] {}\n".format(status, file_path))
migrated = try_migrate(file_path)
status = 'OK' if migrated else 'FAIL'
sys.stderr.write("[{}] {}\n".format(status, file_path))
sys.stderr.write("Done.\n")

View File

@@ -58,7 +58,10 @@ class Request(object):
parse_uri = urlparse(self.uri)
port = parse_uri.port
if port is None:
port = {'https': 443, 'http': 80}[parse_uri.scheme]
try:
port = {'https': 443, 'http': 80}[parse_uri.scheme]
except KeyError:
pass
return port
@property
@@ -91,7 +94,7 @@ class Request(object):
'method': self.method,
'uri': self.uri,
'body': self.body,
'headers': dict(((k, [v]) for k, v in self.headers.items())),
'headers': {k: [v] for k, v in self.headers.items()},
}
@classmethod
@@ -112,7 +115,7 @@ class HeadersDict(CaseInsensitiveDict):
In addition, some servers sometimes send the same header more than once,
and httplib *can* deal with this situation.
Futhermore, I wanted to keep the request and response cassette format as
Furthermore, I wanted to keep the request and response cassette format as
similar as possible.
For this reason, in cassettes I keep a dict with lists as keys, but once

View File

@@ -25,5 +25,5 @@ def serialize(cassette_dict):
original.end,
original.args[-1] + error_message
)
except TypeError as original: # py3
except TypeError: # py3
raise TypeError(error_message)

View File

@@ -18,7 +18,7 @@ log = logging.getLogger(__name__)
class VCRFakeSocket(object):
"""
A socket that doesn't do anything!
Used when playing back casssettes, when there
Used when playing back cassettes, when there
is no actual open socket.
"""
@@ -60,9 +60,10 @@ def serialize_headers(response):
class VCRHTTPResponse(HTTPResponse):
"""
Stub reponse class that gets returned instead of a HTTPResponse
Stub response class that gets returned instead of a HTTPResponse
"""
def __init__(self, recorded_response):
self.fp = None
self.recorded_response = recorded_response
self.reason = recorded_response['status']['message']
self.status = self.code = recorded_response['status']['code']
@@ -93,9 +94,30 @@ class VCRHTTPResponse(HTTPResponse):
def read(self, *args, **kwargs):
return self._content.read(*args, **kwargs)
def readall(self):
return self._content.readall()
def readinto(self, *args, **kwargs):
return self._content.readinto(*args, **kwargs)
def readline(self, *args, **kwargs):
return self._content.readline(*args, **kwargs)
def readlines(self, *args, **kwargs):
return self._content.readlines(*args, **kwargs)
def seekable(self):
return self._content.seekable()
def tell(self):
return self._content.tell()
def isatty(self):
return self._content.isatty()
def seek(self, *args, **kwargs):
return self._content.seek(*args, **kwargs)
def close(self):
self._closed = True
return True
@@ -121,6 +143,9 @@ class VCRHTTPResponse(HTTPResponse):
else:
return default
def readable(self):
return self._content.readable()
class VCRConnection(object):
# A reference to the cassette that's currently being patched in
@@ -136,6 +161,9 @@ class VCRConnection(object):
def _uri(self, url):
"""Returns request absolute URI"""
if url and not url.startswith('/'):
# Then this must be a proxy request.
return url
uri = "{}://{}{}{}".format(
self._protocol,
self.real_connection.host,
@@ -168,6 +196,8 @@ class VCRConnection(object):
# allows me to compare the entire length of the response to see if it
# exists in the cassette.
self._sock = VCRFakeSocket()
def putrequest(self, method, url, *args, **kwargs):
"""
httplib gives you more than one way to do it. This is a way
@@ -225,11 +255,8 @@ class VCRConnection(object):
self._vcr_request
):
raise CannotOverwriteExistingCassetteException(
"No match for the request (%r) was found. "
"Can't overwrite existing cassette (%r) in "
"your current record mode (%r)."
% (self._vcr_request, self.cassette._path,
self.cassette.record_mode)
cassette=self.cassette,
failed_request=self._vcr_request
)
# Otherwise, we should send the request, then get the response
@@ -291,11 +318,13 @@ class VCRConnection(object):
with force_reset():
return self.real_connection.connect(*args, **kwargs)
self._sock = VCRFakeSocket()
@property
def sock(self):
if self.real_connection.sock:
return self.real_connection.sock
return VCRFakeSocket()
return self._sock
@sock.setter
def sock(self, value):
@@ -313,6 +342,8 @@ class VCRConnection(object):
with force_reset():
self.real_connection = self._baseclass(*args, **kwargs)
self._sock = None
def __setattr__(self, name, value):
"""
We need to define this because any attributes that are set on the

View File

@@ -5,12 +5,16 @@ import asyncio
import functools
import json
from aiohttp import ClientResponse
from aiohttp import ClientResponse, streams
from yarl import URL
from vcr.request import Request
class MockStream(asyncio.StreamReader, streams.AsyncStreamReaderMixin):
pass
class MockClientResponse(ClientResponse):
def __init__(self, method, url):
super().__init__(
@@ -25,28 +29,33 @@ class MockClientResponse(ClientResponse):
session=None,
)
# TODO: get encoding from header
@asyncio.coroutine
def json(self, *, encoding='utf-8', loads=json.loads, **kwargs): # NOQA: E999
return loads(self._body.decode(encoding))
async def json(self, *, encoding='utf-8', loads=json.loads, **kwargs): # NOQA: E999
stripped = self._body.strip()
if not stripped:
return None
@asyncio.coroutine
def text(self, encoding='utf-8'):
return self._body.decode(encoding)
return loads(stripped.decode(encoding))
@asyncio.coroutine
def read(self):
async def text(self, encoding='utf-8', errors='strict'):
return self._body.decode(encoding, errors=errors)
async def read(self):
return self._body
@asyncio.coroutine
def release(self):
pass
@property
def content(self):
s = MockStream()
s.feed_data(self._body)
s.feed_eof()
return s
def vcr_request(cassette, real_request):
@functools.wraps(real_request)
@asyncio.coroutine
def new_request(self, method, url, **kwargs):
async def new_request(self, method, url, **kwargs):
headers = kwargs.get('headers')
headers = self._prepare_headers(headers)
data = kwargs.get('data')
@@ -82,7 +91,7 @@ def vcr_request(cassette, real_request):
response.close()
return response
response = yield from real_request(self, method, url, **kwargs) # NOQA: E999
response = await real_request(self, method, url, **kwargs) # NOQA: E999
vcr_response = {
'status': {
@@ -90,7 +99,7 @@ def vcr_request(cassette, real_request):
'message': response.reason,
},
'headers': dict(response.headers),
'body': {'string': (yield from response.read())}, # NOQA: E999
'body': {'string': (await response.read())}, # NOQA: E999
'url': response.url,
}
cassette.append(vcr_request, vcr_response)

View File

@@ -40,6 +40,7 @@ class VCRHTTPSConnectionWithTimeout(VCRHTTPSConnection,
'timeout',
'source_address',
'ca_certs',
'disable_ssl_certificate_validation',
}
unknown_keys = set(kwargs.keys()) - safe_keys
safe_kwargs = kwargs.copy()

View File

@@ -75,10 +75,8 @@ def vcr_fetch_impl(cassette, real_fetch_impl):
request,
599,
error=CannotOverwriteExistingCassetteException(
"No match for the request (%r) was found. "
"Can't overwrite existing cassette (%r) in "
"your current record mode (%r)."
% (vcr_request, cassette._path, cassette.record_mode)
cassette=cassette,
failed_request=vcr_request
),
request_time=self.io_loop.time() - request.start_time,
)

View File

@@ -1,13 +1,17 @@
import collections
import types
try:
from collections.abc import Mapping, MutableMapping
except ImportError:
from collections import Mapping, MutableMapping
# Shamelessly stolen from https://github.com/kennethreitz/requests/blob/master/requests/structures.py
class CaseInsensitiveDict(collections.MutableMapping):
class CaseInsensitiveDict(MutableMapping):
"""
A case-insensitive ``dict``-like object.
Implements all methods and operations of
``collections.MutableMapping`` as well as dict's ``copy``. Also
``collections.abc.MutableMapping`` as well as dict's ``copy``. Also
provides ``lower_items``.
All keys are expected to be strings. The structure remembers the
case of the last key to be set, and ``iter(instance)``,
@@ -57,7 +61,7 @@ class CaseInsensitiveDict(collections.MutableMapping):
)
def __eq__(self, other):
if isinstance(other, collections.Mapping):
if isinstance(other, Mapping):
other = CaseInsensitiveDict(other)
else:
return NotImplemented
@@ -114,10 +118,10 @@ def auto_decorate(
)
def __new__(cls, name, bases, attributes_dict):
new_attributes_dict = dict(
(attribute, maybe_decorate(attribute, value))
new_attributes_dict = {
attribute: maybe_decorate(attribute, value)
for attribute, value in attributes_dict.items()
)
}
return super(DecorateAll, cls).__new__(
cls, name, bases, new_attributes_dict
)