1
0
mirror of https://github.com/kevin1024/vcrpy.git synced 2025-12-08 16:53:23 +00:00

Add ability to filter post data parameters

This commit is contained in:
Edward Stone
2015-03-30 22:30:03 -07:00
committed by Ivan Malison
parent 0dd7b05990
commit 0def349420
5 changed files with 79 additions and 3 deletions

View File

@@ -294,6 +294,15 @@ with my_vcr.use_cassette('test.yml', filter_query_parameters=['api_key']):
requests.get('http://api.com/getdata?api_key=secretstring') requests.get('http://api.com/getdata?api_key=secretstring')
``` ```
### Filter information from HTTP post data
Use the `filter_post_data_parameters` configuration option with a list of query
parameters to filter.
```python
with my_vcr.use_cassette('test.yml', filter_post_data_parameters=['client_secret']):
requests.post('http://api.com/postdata', data={'api_key': 'secretstring'})
```
### Custom Request filtering ### Custom Request filtering
If neither of these covers your request filtering needs, you can register a callback If neither of these covers your request filtering needs, you can register a callback

View File

@@ -1,6 +1,7 @@
import base64 import base64
import pytest import pytest
from six.moves.urllib.request import urlopen, Request from six.moves.urllib.request import urlopen, Request
from six.moves.urllib.parse import urlencode
from six.moves.urllib.error import HTTPError from six.moves.urllib.error import HTTPError
import vcr import vcr
@@ -55,6 +56,16 @@ def test_filter_querystring(tmpdir):
assert 'foo' not in cass.requests[0].url assert 'foo' not in cass.requests[0].url
def test_filter_post_data(tmpdir):
url = 'http://httpbin.org/post'
data = urlencode({'id': 'secret', 'foo': 'bar'}).encode('utf-8')
cass_file = str(tmpdir.join('filter_pd.yaml'))
with vcr.use_cassette(cass_file, filter_post_data_parameters=['id']):
urlopen(url, data)
with vcr.use_cassette(cass_file, filter_post_data_parameters=['id']) as cass:
assert b'id=secret' not in cass.requests[0].body
def test_filter_callback(tmpdir): def test_filter_callback(tmpdir):
url = 'http://httpbin.org/get' url = 'http://httpbin.org/get'
cass_file = str(tmpdir.join('basic_auth_filter.yaml')) cass_file = str(tmpdir.join('basic_auth_filter.yaml'))

View File

@@ -1,4 +1,8 @@
from vcr.filters import remove_headers, remove_query_parameters from vcr.filters import (
remove_headers,
remove_query_parameters,
remove_post_data_parameters
)
from vcr.request import Request from vcr.request import Request
@@ -35,3 +39,31 @@ def test_remove_nonexistent_query_parameters():
request = Request('GET', uri, '', {}) request = Request('GET', uri, '', {})
remove_query_parameters(request, ['w', 'q']) remove_query_parameters(request, ['w', 'q'])
assert request.uri == 'http://g.com/' assert request.uri == 'http://g.com/'
def test_remove_post_data_parameters():
body = b'id=secret&foo=bar'
request = Request('POST', 'http://google.com', body, {})
remove_post_data_parameters(request, ['id'])
assert request.body == b'foo=bar'
def test_preserve_multiple_post_data_parameters():
body = b'id=secret&foo=bar&foo=baz'
request = Request('POST', 'http://google.com', body, {})
remove_post_data_parameters(request, ['id'])
assert request.body == b'foo=bar&foo=baz'
def test_remove_all_post_data_parameters():
body = b'id=secret&foo=bar'
request = Request('POST', 'http://google.com', body, {})
remove_post_data_parameters(request, ['id', 'foo'])
assert request.body == b''
def test_remove_nonexistent_post_data_parameters():
body = b''
request = Request('POST', 'http://google.com', body, {})
remove_post_data_parameters(request, ['id'])
assert request.body == b''

View File

@@ -13,8 +13,8 @@ class VCR(object):
def __init__(self, serializer='yaml', cassette_library_dir=None, def __init__(self, serializer='yaml', cassette_library_dir=None,
record_mode="once", filter_headers=(), custom_patches=(), record_mode="once", filter_headers=(), custom_patches=(),
filter_query_parameters=(), before_record_request=None, filter_query_parameters=(), filter_post_data_parameters=(),
before_record_response=None, ignore_hosts=(), before_record_request=None, before_record_response=None, ignore_hosts=(),
match_on=('method', 'scheme', 'host', 'port', 'path', 'query',), match_on=('method', 'scheme', 'host', 'port', 'path', 'query',),
ignore_localhost=False, before_record=None): ignore_localhost=False, before_record=None):
self.serializer = serializer self.serializer = serializer
@@ -39,6 +39,7 @@ class VCR(object):
self.record_mode = record_mode self.record_mode = record_mode
self.filter_headers = filter_headers self.filter_headers = filter_headers
self.filter_query_parameters = filter_query_parameters self.filter_query_parameters = filter_query_parameters
self.filter_post_data_parameters = filter_post_data_parameters
self.before_record_request = before_record_request or before_record self.before_record_request = before_record_request or before_record
self.before_record_response = before_record_response self.before_record_response = before_record_response
self.ignore_hosts = ignore_hosts self.ignore_hosts = ignore_hosts
@@ -121,6 +122,9 @@ class VCR(object):
filter_query_parameters = options.get( filter_query_parameters = options.get(
'filter_query_parameters', self.filter_query_parameters 'filter_query_parameters', self.filter_query_parameters
) )
filter_post_data_parameters = options.get(
'filter_post_data_parameters', self.filter_post_data_parameters
)
before_record_request = options.get( before_record_request = options.get(
"before_record_request", options.get("before_record", self.before_record_request) "before_record_request", options.get("before_record", self.before_record_request)
) )
@@ -137,6 +141,10 @@ class VCR(object):
filter_functions.append(functools.partial(filters.remove_query_parameters, filter_functions.append(functools.partial(filters.remove_query_parameters,
query_parameters_to_remove=filter_query_parameters)) query_parameters_to_remove=filter_query_parameters))
if filter_post_data_parameters:
filter_functions.append(functools.partial(filters.remove_post_data_parameters,
post_data_parameters_to_remove=filter_post_data_parameters))
hosts_to_ignore = list(ignore_hosts) hosts_to_ignore = list(ignore_hosts)
if ignore_localhost: if ignore_localhost:
hosts_to_ignore.extend(('localhost', '0.0.0.0', '127.0.0.1')) hosts_to_ignore.extend(('localhost', '0.0.0.0', '127.0.0.1'))

View File

@@ -1,4 +1,6 @@
from six import BytesIO
from six.moves.urllib.parse import urlparse, urlencode, urlunparse from six.moves.urllib.parse import urlparse, urlencode, urlunparse
from collections import OrderedDict
import copy import copy
@@ -22,3 +24,17 @@ def remove_query_parameters(request, query_parameters_to_remove):
uri_parts[4] = urlencode(new_query) uri_parts[4] = urlencode(new_query)
request.uri = urlunparse(uri_parts) request.uri = urlunparse(uri_parts)
return request return request
def remove_post_data_parameters(request, post_data_parameters_to_remove):
if request.method == 'POST' and not isinstance(request.body, BytesIO):
post_data = OrderedDict()
for k, sep, v in [p.partition(b'=') for p in request.body.split(b'&')]:
if k in post_data:
post_data[k].append(v)
elif len(k) > 0 and k.decode('utf-8') not in post_data_parameters_to_remove:
post_data[k] = [v]
request.body = b'&'.join(
b'='.join([k, v])
for k, vals in post_data.items() for v in vals)
return request