diff --git a/tests/integration/test_httpx.py b/tests/integration/test_httpx.py index 507628c..fee9eeb 100644 --- a/tests/integration/test_httpx.py +++ b/tests/integration/test_httpx.py @@ -1,7 +1,11 @@ +from unittest.mock import MagicMock import pytest import contextlib import vcr # noqa: E402 +from vcr.stubs.httpx_stubs import _get_next_url + + asyncio = pytest.importorskip("asyncio") httpx = pytest.importorskip("httpx") @@ -151,3 +155,34 @@ def test_work_with_gzipped_data(tmpdir, do_request, yml): assert "gzip" in cassette_response.json()["headers"]["Accept-Encoding"] assert cassette_response.read() assert cassette.play_count == 1 + + +@pytest.mark.parametrize("url", [f"http://github.com/kevin1024/vcrpy/issues/{i}" for i in range(3, 6)]) +def test_simple_fetching(tmpdir, do_request, yml, url): + with vcr.use_cassette(yml): + response = do_request()("GET", url) + + with vcr.use_cassette(yml) as cassette: + cassette_response = do_request()("GET", url) + cassette_response.request.url == url + assert cassette.play_count == 1 + + +class TestGetNextUrl: + def test_relative_location(self): + response = MagicMock() + response.url = "http://github.com/" + response.headers = {"location": "relative"} + assert str(_get_next_url(response)) == "http://github.com/relative" + + def test_absolute_location(self): + response = MagicMock() + response.url = "http://github.com/" + response.headers = {"location": "http://google.com"} + assert str(_get_next_url(response)) == "http://google.com" + + def test_no_location(self): + response = MagicMock() + response.url = "http://github.com/" + response.headers = {} + assert _get_next_url(response) is None diff --git a/vcr/stubs/httpx_stubs.py b/vcr/stubs/httpx_stubs.py index 072adfc..ff5011b 100644 --- a/vcr/stubs/httpx_stubs.py +++ b/vcr/stubs/httpx_stubs.py @@ -73,14 +73,27 @@ def _record_responses(cassette, vcr_request, real_response): return real_response +def _get_next_url(response): + location_str = response.headers.get("location") + if not location_str: + return None + + next_url = URL(location_str) + if not next_url.is_absolute(): + next_url = URL(str(response.url)).with_path(str(next_url)) + + return next_url + + def _play_responses(cassette, request, vcr_request): history = [] vcr_response = cassette.play_response(vcr_request) response = _from_serialized_response(request, vcr_response) while 300 <= response.status_code <= 399: - location = response.headers["location"] - next_url = URL(str(response.url)).with_path(location) + next_url = _get_next_url(response) + if not next_url: + break vcr_request = VcrRequest("GET", str(next_url), None, dict(response.headers)) vcr_request = cassette.find_requests_with_most_matches(vcr_request)[0][0]