"""Stubs for aiohttp HTTP clients""" import asyncio import functools import logging import json from aiohttp import ClientConnectionError, ClientResponse, RequestInfo, streams from multidict import CIMultiDict, CIMultiDictProxy from yarl import URL from vcr.request import Request log = logging.getLogger(__name__) class MockStream(asyncio.StreamReader, streams.AsyncStreamReaderMixin): pass class MockClientResponse(ClientResponse): def __init__(self, method, url, request_info=None): super().__init__( method=method, url=url, writer=None, continue100=None, timer=None, request_info=request_info, traces=None, loop=asyncio.get_event_loop(), session=None, ) async def json(self, *, encoding="utf-8", loads=json.loads, **kwargs): # NOQA: E999 stripped = self._body.strip() if not stripped: return None return loads(stripped.decode(encoding)) async def text(self, encoding="utf-8", errors="strict"): return self._body.decode(encoding, errors=errors) async def read(self): return self._body def release(self): pass @property def content(self): s = MockStream() s.feed_data(self._body) s.feed_eof() return s def build_response(vcr_request, vcr_response, history): request_info = RequestInfo( url=URL(vcr_request.url), method=vcr_request.method, headers=CIMultiDictProxy(CIMultiDict(vcr_request.headers)), real_url=URL(vcr_request.url), ) response = MockClientResponse(vcr_request.method, URL(vcr_response.get("url")), request_info=request_info) response.status = vcr_response["status"]["code"] response._body = vcr_response["body"].get("string", b"") response.reason = vcr_response["status"]["message"] response._headers = CIMultiDictProxy(CIMultiDict(vcr_response["headers"])) response._history = tuple(history) response.close() return response def _serialize_headers(headers): """Serialize CIMultiDictProxy to a pickle-able dict because proxy objects forbid pickling: https://github.com/aio-libs/multidict/issues/340 """ # Mark strings as keys so 'istr' types don't show up in # the cassettes as comments. return {str(k): v for k, v in headers.items()} def play_responses(cassette, vcr_request): history = [] vcr_response = cassette.play_response(vcr_request) response = build_response(vcr_request, vcr_response, history) # If we're following redirects, continue playing until we reach # our final destination. while 300 <= response.status <= 399: next_url = URL(response.url).with_path(response.headers["location"]) # Make a stub VCR request that we can then use to look up the recorded # VCR request saved to the cassette. This feels a little hacky and # may have edge cases based on the headers we're providing (e.g. if # there's a matcher that is used to filter by headers). vcr_request = Request("GET", str(next_url), None, _serialize_headers(response.request_info.headers)) vcr_request = cassette.find_requests_with_most_matches(vcr_request)[0][0] # Tack on the response we saw from the redirect into the history # list that is added on to the final response. history.append(response) vcr_response = cassette.play_response(vcr_request) response = build_response(vcr_request, vcr_response, history) return response async def record_response(cassette, vcr_request, response): """Record a VCR request-response chain to the cassette.""" try: body = {"string": (await response.read())} # aiohttp raises a ClientConnectionError on reads when # there is no body. We can use this to know to not write one. except ClientConnectionError: body = {} vcr_response = { "status": {"code": response.status, "message": response.reason}, "headers": _serialize_headers(response.headers), "body": body, # NOQA: E999 "url": str(response.url), } cassette.append(vcr_request, vcr_response) async def record_responses(cassette, vcr_request, response): """Because aiohttp follows redirects by default, we must support them by default. This method is used to write individual request-response chains that were implicitly followed to get to the final destination. """ for past_response in response.history: aiohttp_request = past_response.request_info # No data because it's following a redirect. past_request = Request( aiohttp_request.method, str(aiohttp_request.url), None, _serialize_headers(aiohttp_request.headers), ) await record_response(cassette, past_request, past_response) # If we're following redirects, then the last request-response # we record is the one attached to the `response`. if response.history: aiohttp_request = response.request_info vcr_request = Request( aiohttp_request.method, str(aiohttp_request.url), None, _serialize_headers(aiohttp_request.headers), ) await record_response(cassette, vcr_request, response) def vcr_request(cassette, real_request): @functools.wraps(real_request) async def new_request(self, method, url, **kwargs): headers = kwargs.get("headers") auth = kwargs.get("auth") headers = self._prepare_headers(headers) data = kwargs.get("data", kwargs.get("json")) params = kwargs.get("params") if auth is not None: headers["AUTHORIZATION"] = auth.encode() request_url = URL(url) if params: for k, v in params.items(): params[k] = str(v) request_url = URL(url).with_query(params) vcr_request = Request(method, str(request_url), data, headers) if cassette.can_play_response_for(vcr_request): return play_responses(cassette, vcr_request) if cassette.write_protected and cassette.filter_request(vcr_request): response = MockClientResponse(method, URL(url)) response.status = 599 msg = ( "No match for the request {!r} was found. Can't overwrite " "existing cassette {!r} in your current record mode {!r}." ) msg = msg.format(vcr_request, cassette._path, cassette.record_mode) response._body = msg.encode() response.close() return response log.info("%s not in cassette, sending to real server", vcr_request) response = await real_request(self, method, url, **kwargs) # NOQA: E999 await record_responses(cassette, vcr_request, response) return response return new_request