diff --git a/tests/integration/test_persist_callbacks.py b/tests/integration/test_persist_callbacks.py index 772b7e0..3b212c9 100644 --- a/tests/integration/test_persist_callbacks.py +++ b/tests/integration/test_persist_callbacks.py @@ -10,41 +10,31 @@ import vcr from vcr.persisters.filesystem import FilesystemPersister -def test_overriding_save_cassette_with_callback(tmpdir, httpbin): +def test_save_cassette_with_custom_persister(tmpdir, httpbin): '''Ensure you can save a cassette using save_callback''' - - def save_callback(cassette_path, data): - FilesystemPersister.write(cassette_path, data) + my_vcr = vcr.VCR() + my_vcr.register_persister(FilesystemPersister) # Check to make sure directory doesnt exist assert not os.path.exists(str(tmpdir.join('nonexistent'))) # Run VCR to create dir and cassette file using new save_cassette callback - with vcr.use_cassette( - str(tmpdir.join('nonexistent', 'cassette.yml')), - save_callback=save_callback - ): + with my_vcr.use_cassette(str(tmpdir.join('nonexistent', 'cassette.yml'))): urlopen(httpbin.url).read() # Callback should have made the file and the directory assert os.path.exists(str(tmpdir.join('nonexistent', 'cassette.yml'))) -def test_overriding_load_cassette_with_callback(tmpdir, httpbin): +def test_load_cassette_with_custom_persister(tmpdir, httpbin): ''' Ensure you can load a cassette using load_callback ''' + my_vcr = vcr.VCR() + my_vcr.register_persister(FilesystemPersister) + test_fixture = str(tmpdir.join('synopsis.json')) - def load_callback(cassette_path): - with open(cassette_path) as f: - cassette_content = f.read() - return cassette_content - - with vcr.use_cassette( - test_fixture, - serializer='json', - load_callback=load_callback - ): + with my_vcr.use_cassette(test_fixture, serializer='json'): response = urlopen(httpbin.url).read() assert b'difficult sometimes' in response diff --git a/tests/unit/test_cassettes.py b/tests/unit/test_cassettes.py index ef9e808..34088c3 100644 --- a/tests/unit/test_cassettes.py +++ b/tests/unit/test_cassettes.py @@ -1,6 +1,7 @@ import copy import inspect import os +import sys from six.moves import http_client as httplib import pytest @@ -8,6 +9,7 @@ import yaml from vcr.compat import mock, contextlib from vcr.cassette import Cassette +from vcr.persisters.filesystem import FilesystemPersister from vcr.errors import UnhandledHTTPRequestError from vcr.patch import force_reset from vcr.stubs import VCRHTTPSConnection @@ -83,7 +85,8 @@ def make_get_request(): @mock.patch('vcr.cassette.requests_match', return_value=True) -@mock.patch('vcr.cassette.load_cassette', lambda *args, **kwargs: (('foo',), (mock.MagicMock(),))) +@mock.patch('vcr.cassette.FilesystemPersister.load_cassette', + lambda *args, **kwargs: (('foo',), (mock.MagicMock(),))) @mock.patch('vcr.cassette.Cassette.can_play_response_for', return_value=True) @mock.patch('vcr.stubs.VCRHTTPResponse') def test_function_decorated_with_use_cassette_can_be_invoked_multiple_times(*args): diff --git a/tests/unit/test_persist.py b/tests/unit/test_persist.py index e9b30be..18ac4a1 100644 --- a/tests/unit/test_persist.py +++ b/tests/unit/test_persist.py @@ -1,6 +1,6 @@ import pytest -import vcr.persist +from vcr.persisters.filesystem import FilesystemPersister from vcr.serializers import jsonserializer, yamlserializer @@ -10,7 +10,7 @@ from vcr.serializers import jsonserializer, yamlserializer ]) def test_load_cassette_with_old_cassettes(cassette_path, serializer): with pytest.raises(ValueError) as excinfo: - vcr.persist.load_cassette(cassette_path, serializer) + FilesystemPersister.load_cassette(cassette_path, serializer) assert "run the migration script" in excinfo.exconly() @@ -20,5 +20,5 @@ def test_load_cassette_with_old_cassettes(cassette_path, serializer): ]) def test_load_cassette_with_invalid_cassettes(cassette_path, serializer): with pytest.raises(Exception) as excinfo: - vcr.persist.load_cassette(cassette_path, serializer) + FilesystemPersister.load_cassette(cassette_path, serializer) assert "run the migration script" not in excinfo.exconly() diff --git a/tests/unit/test_vcr.py b/tests/unit/test_vcr.py index decb366..3592ef0 100644 --- a/tests/unit/test_vcr.py +++ b/tests/unit/test_vcr.py @@ -94,7 +94,7 @@ def test_vcr_before_record_response_iterable(): response = object() # just can't be None # Prevent actually saving the cassette - with mock.patch('vcr.cassette.save_cassette'): + with mock.patch('vcr.cassette.FilesystemPersister.save_cassette'): # Baseline: non-iterable before_record_response should work mock_filter = mock.Mock() @@ -118,7 +118,7 @@ def test_before_record_response_as_filter(): response = object() # just can't be None # Prevent actually saving the cassette - with mock.patch('vcr.cassette.save_cassette'): + with mock.patch('vcr.cassette.FilesystemPersister.save_cassette'): filter_all = mock.Mock(return_value=None) vcr = VCR(before_record_response=filter_all) @@ -132,7 +132,7 @@ def test_vcr_path_transformer(): # Regression test for #199 # Prevent actually saving the cassette - with mock.patch('vcr.cassette.save_cassette'): + with mock.patch('vcr.cassette.FilesystemPersister.save_cassette'): # Baseline: path should be unchanged vcr = VCR() diff --git a/vcr/cassette.py b/vcr/cassette.py index e725d94..31a4835 100644 --- a/vcr/cassette.py +++ b/vcr/cassette.py @@ -8,8 +8,8 @@ from .compat import contextlib, collections from .errors import UnhandledHTTPRequestError from .matchers import requests_match, uri, method from .patch import CassettePatcherBuilder -from .persist import load_cassette, save_cassette from .serializers import yamlserializer +from .persisters.filesystem import FilesystemPersister from .util import partition_dict @@ -163,16 +163,13 @@ class Cassette(object): def use(cls, **kwargs): return CassetteContextDecorator.from_args(cls, **kwargs) - def __init__(self, path, serializer=yamlserializer, save_callback=None, - load_callback=None, record_mode='once', + def __init__(self, path, serializer=yamlserializer, persister=FilesystemPersister, record_mode='once', match_on=(uri, method), before_record_request=None, before_record_response=None, custom_patches=(), inject=False): - + self._persister = persister self._path = path self._serializer = serializer - self._save_callback = save_callback - self._load_callback = load_callback self._match_on = match_on self._before_record_request = before_record_request or (lambda x: x) self._before_record_response = before_record_response or (lambda x: x) @@ -274,20 +271,18 @@ class Cassette(object): def _save(self, force=False): if force or self.dirty: - save_cassette( + self._persister.save_cassette( self._path, self._as_dict(), serializer=self._serializer, - save_callback=self._save_callback ) self.dirty = False def _load(self): try: - requests, responses = load_cassette( + requests, responses = self._persister.load_cassette( self._path, serializer=self._serializer, - load_callback=self._load_callback ) for request, response in zip(requests, responses): self.append(request, response) diff --git a/vcr/config.py b/vcr/config.py index b1255ce..15ae459 100644 --- a/vcr/config.py +++ b/vcr/config.py @@ -9,6 +9,7 @@ import six from .compat import collections from .cassette import Cassette from .serializers import yamlserializer, jsonserializer +from .persisters.filesystem import FilesystemPersister from .util import compose, auto_decorate from . import matchers from . import filters @@ -36,8 +37,7 @@ class VCR(object): match_on=('method', 'scheme', 'host', 'port', 'path', 'query'), before_record=None, inject_cassette=False, serializer='yaml', cassette_library_dir=None, func_path_generator=None, - decode_compressed_response=False, save_callback=None, - load_callback=None): + decode_compressed_response=False): self.serializer = serializer self.match_on = match_on self.cassette_library_dir = cassette_library_dir @@ -58,6 +58,7 @@ class VCR(object): 'raw_body': matchers.raw_body, 'body': matchers.body, } + self.persister = FilesystemPersister self.record_mode = record_mode self.filter_headers = filter_headers self.filter_query_parameters = filter_query_parameters @@ -70,8 +71,6 @@ class VCR(object): self.path_transformer = path_transformer self.func_path_generator = func_path_generator self.decode_compressed_response = decode_compressed_response - self.save_callback = save_callback - self.load_callback = load_callback self._custom_patches = tuple(custom_patches) def _get_serializer(self, serializer_name): @@ -150,8 +149,6 @@ class VCR(object): tuple(matcher_names) + tuple(additional_matchers) ), 'record_mode': kwargs.get('record_mode', self.record_mode), - 'save_callback': kwargs.get('save_callback', self.save_callback), - 'load_callback': kwargs.get('load_callback', self.load_callback), 'before_record_request': self._build_before_record_request(kwargs), 'before_record_response': self._build_before_record_response(kwargs), 'custom_patches': self._custom_patches + kwargs.get( @@ -275,6 +272,10 @@ class VCR(object): def register_matcher(self, name, matcher): self.matchers[name] = matcher + def register_persister(self, persister): + # Singleton, no name required + self.persister = persister + def test_case(self, predicate=None): predicate = predicate or self.is_test_method return six.with_metaclass(auto_decorate(self.use_cassette, predicate)) diff --git a/vcr/persist.py b/vcr/persist.py deleted file mode 100644 index a1b65b4..0000000 --- a/vcr/persist.py +++ /dev/null @@ -1,26 +0,0 @@ -from .persisters.filesystem import FilesystemPersister -from .serialize import serialize, deserialize - - -def load_cassette(cassette_path, serializer, load_callback=None): - # Injected `load_callback` must return a cassette or raise IOError - if load_callback is None: - with open(cassette_path) as f: - cassette_content = f.read() - else: - cassette_content = load_callback(cassette_path) - cassette = deserialize(cassette_content, serializer) - return cassette - - -def save_cassette( - cassette_path, - cassette_dict, - serializer, - save_callback=None -): - data = serialize(cassette_dict, serializer) - if save_callback is None: - FilesystemPersister.write(cassette_path, data) - else: - save_callback(cassette_path, data) diff --git a/vcr/persisters/filesystem.py b/vcr/persisters/filesystem.py index 884d891..f64c654 100644 --- a/vcr/persisters/filesystem.py +++ b/vcr/persisters/filesystem.py @@ -1,9 +1,18 @@ import os +from ..serialize import serialize, deserialize class FilesystemPersister(object): @classmethod - def write(cls, cassette_path, data): + def load_cassette(cls, cassette_path, serializer): + with open(cassette_path) as f: + cassette_content = f.read() + cassette = deserialize(cassette_content, serializer) + return cassette + + @staticmethod + def save_cassette(cassette_path, cassette_dict, serializer): + data = serialize(cassette_dict, serializer) dirname, filename = os.path.split(cassette_path) if dirname and not os.path.exists(dirname): os.makedirs(dirname)