mirror of
https://github.com/kevin1024/vcrpy.git
synced 2025-12-08 16:53:23 +00:00
Merge pull request #851 from sathieu/consume-body-once2
Ensure body is consumed only once (alternative to #847)
This commit is contained in:
@@ -1,9 +1,12 @@
|
|||||||
import contextlib
|
import contextlib
|
||||||
|
import http.client as httplib
|
||||||
|
from io import BytesIO
|
||||||
|
from tempfile import NamedTemporaryFile
|
||||||
from unittest import mock
|
from unittest import mock
|
||||||
|
|
||||||
from pytest import mark
|
from pytest import mark
|
||||||
|
|
||||||
from vcr import mode
|
from vcr import mode, use_cassette
|
||||||
from vcr.cassette import Cassette
|
from vcr.cassette import Cassette
|
||||||
from vcr.stubs import VCRHTTPSConnection
|
from vcr.stubs import VCRHTTPSConnection
|
||||||
|
|
||||||
@@ -21,3 +24,52 @@ class TestVCRConnection:
|
|||||||
vcr_connection.cassette = Cassette("test", record_mode=mode.ALL)
|
vcr_connection.cassette = Cassette("test", record_mode=mode.ALL)
|
||||||
vcr_connection.real_connection.connect()
|
vcr_connection.real_connection.connect()
|
||||||
assert vcr_connection.real_connection.sock is not None
|
assert vcr_connection.real_connection.sock is not None
|
||||||
|
|
||||||
|
def test_body_consumed_once_stream(self, tmpdir, httpbin):
|
||||||
|
self._test_body_consumed_once(
|
||||||
|
tmpdir,
|
||||||
|
httpbin,
|
||||||
|
BytesIO(b"1234567890"),
|
||||||
|
BytesIO(b"9876543210"),
|
||||||
|
BytesIO(b"9876543210"),
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_body_consumed_once_iterator(self, tmpdir, httpbin):
|
||||||
|
self._test_body_consumed_once(
|
||||||
|
tmpdir,
|
||||||
|
httpbin,
|
||||||
|
iter([b"1234567890"]),
|
||||||
|
iter([b"9876543210"]),
|
||||||
|
iter([b"9876543210"]),
|
||||||
|
)
|
||||||
|
|
||||||
|
# data2 and data3 should serve the same data, potentially as iterators
|
||||||
|
def _test_body_consumed_once(
|
||||||
|
self,
|
||||||
|
tmpdir,
|
||||||
|
httpbin,
|
||||||
|
data1,
|
||||||
|
data2,
|
||||||
|
data3,
|
||||||
|
):
|
||||||
|
with NamedTemporaryFile(dir=tmpdir, suffix=".yml") as f:
|
||||||
|
testpath = f.name
|
||||||
|
# NOTE: ``use_cassette`` is not okay with the file existing
|
||||||
|
# already. So we using ``.close()`` to not only
|
||||||
|
# close but also delete the empty file, before we start.
|
||||||
|
f.close()
|
||||||
|
host, port = httpbin.host, httpbin.port
|
||||||
|
match_on = ["method", "uri", "body"]
|
||||||
|
with use_cassette(testpath, match_on=match_on):
|
||||||
|
conn1 = httplib.HTTPConnection(host, port)
|
||||||
|
conn1.request("POST", "/anything", body=data1)
|
||||||
|
conn1.getresponse()
|
||||||
|
conn2 = httplib.HTTPConnection(host, port)
|
||||||
|
conn2.request("POST", "/anything", body=data2)
|
||||||
|
conn2.getresponse()
|
||||||
|
with use_cassette(testpath, match_on=match_on) as cass:
|
||||||
|
conn3 = httplib.HTTPConnection(host, port)
|
||||||
|
conn3.request("POST", "/anything", body=data3)
|
||||||
|
conn3.getresponse()
|
||||||
|
assert cass.play_counts[0] == 0
|
||||||
|
assert cass.play_counts[1] == 1
|
||||||
|
|||||||
33
tests/unit/test_util.py
Normal file
33
tests/unit/test_util.py
Normal file
@@ -0,0 +1,33 @@
|
|||||||
|
from io import BytesIO, StringIO
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from vcr import request
|
||||||
|
from vcr.util import read_body
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"input_, expected_output",
|
||||||
|
[
|
||||||
|
(BytesIO(b"Stream"), b"Stream"),
|
||||||
|
(StringIO("Stream"), b"Stream"),
|
||||||
|
(iter(["StringIter"]), b"StringIter"),
|
||||||
|
(iter(["String", "Iter"]), b"StringIter"),
|
||||||
|
(iter([b"BytesIter"]), b"BytesIter"),
|
||||||
|
(iter([b"Bytes", b"Iter"]), b"BytesIter"),
|
||||||
|
(iter([70, 111, 111]), b"Foo"),
|
||||||
|
(iter([]), b""),
|
||||||
|
("String", b"String"),
|
||||||
|
(b"Bytes", b"Bytes"),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_read_body(input_, expected_output):
|
||||||
|
r = request.Request("POST", "http://host.com/", input_, {})
|
||||||
|
assert read_body(r) == expected_output
|
||||||
|
|
||||||
|
|
||||||
|
def test_unsupported_read_body():
|
||||||
|
r = request.Request("POST", "http://host.com/", iter([[]]), {})
|
||||||
|
with pytest.raises(ValueError) as excinfo:
|
||||||
|
assert read_body(r)
|
||||||
|
assert excinfo.value.args == ("Body type <class 'list'> not supported",)
|
||||||
@@ -3,7 +3,7 @@ import warnings
|
|||||||
from io import BytesIO
|
from io import BytesIO
|
||||||
from urllib.parse import parse_qsl, urlparse
|
from urllib.parse import parse_qsl, urlparse
|
||||||
|
|
||||||
from .util import CaseInsensitiveDict
|
from .util import CaseInsensitiveDict, _is_nonsequence_iterator
|
||||||
|
|
||||||
log = logging.getLogger(__name__)
|
log = logging.getLogger(__name__)
|
||||||
|
|
||||||
@@ -17,8 +17,11 @@ class Request:
|
|||||||
self.method = method
|
self.method = method
|
||||||
self.uri = uri
|
self.uri = uri
|
||||||
self._was_file = hasattr(body, "read")
|
self._was_file = hasattr(body, "read")
|
||||||
|
self._was_iter = _is_nonsequence_iterator(body)
|
||||||
if self._was_file:
|
if self._was_file:
|
||||||
self.body = body.read()
|
self.body = body.read()
|
||||||
|
elif self._was_iter:
|
||||||
|
self.body = list(body)
|
||||||
else:
|
else:
|
||||||
self.body = body
|
self.body = body
|
||||||
self.headers = headers
|
self.headers = headers
|
||||||
@@ -36,7 +39,11 @@ class Request:
|
|||||||
|
|
||||||
@property
|
@property
|
||||||
def body(self):
|
def body(self):
|
||||||
return BytesIO(self._body) if self._was_file else self._body
|
if self._was_file:
|
||||||
|
return BytesIO(self._body)
|
||||||
|
if self._was_iter:
|
||||||
|
return iter(self._body)
|
||||||
|
return self._body
|
||||||
|
|
||||||
@body.setter
|
@body.setter
|
||||||
def body(self, value):
|
def body(self, value):
|
||||||
|
|||||||
19
vcr/util.py
19
vcr/util.py
@@ -89,9 +89,28 @@ def compose(*functions):
|
|||||||
return composed
|
return composed
|
||||||
|
|
||||||
|
|
||||||
|
def _is_nonsequence_iterator(obj):
|
||||||
|
return hasattr(obj, "__iter__") and not isinstance(
|
||||||
|
obj,
|
||||||
|
(bytearray, bytes, dict, list, str),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def read_body(request):
|
def read_body(request):
|
||||||
if hasattr(request.body, "read"):
|
if hasattr(request.body, "read"):
|
||||||
return request.body.read()
|
return request.body.read()
|
||||||
|
if _is_nonsequence_iterator(request.body):
|
||||||
|
body = list(request.body)
|
||||||
|
if body:
|
||||||
|
if isinstance(body[0], str):
|
||||||
|
return "".join(body).encode("utf-8")
|
||||||
|
elif isinstance(body[0], (bytes, bytearray)):
|
||||||
|
return b"".join(body)
|
||||||
|
elif isinstance(body[0], int):
|
||||||
|
return bytes(body)
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Body type {type(body[0])} not supported")
|
||||||
|
return b""
|
||||||
return request.body
|
return request.body
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user