1
0
mirror of https://github.com/kevin1024/vcrpy.git synced 2025-12-09 01:03:24 +00:00
Files
vcrpy/vcr/stubs/httpx_stubs.py
Hernan Ezequiel Di Giorgi 936feb7748 fix redirection url
2020-07-24 15:34:57 -05:00

123 lines
3.9 KiB
Python

import functools
import logging
from unittest.mock import patch, MagicMock
from yarl import URL
import httpx
from vcr.request import Request as VcrRequest
from vcr.errors import CannotOverwriteExistingCassetteException
_logger = logging.getLogger(__name__)
def _transform_headers(httpx_reponse):
return {
key.decode("utf-8"): var.decode("utf-8") for (key, var) in dict(httpx_reponse.headers.raw).items()
}
def _to_serialized_response(httpx_reponse):
return {
"status_code": httpx_reponse.status_code,
"http_version": httpx_reponse.http_version,
"headers": _transform_headers(httpx_reponse),
"content": httpx_reponse.content.decode("utf-8"),
}
@patch("httpx.Response.close", MagicMock())
@patch("httpx.Response.read", MagicMock())
def _from_serialized_response(request, serialized_response, history=None):
content = serialized_response.get("content").encode()
response = httpx.Response(
status_code=serialized_response.get("status_code"),
request=request,
http_version=serialized_response.get("http_version"),
headers=serialized_response.get("headers"),
content=content,
history=history or [],
)
response._content = content
return response
def _make_vcr_request(httpx_request, **kwargs):
body = httpx_request.read().decode("utf-8")
uri = str(httpx_request.url)
headers = dict(httpx_request.headers)
return VcrRequest(httpx_request.method, uri, body, headers)
def _shared_vcr_send(cassette, real_send, *args, **kwargs):
real_request = args[1]
vcr_request = _make_vcr_request(real_request, **kwargs)
if cassette.can_play_response_for(vcr_request):
return vcr_request, _play_responses(cassette, real_request, vcr_request)
if cassette.write_protected and cassette.filter_request(vcr_request):
raise CannotOverwriteExistingCassetteException(cassette=cassette, failed_request=vcr_request)
_logger.info("%s not in cassette, sending to real server", vcr_request)
return vcr_request, None
def _record_responses(cassette, vcr_request, real_response):
for past_real_response in real_response.history:
past_vcr_request = _make_vcr_request(past_real_response.request)
cassette.append(past_vcr_request, _to_serialized_response(past_real_response))
vcr_request = _make_vcr_request(real_response.request)
cassette.append(vcr_request, _to_serialized_response(real_response))
return real_response
def _get_next_url(response):
location_str = response.headers.get("location")
if not location_str:
return None
next_url = URL(location_str)
if not next_url.is_absolute():
next_url = URL(str(response.url)).with_path(str(next_url))
return next_url
def _play_responses(cassette, request, vcr_request):
history = []
vcr_response = cassette.play_response(vcr_request)
response = _from_serialized_response(request, vcr_response)
while 300 <= response.status_code <= 399:
next_url = _get_next_url(response)
if not next_url:
break
vcr_request = VcrRequest("GET", str(next_url), None, dict(response.headers))
vcr_request = cassette.find_requests_with_most_matches(vcr_request)[0][0]
history.append(response)
vcr_response = cassette.play_response(vcr_request)
response = _from_serialized_response(vcr_request, vcr_response, history)
return response
async def _async_vcr_send(cassette, real_send, *args, **kwargs):
vcr_request, response = _shared_vcr_send(cassette, real_send, *args, **kwargs)
if response:
return response
real_response = await real_send(*args, **kwargs)
return _record_responses(cassette, vcr_request, real_response)
def async_vcr_send(cassette, real_send):
@functools.wraps(real_send)
def _inner_send(*args, **kwargs):
return _async_vcr_send(cassette, real_send, *args, **kwargs)
return _inner_send