mirror of
https://github.com/kevin1024/vcrpy.git
synced 2025-12-09 01:03:24 +00:00
Merge pull request #801 from graingert/fix-resource-warning-2
This commit is contained in:
@@ -5,7 +5,19 @@ ignore-regex = "\\\\[fnrstv]"
|
|||||||
# ignore-words-list = ''
|
# ignore-words-list = ''
|
||||||
|
|
||||||
[tool.pytest.ini_options]
|
[tool.pytest.ini_options]
|
||||||
|
addopts = [
|
||||||
|
"--strict-config",
|
||||||
|
"--strict-markers",
|
||||||
|
]
|
||||||
markers = ["online"]
|
markers = ["online"]
|
||||||
|
filterwarnings = [
|
||||||
|
"error",
|
||||||
|
'''ignore:datetime\.datetime\.utcfromtimestamp\(\) is deprecated and scheduled for removal in a future version.*:DeprecationWarning''',
|
||||||
|
'''ignore:There is no current event loop:DeprecationWarning''',
|
||||||
|
'''ignore:make_current is deprecated; start the event loop first:DeprecationWarning''',
|
||||||
|
'''ignore:clear_current is deprecated:DeprecationWarning''',
|
||||||
|
'''ignore:the \(type, exc, tb\) signature of throw\(\) is deprecated, use the single-arg signature instead.:DeprecationWarning''',
|
||||||
|
]
|
||||||
|
|
||||||
[tool.ruff]
|
[tool.ruff]
|
||||||
select = [
|
select = [
|
||||||
|
|||||||
1
setup.py
1
setup.py
@@ -67,6 +67,7 @@ extras_require = {
|
|||||||
"pytest-asyncio",
|
"pytest-asyncio",
|
||||||
"pytest-cov",
|
"pytest-cov",
|
||||||
"pytest-httpbin",
|
"pytest-httpbin",
|
||||||
|
"pytest-tornado",
|
||||||
"pytest",
|
"pytest",
|
||||||
"requests>=2.22.0",
|
"requests>=2.22.0",
|
||||||
"tornado",
|
"tornado",
|
||||||
|
|||||||
@@ -5,24 +5,24 @@ import aiohttp
|
|||||||
|
|
||||||
|
|
||||||
async def aiohttp_request(loop, method, url, output="text", encoding="utf-8", content_type=None, **kwargs):
|
async def aiohttp_request(loop, method, url, output="text", encoding="utf-8", content_type=None, **kwargs):
|
||||||
session = aiohttp.ClientSession(loop=loop)
|
async with aiohttp.ClientSession(loop=loop) as session:
|
||||||
response_ctx = session.request(method, url, **kwargs)
|
response_ctx = session.request(method, url, **kwargs)
|
||||||
|
|
||||||
response = await response_ctx.__aenter__()
|
response = await response_ctx.__aenter__()
|
||||||
if output == "text":
|
if output == "text":
|
||||||
content = await response.text()
|
content = await response.text()
|
||||||
elif output == "json":
|
elif output == "json":
|
||||||
content_type = content_type or "application/json"
|
content_type = content_type or "application/json"
|
||||||
content = await response.json(encoding=encoding, content_type=content_type)
|
content = await response.json(encoding=encoding, content_type=content_type)
|
||||||
elif output == "raw":
|
elif output == "raw":
|
||||||
content = await response.read()
|
content = await response.read()
|
||||||
elif output == "stream":
|
elif output == "stream":
|
||||||
content = await response.content.read()
|
content = await response.content.read()
|
||||||
|
|
||||||
response_ctx._resp.close()
|
response_ctx._resp.close()
|
||||||
await session.close()
|
await session.close()
|
||||||
|
|
||||||
return response, content
|
return response, content
|
||||||
|
|
||||||
|
|
||||||
def aiohttp_app():
|
def aiohttp_app():
|
||||||
|
|||||||
@@ -1,4 +1,3 @@
|
|||||||
import contextlib
|
|
||||||
import logging
|
import logging
|
||||||
import urllib.parse
|
import urllib.parse
|
||||||
|
|
||||||
@@ -14,10 +13,10 @@ from .aiohttp_utils import aiohttp_app, aiohttp_request # noqa: E402
|
|||||||
|
|
||||||
|
|
||||||
def run_in_loop(fn):
|
def run_in_loop(fn):
|
||||||
with contextlib.closing(asyncio.new_event_loop()) as loop:
|
async def wrapper():
|
||||||
asyncio.set_event_loop(loop)
|
return await fn(asyncio.get_running_loop())
|
||||||
task = loop.create_task(fn(loop))
|
|
||||||
return loop.run_until_complete(task)
|
return asyncio.run(wrapper())
|
||||||
|
|
||||||
|
|
||||||
def request(method, url, output="text", **kwargs):
|
def request(method, url, output="text", **kwargs):
|
||||||
@@ -260,6 +259,12 @@ def test_aiohttp_test_client_json(aiohttp_client, tmpdir):
|
|||||||
assert cassette.play_count == 1
|
assert cassette.play_count == 1
|
||||||
|
|
||||||
|
|
||||||
|
def test_cleanup_from_pytest_asyncio():
|
||||||
|
# work around https://github.com/pytest-dev/pytest-asyncio/issues/724
|
||||||
|
asyncio.get_event_loop().close()
|
||||||
|
asyncio.set_event_loop(None)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.online
|
@pytest.mark.online
|
||||||
def test_redirect(tmpdir, httpbin):
|
def test_redirect(tmpdir, httpbin):
|
||||||
url = httpbin.url + "/redirect/2"
|
url = httpbin.url + "/redirect/2"
|
||||||
|
|||||||
@@ -32,25 +32,37 @@ class DoSyncRequest(BaseDoRequest):
|
|||||||
_client_class = httpx.Client
|
_client_class = httpx.Client
|
||||||
|
|
||||||
def __enter__(self):
|
def __enter__(self):
|
||||||
|
self._client = self._make_client()
|
||||||
return self
|
return self
|
||||||
|
|
||||||
def __exit__(self, *args):
|
def __exit__(self, *args):
|
||||||
pass
|
self._client.close()
|
||||||
|
del self._client
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def client(self):
|
def client(self):
|
||||||
try:
|
try:
|
||||||
return self._client
|
return self._client
|
||||||
except AttributeError:
|
except AttributeError as e:
|
||||||
self._client = self._make_client()
|
raise ValueError('To access sync client, use "with do_request() as client"') from e
|
||||||
return self._client
|
|
||||||
|
|
||||||
def __call__(self, *args, **kwargs):
|
def __call__(self, *args, **kwargs):
|
||||||
return self.client.request(*args, timeout=60, **kwargs)
|
if hasattr(self, "_client"):
|
||||||
|
return self.client.request(*args, timeout=60, **kwargs)
|
||||||
|
|
||||||
|
# Use one-time context and dispose of the client afterwards
|
||||||
|
with self:
|
||||||
|
return self.client.request(*args, timeout=60, **kwargs)
|
||||||
|
|
||||||
def stream(self, *args, **kwargs):
|
def stream(self, *args, **kwargs):
|
||||||
with self.client.stream(*args, **kwargs) as response:
|
if hasattr(self, "_client"):
|
||||||
return b"".join(response.iter_bytes())
|
with self.client.stream(*args, **kwargs) as response:
|
||||||
|
return b"".join(response.iter_bytes())
|
||||||
|
|
||||||
|
# Use one-time context and dispose of the client afterwards
|
||||||
|
with self:
|
||||||
|
with self.client.stream(*args, **kwargs) as response:
|
||||||
|
return b"".join(response.iter_bytes())
|
||||||
|
|
||||||
|
|
||||||
class DoAsyncRequest(BaseDoRequest):
|
class DoAsyncRequest(BaseDoRequest):
|
||||||
|
|||||||
@@ -39,12 +39,12 @@ class Proxy(http.server.SimpleHTTPRequestHandler):
|
|||||||
|
|
||||||
@pytest.fixture(scope="session")
|
@pytest.fixture(scope="session")
|
||||||
def proxy_server():
|
def proxy_server():
|
||||||
httpd = socketserver.ThreadingTCPServer(("", 0), Proxy)
|
with socketserver.ThreadingTCPServer(("", 0), Proxy) as httpd:
|
||||||
proxy_process = threading.Thread(target=httpd.serve_forever)
|
proxy_process = threading.Thread(target=httpd.serve_forever)
|
||||||
proxy_process.start()
|
proxy_process.start()
|
||||||
yield "http://{}:{}".format(*httpd.server_address)
|
yield "http://{}:{}".format(*httpd.server_address)
|
||||||
httpd.shutdown()
|
httpd.shutdown()
|
||||||
proxy_process.join()
|
proxy_process.join()
|
||||||
|
|
||||||
|
|
||||||
def test_use_proxy(tmpdir, httpbin, proxy_server):
|
def test_use_proxy(tmpdir, httpbin, proxy_server):
|
||||||
|
|||||||
@@ -15,6 +15,13 @@ http = pytest.importorskip("tornado.httpclient")
|
|||||||
# whether the current version of Tornado supports the raise_error argument for
|
# whether the current version of Tornado supports the raise_error argument for
|
||||||
# fetch().
|
# fetch().
|
||||||
supports_raise_error = tornado.version_info >= (4,)
|
supports_raise_error = tornado.version_info >= (4,)
|
||||||
|
raise_error_for_response_code_only = tornado.version_info >= (6,)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(params=["https", "http"])
|
||||||
|
def scheme(request):
|
||||||
|
"""Fixture that returns both http and https."""
|
||||||
|
return request.param
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(params=["simple", "curl", "default"])
|
@pytest.fixture(params=["simple", "curl", "default"])
|
||||||
@@ -44,6 +51,7 @@ def post(client, url, data=None, **kwargs):
|
|||||||
return client.fetch(http.HTTPRequest(url, method="POST", **kwargs))
|
return client.fetch(http.HTTPRequest(url, method="POST", **kwargs))
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.online
|
||||||
@pytest.mark.gen_test
|
@pytest.mark.gen_test
|
||||||
def test_status_code(get_client, scheme, tmpdir):
|
def test_status_code(get_client, scheme, tmpdir):
|
||||||
"""Ensure that we can read the status code"""
|
"""Ensure that we can read the status code"""
|
||||||
@@ -56,6 +64,7 @@ def test_status_code(get_client, scheme, tmpdir):
|
|||||||
assert 1 == cass.play_count
|
assert 1 == cass.play_count
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.online
|
||||||
@pytest.mark.gen_test
|
@pytest.mark.gen_test
|
||||||
def test_headers(get_client, scheme, tmpdir):
|
def test_headers(get_client, scheme, tmpdir):
|
||||||
"""Ensure that we can read the headers back"""
|
"""Ensure that we can read the headers back"""
|
||||||
@@ -68,6 +77,7 @@ def test_headers(get_client, scheme, tmpdir):
|
|||||||
assert 1 == cass.play_count
|
assert 1 == cass.play_count
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.online
|
||||||
@pytest.mark.gen_test
|
@pytest.mark.gen_test
|
||||||
def test_body(get_client, tmpdir, scheme):
|
def test_body(get_client, tmpdir, scheme):
|
||||||
"""Ensure the responses are all identical enough"""
|
"""Ensure the responses are all identical enough"""
|
||||||
@@ -94,6 +104,7 @@ def test_effective_url(get_client, tmpdir, httpbin):
|
|||||||
assert 1 == cass.play_count
|
assert 1 == cass.play_count
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.online
|
||||||
@pytest.mark.gen_test
|
@pytest.mark.gen_test
|
||||||
def test_auth(get_client, tmpdir, scheme):
|
def test_auth(get_client, tmpdir, scheme):
|
||||||
"""Ensure that we can handle basic auth"""
|
"""Ensure that we can handle basic auth"""
|
||||||
@@ -109,6 +120,7 @@ def test_auth(get_client, tmpdir, scheme):
|
|||||||
assert 1 == cass.play_count
|
assert 1 == cass.play_count
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.online
|
||||||
@pytest.mark.gen_test
|
@pytest.mark.gen_test
|
||||||
def test_auth_failed(get_client, tmpdir, scheme):
|
def test_auth_failed(get_client, tmpdir, scheme):
|
||||||
"""Ensure that we can save failed auth statuses"""
|
"""Ensure that we can save failed auth statuses"""
|
||||||
@@ -132,6 +144,7 @@ def test_auth_failed(get_client, tmpdir, scheme):
|
|||||||
assert 1 == cass.play_count
|
assert 1 == cass.play_count
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.online
|
||||||
@pytest.mark.gen_test
|
@pytest.mark.gen_test
|
||||||
def test_post(get_client, tmpdir, scheme):
|
def test_post(get_client, tmpdir, scheme):
|
||||||
"""Ensure that we can post and cache the results"""
|
"""Ensure that we can post and cache the results"""
|
||||||
@@ -148,9 +161,9 @@ def test_post(get_client, tmpdir, scheme):
|
|||||||
|
|
||||||
|
|
||||||
@pytest.mark.gen_test
|
@pytest.mark.gen_test
|
||||||
def test_redirects(get_client, tmpdir, scheme):
|
def test_redirects(get_client, tmpdir, httpbin):
|
||||||
"""Ensure that we can handle redirects"""
|
"""Ensure that we can handle redirects"""
|
||||||
url = scheme + "://mockbin.org/redirect/301?url=bytes/1024"
|
url = httpbin + "/redirect-to?url=bytes/1024&status_code=301"
|
||||||
with vcr.use_cassette(str(tmpdir.join("requests.yaml"))):
|
with vcr.use_cassette(str(tmpdir.join("requests.yaml"))):
|
||||||
content = (yield get(get_client(), url)).body
|
content = (yield get(get_client(), url)).body
|
||||||
|
|
||||||
@@ -159,6 +172,7 @@ def test_redirects(get_client, tmpdir, scheme):
|
|||||||
assert cass.play_count == 1
|
assert cass.play_count == 1
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.online
|
||||||
@pytest.mark.gen_test
|
@pytest.mark.gen_test
|
||||||
def test_cross_scheme(get_client, tmpdir, scheme):
|
def test_cross_scheme(get_client, tmpdir, scheme):
|
||||||
"""Ensure that requests between schemes are treated separately"""
|
"""Ensure that requests between schemes are treated separately"""
|
||||||
@@ -178,6 +192,7 @@ def test_cross_scheme(get_client, tmpdir, scheme):
|
|||||||
assert cass.play_count == 2
|
assert cass.play_count == 2
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.online
|
||||||
@pytest.mark.gen_test
|
@pytest.mark.gen_test
|
||||||
def test_gzip(get_client, tmpdir, scheme):
|
def test_gzip(get_client, tmpdir, scheme):
|
||||||
"""
|
"""
|
||||||
@@ -203,6 +218,7 @@ def test_gzip(get_client, tmpdir, scheme):
|
|||||||
assert 1 == cass.play_count
|
assert 1 == cass.play_count
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.online
|
||||||
@pytest.mark.gen_test
|
@pytest.mark.gen_test
|
||||||
def test_https_with_cert_validation_disabled(get_client, tmpdir):
|
def test_https_with_cert_validation_disabled(get_client, tmpdir):
|
||||||
cass_path = str(tmpdir.join("cert_validation_disabled.yaml"))
|
cass_path = str(tmpdir.join("cert_validation_disabled.yaml"))
|
||||||
@@ -233,6 +249,10 @@ def test_unsupported_features_raises_in_future(get_client, tmpdir):
|
|||||||
|
|
||||||
|
|
||||||
@pytest.mark.skipif(not supports_raise_error, reason="raise_error unavailable in tornado <= 3")
|
@pytest.mark.skipif(not supports_raise_error, reason="raise_error unavailable in tornado <= 3")
|
||||||
|
@pytest.mark.skipif(
|
||||||
|
raise_error_for_response_code_only,
|
||||||
|
reason="raise_error only ignores HTTPErrors due to response code",
|
||||||
|
)
|
||||||
@pytest.mark.gen_test
|
@pytest.mark.gen_test
|
||||||
def test_unsupported_features_raise_error_disabled(get_client, tmpdir):
|
def test_unsupported_features_raise_error_disabled(get_client, tmpdir):
|
||||||
"""Ensure that the exception for an AsyncHTTPClient feature not being
|
"""Ensure that the exception for an AsyncHTTPClient feature not being
|
||||||
@@ -252,6 +272,7 @@ def test_unsupported_features_raise_error_disabled(get_client, tmpdir):
|
|||||||
assert "not yet supported by VCR" in str(response.error)
|
assert "not yet supported by VCR" in str(response.error)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.online
|
||||||
@pytest.mark.gen_test
|
@pytest.mark.gen_test
|
||||||
def test_cannot_overwrite_cassette_raises_in_future(get_client, tmpdir):
|
def test_cannot_overwrite_cassette_raises_in_future(get_client, tmpdir):
|
||||||
"""Ensure that CannotOverwriteExistingCassetteException is raised inside
|
"""Ensure that CannotOverwriteExistingCassetteException is raised inside
|
||||||
@@ -268,6 +289,10 @@ def test_cannot_overwrite_cassette_raises_in_future(get_client, tmpdir):
|
|||||||
|
|
||||||
|
|
||||||
@pytest.mark.skipif(not supports_raise_error, reason="raise_error unavailable in tornado <= 3")
|
@pytest.mark.skipif(not supports_raise_error, reason="raise_error unavailable in tornado <= 3")
|
||||||
|
@pytest.mark.skipif(
|
||||||
|
raise_error_for_response_code_only,
|
||||||
|
reason="raise_error only ignores HTTPErrors due to response code",
|
||||||
|
)
|
||||||
@pytest.mark.gen_test
|
@pytest.mark.gen_test
|
||||||
def test_cannot_overwrite_cassette_raise_error_disabled(get_client, tmpdir):
|
def test_cannot_overwrite_cassette_raise_error_disabled(get_client, tmpdir):
|
||||||
"""Ensure that CannotOverwriteExistingCassetteException is not raised if
|
"""Ensure that CannotOverwriteExistingCassetteException is not raised if
|
||||||
@@ -303,6 +328,7 @@ def test_tornado_exception_can_be_caught(get_client):
|
|||||||
assert e.code == 404
|
assert e.code == 404
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.online
|
||||||
@pytest.mark.gen_test
|
@pytest.mark.gen_test
|
||||||
def test_existing_references_get_patched(tmpdir):
|
def test_existing_references_get_patched(tmpdir):
|
||||||
from tornado.httpclient import AsyncHTTPClient
|
from tornado.httpclient import AsyncHTTPClient
|
||||||
@@ -316,6 +342,7 @@ def test_existing_references_get_patched(tmpdir):
|
|||||||
assert cass.play_count == 1
|
assert cass.play_count == 1
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.online
|
||||||
@pytest.mark.gen_test
|
@pytest.mark.gen_test
|
||||||
def test_existing_instances_get_patched(get_client, tmpdir):
|
def test_existing_instances_get_patched(get_client, tmpdir):
|
||||||
"""Ensure that existing instances of AsyncHTTPClient get patched upon
|
"""Ensure that existing instances of AsyncHTTPClient get patched upon
|
||||||
@@ -331,6 +358,7 @@ def test_existing_instances_get_patched(get_client, tmpdir):
|
|||||||
assert cass.play_count == 1
|
assert cass.play_count == 1
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.online
|
||||||
@pytest.mark.gen_test
|
@pytest.mark.gen_test
|
||||||
def test_request_time_is_set(get_client, tmpdir):
|
def test_request_time_is_set(get_client, tmpdir):
|
||||||
"""Ensures that the request_time on HTTPResponses is set."""
|
"""Ensures that the request_time on HTTPResponses is set."""
|
||||||
|
|||||||
@@ -63,12 +63,12 @@ def test_flickr_should_respond_with_200(tmpdir):
|
|||||||
def test_cookies(tmpdir, httpbin):
|
def test_cookies(tmpdir, httpbin):
|
||||||
testfile = str(tmpdir.join("cookies.yml"))
|
testfile = str(tmpdir.join("cookies.yml"))
|
||||||
with vcr.use_cassette(testfile):
|
with vcr.use_cassette(testfile):
|
||||||
s = requests.Session()
|
with requests.Session() as s:
|
||||||
s.get(httpbin.url + "/cookies/set?k1=v1&k2=v2")
|
s.get(httpbin.url + "/cookies/set?k1=v1&k2=v2")
|
||||||
assert s.cookies.keys() == ["k1", "k2"]
|
assert s.cookies.keys() == ["k1", "k2"]
|
||||||
|
|
||||||
r2 = s.get(httpbin.url + "/cookies")
|
r2 = s.get(httpbin.url + "/cookies")
|
||||||
assert sorted(r2.json()["cookies"].keys()) == ["k1", "k2"]
|
assert sorted(r2.json()["cookies"].keys()) == ["k1", "k2"]
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.online
|
@pytest.mark.online
|
||||||
|
|||||||
@@ -1,3 +1,4 @@
|
|||||||
|
import contextlib
|
||||||
from unittest import mock
|
from unittest import mock
|
||||||
|
|
||||||
from pytest import mark
|
from pytest import mark
|
||||||
@@ -16,7 +17,7 @@ class TestVCRConnection:
|
|||||||
@mark.online
|
@mark.online
|
||||||
@mock.patch("vcr.cassette.Cassette.can_play_response_for", return_value=False)
|
@mock.patch("vcr.cassette.Cassette.can_play_response_for", return_value=False)
|
||||||
def testing_connect(*args):
|
def testing_connect(*args):
|
||||||
vcr_connection = VCRHTTPSConnection("www.google.com")
|
with contextlib.closing(VCRHTTPSConnection("www.google.com")) as vcr_connection:
|
||||||
vcr_connection.cassette = Cassette("test", record_mode=mode.ALL)
|
vcr_connection.cassette = Cassette("test", record_mode=mode.ALL)
|
||||||
vcr_connection.real_connection.connect()
|
vcr_connection.real_connection.connect()
|
||||||
assert vcr_connection.real_connection.sock is not None
|
assert vcr_connection.real_connection.sock is not None
|
||||||
|
|||||||
@@ -3,7 +3,6 @@ import contextlib
|
|||||||
import copy
|
import copy
|
||||||
import inspect
|
import inspect
|
||||||
import logging
|
import logging
|
||||||
import sys
|
|
||||||
from asyncio import iscoroutinefunction
|
from asyncio import iscoroutinefunction
|
||||||
|
|
||||||
import wrapt
|
import wrapt
|
||||||
@@ -126,20 +125,7 @@ class CassetteContextDecorator:
|
|||||||
duration of the generator.
|
duration of the generator.
|
||||||
"""
|
"""
|
||||||
with self as cassette:
|
with self as cassette:
|
||||||
coroutine = fn(cassette)
|
yield from fn(cassette)
|
||||||
# We don't need to catch StopIteration. The caller (Tornado's
|
|
||||||
# gen.coroutine, for example) will handle that.
|
|
||||||
to_yield = next(coroutine)
|
|
||||||
while True:
|
|
||||||
try:
|
|
||||||
to_send = yield to_yield
|
|
||||||
except Exception:
|
|
||||||
to_yield = coroutine.throw(*sys.exc_info())
|
|
||||||
else:
|
|
||||||
try:
|
|
||||||
to_yield = coroutine.send(to_send)
|
|
||||||
except StopIteration:
|
|
||||||
break
|
|
||||||
|
|
||||||
def _handle_function(self, fn):
|
def _handle_function(self, fn):
|
||||||
with self as cassette:
|
with self as cassette:
|
||||||
|
|||||||
@@ -372,10 +372,6 @@ class ConnectionRemover:
|
|||||||
if isinstance(connection, self._connection_class):
|
if isinstance(connection, self._connection_class):
|
||||||
self._connection_pool_to_connections.setdefault(pool, set()).add(connection)
|
self._connection_pool_to_connections.setdefault(pool, set()).add(connection)
|
||||||
|
|
||||||
def remove_connection_to_pool_entry(self, pool, connection):
|
|
||||||
if isinstance(connection, self._connection_class):
|
|
||||||
self._connection_pool_to_connections[self._connection_class].remove(connection)
|
|
||||||
|
|
||||||
def __enter__(self):
|
def __enter__(self):
|
||||||
return self
|
return self
|
||||||
|
|
||||||
@@ -386,10 +382,13 @@ class ConnectionRemover:
|
|||||||
connection = pool.pool.get()
|
connection = pool.pool.get()
|
||||||
if isinstance(connection, self._connection_class):
|
if isinstance(connection, self._connection_class):
|
||||||
connections.remove(connection)
|
connections.remove(connection)
|
||||||
|
connection.close()
|
||||||
else:
|
else:
|
||||||
readd_connections.append(connection)
|
readd_connections.append(connection)
|
||||||
for connection in readd_connections:
|
for connection in readd_connections:
|
||||||
pool._put_conn(connection)
|
pool._put_conn(connection)
|
||||||
|
for connection in connections:
|
||||||
|
connection.close()
|
||||||
|
|
||||||
|
|
||||||
def reset_patchers():
|
def reset_patchers():
|
||||||
|
|||||||
Reference in New Issue
Block a user