diff --git a/tests/integration/test_basic.py b/tests/integration/test_basic.py index 31fb6c6..9ddb669 100644 --- a/tests/integration/test_basic.py +++ b/tests/integration/test_basic.py @@ -37,3 +37,9 @@ def test_basic_use(tmpdir): with vcr.use_cassette('fixtures/vcr_cassettes/synopsis.yaml'): response = urllib2.urlopen('http://www.iana.org/domains/reserved').read() assert 'Example domains' in response + +def test_basic_json_use(tmpdir): + '''Ensure you can load a json serialized cassette''' + with vcr.use_cassette('fixtures/vcr_cassettes/synopsis.json', serializer='json'): + response = urllib2.urlopen('http://www.iana.org/domains/reserved').read() + assert 'Example domains' in response diff --git a/vcr/__init__.py b/vcr/__init__.py index 34c261f..efa5e94 100644 --- a/vcr/__init__.py +++ b/vcr/__init__.py @@ -2,5 +2,5 @@ from .cassette import Cassette # Also, make a 'load' function available -def use_cassette(path): - return Cassette.load(path) +def use_cassette(path, **kwargs): + return Cassette.load(path, **kwargs) diff --git a/vcr/cassette.py b/vcr/cassette.py index 3a6ffca..03db125 100644 --- a/vcr/cassette.py +++ b/vcr/cassette.py @@ -15,19 +15,15 @@ from .persist import load_cassette, save_cassette class Cassette(object): '''A container for recorded requests and responses''' @classmethod - def load(cls, path): + def load(cls, path, **kwargs): '''Load in the cassette stored at the provided path''' - new_cassette = cls(path) - try: - requests, responses = load_cassette(path) - for request, response in zip(requests, responses): - new_cassette.append(request, response) - except IOError: - pass + new_cassette = cls(path, **kwargs) + new_cassette._load() return new_cassette - def __init__(self, path): + def __init__(self, path, serializer="yaml"): self._path = path + self._serializer = serializer self.data = OrderedDict() self.play_counts = Counter() @@ -58,7 +54,15 @@ class Cassette(object): return self.data[request] def _save(self): - save_cassette(self._path, self.requests, self.responses) + save_cassette(self._path, self.requests, self.responses, serializer=self._serializer) + + def _load(self): + try: + requests, responses = load_cassette(self._path, serializer=self._serializer) + for request, response in zip(requests, responses): + self.append(request, response) + except IOError: + pass def __str__(self): return "".format(len(self)) diff --git a/vcr/persist.py b/vcr/persist.py index 2112dd5..e7f8c70 100644 --- a/vcr/persist.py +++ b/vcr/persist.py @@ -1,44 +1,21 @@ -import tempfile -import os -import yaml +from .persisters.filesystem import FilesystemPersister +from .serializers.yamlserializer import YamlSerializer +from .serializers.jsonserializer import JSONSerializer -# Use the libYAML versions if possible -try: - from yaml import CLoader as Loader, CDumper as Dumper -except ImportError: - from yaml import Loader, Dumper +def _get_serializer_cls(serializer): + serializer_cls = { + 'yaml': YamlSerializer, + 'json': JSONSerializer, + }.get(serializer) + if not serializer_cls: + raise ImportError('Invalid serializer %s' % serializer) + return serializer_cls -def _serialize_cassette(requests, responses): - '''Return a serializable version of the cassette''' - return ([{ - 'request': request, - 'response': response, - } for request, response in zip(requests, responses)]) - -def _deserialize_cassette(data): - requests = [r['request'] for r in data] - responses = [r['response'] for r in data] - return requests, responses - -def _secure_write(path, contents): - """ - We'll overwrite the old version securely by writing out a temporary - version and then moving it to replace the old version - """ - dirname, filename = os.path.split(path) - fd, name = tempfile.mkstemp(dir=dirname, prefix=filename) - with os.fdopen(fd, 'w') as fout: - fout.write(contents) - os.rename(name, path) - -def load_cassette(cassette_path): - data = yaml.load(open(cassette_path), Loader=Loader) - return _deserialize_cassette(data) - -def save_cassette(cassette_path, requests, responses): - dirname, filename = os.path.split(cassette_path) - if dirname and not os.path.exists(dirname): - os.makedirs(dirname) - data = _serialize_cassette(requests, responses) - data = yaml.dump(data, Dumper=Dumper) - _secure_write(cassette_path, data) +def load_cassette(cassette_path, serializer): + serializer_cls = _get_serializer_cls(serializer) + return serializer_cls.load(cassette_path) + +def save_cassette(cassette_path, requests, responses, serializer): + serializer_cls = _get_serializer_cls(serializer) + data = serializer_cls.dumps(requests, responses) + FilesystemPersister.write(cassette_path, data) diff --git a/vcr/persisters/__init__.py b/vcr/persisters/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/vcr/persisters/filesystem.py b/vcr/persisters/filesystem.py new file mode 100644 index 0000000..8d01c52 --- /dev/null +++ b/vcr/persisters/filesystem.py @@ -0,0 +1,22 @@ +import tempfile +import os + +class FilesystemPersister(object): + @classmethod + def _secure_write(cls, path, contents): + """ + We'll overwrite the old version securely by writing out a temporary + version and then moving it to replace the old version + """ + dirname, filename = os.path.split(path) + fd, name = tempfile.mkstemp(dir=dirname, prefix=filename) + with os.fdopen(fd, 'w') as fout: + fout.write(contents) + os.rename(name, path) + + @classmethod + def write(cls, cassette_path, data): + dirname, filename = os.path.split(cassette_path) + if dirname and not os.path.exists(dirname): + os.makedirs(dirname) + cls._secure_write(cassette_path, data) diff --git a/vcr/request.py b/vcr/request.py index 86e63e3..1e3ab8c 100644 --- a/vcr/request.py +++ b/vcr/request.py @@ -7,7 +7,7 @@ class Request(object): self.method = method self.path = path self.body = body - # make haders a frozenset so it will be hashable + # make headers a frozenset so it will be hashable self.headers = frozenset(headers.items()) @property @@ -28,3 +28,18 @@ class Request(object): def __repr__(self): return self.__str__() + + def _to_dict(self): + return { + 'protocol': self.protocol, + 'host': self.host, + 'port': self.port, + 'method': self.method, + 'path': self.path, + 'body': self.body, + 'headers': self.headers, + } + + @classmethod + def _from_dict(cls, dct): + return Request(**dct) diff --git a/vcr/serializers/__init__.py b/vcr/serializers/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/vcr/serializers/jsonserializer.py b/vcr/serializers/jsonserializer.py new file mode 100644 index 0000000..d5dca43 --- /dev/null +++ b/vcr/serializers/jsonserializer.py @@ -0,0 +1,31 @@ +from vcr.request import Request +try: + import simplejson as json +except ImportError: + import json + +def _json_default(obj): + if isinstance(obj, frozenset): + return dict(obj) + return obj + +def _fix_response_unicode(d): + d['body']['string'] = d['body']['string'].encode('utf-8') + return d + +class JSONSerializer(object): + @classmethod + def load(cls, cassette_path): + with open(cassette_path) as fh: + data = json.load(fh) + requests = [Request._from_dict(r['request']) for r in data] + responses = [_fix_response_unicode(r['response']) for r in data] + return requests, responses + + @classmethod + def dumps(cls, requests, responses): + data = ([{ + 'request': request._to_dict(), + 'response': response, + } for request, response in zip(requests, responses)]) + return json.dumps(data, indent=4, default=_json_default) diff --git a/vcr/serializers/yamlserializer.py b/vcr/serializers/yamlserializer.py new file mode 100644 index 0000000..2c0a5af --- /dev/null +++ b/vcr/serializers/yamlserializer.py @@ -0,0 +1,24 @@ +import yaml + +# Use the libYAML versions if possible +try: + from yaml import CLoader as Loader, CDumper as Dumper +except ImportError: + from yaml import Loader, Dumper + + +class YamlSerializer(object): + @classmethod + def load(cls, cassette_path): + data = yaml.load(open(cassette_path), Loader=Loader) + requests = [r['request'] for r in data] + responses = [r['response'] for r in data] + return requests, responses + + @classmethod + def dumps(cls, requests, responses): + data = ([{ + 'request': request, + 'response': response, + } for request, response in zip(requests, responses)]) + return yaml.dump(data, Dumper=Dumper)