From f98684e8aa4a2314d3429ac1af9a384fac48fbf8 Mon Sep 17 00:00:00 2001 From: Luiz Menezes Date: Thu, 4 Aug 2016 00:21:49 -0300 Subject: [PATCH] add support for aiohttp --- vcr/patch.py | 22 ++++++++++- vcr/stubs/aiohttp_stubs.py | 76 ++++++++++++++++++++++++++++++++++++++ 2 files changed, 97 insertions(+), 1 deletion(-) create mode 100644 vcr/stubs/aiohttp_stubs.py diff --git a/vcr/patch.py b/vcr/patch.py index 0c54354..ac1127a 100644 --- a/vcr/patch.py +++ b/vcr/patch.py @@ -80,6 +80,13 @@ else: _CurlAsyncHTTPClient_fetch_impl = \ tornado.curl_httpclient.CurlAsyncHTTPClient.fetch_impl +try: + import aiohttp.client +except ImportError: # pragma: no cover + pass +else: + _AiohttpClientSessionRequest = aiohttp.client.ClientSession._request + class CassettePatcherBuilder(object): @@ -98,7 +105,7 @@ class CassettePatcherBuilder(object): def build(self): return itertools.chain( self._httplib(), self._requests(), self._boto3(), self._urllib3(), - self._httplib2(), self._boto(), self._tornado(), + self._httplib2(), self._boto(), self._tornado(), self._aiohttp(), self._build_patchers_from_mock_triples( self._cassette.custom_patches ), @@ -273,6 +280,19 @@ class CassettePatcherBuilder(object): ) yield curl.CurlAsyncHTTPClient, 'fetch_impl', new_fetch_impl + @_build_patchers_from_mock_triples_decorator + def _aiohttp(self): + try: + import aiohttp.client as client + except ImportError: # pragma: no cover + pass + else: + from .stubs.aiohttp_stubs import vcr_request + new_request = vcr_request( + self._cassette, _AiohttpClientSessionRequest + ) + yield client.ClientSession, '_request', new_request + def _urllib3_patchers(self, cpool, stubs): http_connection_remover = ConnectionRemover( self._get_cassette_subclass(stubs.VCRRequestsHTTPConnection) diff --git a/vcr/stubs/aiohttp_stubs.py b/vcr/stubs/aiohttp_stubs.py new file mode 100644 index 0000000..070a0f7 --- /dev/null +++ b/vcr/stubs/aiohttp_stubs.py @@ -0,0 +1,76 @@ +'''Stubs for aiohttp HTTP clients''' +from __future__ import absolute_import + +import functools +import json + +from aiohttp import ClientResponse +from multidict import CIMultiDictProxy + +from vcr.request import Request + + +class MockClientResponse(ClientResponse): + # TODO: get encoding from header + async def json(self, *, encoding='utf-8', loads=json.loads): + return loads(self.content.decode(encoding)) + + async def text(self, encoding='utf-8'): + return self.content.decode(encoding) + + async def release(self): + self.close() + + +def vcr_request(cassette, real_request): + + @functools.wraps(real_request) + async def new_request(self, method, url, **kwargs): + headers = kwargs.get('headers') + headers = self._prepare_headers(headers) + data = kwargs.get('data') + + vcr_request = Request(method, url, data, headers) + + if cassette.can_play_response_for(vcr_request): + vcr_response = cassette.play_response(vcr_request) + + response = MockClientResponse( + method, + vcr_response.get('url'), + ) + response.status = vcr_response['status']['code'] + response.content = vcr_response['body']['string'] + response.reason = vcr_response['status']['message'] + response.headers = CIMultiDictProxy(headers) + + return response + + if cassette.write_protected and cassette.filter_request(vcr_request): + response = MockClientResponse( + method, + vcr_response.get('url'), + ) + response.status = 599 + response.content = ("No match for the request (%r) was found. " + "Can't overwrite existing cassette (%r) in " + "your current record mode (%r).") + response.close() + return response + + response = await real_request(self, method, url, **kwargs) + + vcr_response = { + 'status': { + 'code': response.status, + 'message': response.reason, + }, + 'headers': response.headers, + 'body': {'string': await response.text()}, + 'url': response.url, + } + cassette.append(vcr_request, vcr_response) + + return response + + return new_request