1
0
mirror of https://github.com/kevin1024/vcrpy.git synced 2025-12-09 01:03:24 +00:00

Merge pull request #801 from graingert/fix-resource-warning-2

This commit is contained in:
Thomas Grainger
2024-01-23 12:39:35 +00:00
committed by GitHub
11 changed files with 107 additions and 63 deletions

View File

@@ -5,7 +5,19 @@ ignore-regex = "\\\\[fnrstv]"
# ignore-words-list = '' # ignore-words-list = ''
[tool.pytest.ini_options] [tool.pytest.ini_options]
addopts = [
"--strict-config",
"--strict-markers",
]
markers = ["online"] markers = ["online"]
filterwarnings = [
"error",
'''ignore:datetime\.datetime\.utcfromtimestamp\(\) is deprecated and scheduled for removal in a future version.*:DeprecationWarning''',
'''ignore:There is no current event loop:DeprecationWarning''',
'''ignore:make_current is deprecated; start the event loop first:DeprecationWarning''',
'''ignore:clear_current is deprecated:DeprecationWarning''',
'''ignore:the \(type, exc, tb\) signature of throw\(\) is deprecated, use the single-arg signature instead.:DeprecationWarning''',
]
[tool.ruff] [tool.ruff]
select = [ select = [

View File

@@ -67,6 +67,7 @@ extras_require = {
"pytest-asyncio", "pytest-asyncio",
"pytest-cov", "pytest-cov",
"pytest-httpbin", "pytest-httpbin",
"pytest-tornado",
"pytest", "pytest",
"requests>=2.22.0", "requests>=2.22.0",
"tornado", "tornado",

View File

@@ -5,24 +5,24 @@ import aiohttp
async def aiohttp_request(loop, method, url, output="text", encoding="utf-8", content_type=None, **kwargs): async def aiohttp_request(loop, method, url, output="text", encoding="utf-8", content_type=None, **kwargs):
session = aiohttp.ClientSession(loop=loop) async with aiohttp.ClientSession(loop=loop) as session:
response_ctx = session.request(method, url, **kwargs) response_ctx = session.request(method, url, **kwargs)
response = await response_ctx.__aenter__() response = await response_ctx.__aenter__()
if output == "text": if output == "text":
content = await response.text() content = await response.text()
elif output == "json": elif output == "json":
content_type = content_type or "application/json" content_type = content_type or "application/json"
content = await response.json(encoding=encoding, content_type=content_type) content = await response.json(encoding=encoding, content_type=content_type)
elif output == "raw": elif output == "raw":
content = await response.read() content = await response.read()
elif output == "stream": elif output == "stream":
content = await response.content.read() content = await response.content.read()
response_ctx._resp.close() response_ctx._resp.close()
await session.close() await session.close()
return response, content return response, content
def aiohttp_app(): def aiohttp_app():

View File

@@ -1,4 +1,3 @@
import contextlib
import logging import logging
import urllib.parse import urllib.parse
@@ -14,10 +13,10 @@ from .aiohttp_utils import aiohttp_app, aiohttp_request # noqa: E402
def run_in_loop(fn): def run_in_loop(fn):
with contextlib.closing(asyncio.new_event_loop()) as loop: async def wrapper():
asyncio.set_event_loop(loop) return await fn(asyncio.get_running_loop())
task = loop.create_task(fn(loop))
return loop.run_until_complete(task) return asyncio.run(wrapper())
def request(method, url, output="text", **kwargs): def request(method, url, output="text", **kwargs):
@@ -260,6 +259,12 @@ def test_aiohttp_test_client_json(aiohttp_client, tmpdir):
assert cassette.play_count == 1 assert cassette.play_count == 1
def test_cleanup_from_pytest_asyncio():
# work around https://github.com/pytest-dev/pytest-asyncio/issues/724
asyncio.get_event_loop().close()
asyncio.set_event_loop(None)
@pytest.mark.online @pytest.mark.online
def test_redirect(tmpdir, httpbin): def test_redirect(tmpdir, httpbin):
url = httpbin.url + "/redirect/2" url = httpbin.url + "/redirect/2"

View File

@@ -32,25 +32,37 @@ class DoSyncRequest(BaseDoRequest):
_client_class = httpx.Client _client_class = httpx.Client
def __enter__(self): def __enter__(self):
self._client = self._make_client()
return self return self
def __exit__(self, *args): def __exit__(self, *args):
pass self._client.close()
del self._client
@property @property
def client(self): def client(self):
try: try:
return self._client return self._client
except AttributeError: except AttributeError as e:
self._client = self._make_client() raise ValueError('To access sync client, use "with do_request() as client"') from e
return self._client
def __call__(self, *args, **kwargs): def __call__(self, *args, **kwargs):
return self.client.request(*args, timeout=60, **kwargs) if hasattr(self, "_client"):
return self.client.request(*args, timeout=60, **kwargs)
# Use one-time context and dispose of the client afterwards
with self:
return self.client.request(*args, timeout=60, **kwargs)
def stream(self, *args, **kwargs): def stream(self, *args, **kwargs):
with self.client.stream(*args, **kwargs) as response: if hasattr(self, "_client"):
return b"".join(response.iter_bytes()) with self.client.stream(*args, **kwargs) as response:
return b"".join(response.iter_bytes())
# Use one-time context and dispose of the client afterwards
with self:
with self.client.stream(*args, **kwargs) as response:
return b"".join(response.iter_bytes())
class DoAsyncRequest(BaseDoRequest): class DoAsyncRequest(BaseDoRequest):

View File

@@ -39,12 +39,12 @@ class Proxy(http.server.SimpleHTTPRequestHandler):
@pytest.fixture(scope="session") @pytest.fixture(scope="session")
def proxy_server(): def proxy_server():
httpd = socketserver.ThreadingTCPServer(("", 0), Proxy) with socketserver.ThreadingTCPServer(("", 0), Proxy) as httpd:
proxy_process = threading.Thread(target=httpd.serve_forever) proxy_process = threading.Thread(target=httpd.serve_forever)
proxy_process.start() proxy_process.start()
yield "http://{}:{}".format(*httpd.server_address) yield "http://{}:{}".format(*httpd.server_address)
httpd.shutdown() httpd.shutdown()
proxy_process.join() proxy_process.join()
def test_use_proxy(tmpdir, httpbin, proxy_server): def test_use_proxy(tmpdir, httpbin, proxy_server):

View File

@@ -15,6 +15,13 @@ http = pytest.importorskip("tornado.httpclient")
# whether the current version of Tornado supports the raise_error argument for # whether the current version of Tornado supports the raise_error argument for
# fetch(). # fetch().
supports_raise_error = tornado.version_info >= (4,) supports_raise_error = tornado.version_info >= (4,)
raise_error_for_response_code_only = tornado.version_info >= (6,)
@pytest.fixture(params=["https", "http"])
def scheme(request):
"""Fixture that returns both http and https."""
return request.param
@pytest.fixture(params=["simple", "curl", "default"]) @pytest.fixture(params=["simple", "curl", "default"])
@@ -44,6 +51,7 @@ def post(client, url, data=None, **kwargs):
return client.fetch(http.HTTPRequest(url, method="POST", **kwargs)) return client.fetch(http.HTTPRequest(url, method="POST", **kwargs))
@pytest.mark.online
@pytest.mark.gen_test @pytest.mark.gen_test
def test_status_code(get_client, scheme, tmpdir): def test_status_code(get_client, scheme, tmpdir):
"""Ensure that we can read the status code""" """Ensure that we can read the status code"""
@@ -56,6 +64,7 @@ def test_status_code(get_client, scheme, tmpdir):
assert 1 == cass.play_count assert 1 == cass.play_count
@pytest.mark.online
@pytest.mark.gen_test @pytest.mark.gen_test
def test_headers(get_client, scheme, tmpdir): def test_headers(get_client, scheme, tmpdir):
"""Ensure that we can read the headers back""" """Ensure that we can read the headers back"""
@@ -68,6 +77,7 @@ def test_headers(get_client, scheme, tmpdir):
assert 1 == cass.play_count assert 1 == cass.play_count
@pytest.mark.online
@pytest.mark.gen_test @pytest.mark.gen_test
def test_body(get_client, tmpdir, scheme): def test_body(get_client, tmpdir, scheme):
"""Ensure the responses are all identical enough""" """Ensure the responses are all identical enough"""
@@ -94,6 +104,7 @@ def test_effective_url(get_client, tmpdir, httpbin):
assert 1 == cass.play_count assert 1 == cass.play_count
@pytest.mark.online
@pytest.mark.gen_test @pytest.mark.gen_test
def test_auth(get_client, tmpdir, scheme): def test_auth(get_client, tmpdir, scheme):
"""Ensure that we can handle basic auth""" """Ensure that we can handle basic auth"""
@@ -109,6 +120,7 @@ def test_auth(get_client, tmpdir, scheme):
assert 1 == cass.play_count assert 1 == cass.play_count
@pytest.mark.online
@pytest.mark.gen_test @pytest.mark.gen_test
def test_auth_failed(get_client, tmpdir, scheme): def test_auth_failed(get_client, tmpdir, scheme):
"""Ensure that we can save failed auth statuses""" """Ensure that we can save failed auth statuses"""
@@ -132,6 +144,7 @@ def test_auth_failed(get_client, tmpdir, scheme):
assert 1 == cass.play_count assert 1 == cass.play_count
@pytest.mark.online
@pytest.mark.gen_test @pytest.mark.gen_test
def test_post(get_client, tmpdir, scheme): def test_post(get_client, tmpdir, scheme):
"""Ensure that we can post and cache the results""" """Ensure that we can post and cache the results"""
@@ -148,9 +161,9 @@ def test_post(get_client, tmpdir, scheme):
@pytest.mark.gen_test @pytest.mark.gen_test
def test_redirects(get_client, tmpdir, scheme): def test_redirects(get_client, tmpdir, httpbin):
"""Ensure that we can handle redirects""" """Ensure that we can handle redirects"""
url = scheme + "://mockbin.org/redirect/301?url=bytes/1024" url = httpbin + "/redirect-to?url=bytes/1024&status_code=301"
with vcr.use_cassette(str(tmpdir.join("requests.yaml"))): with vcr.use_cassette(str(tmpdir.join("requests.yaml"))):
content = (yield get(get_client(), url)).body content = (yield get(get_client(), url)).body
@@ -159,6 +172,7 @@ def test_redirects(get_client, tmpdir, scheme):
assert cass.play_count == 1 assert cass.play_count == 1
@pytest.mark.online
@pytest.mark.gen_test @pytest.mark.gen_test
def test_cross_scheme(get_client, tmpdir, scheme): def test_cross_scheme(get_client, tmpdir, scheme):
"""Ensure that requests between schemes are treated separately""" """Ensure that requests between schemes are treated separately"""
@@ -178,6 +192,7 @@ def test_cross_scheme(get_client, tmpdir, scheme):
assert cass.play_count == 2 assert cass.play_count == 2
@pytest.mark.online
@pytest.mark.gen_test @pytest.mark.gen_test
def test_gzip(get_client, tmpdir, scheme): def test_gzip(get_client, tmpdir, scheme):
""" """
@@ -203,6 +218,7 @@ def test_gzip(get_client, tmpdir, scheme):
assert 1 == cass.play_count assert 1 == cass.play_count
@pytest.mark.online
@pytest.mark.gen_test @pytest.mark.gen_test
def test_https_with_cert_validation_disabled(get_client, tmpdir): def test_https_with_cert_validation_disabled(get_client, tmpdir):
cass_path = str(tmpdir.join("cert_validation_disabled.yaml")) cass_path = str(tmpdir.join("cert_validation_disabled.yaml"))
@@ -233,6 +249,10 @@ def test_unsupported_features_raises_in_future(get_client, tmpdir):
@pytest.mark.skipif(not supports_raise_error, reason="raise_error unavailable in tornado <= 3") @pytest.mark.skipif(not supports_raise_error, reason="raise_error unavailable in tornado <= 3")
@pytest.mark.skipif(
raise_error_for_response_code_only,
reason="raise_error only ignores HTTPErrors due to response code",
)
@pytest.mark.gen_test @pytest.mark.gen_test
def test_unsupported_features_raise_error_disabled(get_client, tmpdir): def test_unsupported_features_raise_error_disabled(get_client, tmpdir):
"""Ensure that the exception for an AsyncHTTPClient feature not being """Ensure that the exception for an AsyncHTTPClient feature not being
@@ -252,6 +272,7 @@ def test_unsupported_features_raise_error_disabled(get_client, tmpdir):
assert "not yet supported by VCR" in str(response.error) assert "not yet supported by VCR" in str(response.error)
@pytest.mark.online
@pytest.mark.gen_test @pytest.mark.gen_test
def test_cannot_overwrite_cassette_raises_in_future(get_client, tmpdir): def test_cannot_overwrite_cassette_raises_in_future(get_client, tmpdir):
"""Ensure that CannotOverwriteExistingCassetteException is raised inside """Ensure that CannotOverwriteExistingCassetteException is raised inside
@@ -268,6 +289,10 @@ def test_cannot_overwrite_cassette_raises_in_future(get_client, tmpdir):
@pytest.mark.skipif(not supports_raise_error, reason="raise_error unavailable in tornado <= 3") @pytest.mark.skipif(not supports_raise_error, reason="raise_error unavailable in tornado <= 3")
@pytest.mark.skipif(
raise_error_for_response_code_only,
reason="raise_error only ignores HTTPErrors due to response code",
)
@pytest.mark.gen_test @pytest.mark.gen_test
def test_cannot_overwrite_cassette_raise_error_disabled(get_client, tmpdir): def test_cannot_overwrite_cassette_raise_error_disabled(get_client, tmpdir):
"""Ensure that CannotOverwriteExistingCassetteException is not raised if """Ensure that CannotOverwriteExistingCassetteException is not raised if
@@ -303,6 +328,7 @@ def test_tornado_exception_can_be_caught(get_client):
assert e.code == 404 assert e.code == 404
@pytest.mark.online
@pytest.mark.gen_test @pytest.mark.gen_test
def test_existing_references_get_patched(tmpdir): def test_existing_references_get_patched(tmpdir):
from tornado.httpclient import AsyncHTTPClient from tornado.httpclient import AsyncHTTPClient
@@ -316,6 +342,7 @@ def test_existing_references_get_patched(tmpdir):
assert cass.play_count == 1 assert cass.play_count == 1
@pytest.mark.online
@pytest.mark.gen_test @pytest.mark.gen_test
def test_existing_instances_get_patched(get_client, tmpdir): def test_existing_instances_get_patched(get_client, tmpdir):
"""Ensure that existing instances of AsyncHTTPClient get patched upon """Ensure that existing instances of AsyncHTTPClient get patched upon
@@ -331,6 +358,7 @@ def test_existing_instances_get_patched(get_client, tmpdir):
assert cass.play_count == 1 assert cass.play_count == 1
@pytest.mark.online
@pytest.mark.gen_test @pytest.mark.gen_test
def test_request_time_is_set(get_client, tmpdir): def test_request_time_is_set(get_client, tmpdir):
"""Ensures that the request_time on HTTPResponses is set.""" """Ensures that the request_time on HTTPResponses is set."""

View File

@@ -63,12 +63,12 @@ def test_flickr_should_respond_with_200(tmpdir):
def test_cookies(tmpdir, httpbin): def test_cookies(tmpdir, httpbin):
testfile = str(tmpdir.join("cookies.yml")) testfile = str(tmpdir.join("cookies.yml"))
with vcr.use_cassette(testfile): with vcr.use_cassette(testfile):
s = requests.Session() with requests.Session() as s:
s.get(httpbin.url + "/cookies/set?k1=v1&k2=v2") s.get(httpbin.url + "/cookies/set?k1=v1&k2=v2")
assert s.cookies.keys() == ["k1", "k2"] assert s.cookies.keys() == ["k1", "k2"]
r2 = s.get(httpbin.url + "/cookies") r2 = s.get(httpbin.url + "/cookies")
assert sorted(r2.json()["cookies"].keys()) == ["k1", "k2"] assert sorted(r2.json()["cookies"].keys()) == ["k1", "k2"]
@pytest.mark.online @pytest.mark.online

View File

@@ -1,3 +1,4 @@
import contextlib
from unittest import mock from unittest import mock
from pytest import mark from pytest import mark
@@ -16,7 +17,7 @@ class TestVCRConnection:
@mark.online @mark.online
@mock.patch("vcr.cassette.Cassette.can_play_response_for", return_value=False) @mock.patch("vcr.cassette.Cassette.can_play_response_for", return_value=False)
def testing_connect(*args): def testing_connect(*args):
vcr_connection = VCRHTTPSConnection("www.google.com") with contextlib.closing(VCRHTTPSConnection("www.google.com")) as vcr_connection:
vcr_connection.cassette = Cassette("test", record_mode=mode.ALL) vcr_connection.cassette = Cassette("test", record_mode=mode.ALL)
vcr_connection.real_connection.connect() vcr_connection.real_connection.connect()
assert vcr_connection.real_connection.sock is not None assert vcr_connection.real_connection.sock is not None

View File

@@ -3,7 +3,6 @@ import contextlib
import copy import copy
import inspect import inspect
import logging import logging
import sys
from asyncio import iscoroutinefunction from asyncio import iscoroutinefunction
import wrapt import wrapt
@@ -126,20 +125,7 @@ class CassetteContextDecorator:
duration of the generator. duration of the generator.
""" """
with self as cassette: with self as cassette:
coroutine = fn(cassette) yield from fn(cassette)
# We don't need to catch StopIteration. The caller (Tornado's
# gen.coroutine, for example) will handle that.
to_yield = next(coroutine)
while True:
try:
to_send = yield to_yield
except Exception:
to_yield = coroutine.throw(*sys.exc_info())
else:
try:
to_yield = coroutine.send(to_send)
except StopIteration:
break
def _handle_function(self, fn): def _handle_function(self, fn):
with self as cassette: with self as cassette:

View File

@@ -372,10 +372,6 @@ class ConnectionRemover:
if isinstance(connection, self._connection_class): if isinstance(connection, self._connection_class):
self._connection_pool_to_connections.setdefault(pool, set()).add(connection) self._connection_pool_to_connections.setdefault(pool, set()).add(connection)
def remove_connection_to_pool_entry(self, pool, connection):
if isinstance(connection, self._connection_class):
self._connection_pool_to_connections[self._connection_class].remove(connection)
def __enter__(self): def __enter__(self):
return self return self
@@ -386,10 +382,13 @@ class ConnectionRemover:
connection = pool.pool.get() connection = pool.pool.get()
if isinstance(connection, self._connection_class): if isinstance(connection, self._connection_class):
connections.remove(connection) connections.remove(connection)
connection.close()
else: else:
readd_connections.append(connection) readd_connections.append(connection)
for connection in readd_connections: for connection in readd_connections:
pool._put_conn(connection) pool._put_conn(connection)
for connection in connections:
connection.close()
def reset_patchers(): def reset_patchers():