From 4f0668471e5cb6f22ccae8bad12212529bb1377a Mon Sep 17 00:00:00 2001 From: Kevin McCarthy Date: Sat, 10 Aug 2013 12:51:24 -1000 Subject: [PATCH] Refactor again to try to give Cassette a cleaner interface --- tests/unit/test_cassettes.py | 13 ++------- vcr/cassette.py | 53 +++++++++++++++++------------------- vcr/{files.py => persist.py} | 21 ++++++++++++-- 3 files changed, 45 insertions(+), 42 deletions(-) rename vcr/{files.py => persist.py} (52%) diff --git a/tests/unit/test_cassettes.py b/tests/unit/test_cassettes.py index 8467ff4..d4ba511 100644 --- a/tests/unit/test_cassettes.py +++ b/tests/unit/test_cassettes.py @@ -10,16 +10,6 @@ def test_cassette_load(tmpdir): a_cassette = Cassette.load(str(a_file)) assert len(a_cassette) == 1 -def test_cassette_serialize(): - a = Cassette('test') - a.append('foo','bar') - assert a.serialize() == [{'request': 'foo', 'response': 'bar'}] - -def test_cassette_deserialize(): - a = Cassette('test') - a.deserialize([{'request': 'foo', 'response': 'bar'}]) - assert a.requests == {'foo':'bar'} - def test_cassette_not_played(): a = Cassette('test') assert not a.play_count @@ -40,7 +30,8 @@ def test_cassette_play_counter(): def test_cassette_append(): a = Cassette('test') a.append('foo', 'bar') - assert a.requests == {'foo':'bar'} + assert a.requests == ['foo'] + assert a.responses == ['bar'] def test_cassette_len(): a = Cassette('test') diff --git a/vcr/cassette.py b/vcr/cassette.py index 3443f91..5b158d6 100644 --- a/vcr/cassette.py +++ b/vcr/cassette.py @@ -9,45 +9,39 @@ except ImportError: # Internal imports from .patch import install, reset -from .files import load_cassette, save_cassette +from .persist import load_cassette, save_cassette class Cassette(object): '''A container for recorded requests and responses''' @classmethod def load(cls, path): '''Load in the cassette stored at the provided path''' + new_cassette = cls(path) try: - return cls(path, load_cassette(path)) + requests, responses = load_cassette(path) + for request, response in zip(requests, responses): + new_cassette.append(request, response) except IOError: - return cls(path) + pass + return new_cassette - def __init__(self, path, data=None): + def __init__(self, path): self._path = path - self.requests = {} + self.data = {} self.play_counts = Counter() - if data: - self.deserialize(data) - - def save(self, path): - '''Save this cassette to a path''' - save_cassette(path, self.serialize()) - - def serialize(self): - '''Return a serializable version of the cassette''' - return ([{ - 'request': req, - 'response': res, - } for req, res in self.requests.iteritems()]) - - def deserialize(self, source): - '''Given a serialized version, load the requests''' - for r in source: - self.requests[r['request']] = r['response'] @property def play_count(self): return sum(self.play_counts.values()) + @property + def requests(self): + return self.data.keys() + + @property + def responses(self): + return self.data.values() + def mark_played(self, request): ''' Alert the cassette of a request that's been played @@ -56,22 +50,25 @@ class Cassette(object): def append(self, request, response): '''Add a pair of request, response to this cassette''' - self.requests[request] = response + self.data[request] = response def response(self, request): '''Find the response corresponding to a request''' - return self.requests[request] + return self.data[request] + + def _save(self): + save_cassette(self._path, self.requests, self.responses) def __str__(self): return "".format(len(self)) def __len__(self): '''Return the number of request / response pairs stored in here''' - return len(self.requests) + return len(self.data) def __contains__(self, request): '''Return whether or not a request has been stored''' - return request in self.requests + return request in self.data def __enter__(self): '''Patch the fetching libraries we know about''' @@ -79,5 +76,5 @@ class Cassette(object): return self def __exit__(self, typ, value, traceback): - self.save(self._path) + self._save() reset() diff --git a/vcr/files.py b/vcr/persist.py similarity index 52% rename from vcr/files.py rename to vcr/persist.py index 6850dd1..5dfe314 100644 --- a/vcr/files.py +++ b/vcr/persist.py @@ -8,6 +8,18 @@ try: except ImportError: from yaml import Loader, Dumper +def _serialize_cassette(requests, responses): + '''Return a serializable version of the cassette''' + return ([{ + 'request': request, + 'response': response, + } for request, response in zip(requests, responses)]) + +def _deserialize_cassette(data): + requests = [r['request'] for r in data] + responses = [r['response'] for r in data] + return requests, responses + def _secure_write(path, contents): """ We'll overwrite the old version securely by writing out a temporary @@ -20,10 +32,13 @@ def _secure_write(path, contents): os.rename(name, path) def load_cassette(cassette_path): - return yaml.load(open(cassette_path), Loader=Loader) + data = yaml.load(open(cassette_path), Loader=Loader) + return _deserialize_cassette(data) -def save_cassette(cassette_path, data): +def save_cassette(cassette_path, requests, responses): dirname, filename = os.path.split(cassette_path) if not os.path.exists(dirname): os.makedirs(dirname) - _secure_write(cassette_path, yaml.dump(data, Dumper=Dumper)) + data = _serialize_cassette(requests, responses) + data = yaml.dump(data, Dumper=Dumper) + _secure_write(cassette_path, data)