1
0
mirror of https://github.com/kevin1024/vcrpy.git synced 2025-12-08 16:53:23 +00:00
Files
vcrpy/vcr/config.py
Arthur Hamon d682e7b19a Fix all warnings (#449)
* fix typo in pytest.mark.xskip

Change xskip by skipif marker as xskip is an unknown pytest marker.

* fix FileModeWarning

This fix the following warning:
FileModeWarning: Requests has determined the content-length for this request using the binary size of the file: however, the file has been opened in text mode (i.e. without the 'b' flag in the mode). This may lead to an incorrect content-length. In Requests 3.0, support will be removed for files in text mode.

* fix waring "calling yaml.load() without Loader=... is deprecated"
This fix the following warning:
YAMLLoadWarning: calling yaml.load() without Loader=... is deprecated, as the default Loader is unsafe. Please read https://msg.pyyaml.org/load for full details.

* fix collections.abc deprecation warning in python 3.7.

* update Flask dependency in order to get rid of the Request.is_xhr warning

This fix the following warning:
DeprecationWarning: 'Request.is_xhr' is deprecated as of version 0.13 and will be removed in version 1.0. The 'X-Requested-With' header is not standard and is unreliable. You may be able to use 'accept_mimetypes' instead.
2019-07-01 17:45:01 +10:00

286 lines
11 KiB
Python

import copy
try:
from collections import abc as collections_abc # only works on python 3.3+
except ImportError:
import collections as collections_abc
import functools
import inspect
import os
import types
import six
from .cassette import Cassette
from .serializers import yamlserializer, jsonserializer
from .persisters.filesystem import FilesystemPersister
from .util import compose, auto_decorate
from . import matchers
from . import filters
class VCR(object):
@staticmethod
def is_test_method(method_name, function):
return method_name.startswith('test') and \
isinstance(function, types.FunctionType)
@staticmethod
def ensure_suffix(suffix):
def ensure(path):
if not path.endswith(suffix):
return path + suffix
return path
return ensure
def __init__(self, path_transformer=None, before_record_request=None,
custom_patches=(), filter_query_parameters=(), ignore_hosts=(),
record_mode="once", ignore_localhost=False, filter_headers=(),
before_record_response=None, filter_post_data_parameters=(),
match_on=('method', 'scheme', 'host', 'port', 'path', 'query'),
before_record=None, inject_cassette=False, serializer='yaml',
cassette_library_dir=None, func_path_generator=None,
decode_compressed_response=False):
self.serializer = serializer
self.match_on = match_on
self.cassette_library_dir = cassette_library_dir
self.serializers = {
'yaml': yamlserializer,
'json': jsonserializer,
}
self.matchers = {
'method': matchers.method,
'uri': matchers.uri,
'url': matchers.uri, # matcher for backwards compatibility
'scheme': matchers.scheme,
'host': matchers.host,
'port': matchers.port,
'path': matchers.path,
'query': matchers.query,
'headers': matchers.headers,
'raw_body': matchers.raw_body,
'body': matchers.body,
}
self.persister = FilesystemPersister
self.record_mode = record_mode
self.filter_headers = filter_headers
self.filter_query_parameters = filter_query_parameters
self.filter_post_data_parameters = filter_post_data_parameters
self.before_record_request = before_record_request or before_record
self.before_record_response = before_record_response
self.ignore_hosts = ignore_hosts
self.ignore_localhost = ignore_localhost
self.inject_cassette = inject_cassette
self.path_transformer = path_transformer
self.func_path_generator = func_path_generator
self.decode_compressed_response = decode_compressed_response
self._custom_patches = tuple(custom_patches)
def _get_serializer(self, serializer_name):
try:
serializer = self.serializers[serializer_name]
except KeyError:
raise KeyError(
"Serializer {} doesn't exist or isn't registered".format(
serializer_name
)
)
return serializer
def _get_matchers(self, matcher_names):
matchers = []
try:
for m in matcher_names:
matchers.append(self.matchers[m])
except KeyError:
raise KeyError(
"Matcher {} doesn't exist or isn't registered".format(m)
)
return matchers
def use_cassette(self, path=None, **kwargs):
if path is not None and not isinstance(path, six.string_types):
function = path
# Assume this is an attempt to decorate a function
return self._use_cassette(**kwargs)(function)
return self._use_cassette(path=path, **kwargs)
def _use_cassette(self, with_current_defaults=False, **kwargs):
if with_current_defaults:
config = self.get_merged_config(**kwargs)
return Cassette.use(**config)
# This is made a function that evaluates every time a cassette
# is made so that changes that are made to this VCR instance
# that occur AFTER the `use_cassette` decorator is applied
# still affect subsequent calls to the decorated function.
args_getter = functools.partial(self.get_merged_config, **kwargs)
return Cassette.use_arg_getter(args_getter)
def get_merged_config(self, **kwargs):
serializer_name = kwargs.get('serializer', self.serializer)
matcher_names = kwargs.get('match_on', self.match_on)
path_transformer = kwargs.get(
'path_transformer',
self.path_transformer
)
func_path_generator = kwargs.get(
'func_path_generator',
self.func_path_generator
)
cassette_library_dir = kwargs.get(
'cassette_library_dir',
self.cassette_library_dir
)
additional_matchers = kwargs.get('additional_matchers', ())
if cassette_library_dir:
def add_cassette_library_dir(path):
if not path.startswith(cassette_library_dir):
return os.path.join(cassette_library_dir, path)
return path
path_transformer = compose(
add_cassette_library_dir, path_transformer
)
elif not func_path_generator:
# If we don't have a library dir, use the functions
# location to build a full path for cassettes.
func_path_generator = self._build_path_from_func_using_module
merged_config = {
'serializer': self._get_serializer(serializer_name),
'persister': self.persister,
'match_on': self._get_matchers(
tuple(matcher_names) + tuple(additional_matchers)
),
'record_mode': kwargs.get('record_mode', self.record_mode),
'before_record_request': self._build_before_record_request(kwargs),
'before_record_response': self._build_before_record_response(kwargs),
'custom_patches': self._custom_patches + kwargs.get(
'custom_patches', ()
),
'inject': kwargs.get('inject_cassette', self.inject_cassette),
'path_transformer': path_transformer,
'func_path_generator': func_path_generator
}
path = kwargs.get('path')
if path:
merged_config['path'] = path
return merged_config
def _build_before_record_response(self, options):
before_record_response = options.get(
'before_record_response', self.before_record_response
)
decode_compressed_response = options.get(
'decode_compressed_response', self.decode_compressed_response
)
filter_functions = []
if decode_compressed_response:
filter_functions.append(filters.decode_response)
if before_record_response:
if not isinstance(before_record_response, collections_abc.Iterable):
before_record_response = (before_record_response,)
filter_functions.extend(before_record_response)
def before_record_response(response):
for function in filter_functions:
if response is None:
break
response = function(response)
return response
return before_record_response
def _build_before_record_request(self, options):
filter_functions = []
filter_headers = options.get(
'filter_headers', self.filter_headers
)
filter_query_parameters = options.get(
'filter_query_parameters', self.filter_query_parameters
)
filter_post_data_parameters = options.get(
'filter_post_data_parameters', self.filter_post_data_parameters
)
before_record_request = options.get(
"before_record_request",
options.get("before_record", self.before_record_request)
)
ignore_hosts = options.get(
'ignore_hosts', self.ignore_hosts
)
ignore_localhost = options.get(
'ignore_localhost', self.ignore_localhost
)
if filter_headers:
replacements = [h if isinstance(h, tuple) else (h, None)
for h in filter_headers]
filter_functions.append(
functools.partial(
filters.replace_headers,
replacements=replacements,
)
)
if filter_query_parameters:
replacements = [p if isinstance(p, tuple) else (p, None)
for p in filter_query_parameters]
filter_functions.append(functools.partial(
filters.replace_query_parameters,
replacements=replacements,
))
if filter_post_data_parameters:
replacements = [p if isinstance(p, tuple) else (p, None)
for p in filter_post_data_parameters]
filter_functions.append(
functools.partial(
filters.replace_post_data_parameters,
replacements=replacements,
)
)
hosts_to_ignore = set(ignore_hosts)
if ignore_localhost:
hosts_to_ignore.update(('localhost', '0.0.0.0', '127.0.0.1'))
if hosts_to_ignore:
filter_functions.append(self._build_ignore_hosts(hosts_to_ignore))
if before_record_request:
if not isinstance(before_record_request, collections_abc.Iterable):
before_record_request = (before_record_request,)
filter_functions.extend(before_record_request)
def before_record_request(request):
request = copy.copy(request)
for function in filter_functions:
if request is None:
break
request = function(request)
return request
return before_record_request
@staticmethod
def _build_ignore_hosts(hosts_to_ignore):
def filter_ignored_hosts(request):
if hasattr(request, 'host') and request.host in hosts_to_ignore:
return
return request
return filter_ignored_hosts
@staticmethod
def _build_path_from_func_using_module(function):
return os.path.join(os.path.dirname(inspect.getfile(function)),
function.__name__)
def register_serializer(self, name, serializer):
self.serializers[name] = serializer
def register_matcher(self, name, matcher):
self.matchers[name] = matcher
def register_persister(self, persister):
# Singleton, no name required
self.persister = persister
def test_case(self, predicate=None):
predicate = predicate or self.is_test_method
return six.with_metaclass(auto_decorate(self.use_cassette, predicate))