From 64397d7eccb163924d6bb2527f94b6be8d7d6cde Mon Sep 17 00:00:00 2001 From: Olutobi Owoputi Date: Wed, 2 Dec 2015 12:25:36 -0800 Subject: [PATCH] add decode_compressed_response option and filter --- vcr/config.py | 9 ++++++++- vcr/filters.py | 35 +++++++++++++++++++++++++++++++++++ 2 files changed, 43 insertions(+), 1 deletion(-) diff --git a/vcr/config.py b/vcr/config.py index 728b17f..ae99346 100644 --- a/vcr/config.py +++ b/vcr/config.py @@ -35,7 +35,8 @@ class VCR(object): before_record_response=None, filter_post_data_parameters=(), match_on=('method', 'scheme', 'host', 'port', 'path', 'query'), before_record=None, inject_cassette=False, serializer='yaml', - cassette_library_dir=None, func_path_generator=None): + cassette_library_dir=None, func_path_generator=None, + decode_compressed_response=False): self.serializer = serializer self.match_on = match_on self.cassette_library_dir = cassette_library_dir @@ -67,6 +68,7 @@ class VCR(object): self.inject_cassette = inject_cassette self.path_transformer = path_transformer self.func_path_generator = func_path_generator + self.decode_compressed_response = decode_compressed_response self._custom_patches = tuple(custom_patches) def _get_serializer(self, serializer_name): @@ -163,7 +165,12 @@ class VCR(object): before_record_response = options.get( 'before_record_response', self.before_record_response ) + decode_compressed_response = options.get( + 'decode_compressed_response', self.decode_compressed_response + ) filter_functions = [] + if decode_compressed_response: + filter_functions.append(filters.decode_response) if before_record_response: if not isinstance(before_record_response, collections.Iterable): before_record_response = (before_record_response,) diff --git a/vcr/filters.py b/vcr/filters.py index 6070d79..e6e10fd 100644 --- a/vcr/filters.py +++ b/vcr/filters.py @@ -1,6 +1,7 @@ from six import BytesIO, text_type from six.moves.urllib.parse import urlparse, urlencode, urlunparse import json +import zlib def replace_headers(request, replacements): @@ -120,3 +121,37 @@ def remove_post_data_parameters(request, post_data_parameters_to_remove): """ replacements = [(k, None) for k in post_data_parameters_to_remove] return replace_post_data_parameters(request, replacements) + + +def decode_response(response): + """ + If the response is compressed with gzip or deflate: + 1. decompress the response body + 2. delete the content-encoding header + 3. update content-length header to decompressed length + """ + def is_compressed(response): + encoding = response['headers'].get('content-encoding', []) + return encoding and encoding[0] in ('gzip', 'deflate') + + def decompress_body(body, encoding): + """Returns decompressed body according to encoding using zlib. + to (de-)compress gzip format, use wbits = zlib.MAX_WBITS | 16 + """ + if encoding == 'gzip': + return zlib.decompress(body, zlib.MAX_WBITS | 16) + else: # encoding == 'deflate' + return zlib.decompress(body) + + if is_compressed(response): + response = response.copy() + encoding = response['headers']['content-encoding'][0] + + response['headers']['content-encoding'].remove(encoding) + if not response['headers']['content-encoding']: + del response['headers']['content-encoding'] + + new_body = decompress_body(response['body']['string'], encoding) + response['body']['string'] = new_body + response['content-length'] = len(new_body) + return response