1
0
mirror of https://github.com/kevin1024/vcrpy.git synced 2025-12-09 17:15:35 +00:00

Merge pull request #192 from agriffis/insensitive-headers

Make request.headers always a CaseInsensitiveDict.
This commit is contained in:
Ivan 'Goat' Malison
2015-08-28 14:47:45 -07:00
8 changed files with 99 additions and 53 deletions

View File

@@ -17,11 +17,7 @@ def _request_with_auth(url, username, password):
def _find_header(cassette, header): def _find_header(cassette, header):
for request in cassette.requests: return any(header in request.headers for request in cassette.requests)
for k in request.headers:
if header.lower() == k.lower():
return True
return False
def test_filter_basic_auth(tmpdir): def test_filter_basic_auth(tmpdir):

View File

@@ -73,7 +73,7 @@ def test_remove_nonexistent_post_data_parameters():
def test_remove_json_post_data_parameters(): def test_remove_json_post_data_parameters():
body = b'{"id": "secret", "foo": "bar", "baz": "qux"}' body = b'{"id": "secret", "foo": "bar", "baz": "qux"}'
request = Request('POST', 'http://google.com', body, {}) request = Request('POST', 'http://google.com', body, {})
request.add_header('Content-Type', 'application/json') request.headers['Content-Type'] = 'application/json'
remove_post_data_parameters(request, ['id']) remove_post_data_parameters(request, ['id'])
request_body_json = json.loads(request.body.decode('utf-8')) request_body_json = json.loads(request.body.decode('utf-8'))
expected_json = json.loads(b'{"foo": "bar", "baz": "qux"}'.decode('utf-8')) expected_json = json.loads(b'{"foo": "bar", "baz": "qux"}'.decode('utf-8'))
@@ -83,7 +83,7 @@ def test_remove_json_post_data_parameters():
def test_remove_all_json_post_data_parameters(): def test_remove_all_json_post_data_parameters():
body = b'{"id": "secret", "foo": "bar"}' body = b'{"id": "secret", "foo": "bar"}'
request = Request('POST', 'http://google.com', body, {}) request = Request('POST', 'http://google.com', body, {})
request.add_header('Content-Type', 'application/json') request.headers['Content-Type'] = 'application/json'
remove_post_data_parameters(request, ['id', 'foo']) remove_post_data_parameters(request, ['id', 'foo'])
assert request.body == b'{}' assert request.body == b'{}'
@@ -91,6 +91,6 @@ def test_remove_all_json_post_data_parameters():
def test_remove_nonexistent_json_post_data_parameters(): def test_remove_nonexistent_json_post_data_parameters():
body = b'{}' body = b'{}'
request = Request('POST', 'http://google.com', body, {}) request = Request('POST', 'http://google.com', body, {})
request.add_header('Content-Type', 'application/json') request.headers['Content-Type'] = 'application/json'
remove_post_data_parameters(request, ['id']) remove_post_data_parameters(request, ['id'])
assert request.body == b'{}' assert request.body == b'{}'

View File

@@ -1,6 +1,6 @@
import pytest import pytest
from vcr.request import Request from vcr.request import Request, HeadersDict
def test_str(): def test_str():
@@ -12,11 +12,16 @@ def test_headers():
headers = {'X-Header1': ['h1'], 'X-Header2': 'h2'} headers = {'X-Header1': ['h1'], 'X-Header2': 'h2'}
req = Request('GET', 'http://go.com/', '', headers) req = Request('GET', 'http://go.com/', '', headers)
assert req.headers == {'X-Header1': 'h1', 'X-Header2': 'h2'} assert req.headers == {'X-Header1': 'h1', 'X-Header2': 'h2'}
req.headers['X-Header1'] = 'h11'
req.add_header('X-Header1', 'h11')
assert req.headers == {'X-Header1': 'h11', 'X-Header2': 'h2'} assert req.headers == {'X-Header1': 'h11', 'X-Header2': 'h2'}
def test_add_header_deprecated():
req = Request('GET', 'http://go.com/', '', {})
pytest.deprecated_call(req.add_header, 'foo', 'bar')
assert req.headers == {'foo': 'bar'}
@pytest.mark.parametrize("uri, expected_port", [ @pytest.mark.parametrize("uri, expected_port", [
('http://go.com/', 80), ('http://go.com/', 80),
('http://go.com:80/', 80), ('http://go.com:80/', 80),
@@ -36,3 +41,30 @@ def test_uri():
req = Request('GET', 'http://go.com:80/', '', {}) req = Request('GET', 'http://go.com:80/', '', {})
assert req.uri == 'http://go.com:80/' assert req.uri == 'http://go.com:80/'
def test_HeadersDict():
# Simple test of CaseInsensitiveDict
h = HeadersDict()
assert h == {}
h['Content-Type'] = 'application/json'
assert h == {'Content-Type': 'application/json'}
assert h['content-type'] == 'application/json'
assert h['CONTENT-TYPE'] == 'application/json'
# Test feature of HeadersDict: devolve list to first element
h = HeadersDict()
assert h == {}
h['x'] = ['foo', 'bar']
assert h == {'x': 'foo'}
# Test feature of HeadersDict: preserve original key case
h = HeadersDict()
assert h == {}
h['Content-Type'] = 'application/json'
assert h == {'Content-Type': 'application/json'}
h['content-type'] = 'text/plain'
assert h == {'Content-Type': 'text/plain'}
h['CONtent-tyPE'] = 'whoa'
assert h == {'Content-Type': 'whoa'}

View File

@@ -1,19 +1,16 @@
from six import BytesIO, text_type from six import BytesIO, text_type
from six.moves.urllib.parse import urlparse, urlencode, urlunparse from six.moves.urllib.parse import urlparse, urlencode, urlunparse
import copy
import json import json
from .compat import collections from .compat import collections
def remove_headers(request, headers_to_remove): def remove_headers(request, headers_to_remove):
headers = copy.copy(request.headers) new_headers = request.headers.copy()
headers_to_remove = [h.lower() for h in headers_to_remove] for k in headers_to_remove:
keys = [k for k in headers if k.lower() in headers_to_remove] if k in new_headers:
if keys: del new_headers[k]
for k in keys: request.headers = new_headers
headers.pop(k)
request.headers = headers
return request return request
@@ -30,8 +27,7 @@ def remove_query_parameters(request, query_parameters_to_remove):
def remove_post_data_parameters(request, post_data_parameters_to_remove): def remove_post_data_parameters(request, post_data_parameters_to_remove):
if request.method == 'POST' and not isinstance(request.body, BytesIO): if request.method == 'POST' and not isinstance(request.body, BytesIO):
if ('Content-Type' in request.headers and if request.headers.get('Content-Type') == 'application/json':
request.headers['Content-Type'] == 'application/json'):
json_data = json.loads(request.body.decode('utf-8')) json_data = json.loads(request.body.decode('utf-8'))
for k in list(json_data.keys()): for k in list(json_data.keys()):
if k in post_data_parameters_to_remove: if k in post_data_parameters_to_remove:

View File

@@ -1,6 +1,6 @@
import json import json
from six.moves import urllib, xmlrpc_client from six.moves import urllib, xmlrpc_client
from .util import CaseInsensitiveDict, read_body from .util import read_body
import logging import logging
@@ -66,9 +66,8 @@ def _identity(x):
def _get_transformer(request): def _get_transformer(request):
headers = CaseInsensitiveDict(request.headers)
for checker, transformer in _checker_transformer_pairs: for checker, transformer in _checker_transformer_pairs:
if checker(headers): return transformer if checker(request.headers): return transformer
else: else:
return _identity return _identity

View File

@@ -1,27 +1,12 @@
import warnings
from six import BytesIO, text_type from six import BytesIO, text_type
from six.moves.urllib.parse import urlparse, parse_qsl from six.moves.urllib.parse import urlparse, parse_qsl
from .util import CaseInsensitiveDict
class Request(object): class Request(object):
""" """
VCR's representation of a request. VCR's representation of a request.
There is a weird quirk in HTTP. You can send the same header twice. For
this reason, headers are represented by a dict, with lists as the values.
However, it appears that HTTPlib is completely incapable of sending the
same header twice. This puts me in a weird position: I want to be able to
accurately represent HTTP headers in cassettes, but I don't want the extra
step of always having to do [0] in the general case, i.e.
request.headers['key'][0]
In addition, some servers sometimes send the same header more than once,
and httplib *can* deal with this situation.
Futhermore, I wanted to keep the request and response cassette format as
similar as possible.
For this reason, in cassettes I keep a dict with lists as keys, but once
deserialized into VCR, I keep them as plain, naked dicts.
""" """
def __init__(self, method, uri, body, headers): def __init__(self, method, uri, body, headers):
@@ -32,9 +17,17 @@ class Request(object):
self.body = body.read() self.body = body.read()
else: else:
self.body = body self.body = body
self.headers = {} self.headers = headers
for key in headers:
self.add_header(key, headers[key]) @property
def headers(self):
return self._headers
@headers.setter
def headers(self, value):
if not isinstance(value, HeadersDict):
value = HeadersDict(value)
self._headers = value
@property @property
def body(self): def body(self):
@@ -47,11 +40,10 @@ class Request(object):
self._body = value self._body = value
def add_header(self, key, value): def add_header(self, key, value):
# see class docstring for an explanation warnings.warn("Request.add_header is deprecated. "
if isinstance(value, (tuple, list)): "Please assign to request.headers instead.",
self.headers[key] = value[0] DeprecationWarning)
else: self.headers[key] = value
self.headers[key] = value
@property @property
def scheme(self): def scheme(self):
@@ -105,3 +97,35 @@ class Request(object):
@classmethod @classmethod
def _from_dict(cls, dct): def _from_dict(cls, dct):
return Request(**dct) return Request(**dct)
class HeadersDict(CaseInsensitiveDict):
"""
There is a weird quirk in HTTP. You can send the same header twice. For
this reason, headers are represented by a dict, with lists as the values.
However, it appears that HTTPlib is completely incapable of sending the
same header twice. This puts me in a weird position: I want to be able to
accurately represent HTTP headers in cassettes, but I don't want the extra
step of always having to do [0] in the general case, i.e.
request.headers['key'][0]
In addition, some servers sometimes send the same header more than once,
and httplib *can* deal with this situation.
Futhermore, I wanted to keep the request and response cassette format as
similar as possible.
For this reason, in cassettes I keep a dict with lists as keys, but once
deserialized into VCR, I keep them as plain, naked dicts.
"""
def __setitem__(self, key, value):
if isinstance(value, (tuple, list)):
value = value[0]
# Preserve the case from the first time this key was set.
old = self._store.get(key.lower())
if old:
key = old[0]
super(HeadersDict, self).__setitem__(key, value)

View File

@@ -188,8 +188,7 @@ class VCRConnection(object):
log.debug('Got {0}'.format(self._vcr_request)) log.debug('Got {0}'.format(self._vcr_request))
def putheader(self, header, *values): def putheader(self, header, *values):
for value in values: self._vcr_request.headers[header] = values
self._vcr_request.add_header(header, value)
def send(self, data): def send(self, data):
''' '''

View File

@@ -15,7 +15,7 @@ def vcr_fetch_impl(cassette, real_fetch_impl):
@functools.wraps(real_fetch_impl) @functools.wraps(real_fetch_impl)
def new_fetch_impl(self, request, callback): def new_fetch_impl(self, request, callback):
headers = dict(request.headers) headers = request.headers.copy()
if request.user_agent: if request.user_agent:
headers.setdefault('User-Agent', request.user_agent) headers.setdefault('User-Agent', request.user_agent)