From cb40a45eba98c6baf218cc069a726f4ffa1033f1 Mon Sep 17 00:00:00 2001 From: Aron Griffis Date: Tue, 25 Aug 2015 07:15:51 -0400 Subject: [PATCH] Add replace_post_data_parameters() --- tests/unit/test_filters.py | 62 ++++++++++++++++++++++++++++---------- vcr/filters.py | 59 +++++++++++++++++++++++++++--------- 2 files changed, 90 insertions(+), 31 deletions(-) diff --git a/tests/unit/test_filters.py b/tests/unit/test_filters.py index 26bfdfa..be1d657 100644 --- a/tests/unit/test_filters.py +++ b/tests/unit/test_filters.py @@ -1,7 +1,7 @@ from vcr.filters import ( remove_headers, replace_headers, remove_query_parameters, replace_query_parameters, - remove_post_data_parameters + remove_post_data_parameters, replace_post_data_parameters, ) from vcr.compat import mock from vcr.request import Request @@ -118,7 +118,28 @@ def test_remove_query_parameters(): assert request.uri == 'http://g.com/?q=cowboys' +def test_replace_post_data_parameters(): + # This tests all of: + # 1. keeping a parameter + # 2. removing a parameter + # 3. replacing a parameter + # 4. replacing a parameter using a callable + # 5. removing a parameter using a callable + # 6. replacing a parameter that doesn't exist + body = b'one=keep&two=lose&three=change&four=shout&five=whisper' + request = Request('POST', 'http://google.com', body, {}) + replace_post_data_parameters(request, [ + ('two', None), + ('three', 'tada'), + ('four', lambda key, value, request: value.upper()), + ('five', lambda key, value, request: None), + ('six', 'doesntexist'), + ]) + assert request.body == b'one=keep&three=tada&four=SHOUT' + + def test_remove_post_data_parameters(): + # Test the backward-compatible API wrapper. body = b'id=secret&foo=bar' request = Request('POST', 'http://google.com', body, {}) remove_post_data_parameters(request, ['id']) @@ -128,25 +149,42 @@ def test_remove_post_data_parameters(): def test_preserve_multiple_post_data_parameters(): body = b'id=secret&foo=bar&foo=baz' request = Request('POST', 'http://google.com', body, {}) - remove_post_data_parameters(request, ['id']) + replace_post_data_parameters(request, [('id', None)]) assert request.body == b'foo=bar&foo=baz' def test_remove_all_post_data_parameters(): body = b'id=secret&foo=bar' request = Request('POST', 'http://google.com', body, {}) - remove_post_data_parameters(request, ['id', 'foo']) + replace_post_data_parameters(request, [('id', None), ('foo', None)]) assert request.body == b'' -def test_remove_nonexistent_post_data_parameters(): - body = b'' +def test_replace_json_post_data_parameters(): + # This tests all of: + # 1. keeping a parameter + # 2. removing a parameter + # 3. replacing a parameter + # 4. replacing a parameter using a callable + # 5. removing a parameter using a callable + # 6. replacing a parameter that doesn't exist + body = b'{"one": "keep", "two": "lose", "three": "change", "four": "shout", "five": "whisper"}' request = Request('POST', 'http://google.com', body, {}) - remove_post_data_parameters(request, ['id']) - assert request.body == b'' + request.headers['Content-Type'] = 'application/json' + replace_post_data_parameters(request, [ + ('two', None), + ('three', 'tada'), + ('four', lambda key, value, request: value.upper()), + ('five', lambda key, value, request: None), + ('six', 'doesntexist'), + ]) + request_data = json.loads(request.body.decode('utf-8')) + expected_data = json.loads('{"one": "keep", "three": "tada", "four": "SHOUT"}') + assert request_data == expected_data def test_remove_json_post_data_parameters(): + # Test the backward-compatible API wrapper. body = b'{"id": "secret", "foo": "bar", "baz": "qux"}' request = Request('POST', 'http://google.com', body, {}) request.headers['Content-Type'] = 'application/json' @@ -160,13 +198,5 @@ def test_remove_all_json_post_data_parameters(): body = b'{"id": "secret", "foo": "bar"}' request = Request('POST', 'http://google.com', body, {}) request.headers['Content-Type'] = 'application/json' - remove_post_data_parameters(request, ['id', 'foo']) - assert request.body == b'{}' - - -def test_remove_nonexistent_json_post_data_parameters(): - body = b'{}' - request = Request('POST', 'http://google.com', body, {}) - request.headers['Content-Type'] = 'application/json' - remove_post_data_parameters(request, ['id']) + replace_post_data_parameters(request, [('id', None), ('foo', None)]) assert request.body == b'{}' diff --git a/vcr/filters.py b/vcr/filters.py index 3fcb9db..6070d79 100644 --- a/vcr/filters.py +++ b/vcr/filters.py @@ -2,7 +2,6 @@ from six import BytesIO, text_type from six.moves.urllib.parse import urlparse, urlencode, urlunparse import json -from .compat import collections def replace_headers(request, replacements): """ @@ -69,25 +68,55 @@ def remove_query_parameters(request, query_parameters_to_remove): return replace_query_parameters(request, replacements) -def remove_post_data_parameters(request, post_data_parameters_to_remove): +def replace_post_data_parameters(request, replacements): + """ + Replace post data in request--either form data or json--according to + replacements. The replacements should be a list of (key, value) pairs where + the value can be any of: + 1. A simple replacement string value. + 2. None to remove the given header. + 3. A callable which accepts (key, value, request) and returns a string + value or None. + """ + replacements = dict(replacements) if request.method == 'POST' and not isinstance(request.body, BytesIO): if request.headers.get('Content-Type') == 'application/json': json_data = json.loads(request.body.decode('utf-8')) - for k in list(json_data.keys()): - if k in post_data_parameters_to_remove: - del json_data[k] + for k, rv in replacements.items(): + if k in json_data: + ov = json_data.pop(k) + if callable(rv): + rv = rv(key=k, value=ov, request=request) + if rv is not None: + json_data[k] = rv request.body = json.dumps(json_data).encode('utf-8') else: - post_data = collections.OrderedDict() if isinstance(request.body, text_type): request.body = request.body.encode('utf-8') - - for k, sep, v in (p.partition(b'=') for p in request.body.split(b'&')): - if k in post_data: - post_data[k].append(v) - elif len(k) > 0 and k.decode('utf-8') not in post_data_parameters_to_remove: - post_data[k] = [v] - request.body = b'&'.join( - b'='.join([k, v]) - for k, vals in post_data.items() for v in vals) + splits = [p.partition(b'=') for p in request.body.split(b'&')] + new_splits = [] + for k, sep, ov in splits: + if sep is None: + new_splits.append((k, sep, ov)) + else: + rk = k.decode('utf-8') + if rk not in replacements: + new_splits.append((k, sep, ov)) + else: + rv = replacements[rk] + if callable(rv): + rv = rv(key=rk, value=ov.decode('utf-8'), + request=request) + if rv is not None: + new_splits.append((k, sep, rv.encode('utf-8'))) + request.body = b'&'.join(k if sep is None else b''.join([k, sep, v]) + for k, sep, v in new_splits) return request + + +def remove_post_data_parameters(request, post_data_parameters_to_remove): + """ + Wrap replace_post_data_parameters() for API backward compatibility. + """ + replacements = [(k, None) for k in post_data_parameters_to_remove] + return replace_post_data_parameters(request, replacements)