From cc9af1d5fbd0db68e11f67c7b24f10589113b5dc Mon Sep 17 00:00:00 2001 From: Diaoul Date: Sat, 11 Jul 2015 23:18:45 +0200 Subject: [PATCH] Use CaseInsensitiveDict in body matcher --- vcr/matchers.py | 11 +++++--- vcr/util.py | 71 +++++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 78 insertions(+), 4 deletions(-) diff --git a/vcr/matchers.py b/vcr/matchers.py index 96c5190..94a9187 100644 --- a/vcr/matchers.py +++ b/vcr/matchers.py @@ -1,5 +1,6 @@ import json from six.moves import urllib, xmlrpc_client +from .util import CaseInsensitiveDict import logging log = logging.getLogger(__name__) @@ -45,12 +46,14 @@ def body(r1, r2): else: r1_body = r1.body r2_body = r2.body - if r1.headers.get('Content-Type') == r2.headers.get('Content-Type') == 'application/x-www-form-urlencoded': + r1_headers = CaseInsensitiveDict(r1.headers) + r2_headers = CaseInsensitiveDict(r2.headers) + if r1_headers.get('Content-Type') == r2_headers.get('Content-Type') == 'application/x-www-form-urlencoded': return urllib.parse.parse_qs(r1_body) == urllib.parse.parse_qs(r2_body) - if r1.headers.get('Content-Type') == r2.headers.get('Content-Type') == 'application/json': + if r1_headers.get('Content-Type') == r2_headers.get('Content-Type') == 'application/json': return json.loads(r1_body) == json.loads(r2_body) - if ('xmlrpc' in r1.headers.get('User-Agent', '') and 'xmlrpc' in r2.headers.get('User-Agent', '') and - r1.headers.get('Content-Type') == r2.headers.get('Content-Type') == 'text/xml'): + if ('xmlrpc' in r1_headers.get('User-Agent', '') and 'xmlrpc' in r2_headers.get('User-Agent', '') and + r1_headers.get('Content-Type') == r2_headers.get('Content-Type') == 'text/xml'): return xmlrpc_client.loads(r1_body) == xmlrpc_client.loads(r2_body) return r1_body == r2_body diff --git a/vcr/util.py b/vcr/util.py index 57f72b1..e3b05d8 100644 --- a/vcr/util.py +++ b/vcr/util.py @@ -1,3 +1,74 @@ +import collections + +# Shamelessly stolen from https://github.com/kennethreitz/requests/blob/master/requests/structures.py +class CaseInsensitiveDict(collections.MutableMapping): + """ + A case-insensitive ``dict``-like object. + Implements all methods and operations of + ``collections.MutableMapping`` as well as dict's ``copy``. Also + provides ``lower_items``. + All keys are expected to be strings. The structure remembers the + case of the last key to be set, and ``iter(instance)``, + ``keys()``, ``items()``, ``iterkeys()``, and ``iteritems()`` + will contain case-sensitive keys. However, querying and contains + testing is case insensitive:: + cid = CaseInsensitiveDict() + cid['Accept'] = 'application/json' + cid['aCCEPT'] == 'application/json' # True + list(cid) == ['Accept'] # True + For example, ``headers['content-encoding']`` will return the + value of a ``'Content-Encoding'`` response header, regardless + of how the header name was originally stored. + If the constructor, ``.update``, or equality comparison + operations are given keys that have equal ``.lower()``s, the + behavior is undefined. + """ + def __init__(self, data=None, **kwargs): + self._store = dict() + if data is None: + data = {} + self.update(data, **kwargs) + + def __setitem__(self, key, value): + # Use the lowercased key for lookups, but store the actual + # key alongside the value. + self._store[key.lower()] = (key, value) + + def __getitem__(self, key): + return self._store[key.lower()][1] + + def __delitem__(self, key): + del self._store[key.lower()] + + def __iter__(self): + return (casedkey for casedkey, mappedvalue in self._store.values()) + + def __len__(self): + return len(self._store) + + def lower_items(self): + """Like iteritems(), but with all lowercase keys.""" + return ( + (lowerkey, keyval[1]) + for (lowerkey, keyval) + in self._store.items() + ) + + def __eq__(self, other): + if isinstance(other, collections.Mapping): + other = CaseInsensitiveDict(other) + else: + return NotImplemented + # Compare insensitively + return dict(self.lower_items()) == dict(other.lower_items()) + + # Copy is required + def copy(self): + return CaseInsensitiveDict(self._store.values()) + + def __repr__(self): + return str(dict(self.items())) + def partition_dict(predicate, dictionary): true_dict = {} false_dict = {}