From 10188678387bf385bcc89c207f55b1d5fcd4d537 Mon Sep 17 00:00:00 2001 From: Ivan Malison Date: Fri, 19 Sep 2014 14:32:21 -0700 Subject: [PATCH] Revert "Fixed issue in test_nested_context_managers_with_session_created_before_first_nesting. by using a single class and patching cassette on that class. Not a great solution :\" This reverts commit 2bf23b2cdf8542f5bf137e386f3bc278bbc2cc83. --- vcr/patch.py | 27 +++++++++++++++++++++++---- 1 file changed, 23 insertions(+), 4 deletions(-) diff --git a/vcr/patch.py b/vcr/patch.py index ab0036b..dc4e72e 100644 --- a/vcr/patch.py +++ b/vcr/patch.py @@ -57,22 +57,41 @@ class PatcherBuilder(object): def __init__(self, cassette): self._cassette = cassette + self._class_to_cassette_subclass = {} def build_patchers(self): patcher_args = itertools.chain(self._httplib(), self._requests(), self._urllib3(), self._httplib2(), self._boto()) - for obj, patched_attribute, replacement_class in patcher_args: - patcher = self._build_patcher(obj, patched_attribute, replacement_class) + for args in patcher_args: + patcher = self._build_patcher(*args) if patcher: yield patcher - if hasattr(replacement_class, 'cassette'): - yield mock.patch.object(replacement_class, 'cassette', self._cassette) def _build_patcher(self, obj, patched_attribute, replacement_class): if not hasattr(obj, patched_attribute): return + + if isinstance(replacement_class, dict): + for key in replacement_class: + replacement_class[key] = self._get_cassette_subclass(replacement_class[key]) + else: + replacement_class = self._get_cassette_subclass(replacement_class) return mock.patch.object(obj, patched_attribute, replacement_class) + def _get_cassette_subclass(self, klass): + if klass.cassette is not None: + return klass + if klass not in self._class_to_cassette_subclass: + self._class_to_cassette_subclass[klass] = self._build_cassette_subclass(klass) + return self._class_to_cassette_subclass[klass] + + def _build_cassette_subclass(self, base_class): + bases = (base_class,) + if not issubclass(base_class, object): # Check for old style class + bases += (object,) + return type('{0}{1}'.format(base_class.__name__, self._cassette._path), + bases, dict(cassette=self._cassette)) + def _httplib(self): yield httplib, 'HTTPConnection', VCRHTTPConnection yield httplib, 'HTTPSConnection', VCRHTTPSConnection