From eda64bc3be52959eabe497b62c3662b6fd9d8d90 Mon Sep 17 00:00:00 2001 From: Aron Griffis Date: Mon, 24 Aug 2015 12:58:34 -0400 Subject: [PATCH] Make request.headers always a CaseInsensitiveDict. Previously request.headers was a normal dict (albeit with the request.add_header interface) which meant that some code paths would do case-sensitive matching, for example remove_post_data_parameters which tests for 'Content-Type'. This change allows all code paths to get the same case-insensitive treatment. Additionally request.headers becomes a property to enforce upgrading it to a CaseInsensitiveDict even if assigned. --- tests/integration/test_filter.py | 6 +----- vcr/filters.py | 13 +++++-------- vcr/matchers.py | 5 ++--- vcr/request.py | 19 +++++++++++++++---- vcr/stubs/tornado_stubs.py | 2 +- 5 files changed, 24 insertions(+), 21 deletions(-) diff --git a/tests/integration/test_filter.py b/tests/integration/test_filter.py index 0a5232d..45c3f78 100644 --- a/tests/integration/test_filter.py +++ b/tests/integration/test_filter.py @@ -17,11 +17,7 @@ def _request_with_auth(url, username, password): def _find_header(cassette, header): - for request in cassette.requests: - for k in request.headers: - if header.lower() == k.lower(): - return True - return False + return any(header in request.headers for request in cassette.requests) def test_filter_basic_auth(tmpdir): diff --git a/vcr/filters.py b/vcr/filters.py index 14159d0..102a95e 100644 --- a/vcr/filters.py +++ b/vcr/filters.py @@ -1,19 +1,16 @@ from six import BytesIO, text_type from six.moves.urllib.parse import urlparse, urlencode, urlunparse -import copy import json from .compat import collections def remove_headers(request, headers_to_remove): - headers = copy.copy(request.headers) - headers_to_remove = [h.lower() for h in headers_to_remove] - keys = [k for k in headers if k.lower() in headers_to_remove] - if keys: - for k in keys: - headers.pop(k) - request.headers = headers + new_headers = request.headers.copy() + for k in headers_to_remove: + if k in new_headers: + del new_headers[k] + request.headers = new_headers return request diff --git a/vcr/matchers.py b/vcr/matchers.py index 6d8cbf8..57e7a02 100644 --- a/vcr/matchers.py +++ b/vcr/matchers.py @@ -1,6 +1,6 @@ import json from six.moves import urllib, xmlrpc_client -from .util import CaseInsensitiveDict, read_body +from .util import read_body import logging @@ -66,9 +66,8 @@ def _identity(x): def _get_transformer(request): - headers = CaseInsensitiveDict(request.headers) for checker, transformer in _checker_transformer_pairs: - if checker(headers): return transformer + if checker(request.headers): return transformer else: return _identity diff --git a/vcr/request.py b/vcr/request.py index 3500044..bc15e6f 100644 --- a/vcr/request.py +++ b/vcr/request.py @@ -1,10 +1,11 @@ from six import BytesIO, text_type from six.moves.urllib.parse import urlparse, parse_qsl +from .util import CaseInsensitiveDict class Request(object): """ - VCR's representation of a request. + VCR's representation of a request. There is a weird quirk in HTTP. You can send the same header twice. For this reason, headers are represented by a dict, with lists as the values. @@ -32,9 +33,19 @@ class Request(object): self.body = body.read() else: self.body = body - self.headers = {} - for key in headers: - self.add_header(key, headers[key]) + self.headers = CaseInsensitiveDict() + for key, value in headers.items(): + self.add_header(key, value) + + @property + def headers(self): + return self._headers + + @headers.setter + def headers(self, value): + if not isinstance(value, CaseInsensitiveDict): + value = CaseInsensitiveDict(value) + self._headers = value @property def body(self): diff --git a/vcr/stubs/tornado_stubs.py b/vcr/stubs/tornado_stubs.py index 5a4ce58..a6422da 100644 --- a/vcr/stubs/tornado_stubs.py +++ b/vcr/stubs/tornado_stubs.py @@ -15,7 +15,7 @@ def vcr_fetch_impl(cassette, real_fetch_impl): @functools.wraps(real_fetch_impl) def new_fetch_impl(self, request, callback): - headers = dict(request.headers) + headers = request.headers.copy() if request.user_agent: headers.setdefault('User-Agent', request.user_agent)