From c4a33d1cffceaa1faed1fe9db48e3bc3a7ecf864 Mon Sep 17 00:00:00 2001 From: Abhinav Gupta Date: Wed, 12 Aug 2015 10:42:00 -0700 Subject: [PATCH] For Tornado AsyncHTTPClient, replace the methods instead of the class. This makes it so patching works even if the user has a reference to, or an instance of the original unpatched AsyncHTTPClient class. Fixes #183. --- tests/integration/test_tornado.py | 28 +++++++++++++ vcr/patch.py | 40 +++++++++++------- vcr/stubs/tornado_stubs.py | 68 +++++-------------------------- 3 files changed, 64 insertions(+), 72 deletions(-) diff --git a/tests/integration/test_tornado.py b/tests/integration/test_tornado.py index 7fd2574..6f88d67 100644 --- a/tests/integration/test_tornado.py +++ b/tests/integration/test_tornado.py @@ -298,3 +298,31 @@ def test_tornado_exception_can_be_caught(get_client): yield get(get_client(), 'http://httpbin.org/status/404') except http.HTTPError as e: assert e.code == 404 + + +@pytest.mark.gen_test +def test_existing_references_get_patched(tmpdir): + from tornado.httpclient import AsyncHTTPClient + + with vcr.use_cassette(str(tmpdir.join('data.yaml'))): + client = AsyncHTTPClient() + yield get(client, 'http://httpbin.org/get') + + with vcr.use_cassette(str(tmpdir.join('data.yaml'))) as cass: + yield get(client, 'http://httpbin.org/get') + assert cass.play_count == 1 + + +@pytest.mark.gen_test +def test_existing_instances_get_patched(get_client, tmpdir): + '''Ensure that existing instances of AsyncHTTPClient get patched upon + entering VCR context.''' + + client = get_client() + + with vcr.use_cassette(str(tmpdir.join('data.yaml'))): + yield get(client, 'http://httpbin.org/get') + + with vcr.use_cassette(str(tmpdir.join('data.yaml'))) as cass: + yield get(client, 'http://httpbin.org/get') + assert cass.play_count == 1 diff --git a/vcr/patch.py b/vcr/patch.py index f34b3fe..8400c68 100644 --- a/vcr/patch.py +++ b/vcr/patch.py @@ -54,13 +54,12 @@ else: # Try to save the original types for Tornado try: - import tornado.httpclient import tornado.simple_httpclient except ImportError: # pragma: no cover pass else: - _AsyncHTTPClient = tornado.httpclient.AsyncHTTPClient - _SimpleAsyncHTTPClient = tornado.simple_httpclient.SimpleAsyncHTTPClient + _SimpleAsyncHTTPClient_fetch_impl = \ + tornado.simple_httpclient.SimpleAsyncHTTPClient.fetch_impl try: @@ -68,7 +67,8 @@ try: except ImportError: # pragma: no cover pass else: - _CurlAsyncHTTPClient = tornado.curl_httpclient.CurlAsyncHTTPClient + _CurlAsyncHTTPClient_fetch_impl = \ + tornado.curl_httpclient.CurlAsyncHTTPClient.fetch_impl class CassettePatcherBuilder(object): @@ -228,23 +228,27 @@ class CassettePatcherBuilder(object): @_build_patchers_from_mock_triples_decorator def _tornado(self): try: - import tornado.httpclient as http import tornado.simple_httpclient as simple except ImportError: # pragma: no cover pass else: - from .stubs.tornado_stubs import VCRAsyncHTTPClient - from .stubs.tornado_stubs import VCRSimpleAsyncHTTPClient + from .stubs.tornado_stubs import vcr_fetch_impl - yield http, 'AsyncHTTPClient', VCRAsyncHTTPClient - yield simple, 'SimpleAsyncHTTPClient', VCRSimpleAsyncHTTPClient + new_fetch_impl = vcr_fetch_impl( + self._cassette, _SimpleAsyncHTTPClient_fetch_impl + ) + yield simple.SimpleAsyncHTTPClient, 'fetch_impl', new_fetch_impl try: import tornado.curl_httpclient as curl except ImportError: # pragma: no cover pass else: - from .stubs.tornado_stubs import VCRCurlAsyncHTTPClient - yield curl, 'CurlAsyncHTTPClient', VCRCurlAsyncHTTPClient + from .stubs.tornado_stubs import vcr_fetch_impl + + new_fetch_impl = vcr_fetch_impl( + self._cassette, _CurlAsyncHTTPClient_fetch_impl + ) + yield curl.CurlAsyncHTTPClient, 'fetch_impl', new_fetch_impl def _urllib3_patchers(self, cpool, stubs): http_connection_remover = ConnectionRemover( @@ -362,19 +366,25 @@ def reset_patchers(): _CertValidatingHTTPSConnection) try: - import tornado.httpclient as http import tornado.simple_httpclient as simple except ImportError: # pragma: no cover pass else: - yield mock.patch.object(http, 'AsyncHTTPClient', _AsyncHTTPClient) - yield mock.patch.object(simple, 'SimpleAsyncHTTPClient', _SimpleAsyncHTTPClient) + yield mock.patch.object( + simple.SimpleAsyncHTTPClient, + 'fetch_impl', + _SimpleAsyncHTTPClient_fetch_impl, + ) try: import tornado.curl_httpclient as curl except ImportError: # pragma: no cover pass else: - yield mock.patch.object(curl, 'CurlAsyncHTTPClient', _CurlAsyncHTTPClient) + yield mock.patch.object( + curl.CurlAsyncHTTPClient, + 'fetch_impl', + _CurlAsyncHTTPClient_fetch_impl, + ) @contextlib.contextmanager diff --git a/vcr/stubs/tornado_stubs.py b/vcr/stubs/tornado_stubs.py index 7a16cd8..e5d93c5 100644 --- a/vcr/stubs/tornado_stubs.py +++ b/vcr/stubs/tornado_stubs.py @@ -1,48 +1,20 @@ '''Stubs for tornado HTTP clients''' from __future__ import absolute_import +import functools from six import BytesIO from tornado import httputil -from tornado.httpclient import AsyncHTTPClient from tornado.httpclient import HTTPResponse -from tornado.simple_httpclient import SimpleAsyncHTTPClient from vcr.errors import CannotOverwriteExistingCassetteException from vcr.request import Request -class _VCRAsyncClient(object): - cassette = None +def vcr_fetch_impl(cassette, real_fetch_impl): - def __new__(cls, *args, **kwargs): - from vcr.patch import force_reset - with force_reset(): - return super(_VCRAsyncClient, cls).__new__(cls, *args, **kwargs) - - def initialize(self, *args, **kwargs): - from vcr.patch import force_reset - with force_reset(): - self.real_client = self._baseclass(*args, **kwargs) - - @property - def io_loop(self): - return self.real_client.io_loop - - @property - def _closed(self): - return self.real_client._closed - - @property - def defaults(self): - return self.real_client.defaults - - def close(self): - from vcr.patch import force_reset - with force_reset(): - self.real_client.close() - - def fetch_impl(self, request, callback): + @functools.wraps(real_fetch_impl) + def new_fetch_impl(self, request, callback): headers = dict(request.headers) if request.user_agent: headers.setdefault('User-Agent', request.user_agent) @@ -74,8 +46,8 @@ class _VCRAsyncClient(object): headers, ) - if self.cassette.can_play_response_for(vcr_request): - vcr_response = self.cassette.play_response(vcr_request) + if cassette.can_play_response_for(vcr_request): + vcr_response = cassette.play_response(vcr_request) headers = httputil.HTTPHeaders() recorded_headers = vcr_response['headers'] @@ -93,7 +65,7 @@ class _VCRAsyncClient(object): ) return callback(response) else: - if self.cassette.write_protected and self.cassette.filter_request( + if cassette.write_protected and cassette.filter_request( vcr_request ): response = HTTPResponse( @@ -103,8 +75,7 @@ class _VCRAsyncClient(object): "No match for the request (%r) was found. " "Can't overwrite existing cassette (%r) in " "your current record mode (%r)." - % (vcr_request, self.cassette._path, - self.cassette.record_mode) + % (vcr_request, cassette._path, cassette.record_mode) ), ) return callback(response) @@ -123,26 +94,9 @@ class _VCRAsyncClient(object): 'headers': headers, 'body': {'string': response.body}, } - self.cassette.append(vcr_request, vcr_response) + cassette.append(vcr_request, vcr_response) return callback(response) - from vcr.patch import force_reset - with force_reset(): - self.real_client.fetch_impl(request, new_callback) + real_fetch_impl(self, request, new_callback) - -class VCRAsyncHTTPClient(_VCRAsyncClient, AsyncHTTPClient): - _baseclass = AsyncHTTPClient - - -class VCRSimpleAsyncHTTPClient(_VCRAsyncClient, SimpleAsyncHTTPClient): - _baseclass = SimpleAsyncHTTPClient - - -try: - from tornado.curl_httpclient import CurlAsyncHTTPClient -except ImportError: # pragma: no cover - VCRCurlAsyncHTTPClient = None -else: - class VCRCurlAsyncHTTPClient(_VCRAsyncClient, CurlAsyncHTTPClient): - _baseclass = CurlAsyncHTTPClient + return new_fetch_impl