diff --git a/tests/integration/aiohttp_utils.py b/tests/integration/aiohttp_utils.py index 195e5eb..5b77fae 100644 --- a/tests/integration/aiohttp_utils.py +++ b/tests/integration/aiohttp_utils.py @@ -1,7 +1,13 @@ import asyncio +import aiohttp @asyncio.coroutine -def aiohttp_request(session, method, url, as_text, **kwargs): - response = yield from session.request(method, url, **kwargs) # NOQA: E999 - return response, (yield from response.text()) if as_text else (yield from response.json()) # NOQA: E999 +def aiohttp_request(loop, method, url, as_text, **kwargs): + with aiohttp.ClientSession(loop=loop) as session: + response = yield from session.request(method, url, **kwargs) # NOQA: E999 + if as_text: + content = yield from response.text() # NOQA: E999 + else: + content = yield from response.json() # NOQA: E999 + return response, content diff --git a/tests/integration/async_def.py b/tests/integration/async_def.py new file mode 100644 index 0000000..96cab1d --- /dev/null +++ b/tests/integration/async_def.py @@ -0,0 +1,13 @@ +import aiohttp +import pytest +import vcr + + +@vcr.use_cassette() +@pytest.mark.asyncio +async def test_http(): # noqa: E999 + async with aiohttp.ClientSession() as session: + url = 'https://httpbin.org/get' + params = {'ham': 'spam'} + resp = await session.get(url, params=params) # noqa: E999 + assert (await resp.json())['args'] == {'ham': 'spam'} # noqa: E999 diff --git a/tests/integration/test_aiohttp.py b/tests/integration/test_aiohttp.py index 82957ff..280be3a 100644 --- a/tests/integration/test_aiohttp.py +++ b/tests/integration/test_aiohttp.py @@ -1,28 +1,40 @@ import pytest aiohttp = pytest.importorskip("aiohttp") -import asyncio # NOQA -import sys # NOQA +import asyncio # noqa: E402 +import contextlib # noqa: E402 -import aiohttp # NOQA -import pytest # NOQA -import vcr # NOQA +import pytest # noqa: E402 +import vcr # noqa: E402 -from .aiohttp_utils import aiohttp_request # NOQA +from .aiohttp_utils import aiohttp_request # noqa: E402 + +try: + from .async_def import test_http # noqa: F401 +except SyntaxError: + pass + + +def run_in_loop(fn): + with contextlib.closing(asyncio.new_event_loop()) as loop: + asyncio.set_event_loop(loop) + task = loop.create_task(fn(loop)) + return loop.run_until_complete(task) + + +def request(method, url, as_text=True, **kwargs): + def run(loop): + return aiohttp_request(loop, method, url, as_text, **kwargs) + + return run_in_loop(run) def get(url, as_text=True, **kwargs): - loop = asyncio.get_event_loop() - with aiohttp.ClientSession() as session: - task = loop.create_task(aiohttp_request(session, 'GET', url, as_text, **kwargs)) - return loop.run_until_complete(task) + return request('GET', url, as_text, **kwargs) def post(url, as_text=True, **kwargs): - loop = asyncio.get_event_loop() - with aiohttp.ClientSession() as session: - task = loop.create_task(aiohttp_request(session, 'POST', url, as_text, **kwargs)) - return loop.run_until_complete(task) + return request('POST', url, as_text, **kwargs) @pytest.fixture(params=["https", "http"]) diff --git a/tests/integration/test_http b/tests/integration/test_http new file mode 100644 index 0000000..522363b --- /dev/null +++ b/tests/integration/test_http @@ -0,0 +1,22 @@ +interactions: +- request: + body: null + headers: {} + method: GET + uri: https://httpbin.org/get?ham=spam + response: + body: {string: "{\n \"args\": {\n \"ham\": \"spam\"\n }, \n \"headers\"\ + : {\n \"Accept\": \"*/*\", \n \"Accept-Encoding\": \"gzip, deflate\"\ + , \n \"Connection\": \"close\", \n \"Host\": \"httpbin.org\", \n \ + \ \"User-Agent\": \"Python/3.5 aiohttp/2.0.1\"\n }, \n \"origin\": \"213.86.221.35\"\ + , \n \"url\": \"https://httpbin.org/get?ham=spam\"\n}\n"} + headers: {Access-Control-Allow-Credentials: 'true', Access-Control-Allow-Origin: '*', + Connection: keep-alive, Content-Length: '299', Content-Type: application/json, + Date: 'Wed, 22 Mar 2017 20:08:29 GMT', Server: gunicorn/19.7.1, Via: 1.1 vegur} + status: {code: 200, message: OK} + url: !!python/object/new:yarl.URL + state: !!python/tuple + - !!python/object/new:urllib.parse.SplitResult [https, httpbin.org, /get, ham=spam, + ''] + - false +version: 1 diff --git a/tox.ini b/tox.ini index fa60b54..250ef18 100644 --- a/tox.ini +++ b/tox.ini @@ -39,6 +39,7 @@ deps = boto: boto boto3: boto3 aiohttp: aiohttp + aiohttp: pytest-asyncio [flake8] max_line_length = 110 diff --git a/vcr/_handle_coroutine.py b/vcr/_handle_coroutine.py new file mode 100644 index 0000000..0b20be6 --- /dev/null +++ b/vcr/_handle_coroutine.py @@ -0,0 +1,7 @@ +import asyncio + + +@asyncio.coroutine +def handle_coroutine(vcr, fn): + with vcr as cassette: + return (yield from fn(cassette)) # noqa: E999 diff --git a/vcr/cassette.py b/vcr/cassette.py index c508602..fc3f0fa 100644 --- a/vcr/cassette.py +++ b/vcr/cassette.py @@ -12,6 +12,16 @@ from .serializers import yamlserializer from .persisters.filesystem import FilesystemPersister from .util import partition_dict +try: + from asyncio import iscoroutinefunction + from ._handle_coroutine import handle_coroutine +except ImportError: + def iscoroutinefunction(*args, **kwargs): + return False + + def handle_coroutine(*args, **kwags): + raise NotImplementedError('Not implemented on Python 2') + log = logging.getLogger(__name__) @@ -96,18 +106,25 @@ class CassetteContextDecorator(object): ) def _execute_function(self, function, args, kwargs): - if inspect.isgeneratorfunction(function): - handler = self._handle_coroutine - else: - handler = self._handle_function - return handler(function, args, kwargs) + def handle_function(cassette): + if cassette.inject: + return function(cassette, *args, **kwargs) + else: + return function(*args, **kwargs) - def _handle_coroutine(self, function, args, kwargs): - """Wraps a coroutine so that we're inside the cassette context for the - duration of the coroutine. + if iscoroutinefunction(function): + return handle_coroutine(vcr=self, fn=handle_function) + if inspect.isgeneratorfunction(function): + return self._handle_generator(fn=handle_function) + + return self._handle_function(fn=handle_function) + + def _handle_generator(self, fn): + """Wraps a generator so that we're inside the cassette context for the + duration of the generator. """ with self as cassette: - coroutine = self.__handle_function(cassette, function, args, kwargs) + coroutine = fn(cassette) # We don't need to catch StopIteration. The caller (Tornado's # gen.coroutine, for example) will handle that. to_yield = next(coroutine) @@ -119,15 +136,9 @@ class CassetteContextDecorator(object): else: to_yield = coroutine.send(to_send) - def __handle_function(self, cassette, function, args, kwargs): - if cassette.inject: - return function(cassette, *args, **kwargs) - else: - return function(*args, **kwargs) - - def _handle_function(self, function, args, kwargs): + def _handle_function(self, fn): with self as cassette: - return self.__handle_function(cassette, function, args, kwargs) + return fn(cassette) @staticmethod def get_function_name(function):