From 2f94d06e9b3f7270e2405f1f0557e7b48c5936c8 Mon Sep 17 00:00:00 2001 From: Hernan Ezequiel Di Giorgi Date: Fri, 24 Apr 2020 20:45:35 -0300 Subject: [PATCH] add httpx support --- tests/integration/test_httpx.py | 141 ++++++++++++++++++++++++++++++++ tox.ini | 9 +- vcr/patch.py | 22 +++++ vcr/stubs/httpx_stubs.py | 105 ++++++++++++++++++++++++ 4 files changed, 273 insertions(+), 4 deletions(-) create mode 100644 tests/integration/test_httpx.py create mode 100644 vcr/stubs/httpx_stubs.py diff --git a/tests/integration/test_httpx.py b/tests/integration/test_httpx.py new file mode 100644 index 0000000..264ae31 --- /dev/null +++ b/tests/integration/test_httpx.py @@ -0,0 +1,141 @@ +import pytest +import contextlib +import vcr # noqa: E402 + +asyncio = pytest.importorskip("asyncio") +httpx = pytest.importorskip("httpx") + + +class BaseDoRequest: + _client_class = None + + def __init__(self, *args, **kwargs): + self._client = self._client_class(*args, **kwargs) + + +class DoSyncRequest(BaseDoRequest): + _client_class = httpx.Client + + def __call__(self, *args, **kwargs): + return self._client.request(*args, timeout=60, **kwargs) + + +class DoAsyncRequest(BaseDoRequest): + _client_class = httpx.AsyncClient + + @staticmethod + def run_in_loop(coroutine): + with contextlib.closing(asyncio.new_event_loop()) as loop: + asyncio.set_event_loop(loop) + task = loop.create_task(coroutine) + return loop.run_until_complete(task) + + def __call__(self, *args, **kwargs): + async def _request(): + async with self._client as c: + return await c.request(*args, **kwargs) + + return DoAsyncRequest.run_in_loop(_request()) + + +def pytest_generate_tests(metafunc): + if "do_request" in metafunc.fixturenames: + metafunc.parametrize("do_request", [DoAsyncRequest, DoSyncRequest]) + if "scheme" in metafunc.fixturenames: + metafunc.parametrize("scheme", ["http", "https"]) + + +@pytest.fixture +def yml(tmpdir, request): + return str(tmpdir.join(request.function.__name__ + ".yaml")) + + +def test_status(tmpdir, scheme, do_request): + url = scheme + "://httpbin.org" + with vcr.use_cassette(str(tmpdir.join("status.yaml"))): + response = do_request()("GET", url) + + with vcr.use_cassette(str(tmpdir.join("status.yaml"))) as cassette: + cassette_response = do_request()("GET", url) + assert cassette_response.status_code == response.status_code + assert cassette.play_count == 1 + + +def test_case_insensitive_headers(tmpdir, scheme, do_request): + url = scheme + "://httpbin.org" + with vcr.use_cassette(str(tmpdir.join("whatever.yaml"))): + do_request()("GET", url) + + with vcr.use_cassette(str(tmpdir.join("whatever.yaml"))) as cassette: + cassette_response = do_request()("GET", url) + assert "Content-Type" in cassette_response.headers + assert "content-type" in cassette_response.headers + assert cassette.play_count == 1 + + +def test_content(tmpdir, scheme, do_request): + url = scheme + "://httpbin.org" + with vcr.use_cassette(str(tmpdir.join("cointent.yaml"))): + response = do_request()("GET", url) + + with vcr.use_cassette(str(tmpdir.join("cointent.yaml"))) as cassette: + cassette_response = do_request()("GET", url) + assert cassette_response.content == response.content + assert cassette.play_count == 1 + + +def test_json(tmpdir, scheme, do_request): + url = scheme + "://httpbin.org/get" + headers = {"Content-Type": "application/json"} + + with vcr.use_cassette(str(tmpdir.join("json.yaml"))): + response = do_request(headers=headers)("GET", url) + + with vcr.use_cassette(str(tmpdir.join("json.yaml"))) as cassette: + cassette_response = do_request(headers=headers)("GET", url) + assert cassette_response.json() == response.json() + assert cassette.play_count == 1 + + +def test_params_same_url_distinct_params(tmpdir, scheme, do_request): + url = scheme + "://httpbin.org/get" + headers = {"Content-Type": "application/json"} + params = {"a": 1, "b": False, "c": "c"} + + with vcr.use_cassette(str(tmpdir.join("get.yaml"))) as cassette: + response = do_request()("GET", url, params=params, headers=headers) + + with vcr.use_cassette(str(tmpdir.join("get.yaml"))) as cassette: + cassette_response = do_request()("GET", url, params=params, headers=headers) + assert cassette_response.request.url == response.request.url + assert cassette_response.json() == response.json() + assert cassette.play_count == 1 + + params = {"other": "params"} + with vcr.use_cassette(str(tmpdir.join("get.yaml"))) as cassette: + with pytest.raises(vcr.errors.CannotOverwriteExistingCassetteException): + do_request()("GET", url, params=params, headers=headers) + + +def test_redirect(tmpdir, do_request, yml): + url = "https://httpbin.org/redirect/2" + + response = do_request()("GET", url) + with vcr.use_cassette(yml): + response = do_request()("GET", url) + + with vcr.use_cassette(yml) as cassette: + cassette_response = do_request()("GET", url) + + assert cassette_response.status_code == response.status_code + assert len(cassette_response.history) == len(response.history) + assert len(cassette) == 3 + assert cassette.play_count == 3 + + # Assert that the real response and the cassette response have a similar + # looking request_info. + assert cassette_response.request.url == response.request.url + assert cassette_response.request.method == response.request.method + assert {k: v for k, v in cassette_response.request.headers.items()} == { + k: v for k, v in response.request.headers.items() + } diff --git a/tox.ini b/tox.ini index 20c4b73..8c41988 100644 --- a/tox.ini +++ b/tox.ini @@ -3,8 +3,8 @@ skip_missing_interpreters=true envlist = cov-clean, lint, - {py35,py36,py37,py38}-{requests,httplib2,urllib3,tornado4,boto3,aiohttp}, - {pypy3}-{requests,httplib2,urllib3,tornado4,boto3}, + {py35,py36,py37,py38}-{requests,httplib2,urllib3,tornado4,boto3,aiohttp,httpx}, + {pypy3}-{requests,httplib2,urllib3,tornado4,boto3,httpx}, cov-report @@ -79,8 +79,9 @@ deps = aiohttp: aiohttp aiohttp: pytest-asyncio aiohttp: pytest-aiohttp -depends = - lint,{py35,py36,py37,py38,pypy3}-{requests,httplib2,urllib3,tornado4,boto3},{py35,py36,py37,py38}-{aiohttp}: cov-clean + httpx: httpx +depends = + lint,{py35,py36,py37,py38,pypy3}-{requests,httplib2,urllib3,tornado4,boto3,httpx},{py35,py36,py37,py38}-{aiohttp}: cov-clean cov-report: lint,{py35,py36,py37,py38,pypy3}-{requests,httplib2,urllib3,tornado4,boto3},{py35,py36,py37,py38}-{aiohttp} passenv = AWS_ACCESS_KEY_ID diff --git a/vcr/patch.py b/vcr/patch.py index 3b30407..db61ba6 100644 --- a/vcr/patch.py +++ b/vcr/patch.py @@ -94,6 +94,15 @@ else: _AiohttpClientSessionRequest = aiohttp.client.ClientSession._request +try: + import httpx +except ImportError: # pragma: no cover + pass +else: + _HttpxClient_send = httpx.Client.send + _HttpxAsyncClient_send = httpx.AsyncClient.send + + class CassettePatcherBuilder: def _build_patchers_from_mock_triples_decorator(function): @functools.wraps(function) @@ -116,6 +125,7 @@ class CassettePatcherBuilder: self._boto(), self._tornado(), self._aiohttp(), + self._httpx(), self._build_patchers_from_mock_triples(self._cassette.custom_patches), ) @@ -313,6 +323,18 @@ class CassettePatcherBuilder: new_request = vcr_request(self._cassette, _AiohttpClientSessionRequest) yield client.ClientSession, "_request", new_request + @_build_patchers_from_mock_triples_decorator + def _httpx(self): + try: + import httpx + except ImportError: # pragma: no cover + return + else: + from .stubs.httpx_stubs import async_vcr_send + + new_async_client_send = async_vcr_send(self._cassette, _HttpxAsyncClient_send) + yield httpx.AsyncClient, "send", new_async_client_send + def _urllib3_patchers(self, cpool, stubs): http_connection_remover = ConnectionRemover( self._get_cassette_subclass(stubs.VCRRequestsHTTPConnection) diff --git a/vcr/stubs/httpx_stubs.py b/vcr/stubs/httpx_stubs.py new file mode 100644 index 0000000..07f677c --- /dev/null +++ b/vcr/stubs/httpx_stubs.py @@ -0,0 +1,105 @@ +import functools +import logging +from unittest.mock import patch, MagicMock + +from yarl import URL +import httpx +from vcr.request import Request as VcrRequest +from vcr.errors import CannotOverwriteExistingCassetteException + + +_logger = logging.getLogger(__name__) + + +def _transform_headers(httpx_reponse): + return { + key.decode("utf-8"): var.decode("utf-8") for (key, var) in dict(httpx_reponse.headers.raw).items() + } + + +def _to_serialized_response(httpx_reponse): + return { + "status_code": httpx_reponse.status_code, + "http_version": httpx_reponse.http_version, + "headers": _transform_headers(httpx_reponse), + "content": httpx_reponse.content.decode("utf-8"), + } + + +@patch("httpx.Response.close", MagicMock()) +def _from_serialized_response(request, serialized_response, history=None): + return httpx.Response( + status_code=serialized_response.get("status_code"), + request=request, + http_version=serialized_response.get("http_version"), + headers=serialized_response.get("headers"), + content=serialized_response.get("content"), + history=history or [], + ) + + +def _make_vcr_request(httpx_request, **kwargs): + body = httpx_request.read().decode("utf-8") + uri = str(httpx_request.url) + headers = dict(httpx_request.headers) + return VcrRequest(httpx_request.method, uri, body, headers) + + +def _shared_vcr_send(cassette, real_send, *args, **kwargs): + real_request = args[1] + vcr_request = _make_vcr_request(real_request, **kwargs) + + if cassette.can_play_response_for(vcr_request): + return vcr_request, _play_responses(cassette, real_request, vcr_request) + + if cassette.write_protected and cassette.filter_request(vcr_request): + raise CannotOverwriteExistingCassetteException(cassette=cassette, failed_request=vcr_request) + + _logger.info("%s not in cassette, sending to real server", vcr_request) + return vcr_request, None + + +def _record_responses(cassette, vcr_request, real_response): + for past_real_response in real_response.history: + past_vcr_request = _make_vcr_request(past_real_response.request) + cassette.append(past_vcr_request, _to_serialized_response(past_real_response)) + + vcr_request = _make_vcr_request(real_response.request) + cassette.append(vcr_request, _to_serialized_response(real_response)) + return real_response + + +def _play_responses(cassette, request, vcr_request): + history = [] + vcr_response = cassette.play_response(vcr_request) + response = _from_serialized_response(request, vcr_response) + + while 300 <= response.status_code <= 399: + location = response.headers["location"] + next_url = URL(str(response.url)).with_path(location) + + vcr_request = VcrRequest("GET", str(next_url), None, dict(response.headers)) + vcr_request = cassette.find_requests_with_most_matches(vcr_request)[0][0] + + history.append(response) + vcr_response = cassette.play_response(vcr_request) + response = _from_serialized_response(vcr_request, vcr_response, history) + + return response + + +async def _async_vcr_send(cassette, real_send, *args, **kwargs): + vcr_request, response = _shared_vcr_send(cassette, real_send, *args, **kwargs) + if response: + return response + + real_response = await real_send(*args, **kwargs) + return _record_responses(cassette, vcr_request, real_response) + + +def async_vcr_send(cassette, real_send): + @functools.wraps(real_send) + def _inner_send(*args, **kwargs): + return _async_vcr_send(cassette, real_send, *args, **kwargs) + + return _inner_send