1
0
mirror of https://github.com/kevin1024/vcrpy.git synced 2025-12-08 16:53:23 +00:00

Merge pull request #851 from sathieu/consume-body-once2

Ensure body is consumed only once (alternative to #847)
This commit is contained in:
Sebastian Pipping
2024-08-02 19:16:38 +02:00
committed by GitHub
4 changed files with 114 additions and 3 deletions

View File

@@ -1,9 +1,12 @@
import contextlib import contextlib
import http.client as httplib
from io import BytesIO
from tempfile import NamedTemporaryFile
from unittest import mock from unittest import mock
from pytest import mark from pytest import mark
from vcr import mode from vcr import mode, use_cassette
from vcr.cassette import Cassette from vcr.cassette import Cassette
from vcr.stubs import VCRHTTPSConnection from vcr.stubs import VCRHTTPSConnection
@@ -21,3 +24,52 @@ class TestVCRConnection:
vcr_connection.cassette = Cassette("test", record_mode=mode.ALL) vcr_connection.cassette = Cassette("test", record_mode=mode.ALL)
vcr_connection.real_connection.connect() vcr_connection.real_connection.connect()
assert vcr_connection.real_connection.sock is not None assert vcr_connection.real_connection.sock is not None
def test_body_consumed_once_stream(self, tmpdir, httpbin):
self._test_body_consumed_once(
tmpdir,
httpbin,
BytesIO(b"1234567890"),
BytesIO(b"9876543210"),
BytesIO(b"9876543210"),
)
def test_body_consumed_once_iterator(self, tmpdir, httpbin):
self._test_body_consumed_once(
tmpdir,
httpbin,
iter([b"1234567890"]),
iter([b"9876543210"]),
iter([b"9876543210"]),
)
# data2 and data3 should serve the same data, potentially as iterators
def _test_body_consumed_once(
self,
tmpdir,
httpbin,
data1,
data2,
data3,
):
with NamedTemporaryFile(dir=tmpdir, suffix=".yml") as f:
testpath = f.name
# NOTE: ``use_cassette`` is not okay with the file existing
# already. So we using ``.close()`` to not only
# close but also delete the empty file, before we start.
f.close()
host, port = httpbin.host, httpbin.port
match_on = ["method", "uri", "body"]
with use_cassette(testpath, match_on=match_on):
conn1 = httplib.HTTPConnection(host, port)
conn1.request("POST", "/anything", body=data1)
conn1.getresponse()
conn2 = httplib.HTTPConnection(host, port)
conn2.request("POST", "/anything", body=data2)
conn2.getresponse()
with use_cassette(testpath, match_on=match_on) as cass:
conn3 = httplib.HTTPConnection(host, port)
conn3.request("POST", "/anything", body=data3)
conn3.getresponse()
assert cass.play_counts[0] == 0
assert cass.play_counts[1] == 1

33
tests/unit/test_util.py Normal file
View File

@@ -0,0 +1,33 @@
from io import BytesIO, StringIO
import pytest
from vcr import request
from vcr.util import read_body
@pytest.mark.parametrize(
"input_, expected_output",
[
(BytesIO(b"Stream"), b"Stream"),
(StringIO("Stream"), b"Stream"),
(iter(["StringIter"]), b"StringIter"),
(iter(["String", "Iter"]), b"StringIter"),
(iter([b"BytesIter"]), b"BytesIter"),
(iter([b"Bytes", b"Iter"]), b"BytesIter"),
(iter([70, 111, 111]), b"Foo"),
(iter([]), b""),
("String", b"String"),
(b"Bytes", b"Bytes"),
],
)
def test_read_body(input_, expected_output):
r = request.Request("POST", "http://host.com/", input_, {})
assert read_body(r) == expected_output
def test_unsupported_read_body():
r = request.Request("POST", "http://host.com/", iter([[]]), {})
with pytest.raises(ValueError) as excinfo:
assert read_body(r)
assert excinfo.value.args == ("Body type <class 'list'> not supported",)

View File

@@ -3,7 +3,7 @@ import warnings
from io import BytesIO from io import BytesIO
from urllib.parse import parse_qsl, urlparse from urllib.parse import parse_qsl, urlparse
from .util import CaseInsensitiveDict from .util import CaseInsensitiveDict, _is_nonsequence_iterator
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
@@ -17,8 +17,11 @@ class Request:
self.method = method self.method = method
self.uri = uri self.uri = uri
self._was_file = hasattr(body, "read") self._was_file = hasattr(body, "read")
self._was_iter = _is_nonsequence_iterator(body)
if self._was_file: if self._was_file:
self.body = body.read() self.body = body.read()
elif self._was_iter:
self.body = list(body)
else: else:
self.body = body self.body = body
self.headers = headers self.headers = headers
@@ -36,7 +39,11 @@ class Request:
@property @property
def body(self): def body(self):
return BytesIO(self._body) if self._was_file else self._body if self._was_file:
return BytesIO(self._body)
if self._was_iter:
return iter(self._body)
return self._body
@body.setter @body.setter
def body(self, value): def body(self, value):

View File

@@ -89,9 +89,28 @@ def compose(*functions):
return composed return composed
def _is_nonsequence_iterator(obj):
return hasattr(obj, "__iter__") and not isinstance(
obj,
(bytearray, bytes, dict, list, str),
)
def read_body(request): def read_body(request):
if hasattr(request.body, "read"): if hasattr(request.body, "read"):
return request.body.read() return request.body.read()
if _is_nonsequence_iterator(request.body):
body = list(request.body)
if body:
if isinstance(body[0], str):
return "".join(body).encode("utf-8")
elif isinstance(body[0], (bytes, bytearray)):
return b"".join(body)
elif isinstance(body[0], int):
return bytes(body)
else:
raise ValueError(f"Body type {type(body[0])} not supported")
return b""
return request.body return request.body