diff --git a/tests/integration/test_aiohttp.py b/tests/integration/test_aiohttp.py index 8f3b743..8c6810b 100644 --- a/tests/integration/test_aiohttp.py +++ b/tests/integration/test_aiohttp.py @@ -47,13 +47,17 @@ def test_status(tmpdir, scheme): assert cassette.play_count == 1 -def test_headers(tmpdir, scheme): +@pytest.mark.parametrize("auth", [None, aiohttp.BasicAuth("vcrpy", "test")]) +def test_headers(tmpdir, scheme, auth): url = scheme + '://httpbin.org' with vcr.use_cassette(str(tmpdir.join('headers.yaml'))): - response, _ = get(url) + response, _ = get(url, auth=auth) with vcr.use_cassette(str(tmpdir.join('headers.yaml'))) as cassette: - cassette_response, _ = get(url) + if auth is not None: + request = cassette.requests[0] + assert "AUTHORIZATION" in request.headers + cassette_response, _ = get(url, auth=auth) assert cassette_response.headers == response.headers assert cassette.play_count == 1 assert 'istr' not in cassette.data[0] @@ -107,14 +111,17 @@ def test_stream(tmpdir, scheme): assert cassette.play_count == 1 -def test_post(tmpdir, scheme): +@pytest.mark.parametrize('body', ['data', 'json']) +def test_post(tmpdir, scheme, body): data = {'key1': 'value1', 'key2': 'value2'} url = scheme + '://httpbin.org/post' with vcr.use_cassette(str(tmpdir.join('post.yaml'))): - _, response_json = post(url, data=data) + _, response_json = post(url, **{body: data}) with vcr.use_cassette(str(tmpdir.join('post.yaml'))) as cassette: - _, cassette_response_json = post(url, data=data) + request = cassette.requests[0] + assert request.body == data + _, cassette_response_json = post(url, **{body: data}) assert cassette_response_json == response_json assert cassette.play_count == 1 diff --git a/vcr/stubs/aiohttp_stubs/__init__.py b/vcr/stubs/aiohttp_stubs/__init__.py index 6de1c17..0512fa2 100644 --- a/vcr/stubs/aiohttp_stubs/__init__.py +++ b/vcr/stubs/aiohttp_stubs/__init__.py @@ -3,6 +3,7 @@ from __future__ import absolute_import import asyncio import functools +import logging import json from aiohttp import ClientResponse, streams @@ -10,6 +11,8 @@ from yarl import URL from vcr.request import Request +log = logging.getLogger(__name__) + class MockStream(asyncio.StreamReader, streams.AsyncStreamReaderMixin): pass @@ -57,10 +60,14 @@ def vcr_request(cassette, real_request): @functools.wraps(real_request) async def new_request(self, method, url, **kwargs): headers = kwargs.get('headers') + auth = kwargs.get('auth') headers = self._prepare_headers(headers) - data = kwargs.get('data') + data = kwargs.get('data', kwargs.get('json')) params = kwargs.get('params') + if auth is not None: + headers['AUTHORIZATION'] = auth.encode() + request_url = URL(url) if params: for k, v in params.items(): @@ -91,6 +98,8 @@ def vcr_request(cassette, real_request): response.close() return response + log.info("{} not in cassette, sending to real server", vcr_request) + response = await real_request(self, method, url, **kwargs) # NOQA: E999 vcr_response = {