mirror of
https://github.com/kevin1024/vcrpy.git
synced 2025-12-09 17:15:35 +00:00
This is required because previous extraction code is now patched out by vcpy. Also handle headers with same key in responses.
152 lines
4.8 KiB
Python
152 lines
4.8 KiB
Python
import functools
|
|
import logging
|
|
from unittest.mock import patch, MagicMock
|
|
|
|
from yarl import URL
|
|
import httpx
|
|
from vcr.request import Request as VcrRequest
|
|
from vcr.errors import CannotOverwriteExistingCassetteException
|
|
|
|
|
|
_logger = logging.getLogger(__name__)
|
|
|
|
|
|
def _transform_headers(httpx_reponse):
|
|
"""
|
|
Some headers can appear multiple times, like "Set-Cookie".
|
|
Therefore transform to every header key to list of values.
|
|
"""
|
|
|
|
out = {}
|
|
for key, var in httpx_reponse.headers.raw:
|
|
decoded_key = key.decode("utf-8")
|
|
out.setdefault(decoded_key, [])
|
|
out[decoded_key].append(var.decode("utf-8"))
|
|
return out
|
|
|
|
|
|
def _to_serialized_response(httpx_reponse):
|
|
return {
|
|
"status_code": httpx_reponse.status_code,
|
|
"http_version": httpx_reponse.http_version,
|
|
"headers": _transform_headers(httpx_reponse),
|
|
"content": httpx_reponse.content.decode("utf-8"),
|
|
}
|
|
|
|
|
|
def _from_serialized_headers(headers):
|
|
"""
|
|
httpx accepts headers as list of tuples of header key and value.
|
|
"""
|
|
|
|
header_list = []
|
|
for key, values in headers.items():
|
|
for v in values:
|
|
header_list.append((key, v))
|
|
return header_list
|
|
|
|
|
|
@patch("httpx.Response.close", MagicMock())
|
|
@patch("httpx.Response.read", MagicMock())
|
|
def _from_serialized_response(request, serialized_response, history=None):
|
|
content = serialized_response.get("content").encode()
|
|
response = httpx.Response(
|
|
status_code=serialized_response.get("status_code"),
|
|
request=request,
|
|
http_version=serialized_response.get("http_version"),
|
|
headers=_from_serialized_headers(serialized_response.get("headers")),
|
|
content=content,
|
|
history=history or [],
|
|
)
|
|
response._content = content
|
|
return response
|
|
|
|
|
|
def _make_vcr_request(httpx_request, **kwargs):
|
|
body = httpx_request.read().decode("utf-8")
|
|
uri = str(httpx_request.url)
|
|
headers = dict(httpx_request.headers)
|
|
return VcrRequest(httpx_request.method, uri, body, headers)
|
|
|
|
|
|
def _shared_vcr_send(cassette, real_send, *args, **kwargs):
|
|
real_request = args[1]
|
|
vcr_request = _make_vcr_request(real_request, **kwargs)
|
|
|
|
if cassette.can_play_response_for(vcr_request):
|
|
return vcr_request, _play_responses(cassette, real_request, vcr_request, args[0])
|
|
|
|
if cassette.write_protected and cassette.filter_request(vcr_request):
|
|
raise CannotOverwriteExistingCassetteException(cassette=cassette, failed_request=vcr_request)
|
|
|
|
_logger.info("%s not in cassette, sending to real server", vcr_request)
|
|
return vcr_request, None
|
|
|
|
|
|
def _record_responses(cassette, vcr_request, real_response):
|
|
for past_real_response in real_response.history:
|
|
past_vcr_request = _make_vcr_request(past_real_response.request)
|
|
cassette.append(past_vcr_request, _to_serialized_response(past_real_response))
|
|
|
|
if real_response.history:
|
|
# If there was a redirection keep we want the request which will hold the
|
|
# final redirect value
|
|
vcr_request = _make_vcr_request(real_response.request)
|
|
|
|
cassette.append(vcr_request, _to_serialized_response(real_response))
|
|
return real_response
|
|
|
|
|
|
def _get_next_url(response):
|
|
location_str = response.headers.get("location")
|
|
if not location_str:
|
|
return None
|
|
|
|
next_url = URL(location_str)
|
|
if not next_url.is_absolute():
|
|
next_url = URL(str(response.url)).with_path(str(next_url))
|
|
|
|
return next_url
|
|
|
|
|
|
def _play_responses(cassette, request, vcr_request, client):
|
|
history = []
|
|
vcr_response = cassette.play_response(vcr_request)
|
|
response = _from_serialized_response(request, vcr_response)
|
|
|
|
while 300 <= response.status_code <= 399:
|
|
next_url = _get_next_url(response)
|
|
if not next_url:
|
|
break
|
|
|
|
vcr_request = VcrRequest("GET", str(next_url), None, dict(response.headers))
|
|
vcr_request = cassette.find_requests_with_most_matches(vcr_request)[0][0]
|
|
|
|
history.append(response)
|
|
# add cookies from response to session cookie store
|
|
client.cookies.extract_cookies(response)
|
|
|
|
vcr_response = cassette.play_response(vcr_request)
|
|
response = _from_serialized_response(vcr_request, vcr_response, history)
|
|
|
|
return response
|
|
|
|
|
|
async def _async_vcr_send(cassette, real_send, *args, **kwargs):
|
|
vcr_request, response = _shared_vcr_send(cassette, real_send, *args, **kwargs)
|
|
if response:
|
|
# add cookies from response to session cookie store
|
|
args[0].cookies.extract_cookies(response)
|
|
return response
|
|
|
|
real_response = await real_send(*args, **kwargs)
|
|
return _record_responses(cassette, vcr_request, real_response)
|
|
|
|
|
|
def async_vcr_send(cassette, real_send):
|
|
@functools.wraps(real_send)
|
|
def _inner_send(*args, **kwargs):
|
|
return _async_vcr_send(cassette, real_send, *args, **kwargs)
|
|
|
|
return _inner_send
|