diff --git a/docs/changelog.rst b/docs/changelog.rst index 0dbb350..f92b89f 100644 --- a/docs/changelog.rst +++ b/docs/changelog.rst @@ -1,5 +1,9 @@ Changelog --------- +- 3.0.0 (UNRELEASED) + - Fix multiple requests being replayed per single request in aiohttp stub (@nickdirienzo) + - Add support for `request_info` on mocked responses in aiohttp stub (@nickdirienzo) + - ... - 2.1.x (UNRELEASED) - .... - 2.1.1 diff --git a/tests/integration/test_aiohttp.py b/tests/integration/test_aiohttp.py index 8f170e9..6508c48 100644 --- a/tests/integration/test_aiohttp.py +++ b/tests/integration/test_aiohttp.py @@ -262,3 +262,43 @@ def test_redirect(aiohttp_client, tmpdir): assert len(cassette_response.history) == len(response.history) assert len(cassette) == 3 assert cassette.play_count == 3 + + # Assert that the real response and the cassette response have a similar + # looking request_info. + assert cassette_response.request_info.url == response.request_info.url + assert cassette_response.request_info.method == response.request_info.method + assert {k: v for k, v in cassette_response.request_info.headers.items()} == { + k: v for k, v in response.request_info.headers.items() + } + assert cassette_response.request_info.real_url == response.request_info.real_url + + +def test_double_requests(tmpdir): + """We should capture, record, and replay all requests and response chains, + even if there are duplicate ones. + + We should replay in the order we saw them. + """ + url = "https://httpbin.org/get" + + with vcr.use_cassette(str(tmpdir.join("text.yaml"))): + _, response_text1 = get(url, output="text") + _, response_text2 = get(url, output="text") + + with vcr.use_cassette(str(tmpdir.join("text.yaml"))) as cassette: + resp, cassette_response_text = get(url, output="text") + assert resp.status == 200 + assert cassette_response_text == response_text1 + + # We made only one request, so we should only play 1 recording. + assert cassette.play_count == 1 + + # Now make the second test to url + resp, cassette_response_text = get(url, output="text") + + assert resp.status == 200 + + assert cassette_response_text == response_text2 + + # Now that we made both requests, we should have played both. + assert cassette.play_count == 2 diff --git a/vcr/stubs/aiohttp_stubs/__init__.py b/vcr/stubs/aiohttp_stubs/__init__.py index 464de3e..2301334 100644 --- a/vcr/stubs/aiohttp_stubs/__init__.py +++ b/vcr/stubs/aiohttp_stubs/__init__.py @@ -6,7 +6,7 @@ import functools import logging import json -from aiohttp import ClientResponse, streams +from aiohttp import ClientConnectionError, ClientResponse, RequestInfo, streams from multidict import CIMultiDict, CIMultiDictProxy from yarl import URL @@ -20,14 +20,14 @@ class MockStream(asyncio.StreamReader, streams.AsyncStreamReaderMixin): class MockClientResponse(ClientResponse): - def __init__(self, method, url): + def __init__(self, method, url, request_info=None): super().__init__( method=method, url=url, writer=None, continue100=None, timer=None, - request_info=None, + request_info=request_info, traces=None, loop=asyncio.get_event_loop(), session=None, @@ -58,7 +58,13 @@ class MockClientResponse(ClientResponse): def build_response(vcr_request, vcr_response, history): - response = MockClientResponse(vcr_request.method, URL(vcr_response.get("url"))) + request_info = RequestInfo( + url=URL(vcr_request.url), + method=vcr_request.method, + headers=CIMultiDictProxy(CIMultiDict(vcr_request.headers)), + real_url=URL(vcr_request.url), + ) + response = MockClientResponse(vcr_request.method, URL(vcr_response.get("url")), request_info=request_info) response.status = vcr_response["status"]["code"] response._body = vcr_response["body"].get("string", b"") response.reason = vcr_response["status"]["message"] @@ -69,12 +75,36 @@ def build_response(vcr_request, vcr_response, history): return response +def _serialize_headers(headers): + """Serialize CIMultiDictProxy to a pickle-able dict because proxy + objects forbid pickling: + + https://github.com/aio-libs/multidict/issues/340 + """ + # Mark strings as keys so 'istr' types don't show up in + # the cassettes as comments. + return {str(k): v for k, v in headers.items()} + + def play_responses(cassette, vcr_request): history = [] vcr_response = cassette.play_response(vcr_request) response = build_response(vcr_request, vcr_response, history) - while cassette.can_play_response_for(vcr_request): + # If we're following redirects, continue playing until we reach + # our final destination. + while 300 <= response.status <= 399: + next_url = URL(response.url).with_path(response.headers["location"]) + + # Make a stub VCR request that we can then use to look up the recorded + # VCR request saved to the cassette. This feels a little hacky and + # may have edge cases based on the headers we're providing (e.g. if + # there's a matcher that is used to filter by headers). + vcr_request = Request("GET", str(next_url), None, _serialize_headers(response.request_info.headers)) + vcr_request = cassette.find_requests_with_most_matches(vcr_request)[0][0] + + # Tack on the response we saw from the redirect into the history + # list that is added on to the final response. history.append(response) vcr_response = cassette.play_response(vcr_request) response = build_response(vcr_request, vcr_response, history) @@ -82,22 +112,55 @@ def play_responses(cassette, vcr_request): return response -async def record_response(cassette, vcr_request, response, past=False): - body = {} if past else {"string": (await response.read())} - headers = {str(key): value for key, value in response.headers.items()} +async def record_response(cassette, vcr_request, response): + """Record a VCR request-response chain to the cassette.""" + + try: + body = {"string": (await response.read())} + # aiohttp raises a ClientConnectionError on reads when + # there is no body. We can use this to know to not write one. + except ClientConnectionError: + body = {} vcr_response = { "status": {"code": response.status, "message": response.reason}, - "headers": headers, + "headers": _serialize_headers(response.headers), "body": body, # NOQA: E999 "url": str(response.url), } + cassette.append(vcr_request, vcr_response) async def record_responses(cassette, vcr_request, response): + """Because aiohttp follows redirects by default, we must support + them by default. This method is used to write individual + request-response chains that were implicitly followed to get + to the final destination. + """ + for past_response in response.history: - await record_response(cassette, vcr_request, past_response, past=True) + aiohttp_request = past_response.request_info + + # No data because it's following a redirect. + past_request = Request( + aiohttp_request.method, + str(aiohttp_request.url), + None, + _serialize_headers(aiohttp_request.headers), + ) + await record_response(cassette, past_request, past_response) + + # If we're following redirects, then the last request-response + # we record is the one attached to the `response`. + if response.history: + aiohttp_request = response.request_info + vcr_request = Request( + aiohttp_request.method, + str(aiohttp_request.url), + None, + _serialize_headers(aiohttp_request.headers), + ) await record_response(cassette, vcr_request, response)