diff --git a/tests/integration/test_persist_callbacks.py b/tests/integration/test_persist_callbacks.py new file mode 100644 index 0000000..772b7e0 --- /dev/null +++ b/tests/integration/test_persist_callbacks.py @@ -0,0 +1,50 @@ +# -*- coding: utf-8 -*- +'''Tests for cassettes with overriden persistence''' + +# External imports +import os +from six.moves.urllib.request import urlopen + +# Internal imports +import vcr +from vcr.persisters.filesystem import FilesystemPersister + + +def test_overriding_save_cassette_with_callback(tmpdir, httpbin): + '''Ensure you can save a cassette using save_callback''' + + def save_callback(cassette_path, data): + FilesystemPersister.write(cassette_path, data) + + # Check to make sure directory doesnt exist + assert not os.path.exists(str(tmpdir.join('nonexistent'))) + + # Run VCR to create dir and cassette file using new save_cassette callback + with vcr.use_cassette( + str(tmpdir.join('nonexistent', 'cassette.yml')), + save_callback=save_callback + ): + urlopen(httpbin.url).read() + + # Callback should have made the file and the directory + assert os.path.exists(str(tmpdir.join('nonexistent', 'cassette.yml'))) + + +def test_overriding_load_cassette_with_callback(tmpdir, httpbin): + ''' + Ensure you can load a cassette using load_callback + ''' + test_fixture = str(tmpdir.join('synopsis.json')) + + def load_callback(cassette_path): + with open(cassette_path) as f: + cassette_content = f.read() + return cassette_content + + with vcr.use_cassette( + test_fixture, + serializer='json', + load_callback=load_callback + ): + response = urlopen(httpbin.url).read() + assert b'difficult sometimes' in response diff --git a/vcr/cassette.py b/vcr/cassette.py index 2f5c293..e725d94 100644 --- a/vcr/cassette.py +++ b/vcr/cassette.py @@ -163,13 +163,16 @@ class Cassette(object): def use(cls, **kwargs): return CassetteContextDecorator.from_args(cls, **kwargs) - def __init__(self, path, serializer=yamlserializer, record_mode='once', + def __init__(self, path, serializer=yamlserializer, save_callback=None, + load_callback=None, record_mode='once', match_on=(uri, method), before_record_request=None, before_record_response=None, custom_patches=(), inject=False): self._path = path self._serializer = serializer + self._save_callback = save_callback + self._load_callback = load_callback self._match_on = match_on self._before_record_request = before_record_request or (lambda x: x) self._before_record_response = before_record_response or (lambda x: x) @@ -274,7 +277,8 @@ class Cassette(object): save_cassette( self._path, self._as_dict(), - serializer=self._serializer + serializer=self._serializer, + save_callback=self._save_callback ) self.dirty = False @@ -282,7 +286,8 @@ class Cassette(object): try: requests, responses = load_cassette( self._path, - serializer=self._serializer + serializer=self._serializer, + load_callback=self._load_callback ) for request, response in zip(requests, responses): self.append(request, response) diff --git a/vcr/config.py b/vcr/config.py index ae99346..b1255ce 100644 --- a/vcr/config.py +++ b/vcr/config.py @@ -36,7 +36,8 @@ class VCR(object): match_on=('method', 'scheme', 'host', 'port', 'path', 'query'), before_record=None, inject_cassette=False, serializer='yaml', cassette_library_dir=None, func_path_generator=None, - decode_compressed_response=False): + decode_compressed_response=False, save_callback=None, + load_callback=None): self.serializer = serializer self.match_on = match_on self.cassette_library_dir = cassette_library_dir @@ -69,6 +70,8 @@ class VCR(object): self.path_transformer = path_transformer self.func_path_generator = func_path_generator self.decode_compressed_response = decode_compressed_response + self.save_callback = save_callback + self.load_callback = load_callback self._custom_patches = tuple(custom_patches) def _get_serializer(self, serializer_name): @@ -147,6 +150,8 @@ class VCR(object): tuple(matcher_names) + tuple(additional_matchers) ), 'record_mode': kwargs.get('record_mode', self.record_mode), + 'save_callback': kwargs.get('save_callback', self.save_callback), + 'load_callback': kwargs.get('load_callback', self.load_callback), 'before_record_request': self._build_before_record_request(kwargs), 'before_record_response': self._build_before_record_response(kwargs), 'custom_patches': self._custom_patches + kwargs.get( diff --git a/vcr/persist.py b/vcr/persist.py index f8b8d2a..a1b65b4 100644 --- a/vcr/persist.py +++ b/vcr/persist.py @@ -2,13 +2,25 @@ from .persisters.filesystem import FilesystemPersister from .serialize import serialize, deserialize -def load_cassette(cassette_path, serializer): - with open(cassette_path) as f: - cassette_content = f.read() - cassette = deserialize(cassette_content, serializer) - return cassette +def load_cassette(cassette_path, serializer, load_callback=None): + # Injected `load_callback` must return a cassette or raise IOError + if load_callback is None: + with open(cassette_path) as f: + cassette_content = f.read() + else: + cassette_content = load_callback(cassette_path) + cassette = deserialize(cassette_content, serializer) + return cassette -def save_cassette(cassette_path, cassette_dict, serializer): +def save_cassette( + cassette_path, + cassette_dict, + serializer, + save_callback=None +): data = serialize(cassette_dict, serializer) - FilesystemPersister.write(cassette_path, data) + if save_callback is None: + FilesystemPersister.write(cassette_path, data) + else: + save_callback(cassette_path, data)