diff --git a/tests/integration/aiohttp_utils.py b/tests/integration/aiohttp_utils.py index 6f6f15e..c8c1b0b 100644 --- a/tests/integration/aiohttp_utils.py +++ b/tests/integration/aiohttp_utils.py @@ -17,6 +17,8 @@ async def aiohttp_request(loop, method, url, output='text', encoding='utf-8', co content = await response.json(encoding=encoding, content_type=content_type) elif output == 'raw': content = await response.read() + elif output == 'stream': + content = await response.content.read() response_ctx._resp.close() await session.close() diff --git a/tests/integration/test_aiohttp.py b/tests/integration/test_aiohttp.py index c20dc61..a02c982 100644 --- a/tests/integration/test_aiohttp.py +++ b/tests/integration/test_aiohttp.py @@ -93,6 +93,18 @@ def test_binary(tmpdir, scheme): assert cassette.play_count == 1 +def test_stream(tmpdir, scheme): + url = scheme + '://httpbin.org/get' + + with vcr.use_cassette(str(tmpdir.join('stream.yaml'))): + resp, body = get(url, output='raw') # Do not use stream here, as the stream is exhausted by vcr + + with vcr.use_cassette(str(tmpdir.join('stream.yaml'))) as cassette: + cassette_resp, cassette_body = get(url, output='stream') + assert cassette_body == body + assert cassette.play_count == 1 + + def test_post(tmpdir, scheme): data = {'key1': 'value1', 'key2': 'value2'} url = scheme + '://httpbin.org/post' diff --git a/vcr/stubs/aiohttp_stubs/__init__.py b/vcr/stubs/aiohttp_stubs/__init__.py index f6525c7..8e03907 100644 --- a/vcr/stubs/aiohttp_stubs/__init__.py +++ b/vcr/stubs/aiohttp_stubs/__init__.py @@ -5,12 +5,16 @@ import asyncio import functools import json -from aiohttp import ClientResponse +from aiohttp import ClientResponse, streams from yarl import URL from vcr.request import Request +class MockStream(asyncio.StreamReader, streams.AsyncStreamReaderMixin): + pass + + class MockClientResponse(ClientResponse): def __init__(self, method, url): super().__init__( @@ -37,6 +41,13 @@ class MockClientResponse(ClientResponse): def release(self): pass + @property + def content(self): + s = MockStream() + s.feed_data(self._body) + s.feed_eof() + return s + def vcr_request(cassette, real_request): @functools.wraps(real_request)