diff --git a/tests/unit/test_vcr.py b/tests/unit/test_vcr.py index 278a924..9a9301c 100644 --- a/tests/unit/test_vcr.py +++ b/tests/unit/test_vcr.py @@ -1,5 +1,6 @@ import http.client as httplib import os +from pathlib import Path from unittest import mock import pytest @@ -95,7 +96,6 @@ def test_vcr_before_record_response_iterable(): # Prevent actually saving the cassette with mock.patch("vcr.cassette.FilesystemPersister.save_cassette"): - # Baseline: non-iterable before_record_response should work mock_filter = mock.Mock() vcr = VCR(before_record_response=mock_filter) @@ -119,7 +119,6 @@ def test_before_record_response_as_filter(): # Prevent actually saving the cassette with mock.patch("vcr.cassette.FilesystemPersister.save_cassette"): - filter_all = mock.Mock(return_value=None) vcr = VCR(before_record_response=filter_all) with vcr.use_cassette("test") as cassette: @@ -133,7 +132,6 @@ def test_vcr_path_transformer(): # Prevent actually saving the cassette with mock.patch("vcr.cassette.FilesystemPersister.save_cassette"): - # Baseline: path should be unchanged vcr = VCR() with vcr.use_cassette("test") as cassette: @@ -360,3 +358,11 @@ def test_dynamically_added(self): TestVCRClass.test_dynamically_added = test_dynamically_added del test_dynamically_added + + +def test_path_class_as_cassette(): + path = Path(__file__).parent.parent.joinpath( + "integration/cassettes/test_httpx_test_test_behind_proxy.yml" + ) + with use_cassette(path): + pass diff --git a/vcr/config.py b/vcr/config.py index 1ff3cac..45412e3 100644 --- a/vcr/config.py +++ b/vcr/config.py @@ -4,6 +4,7 @@ import inspect import os import types from collections import abc as collections_abc +from pathlib import Path import six @@ -98,7 +99,7 @@ class VCR: return matchers def use_cassette(self, path=None, **kwargs): - if path is not None and not isinstance(path, str): + if path is not None and not isinstance(path, (str, Path)): function = path # Assume this is an attempt to decorate a function return self._use_cassette(**kwargs)(function) diff --git a/vcr/persisters/filesystem.py b/vcr/persisters/filesystem.py index 9f50bb8..e971063 100644 --- a/vcr/persisters/filesystem.py +++ b/vcr/persisters/filesystem.py @@ -1,6 +1,6 @@ # .. _persister_example: -import os +from pathlib import Path from ..serialize import deserialize, serialize @@ -8,19 +8,25 @@ from ..serialize import deserialize, serialize class FilesystemPersister: @classmethod def load_cassette(cls, cassette_path, serializer): - try: - with open(cassette_path) as f: - cassette_content = f.read() - except OSError: + cassette_path = Path(cassette_path) # if cassette path is already Path this is no operation + if not cassette_path.is_file(): raise ValueError("Cassette not found.") - cassette = deserialize(cassette_content, serializer) - return cassette + try: + with cassette_path.open() as f: + data = f.read() + except UnicodeEncodeError as err: + raise ValueError("Can't read Cassette, Encoding is broken") from err + + return deserialize(data, serializer) @staticmethod def save_cassette(cassette_path, cassette_dict, serializer): data = serialize(cassette_dict, serializer) - dirname, filename = os.path.split(cassette_path) - if dirname and not os.path.exists(dirname): - os.makedirs(dirname) - with open(cassette_path, "w") as f: + cassette_path = Path(cassette_path) # if cassette path is already Path this is no operation + + cassette_folder = cassette_path.parent + if not cassette_folder.exists(): + cassette_folder.mkdir(parents=True) + + with cassette_path.open("w") as f: f.write(data)