1
0
mirror of https://github.com/kevin1024/vcrpy.git synced 2025-12-08 16:53:23 +00:00

aiohttp: fix multiple requests being replayed per request and add support for request_info on mocked responses (#495)

* Fix how redirects are handled so we can have requests with the same URL be used distinctly

* Add support for request_info. Remove `past` kwarg.

* Remove as e to make linter happy

* Add unreleased 3.0.0 to changelog.
This commit is contained in:
Nick DiRienzo
2019-12-12 20:01:17 -08:00
committed by Josh Peak
parent ffd2142d86
commit 9ec19dd966
3 changed files with 117 additions and 10 deletions

View File

@@ -1,5 +1,9 @@
Changelog
---------
- 3.0.0 (UNRELEASED)
- Fix multiple requests being replayed per single request in aiohttp stub (@nickdirienzo)
- Add support for `request_info` on mocked responses in aiohttp stub (@nickdirienzo)
- ...
- 2.1.x (UNRELEASED)
- ....
- 2.1.1

View File

@@ -262,3 +262,43 @@ def test_redirect(aiohttp_client, tmpdir):
assert len(cassette_response.history) == len(response.history)
assert len(cassette) == 3
assert cassette.play_count == 3
# Assert that the real response and the cassette response have a similar
# looking request_info.
assert cassette_response.request_info.url == response.request_info.url
assert cassette_response.request_info.method == response.request_info.method
assert {k: v for k, v in cassette_response.request_info.headers.items()} == {
k: v for k, v in response.request_info.headers.items()
}
assert cassette_response.request_info.real_url == response.request_info.real_url
def test_double_requests(tmpdir):
"""We should capture, record, and replay all requests and response chains,
even if there are duplicate ones.
We should replay in the order we saw them.
"""
url = "https://httpbin.org/get"
with vcr.use_cassette(str(tmpdir.join("text.yaml"))):
_, response_text1 = get(url, output="text")
_, response_text2 = get(url, output="text")
with vcr.use_cassette(str(tmpdir.join("text.yaml"))) as cassette:
resp, cassette_response_text = get(url, output="text")
assert resp.status == 200
assert cassette_response_text == response_text1
# We made only one request, so we should only play 1 recording.
assert cassette.play_count == 1
# Now make the second test to url
resp, cassette_response_text = get(url, output="text")
assert resp.status == 200
assert cassette_response_text == response_text2
# Now that we made both requests, we should have played both.
assert cassette.play_count == 2

View File

@@ -6,7 +6,7 @@ import functools
import logging
import json
from aiohttp import ClientResponse, streams
from aiohttp import ClientConnectionError, ClientResponse, RequestInfo, streams
from multidict import CIMultiDict, CIMultiDictProxy
from yarl import URL
@@ -20,14 +20,14 @@ class MockStream(asyncio.StreamReader, streams.AsyncStreamReaderMixin):
class MockClientResponse(ClientResponse):
def __init__(self, method, url):
def __init__(self, method, url, request_info=None):
super().__init__(
method=method,
url=url,
writer=None,
continue100=None,
timer=None,
request_info=None,
request_info=request_info,
traces=None,
loop=asyncio.get_event_loop(),
session=None,
@@ -58,7 +58,13 @@ class MockClientResponse(ClientResponse):
def build_response(vcr_request, vcr_response, history):
response = MockClientResponse(vcr_request.method, URL(vcr_response.get("url")))
request_info = RequestInfo(
url=URL(vcr_request.url),
method=vcr_request.method,
headers=CIMultiDictProxy(CIMultiDict(vcr_request.headers)),
real_url=URL(vcr_request.url),
)
response = MockClientResponse(vcr_request.method, URL(vcr_response.get("url")), request_info=request_info)
response.status = vcr_response["status"]["code"]
response._body = vcr_response["body"].get("string", b"")
response.reason = vcr_response["status"]["message"]
@@ -69,12 +75,36 @@ def build_response(vcr_request, vcr_response, history):
return response
def _serialize_headers(headers):
"""Serialize CIMultiDictProxy to a pickle-able dict because proxy
objects forbid pickling:
https://github.com/aio-libs/multidict/issues/340
"""
# Mark strings as keys so 'istr' types don't show up in
# the cassettes as comments.
return {str(k): v for k, v in headers.items()}
def play_responses(cassette, vcr_request):
history = []
vcr_response = cassette.play_response(vcr_request)
response = build_response(vcr_request, vcr_response, history)
while cassette.can_play_response_for(vcr_request):
# If we're following redirects, continue playing until we reach
# our final destination.
while 300 <= response.status <= 399:
next_url = URL(response.url).with_path(response.headers["location"])
# Make a stub VCR request that we can then use to look up the recorded
# VCR request saved to the cassette. This feels a little hacky and
# may have edge cases based on the headers we're providing (e.g. if
# there's a matcher that is used to filter by headers).
vcr_request = Request("GET", str(next_url), None, _serialize_headers(response.request_info.headers))
vcr_request = cassette.find_requests_with_most_matches(vcr_request)[0][0]
# Tack on the response we saw from the redirect into the history
# list that is added on to the final response.
history.append(response)
vcr_response = cassette.play_response(vcr_request)
response = build_response(vcr_request, vcr_response, history)
@@ -82,22 +112,55 @@ def play_responses(cassette, vcr_request):
return response
async def record_response(cassette, vcr_request, response, past=False):
body = {} if past else {"string": (await response.read())}
headers = {str(key): value for key, value in response.headers.items()}
async def record_response(cassette, vcr_request, response):
"""Record a VCR request-response chain to the cassette."""
try:
body = {"string": (await response.read())}
# aiohttp raises a ClientConnectionError on reads when
# there is no body. We can use this to know to not write one.
except ClientConnectionError:
body = {}
vcr_response = {
"status": {"code": response.status, "message": response.reason},
"headers": headers,
"headers": _serialize_headers(response.headers),
"body": body, # NOQA: E999
"url": str(response.url),
}
cassette.append(vcr_request, vcr_response)
async def record_responses(cassette, vcr_request, response):
"""Because aiohttp follows redirects by default, we must support
them by default. This method is used to write individual
request-response chains that were implicitly followed to get
to the final destination.
"""
for past_response in response.history:
await record_response(cassette, vcr_request, past_response, past=True)
aiohttp_request = past_response.request_info
# No data because it's following a redirect.
past_request = Request(
aiohttp_request.method,
str(aiohttp_request.url),
None,
_serialize_headers(aiohttp_request.headers),
)
await record_response(cassette, past_request, past_response)
# If we're following redirects, then the last request-response
# we record is the one attached to the `response`.
if response.history:
aiohttp_request = response.request_info
vcr_request = Request(
aiohttp_request.method,
str(aiohttp_request.url),
None,
_serialize_headers(aiohttp_request.headers),
)
await record_response(cassette, vcr_request, response)