mirror of
https://github.com/kevin1024/vcrpy.git
synced 2025-12-10 09:35:34 +00:00
Refactor again to try to give Cassette a cleaner interface
This commit is contained in:
@@ -10,16 +10,6 @@ def test_cassette_load(tmpdir):
|
|||||||
a_cassette = Cassette.load(str(a_file))
|
a_cassette = Cassette.load(str(a_file))
|
||||||
assert len(a_cassette) == 1
|
assert len(a_cassette) == 1
|
||||||
|
|
||||||
def test_cassette_serialize():
|
|
||||||
a = Cassette('test')
|
|
||||||
a.append('foo','bar')
|
|
||||||
assert a.serialize() == [{'request': 'foo', 'response': 'bar'}]
|
|
||||||
|
|
||||||
def test_cassette_deserialize():
|
|
||||||
a = Cassette('test')
|
|
||||||
a.deserialize([{'request': 'foo', 'response': 'bar'}])
|
|
||||||
assert a.requests == {'foo':'bar'}
|
|
||||||
|
|
||||||
def test_cassette_not_played():
|
def test_cassette_not_played():
|
||||||
a = Cassette('test')
|
a = Cassette('test')
|
||||||
assert not a.play_count
|
assert not a.play_count
|
||||||
@@ -40,7 +30,8 @@ def test_cassette_play_counter():
|
|||||||
def test_cassette_append():
|
def test_cassette_append():
|
||||||
a = Cassette('test')
|
a = Cassette('test')
|
||||||
a.append('foo', 'bar')
|
a.append('foo', 'bar')
|
||||||
assert a.requests == {'foo':'bar'}
|
assert a.requests == ['foo']
|
||||||
|
assert a.responses == ['bar']
|
||||||
|
|
||||||
def test_cassette_len():
|
def test_cassette_len():
|
||||||
a = Cassette('test')
|
a = Cassette('test')
|
||||||
|
|||||||
@@ -9,45 +9,39 @@ except ImportError:
|
|||||||
|
|
||||||
# Internal imports
|
# Internal imports
|
||||||
from .patch import install, reset
|
from .patch import install, reset
|
||||||
from .files import load_cassette, save_cassette
|
from .persist import load_cassette, save_cassette
|
||||||
|
|
||||||
class Cassette(object):
|
class Cassette(object):
|
||||||
'''A container for recorded requests and responses'''
|
'''A container for recorded requests and responses'''
|
||||||
@classmethod
|
@classmethod
|
||||||
def load(cls, path):
|
def load(cls, path):
|
||||||
'''Load in the cassette stored at the provided path'''
|
'''Load in the cassette stored at the provided path'''
|
||||||
|
new_cassette = cls(path)
|
||||||
try:
|
try:
|
||||||
return cls(path, load_cassette(path))
|
requests, responses = load_cassette(path)
|
||||||
|
for request, response in zip(requests, responses):
|
||||||
|
new_cassette.append(request, response)
|
||||||
except IOError:
|
except IOError:
|
||||||
return cls(path)
|
pass
|
||||||
|
return new_cassette
|
||||||
|
|
||||||
def __init__(self, path, data=None):
|
def __init__(self, path):
|
||||||
self._path = path
|
self._path = path
|
||||||
self.requests = {}
|
self.data = {}
|
||||||
self.play_counts = Counter()
|
self.play_counts = Counter()
|
||||||
if data:
|
|
||||||
self.deserialize(data)
|
|
||||||
|
|
||||||
def save(self, path):
|
|
||||||
'''Save this cassette to a path'''
|
|
||||||
save_cassette(path, self.serialize())
|
|
||||||
|
|
||||||
def serialize(self):
|
|
||||||
'''Return a serializable version of the cassette'''
|
|
||||||
return ([{
|
|
||||||
'request': req,
|
|
||||||
'response': res,
|
|
||||||
} for req, res in self.requests.iteritems()])
|
|
||||||
|
|
||||||
def deserialize(self, source):
|
|
||||||
'''Given a serialized version, load the requests'''
|
|
||||||
for r in source:
|
|
||||||
self.requests[r['request']] = r['response']
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def play_count(self):
|
def play_count(self):
|
||||||
return sum(self.play_counts.values())
|
return sum(self.play_counts.values())
|
||||||
|
|
||||||
|
@property
|
||||||
|
def requests(self):
|
||||||
|
return self.data.keys()
|
||||||
|
|
||||||
|
@property
|
||||||
|
def responses(self):
|
||||||
|
return self.data.values()
|
||||||
|
|
||||||
def mark_played(self, request):
|
def mark_played(self, request):
|
||||||
'''
|
'''
|
||||||
Alert the cassette of a request that's been played
|
Alert the cassette of a request that's been played
|
||||||
@@ -56,22 +50,25 @@ class Cassette(object):
|
|||||||
|
|
||||||
def append(self, request, response):
|
def append(self, request, response):
|
||||||
'''Add a pair of request, response to this cassette'''
|
'''Add a pair of request, response to this cassette'''
|
||||||
self.requests[request] = response
|
self.data[request] = response
|
||||||
|
|
||||||
def response(self, request):
|
def response(self, request):
|
||||||
'''Find the response corresponding to a request'''
|
'''Find the response corresponding to a request'''
|
||||||
return self.requests[request]
|
return self.data[request]
|
||||||
|
|
||||||
|
def _save(self):
|
||||||
|
save_cassette(self._path, self.requests, self.responses)
|
||||||
|
|
||||||
def __str__(self):
|
def __str__(self):
|
||||||
return "<Cassette containing {0} recorded response(s)>".format(len(self))
|
return "<Cassette containing {0} recorded response(s)>".format(len(self))
|
||||||
|
|
||||||
def __len__(self):
|
def __len__(self):
|
||||||
'''Return the number of request / response pairs stored in here'''
|
'''Return the number of request / response pairs stored in here'''
|
||||||
return len(self.requests)
|
return len(self.data)
|
||||||
|
|
||||||
def __contains__(self, request):
|
def __contains__(self, request):
|
||||||
'''Return whether or not a request has been stored'''
|
'''Return whether or not a request has been stored'''
|
||||||
return request in self.requests
|
return request in self.data
|
||||||
|
|
||||||
def __enter__(self):
|
def __enter__(self):
|
||||||
'''Patch the fetching libraries we know about'''
|
'''Patch the fetching libraries we know about'''
|
||||||
@@ -79,5 +76,5 @@ class Cassette(object):
|
|||||||
return self
|
return self
|
||||||
|
|
||||||
def __exit__(self, typ, value, traceback):
|
def __exit__(self, typ, value, traceback):
|
||||||
self.save(self._path)
|
self._save()
|
||||||
reset()
|
reset()
|
||||||
|
|||||||
@@ -8,6 +8,18 @@ try:
|
|||||||
except ImportError:
|
except ImportError:
|
||||||
from yaml import Loader, Dumper
|
from yaml import Loader, Dumper
|
||||||
|
|
||||||
|
def _serialize_cassette(requests, responses):
|
||||||
|
'''Return a serializable version of the cassette'''
|
||||||
|
return ([{
|
||||||
|
'request': request,
|
||||||
|
'response': response,
|
||||||
|
} for request, response in zip(requests, responses)])
|
||||||
|
|
||||||
|
def _deserialize_cassette(data):
|
||||||
|
requests = [r['request'] for r in data]
|
||||||
|
responses = [r['response'] for r in data]
|
||||||
|
return requests, responses
|
||||||
|
|
||||||
def _secure_write(path, contents):
|
def _secure_write(path, contents):
|
||||||
"""
|
"""
|
||||||
We'll overwrite the old version securely by writing out a temporary
|
We'll overwrite the old version securely by writing out a temporary
|
||||||
@@ -20,10 +32,13 @@ def _secure_write(path, contents):
|
|||||||
os.rename(name, path)
|
os.rename(name, path)
|
||||||
|
|
||||||
def load_cassette(cassette_path):
|
def load_cassette(cassette_path):
|
||||||
return yaml.load(open(cassette_path), Loader=Loader)
|
data = yaml.load(open(cassette_path), Loader=Loader)
|
||||||
|
return _deserialize_cassette(data)
|
||||||
|
|
||||||
def save_cassette(cassette_path, data):
|
def save_cassette(cassette_path, requests, responses):
|
||||||
dirname, filename = os.path.split(cassette_path)
|
dirname, filename = os.path.split(cassette_path)
|
||||||
if not os.path.exists(dirname):
|
if not os.path.exists(dirname):
|
||||||
os.makedirs(dirname)
|
os.makedirs(dirname)
|
||||||
_secure_write(cassette_path, yaml.dump(data, Dumper=Dumper))
|
data = _serialize_cassette(requests, responses)
|
||||||
|
data = yaml.dump(data, Dumper=Dumper)
|
||||||
|
_secure_write(cassette_path, data)
|
||||||
Reference in New Issue
Block a user