diff --git a/tests/unit/test_filters.py b/tests/unit/test_filters.py index c426e7c..26bfdfa 100644 --- a/tests/unit/test_filters.py +++ b/tests/unit/test_filters.py @@ -1,6 +1,6 @@ from vcr.filters import ( remove_headers, replace_headers, - remove_query_parameters, + remove_query_parameters, replace_query_parameters, remove_post_data_parameters ) from vcr.compat import mock @@ -66,25 +66,56 @@ def test_remove_headers(): assert request.headers == {'hello': 'goodbye'} -def test_remove_query_parameters(): - uri = 'http://g.com/?q=cowboys&w=1' +def test_replace_query_parameters(): + # This tests all of: + # 1. keeping a parameter + # 2. removing a parameter + # 3. replacing a parameter + # 4. replacing a parameter using a callable + # 5. removing a parameter using a callable + # 6. replacing a parameter that doesn't exist + uri = 'http://g.com/?one=keep&two=lose&three=change&four=shout&five=whisper' request = Request('GET', uri, '', {}) - remove_query_parameters(request, ['w']) - assert request.uri == 'http://g.com/?q=cowboys' + replace_query_parameters(request, [ + ('two', None), + ('three', 'tada'), + ('four', lambda key, value, request: value.upper()), + ('five', lambda key, value, request: None), + ('six', 'doesntexist'), + ]) + assert request.query == [ + ('four', 'SHOUT'), + ('one', 'keep'), + ('three', 'tada'), + ] def test_remove_all_query_parameters(): uri = 'http://g.com/?q=cowboys&w=1' request = Request('GET', uri, '', {}) - remove_query_parameters(request, ['w', 'q']) + replace_query_parameters(request, [('w', None), ('q', None)]) assert request.uri == 'http://g.com/' -def test_remove_nonexistent_query_parameters(): - uri = 'http://g.com/' +def test_replace_query_parameters_callable(): + # This goes beyond test_replace_query_parameters() to ensure that the + # callable receives the expected arguments. + uri = 'http://g.com/?hey=there' request = Request('GET', uri, '', {}) - remove_query_parameters(request, ['w', 'q']) - assert request.uri == 'http://g.com/' + callme = mock.Mock(return_value='ho') + replace_query_parameters(request, [('hey', callme)]) + assert request.uri == 'http://g.com/?hey=ho' + assert callme.call_args == ((), {'request': request, + 'key': 'hey', + 'value': 'there'}) + + +def test_remove_query_parameters(): + # Test the backward-compatible API wrapper. + uri = 'http://g.com/?q=cowboys&w=1' + request = Request('GET', uri, '', {}) + remove_query_parameters(request, ['w']) + assert request.uri == 'http://g.com/?q=cowboys' def test_remove_post_data_parameters(): diff --git a/vcr/filters.py b/vcr/filters.py index 6433ddb..3fcb9db 100644 --- a/vcr/filters.py +++ b/vcr/filters.py @@ -33,17 +33,42 @@ def remove_headers(request, headers_to_remove): return replace_headers(request, replacements) -def remove_query_parameters(request, query_parameters_to_remove): +def replace_query_parameters(request, replacements): + """ + Replace query parameters in request according to replacements. The + replacements should be a list of (key, value) pairs where the value can be + any of: + 1. A simple replacement string value. + 2. None to remove the given header. + 3. A callable which accepts (key, value, request) and returns a string + value or None. + """ query = request.query - new_query = [(k, v) for (k, v) in query - if k not in query_parameters_to_remove] - if len(new_query) != len(query): - uri_parts = list(urlparse(request.uri)) - uri_parts[4] = urlencode(new_query) - request.uri = urlunparse(uri_parts) + new_query = [] + replacements = dict(replacements) + for k, ov in query: + if k not in replacements: + new_query.append((k, ov)) + else: + rv = replacements[k] + if callable(rv): + rv = rv(key=k, value=ov, request=request) + if rv is not None: + new_query.append((k, rv)) + uri_parts = list(urlparse(request.uri)) + uri_parts[4] = urlencode(new_query) + request.uri = urlunparse(uri_parts) return request +def remove_query_parameters(request, query_parameters_to_remove): + """ + Wrap replace_query_parameters() for API backward compatibility. + """ + replacements = [(k, None) for k in query_parameters_to_remove] + return replace_query_parameters(request, replacements) + + def remove_post_data_parameters(request, post_data_parameters_to_remove): if request.method == 'POST' and not isinstance(request.body, BytesIO): if request.headers.get('Content-Type') == 'application/json':