diff --git a/tests/unit/test_filters.py b/tests/unit/test_filters.py index 69d26bf..6e729ad 100644 --- a/tests/unit/test_filters.py +++ b/tests/unit/test_filters.py @@ -220,6 +220,49 @@ def test_remove_all_json_post_data_parameters(): assert request.body == b"{}" +def test_replace_dict_post_data_parameters(): + # This tests all of: + # 1. keeping a parameter + # 2. removing a parameter + # 3. replacing a parameter + # 4. replacing a parameter using a callable + # 5. removing a parameter using a callable + # 6. replacing a parameter that doesn't exist + body = {"one": "keep", "two": "lose", "three": "change", "four": "shout", "five": "whisper"} + request = Request("POST", "http://google.com", body, {}) + request.headers["Content-Type"] = "application/x-www-form-urlencoded" + replace_post_data_parameters( + request, + [ + ("two", None), + ("three", "tada"), + ("four", lambda key, value, request: value.upper()), + ("five", lambda key, value, request: None), + ("six", "doesntexist"), + ], + ) + expected_data = {"one": "keep", "three": "tada", "four": "SHOUT"} + assert request.body == expected_data + + +def test_remove_dict_post_data_parameters(): + # Test the backward-compatible API wrapper. + body = {"id": "secret", "foo": "bar", "baz": "qux"} + request = Request("POST", "http://google.com", body, {}) + request.headers["Content-Type"] = "application/x-www-form-urlencoded" + remove_post_data_parameters(request, ["id"]) + expected_data = {"foo": "bar", "baz": "qux"} + assert request.body == expected_data + + +def test_remove_all_dict_post_data_parameters(): + body = {"id": "secret", "foo": "bar"} + request = Request("POST", "http://google.com", body, {}) + request.headers["Content-Type"] = "application/x-www-form-urlencoded" + replace_post_data_parameters(request, [("id", None), ("foo", None)]) + assert request.body == {} + + def test_decode_response_uncompressed(): recorded_response = { "status": {"message": "OK", "code": 200}, diff --git a/vcr/filters.py b/vcr/filters.py index 62254ed..8e00b64 100644 --- a/vcr/filters.py +++ b/vcr/filters.py @@ -84,7 +84,17 @@ def replace_post_data_parameters(request, replacements): replacements = dict(replacements) if request.method == "POST" and not isinstance(request.body, BytesIO): - if request.headers.get("Content-Type") == "application/json": + if isinstance(request.body, dict): + new_body = request.body.copy() + for k, rv in replacements.items(): + if k in new_body: + ov = new_body.pop(k) + if callable(rv): + rv = rv(key=k, value=ov, request=request) + if rv is not None: + new_body[k] = rv + request.body = new_body + elif request.headers.get("Content-Type") == "application/json": json_data = json.loads(request.body.decode("utf-8")) for k, rv in replacements.items(): if k in json_data: