import copy import functools import inspect import os import types import six from .compat import collections from .cassette import Cassette from .serializers import yamlserializer, jsonserializer from .persisters.filesystem import FilesystemPersister from .util import compose, auto_decorate from . import matchers from . import filters class VCR(object): @staticmethod def is_test_method(method_name, function): return method_name.startswith('test') and \ isinstance(function, types.FunctionType) @staticmethod def ensure_suffix(suffix): def ensure(path): if not path.endswith(suffix): return path + suffix return path return ensure def __init__(self, path_transformer=None, before_record_request=None, custom_patches=(), filter_query_parameters=(), ignore_hosts=(), record_mode="once", ignore_localhost=False, filter_headers=(), before_record_response=None, filter_post_data_parameters=(), match_on=('method', 'scheme', 'host', 'port', 'path', 'query'), before_record=None, inject_cassette=False, serializer='yaml', cassette_library_dir=None, func_path_generator=None, decode_compressed_response=False): self.serializer = serializer self.match_on = match_on self.cassette_library_dir = cassette_library_dir self.serializers = { 'yaml': yamlserializer, 'json': jsonserializer, } self.matchers = { 'method': matchers.method, 'uri': matchers.uri, 'url': matchers.uri, # matcher for backwards compatibility 'scheme': matchers.scheme, 'host': matchers.host, 'port': matchers.port, 'path': matchers.path, 'query': matchers.query, 'headers': matchers.headers, 'raw_body': matchers.raw_body, 'body': matchers.body, } self.persister = FilesystemPersister self.record_mode = record_mode self.filter_headers = filter_headers self.filter_query_parameters = filter_query_parameters self.filter_post_data_parameters = filter_post_data_parameters self.before_record_request = before_record_request or before_record self.before_record_response = before_record_response self.ignore_hosts = ignore_hosts self.ignore_localhost = ignore_localhost self.inject_cassette = inject_cassette self.path_transformer = path_transformer self.func_path_generator = func_path_generator self.decode_compressed_response = decode_compressed_response self._custom_patches = tuple(custom_patches) def _get_serializer(self, serializer_name): try: serializer = self.serializers[serializer_name] except KeyError: raise KeyError( "Serializer {0} doesn't exist or isn't registered".format( serializer_name ) ) return serializer def _get_matchers(self, matcher_names): matchers = [] try: for m in matcher_names: matchers.append(self.matchers[m]) except KeyError: raise KeyError( "Matcher {0} doesn't exist or isn't registered".format(m) ) return matchers def use_cassette(self, path=None, **kwargs): if path is not None and not isinstance(path, six.string_types): function = path # Assume this is an attempt to decorate a function return self._use_cassette(**kwargs)(function) return self._use_cassette(path=path, **kwargs) def _use_cassette(self, with_current_defaults=False, **kwargs): if with_current_defaults: config = self.get_merged_config(**kwargs) return Cassette.use(**config) # This is made a function that evaluates every time a cassette # is made so that changes that are made to this VCR instance # that occur AFTER the `use_cassette` decorator is applied # still affect subsequent calls to the decorated function. args_getter = functools.partial(self.get_merged_config, **kwargs) return Cassette.use_arg_getter(args_getter) def get_merged_config(self, **kwargs): serializer_name = kwargs.get('serializer', self.serializer) matcher_names = kwargs.get('match_on', self.match_on) path_transformer = kwargs.get( 'path_transformer', self.path_transformer ) func_path_generator = kwargs.get( 'func_path_generator', self.func_path_generator ) cassette_library_dir = kwargs.get( 'cassette_library_dir', self.cassette_library_dir ) additional_matchers = kwargs.get('additional_matchers', ()) if cassette_library_dir: def add_cassette_library_dir(path): if not path.startswith(cassette_library_dir): return os.path.join(cassette_library_dir, path) return path path_transformer = compose( add_cassette_library_dir, path_transformer ) elif not func_path_generator: # If we don't have a library dir, use the functions # location to build a full path for cassettes. func_path_generator = self._build_path_from_func_using_module merged_config = { 'serializer': self._get_serializer(serializer_name), 'match_on': self._get_matchers( tuple(matcher_names) + tuple(additional_matchers) ), 'record_mode': kwargs.get('record_mode', self.record_mode), 'before_record_request': self._build_before_record_request(kwargs), 'before_record_response': self._build_before_record_response(kwargs), 'custom_patches': self._custom_patches + kwargs.get( 'custom_patches', () ), 'inject': kwargs.get('inject_cassette', self.inject_cassette), 'path_transformer': path_transformer, 'func_path_generator': func_path_generator } path = kwargs.get('path') if path: merged_config['path'] = path return merged_config def _build_before_record_response(self, options): before_record_response = options.get( 'before_record_response', self.before_record_response ) decode_compressed_response = options.get( 'decode_compressed_response', self.decode_compressed_response ) filter_functions = [] if decode_compressed_response: filter_functions.append(filters.decode_response) if before_record_response: if not isinstance(before_record_response, collections.Iterable): before_record_response = (before_record_response,) filter_functions.extend(before_record_response) def before_record_response(response): for function in filter_functions: if response is None: break response = function(response) return response return before_record_response def _build_before_record_request(self, options): filter_functions = [] filter_headers = options.get( 'filter_headers', self.filter_headers ) filter_query_parameters = options.get( 'filter_query_parameters', self.filter_query_parameters ) filter_post_data_parameters = options.get( 'filter_post_data_parameters', self.filter_post_data_parameters ) before_record_request = options.get( "before_record_request", options.get("before_record", self.before_record_request) ) ignore_hosts = options.get( 'ignore_hosts', self.ignore_hosts ) ignore_localhost = options.get( 'ignore_localhost', self.ignore_localhost ) if filter_headers: replacements = [h if isinstance(h, tuple) else (h, None) for h in filter_headers] filter_functions.append( functools.partial( filters.replace_headers, replacements=replacements, ) ) if filter_query_parameters: replacements = [p if isinstance(p, tuple) else (p, None) for p in filter_query_parameters] filter_functions.append(functools.partial( filters.replace_query_parameters, replacements=replacements, )) if filter_post_data_parameters: replacements = [p if isinstance(p, tuple) else (p, None) for p in filter_post_data_parameters] filter_functions.append( functools.partial( filters.replace_post_data_parameters, replacements=replacements, ) ) hosts_to_ignore = set(ignore_hosts) if ignore_localhost: hosts_to_ignore.update(('localhost', '0.0.0.0', '127.0.0.1')) if hosts_to_ignore: filter_functions.append(self._build_ignore_hosts(hosts_to_ignore)) if before_record_request: if not isinstance(before_record_request, collections.Iterable): before_record_request = (before_record_request,) filter_functions.extend(before_record_request) def before_record_request(request): request = copy.copy(request) for function in filter_functions: if request is None: break request = function(request) return request return before_record_request @staticmethod def _build_ignore_hosts(hosts_to_ignore): def filter_ignored_hosts(request): if hasattr(request, 'host') and request.host in hosts_to_ignore: return return request return filter_ignored_hosts @staticmethod def _build_path_from_func_using_module(function): return os.path.join(os.path.dirname(inspect.getfile(function)), function.__name__) def register_serializer(self, name, serializer): self.serializers[name] = serializer def register_matcher(self, name, matcher): self.matchers[name] = matcher def register_persister(self, persister): # Singleton, no name required self.persister = persister def test_case(self, predicate=None): predicate = predicate or self.is_test_method return six.with_metaclass(auto_decorate(self.use_cassette, predicate))