diff --git a/vcr/stubs.py b/vcr/stubs.py index f22ee35..a1d50a7 100644 --- a/vcr/stubs.py +++ b/vcr/stubs.py @@ -5,6 +5,9 @@ from .cassette import Cassette class VCRHTTPResponse(object): + """ + Stub reponse class that gets returned instead of a HTTPResponse + """ def __init__(self, recorded_response): self.recorded_response = recorded_response self.reason = recorded_response['status']['message'] @@ -19,12 +22,7 @@ class VCRHTTPResponse(object): return self._content.read() -class VCRHTTPConnection(HTTPConnection): - - def __init__(self, *args, **kwargs): - self._cassette = Cassette() - HTTPConnection.__init__(self, *args, **kwargs) - +class VCRConnectionMixin: def _load_old_response(self): old_cassette = load_cassette(self._vcr_cassette_path) if old_cassette: @@ -49,7 +47,7 @@ class VCRHTTPConnection(HTTPConnection): body=body, headers=headers )) - HTTPConnection.request(self, method, url, body=body, headers=headers) + self._baseclass.request(self, method, url, body=body, headers=headers) def getresponse(self, buffering=False): old_response = self._load_old_response() @@ -65,14 +63,19 @@ class VCRHTTPConnection(HTTPConnection): return VCRHTTPResponse(old_response) -class VCRHTTPSConnection(HTTPSConnection): - """ - Note that this is pretty much a copy-and-paste of the - VCRHTTPConnection class. I couldn't figure out how to - do multiple inheritance to get this to work without - duplicating code. These are old-style classes which - I frankly don't understand. - """ +class VCRHTTPConnection(VCRConnectionMixin, HTTPConnection): + + # Can't use super since this is an old-style class + _baseclass = HTTPConnection + + def __init__(self, *args, **kwargs): + self._cassette = Cassette() + HTTPConnection.__init__(self, *args, **kwargs) + + +class VCRHTTPSConnection(VCRConnectionMixin, HTTPSConnection): + + _baseclass = HTTPSConnection def __init__(self, *args, **kwargs): """ @@ -85,42 +88,3 @@ class VCRHTTPSConnection(HTTPSConnection): self.key_file = kwargs.pop('key_file', None) self.cert_file = kwargs.pop('cert_file', None) self._cassette = Cassette() - - def _load_old_response(self): - old_cassette = load_cassette(self._vcr_cassette_path) - if old_cassette: - return old_cassette.get_response(self._vcr) - - def request(self, method, url, body=None, headers={}): - """ - Persist the request metadata in self._vcr - """ - self._vcr = { - 'method': method, - 'url': url, - 'body': body, - 'headers': headers, - } - old_cassette = load_cassette(self._vcr_cassette_path) - if old_cassette and old_cassette.get_request(self._vcr): - return - self._cassette.requests.append(dict( - method=method, - url=url, - body=body, - headers=headers - )) - HTTPSConnection.request(self, method, url, body=body, headers=headers) - - def getresponse(self, buffering=False): - old_response = self._load_old_response() - if not old_response: - response = HTTPConnection.getresponse(self) - self._cassette.responses.append({ - 'status': {'code': response.status, 'message': response.reason}, - 'headers': dict(response.getheaders()), - 'body': {'string': response.read()}, - }) - save_cassette(self._vcr_cassette_path, self._cassette) - old_response = self._load_old_response() - return VCRHTTPResponse(old_response)