mirror of
https://github.com/kevin1024/vcrpy.git
synced 2025-12-10 01:25:34 +00:00
Patch httpcore instead of httpx
This commit is contained in:
215
vcr/stubs/httpcore_stubs.py
Normal file
215
vcr/stubs/httpcore_stubs.py
Normal file
@@ -0,0 +1,215 @@
|
||||
import asyncio
|
||||
import functools
|
||||
import logging
|
||||
from collections import defaultdict
|
||||
from collections.abc import AsyncIterable, Iterable
|
||||
|
||||
from httpcore import Response
|
||||
from httpcore._models import ByteStream
|
||||
|
||||
from vcr.errors import CannotOverwriteExistingCassetteException
|
||||
from vcr.filters import decode_response
|
||||
from vcr.request import Request as VcrRequest
|
||||
from vcr.serializers.compat import convert_body_to_bytes
|
||||
|
||||
_logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
async def _convert_byte_stream(stream):
|
||||
if isinstance(stream, Iterable):
|
||||
return list(stream)
|
||||
|
||||
if isinstance(stream, AsyncIterable):
|
||||
return [part async for part in stream]
|
||||
|
||||
raise TypeError(
|
||||
f"_convert_byte_stream: stream must be Iterable or AsyncIterable, got {type(stream).__name__}",
|
||||
)
|
||||
|
||||
|
||||
def _serialize_headers(real_response):
|
||||
"""
|
||||
Some headers can appear multiple times, like "Set-Cookie".
|
||||
Therefore serialize every header key to a list of values.
|
||||
"""
|
||||
|
||||
headers = defaultdict(list)
|
||||
|
||||
for name, value in real_response.headers:
|
||||
headers[name.decode("ascii")].append(value.decode("ascii"))
|
||||
|
||||
return dict(headers)
|
||||
|
||||
|
||||
async def _serialize_response(real_response):
|
||||
# The reason_phrase may not exist
|
||||
try:
|
||||
reason_phrase = real_response.extensions["reason_phrase"].decode("ascii")
|
||||
except KeyError:
|
||||
reason_phrase = None
|
||||
|
||||
# Reading the response stream consumes the iterator, so we need to restore it afterwards
|
||||
content = b"".join(await _convert_byte_stream(real_response.stream))
|
||||
real_response.stream = ByteStream(content)
|
||||
|
||||
return {
|
||||
"status": {"code": real_response.status, "message": reason_phrase},
|
||||
"headers": _serialize_headers(real_response),
|
||||
"body": {"string": content},
|
||||
}
|
||||
|
||||
|
||||
def _deserialize_headers(headers):
|
||||
"""
|
||||
httpcore accepts headers as list of tuples of header key and value.
|
||||
"""
|
||||
|
||||
return [
|
||||
(name.encode("ascii"), value.encode("ascii")) for name, values in headers.items() for value in values
|
||||
]
|
||||
|
||||
|
||||
def _deserialize_response(vcr_response):
|
||||
# Cassette format generated for HTTPX requests by older versions of
|
||||
# vcrpy. We restructure the content to resemble what a regular
|
||||
# cassette looks like.
|
||||
if "status_code" in vcr_response:
|
||||
vcr_response = decode_response(
|
||||
convert_body_to_bytes(
|
||||
{
|
||||
"headers": vcr_response["headers"],
|
||||
"body": {"string": vcr_response["content"]},
|
||||
"status": {"code": vcr_response["status_code"]},
|
||||
},
|
||||
),
|
||||
)
|
||||
extensions = None
|
||||
else:
|
||||
extensions = (
|
||||
{"reason_phrase": vcr_response["status"]["message"].encode("ascii")}
|
||||
if vcr_response["status"]["message"]
|
||||
else None
|
||||
)
|
||||
|
||||
return Response(
|
||||
vcr_response["status"]["code"],
|
||||
headers=_deserialize_headers(vcr_response["headers"]),
|
||||
content=vcr_response["body"]["string"],
|
||||
extensions=extensions,
|
||||
)
|
||||
|
||||
|
||||
async def _make_vcr_request(real_request):
|
||||
# Reading the request stream consumes the iterator, so we need to restore it afterwards
|
||||
body = b"".join(await _convert_byte_stream(real_request.stream))
|
||||
real_request.stream = ByteStream(body)
|
||||
|
||||
uri = bytes(real_request.url).decode("ascii")
|
||||
|
||||
# As per HTTPX: If there are multiple headers with the same key, then we concatenate them with commas
|
||||
headers = defaultdict(list)
|
||||
|
||||
for name, value in real_request.headers:
|
||||
headers[name.decode("ascii")].append(value.decode("ascii"))
|
||||
|
||||
headers = {name: ", ".join(values) for name, values in headers.items()}
|
||||
|
||||
return VcrRequest(real_request.method.decode("ascii"), uri, body, headers)
|
||||
|
||||
|
||||
async def _vcr_request(cassette, real_request):
|
||||
vcr_request = await _make_vcr_request(real_request)
|
||||
|
||||
if cassette.can_play_response_for(vcr_request):
|
||||
return vcr_request, _play_responses(cassette, vcr_request)
|
||||
|
||||
if cassette.write_protected and cassette.filter_request(vcr_request):
|
||||
raise CannotOverwriteExistingCassetteException(
|
||||
cassette=cassette,
|
||||
failed_request=vcr_request,
|
||||
)
|
||||
|
||||
_logger.info("%s not in cassette, sending to real server", vcr_request)
|
||||
|
||||
return vcr_request, None
|
||||
|
||||
|
||||
async def _record_responses(cassette, vcr_request, real_response):
|
||||
cassette.append(vcr_request, await _serialize_response(real_response))
|
||||
|
||||
|
||||
def _play_responses(cassette, vcr_request):
|
||||
vcr_response = cassette.play_response(vcr_request)
|
||||
real_response = _deserialize_response(vcr_response)
|
||||
|
||||
return real_response
|
||||
|
||||
|
||||
async def _vcr_handle_async_request(
|
||||
cassette,
|
||||
real_handle_async_request,
|
||||
self,
|
||||
real_request,
|
||||
):
|
||||
vcr_request, vcr_response = await _vcr_request(cassette, real_request)
|
||||
|
||||
if vcr_response:
|
||||
return vcr_response
|
||||
|
||||
real_response = await real_handle_async_request(self, real_request)
|
||||
await _record_responses(cassette, vcr_request, real_response)
|
||||
|
||||
return real_response
|
||||
|
||||
|
||||
def vcr_handle_async_request(cassette, real_handle_async_request):
|
||||
@functools.wraps(real_handle_async_request)
|
||||
def _inner_handle_async_request(self, real_request):
|
||||
return _vcr_handle_async_request(
|
||||
cassette,
|
||||
real_handle_async_request,
|
||||
self,
|
||||
real_request,
|
||||
)
|
||||
|
||||
return _inner_handle_async_request
|
||||
|
||||
|
||||
def _run_async_function(sync_func, *args, **kwargs):
|
||||
"""
|
||||
Safely run an asynchronous function from a synchronous context.
|
||||
Handles both cases:
|
||||
- An event loop is already running.
|
||||
- No event loop exists yet.
|
||||
"""
|
||||
try:
|
||||
asyncio.get_running_loop()
|
||||
except RuntimeError:
|
||||
return asyncio.run(sync_func(*args, **kwargs))
|
||||
else:
|
||||
# If inside a running loop, create a task and wait for it
|
||||
return asyncio.ensure_future(sync_func(*args, **kwargs))
|
||||
|
||||
|
||||
def _vcr_handle_request(cassette, real_handle_request, self, real_request):
|
||||
vcr_request, vcr_response = _run_async_function(
|
||||
_vcr_request,
|
||||
cassette,
|
||||
real_request,
|
||||
)
|
||||
|
||||
if vcr_response:
|
||||
return vcr_response
|
||||
|
||||
real_response = real_handle_request(self, real_request)
|
||||
_run_async_function(_record_responses, cassette, vcr_request, real_response)
|
||||
|
||||
return real_response
|
||||
|
||||
|
||||
def vcr_handle_request(cassette, real_handle_request):
|
||||
@functools.wraps(real_handle_request)
|
||||
def _inner_handle_request(self, real_request):
|
||||
return _vcr_handle_request(cassette, real_handle_request, self, real_request)
|
||||
|
||||
return _inner_handle_request
|
||||
Reference in New Issue
Block a user