diff --git a/.gitignore b/.gitignore index d0a4d9ec..78056640 100644 --- a/.gitignore +++ b/.gitignore @@ -6,3 +6,5 @@ dist/ .coverage *.egg-info/ pytestdebug.log + +fixtures/ \ No newline at end of file diff --git a/README.md b/README.md index 0a9c954d..7ed85af0 100644 --- a/README.md +++ b/README.md @@ -2,7 +2,7 @@ ![vcr.py](https://raw.github.com/kevin1024/vcrpy/master/vcr.png) -This is a Python version of [Ruby's VCR library](https://github.com/myronmarston/vcr). +This is a Python version of [Ruby's VCR library](https://github.com/vcr/vcr). [![Build Status](https://secure.travis-ci.org/kevin1024/vcrpy.png?branch=master)](http://travis-ci.org/kevin1024/vcrpy) [![Stories in Ready](https://badge.waffle.io/kevin1024/vcrpy.png?label=ready&title=Ready)](https://waffle.io/kevin1024/vcrpy) @@ -176,13 +176,13 @@ with vcr.use_cassette('fixtures/vcr_cassettes/synopsis.yaml') as cass: The `Cassette` object exposes the following properties which I consider part of the API. The fields are as follows: - * `requests`: A list of vcr.Request objects containing the requests made while - this cassette was being used, ordered by the order that the request was made. + * `requests`: A list of vcr.Request objects corresponding to the http requests + that were made during the recording of the cassette. The requests appear in the + order that they were originally processed. * `responses`: A list of the responses made. - * `play_count`: The number of times this cassette has had a response played - back - * `all_played`: A boolean indicates whether all the responses have been - played back + * `play_count`: The number of times this cassette has played back a response. + * `all_played`: A boolean indicating whether all the responses have been + played back. * `responses_of(request)`: Access the responses that match a given request The `Request` object has the following properties: @@ -215,7 +215,7 @@ Finally, register your class with VCR to use your new serializer. ```python import vcr -BogoSerializer(object): +class BogoSerializer(object): """ Must implement serialize() and deserialize() methods """ @@ -293,12 +293,12 @@ with my_vcr.use_cassette('test.yml', filter_query_parameters=['api_key']): requests.get('http://api.com/getdata?api_key=secretstring') ``` -### Custom request filtering +### Custom Request filtering -If neither of these covers your use case, you can register a callback that will -manipulate the HTTP request before adding it to the cassette. Use the -`before_record` configuration option to so this. Here is an -example that will never record requests to the /login endpoint. +If neither of these covers your request filtering needs, you can register a callback +that will manipulate the HTTP request before adding it to the cassette. Use the +`before_record` configuration option to so this. Here is an example that will + never record requests to the /login endpoint. ```python def before_record_cb(request): @@ -312,6 +312,40 @@ with my_vcr.use_cassette('test.yml'): # your http code here ``` +You can also mutate the response using this callback. For example, you could +remove all query parameters from any requests to the `'/login'` path. + +```python +def scrub_login_request(request): + if request.path == '/login': + request.uri, _ = urllib.splitquery(response.uri) + return request + +my_vcr = vcr.VCR( + before_record=scrub_login_request, +) +with my_vcr.use_cassette('test.yml'): + # your http code here +``` + +### Custom Response Filtering + +VCR.py also suports response filtering with the `before_record_response` keyword +argument. It's usage is similar to that of `before_record`: + +```python +def scrub_string(string, replacement=''): + def before_record_reponse(response): + return response['body']['string] = response['body']['string].replace(string, replacement) + return scrub_string + +my_vcr = vcr.VCR( + before_record=scrub_string(settings.USERNAME, 'username'), +) +with my_vcr.use_cassette('test.yml'): + # your http code here +``` + ## Ignore requests If you would like to completely ignore certain requests, you can do it in a @@ -335,7 +369,7 @@ to `brew install libyaml` [[Homebrew](http://mxcl.github.com/homebrew/)]) ## Ruby VCR compatibility -I'm not trying to match the format of the Ruby VCR YAML files. Cassettes +VCR.py does not aim to match the format of the Ruby VCR YAML files. Cassettes generated by Ruby's VCR are not compatible with VCR.py. ## Running VCR's test suite @@ -356,7 +390,7 @@ installed. Also, in order for the boto tests to run, you will need an AWS key. Refer to the [boto documentation](http://boto.readthedocs.org/en/latest/getting_started.html) for -how to set this up. I have marked the boto tests as optional in Travis so you +how to set this up. I have marked the boto tests as optional in Travis so you don't have to worry about them failing if you submit a pull request. @@ -423,6 +457,8 @@ API in version 1.0.x ## Changelog + * 1.1.0 Add `before_record_response`. Fix several bugs related to the context + management of cassettes. * 1.0.3: Fix an issue with requests 2.4 and make sure case sensitivity is consistent across python versions * 1.0.2: Fix an issue with requests 2.3 diff --git a/setup.py b/setup.py index dc2f629b..df313706 100644 --- a/setup.py +++ b/setup.py @@ -20,7 +20,7 @@ def run_tests(self): setup( name='vcrpy', - version='1.0.3', + version='1.1.0', description=( "Automatically mock your HTTP interactions to simplify and " "speed up testing" @@ -41,7 +41,7 @@ def run_tests(self): 'vcr.compat': 'vcr/compat', 'vcr.persisters': 'vcr/persisters', }, - install_requires=['PyYAML', 'contextdecorator', 'six'], + install_requires=['PyYAML', 'mock', 'six', 'contextlib2'], license='MIT', tests_require=['pytest', 'mock', 'pytest-localserver'], cmdclass={'test': PyTest}, diff --git a/tests/integration/test_requests.py b/tests/integration/test_requests.py index 36131554..9f6484fd 100644 --- a/tests/integration/test_requests.py +++ b/tests/integration/test_requests.py @@ -24,30 +24,30 @@ def scheme(request): def test_status_code(scheme, tmpdir): '''Ensure that we can read the status code''' url = scheme + '://httpbin.org/' - with vcr.use_cassette(str(tmpdir.join('atts.yaml'))) as cass: + with vcr.use_cassette(str(tmpdir.join('atts.yaml'))): status_code = requests.get(url).status_code - with vcr.use_cassette(str(tmpdir.join('atts.yaml'))) as cass: + with vcr.use_cassette(str(tmpdir.join('atts.yaml'))): assert status_code == requests.get(url).status_code def test_headers(scheme, tmpdir): '''Ensure that we can read the headers back''' url = scheme + '://httpbin.org/' - with vcr.use_cassette(str(tmpdir.join('headers.yaml'))) as cass: + with vcr.use_cassette(str(tmpdir.join('headers.yaml'))): headers = requests.get(url).headers - with vcr.use_cassette(str(tmpdir.join('headers.yaml'))) as cass: + with vcr.use_cassette(str(tmpdir.join('headers.yaml'))): assert headers == requests.get(url).headers def test_body(tmpdir, scheme): '''Ensure the responses are all identical enough''' url = scheme + '://httpbin.org/bytes/1024' - with vcr.use_cassette(str(tmpdir.join('body.yaml'))) as cass: + with vcr.use_cassette(str(tmpdir.join('body.yaml'))): content = requests.get(url).content - with vcr.use_cassette(str(tmpdir.join('body.yaml'))) as cass: + with vcr.use_cassette(str(tmpdir.join('body.yaml'))): assert content == requests.get(url).content @@ -55,10 +55,10 @@ def test_auth(tmpdir, scheme): '''Ensure that we can handle basic auth''' auth = ('user', 'passwd') url = scheme + '://httpbin.org/basic-auth/user/passwd' - with vcr.use_cassette(str(tmpdir.join('auth.yaml'))) as cass: + with vcr.use_cassette(str(tmpdir.join('auth.yaml'))): one = requests.get(url, auth=auth) - with vcr.use_cassette(str(tmpdir.join('auth.yaml'))) as cass: + with vcr.use_cassette(str(tmpdir.join('auth.yaml'))): two = requests.get(url, auth=auth) assert one.content == two.content assert one.status_code == two.status_code @@ -81,10 +81,10 @@ def test_post(tmpdir, scheme): '''Ensure that we can post and cache the results''' data = {'key1': 'value1', 'key2': 'value2'} url = scheme + '://httpbin.org/post' - with vcr.use_cassette(str(tmpdir.join('requests.yaml'))) as cass: + with vcr.use_cassette(str(tmpdir.join('requests.yaml'))): req1 = requests.post(url, data).content - with vcr.use_cassette(str(tmpdir.join('requests.yaml'))) as cass: + with vcr.use_cassette(str(tmpdir.join('requests.yaml'))): req2 = requests.post(url, data).content assert req1 == req2 @@ -93,7 +93,7 @@ def test_post(tmpdir, scheme): def test_redirects(tmpdir, scheme): '''Ensure that we can handle redirects''' url = scheme + '://httpbin.org/redirect-to?url=bytes/1024' - with vcr.use_cassette(str(tmpdir.join('requests.yaml'))) as cass: + with vcr.use_cassette(str(tmpdir.join('requests.yaml'))): content = requests.get(url).content with vcr.use_cassette(str(tmpdir.join('requests.yaml'))) as cass: @@ -124,11 +124,11 @@ def test_gzip(tmpdir, scheme): url = scheme + '://httpbin.org/gzip' response = requests.get(url) - with vcr.use_cassette(str(tmpdir.join('gzip.yaml'))) as cass: + with vcr.use_cassette(str(tmpdir.join('gzip.yaml'))): response = requests.get(url) assert_is_json(response.content) - with vcr.use_cassette(str(tmpdir.join('gzip.yaml'))) as cass: + with vcr.use_cassette(str(tmpdir.join('gzip.yaml'))): assert_is_json(response.content) @@ -143,9 +143,65 @@ def test_session_and_connection_close(tmpdir, scheme): with vcr.use_cassette(str(tmpdir.join('session_connection_closed.yaml'))): session = requests.session() - resp = session.get('http://httpbin.org/get', headers={'Connection': 'close'}) - resp = session.get('http://httpbin.org/get', headers={'Connection': 'close'}) + session.get('http://httpbin.org/get', headers={'Connection': 'close'}) + session.get('http://httpbin.org/get', headers={'Connection': 'close'}) + def test_https_with_cert_validation_disabled(tmpdir): with vcr.use_cassette(str(tmpdir.join('cert_validation_disabled.yaml'))): requests.get('https://httpbin.org', verify=False) + + +def test_session_can_make_requests_after_requests_unpatched(tmpdir): + with vcr.use_cassette(str(tmpdir.join('test_session_after_unpatched.yaml'))): + session = requests.session() + session.get('http://httpbin.org/get') + + with vcr.use_cassette(str(tmpdir.join('test_session_after_unpatched.yaml'))): + session = requests.session() + session.get('http://httpbin.org/get') + + session.get('http://httpbin.org/status/200') + + +def test_session_created_before_use_cassette_is_patched(tmpdir, scheme): + url = scheme + '://httpbin.org/bytes/1024' + # Record arbitrary, random data to the cassette + with vcr.use_cassette(str(tmpdir.join('session_created_outside.yaml'))): + session = requests.session() + body = session.get(url).content + + # Create a session outside of any cassette context manager + session = requests.session() + # Make a request to make sure that a connectionpool is instantiated + session.get(scheme + '://httpbin.org/get') + + with vcr.use_cassette(str(tmpdir.join('session_created_outside.yaml'))): + # These should only be the same if the patching succeeded. + assert session.get(url).content == body + + +def test_nested_cassettes_with_session_created_before_nesting(scheme, tmpdir): + ''' + This tests ensures that a session that was created while one cassette was + active is patched to the use the responses of a second cassette when it + is enabled. + ''' + url = scheme + '://httpbin.org/bytes/1024' + with vcr.use_cassette(str(tmpdir.join('first_nested.yaml'))): + session = requests.session() + first_body = session.get(url).content + with vcr.use_cassette(str(tmpdir.join('second_nested.yaml'))): + second_body = session.get(url).content + third_body = requests.get(url).content + + with vcr.use_cassette(str(tmpdir.join('second_nested.yaml'))): + session = requests.session() + assert session.get(url).content == second_body + with vcr.use_cassette(str(tmpdir.join('first_nested.yaml'))): + assert session.get(url).content == first_body + assert session.get(url).content == third_body + + # Make sure that the session can now get content normally. + session.get('http://www.reddit.com') + diff --git a/tests/unit/test_cassettes.py b/tests/unit/test_cassettes.py index b6120724..2cab15f3 100644 --- a/tests/unit/test_cassettes.py +++ b/tests/unit/test_cassettes.py @@ -1,7 +1,13 @@ +import copy + +from six.moves import http_client as httplib +import contextlib2 +import mock import pytest import yaml -import mock + from vcr.cassette import Cassette +from vcr.patch import force_reset from vcr.errors import UnhandledHTTPRequestError @@ -68,6 +74,46 @@ def test_cassette_cant_read_same_request_twice(): a.play_response('foo') +def make_get_request(): + conn = httplib.HTTPConnection("www.python.org") + conn.request("GET", "/index.html") + return conn.getresponse() + + +@mock.patch('vcr.cassette.requests_match', return_value=True) +@mock.patch('vcr.cassette.load_cassette', lambda *args, **kwargs: (('foo',), (mock.MagicMock(),))) +@mock.patch('vcr.cassette.Cassette.can_play_response_for', return_value=True) +@mock.patch('vcr.stubs.VCRHTTPResponse') +def test_function_decorated_with_use_cassette_can_be_invoked_multiple_times(*args): + decorated_function = Cassette.use('test')(make_get_request) + for i in range(2): + decorated_function() + + +def test_arg_getter_functionality(): + arg_getter = mock.Mock(return_value=('test', {})) + context_decorator = Cassette.use_arg_getter(arg_getter) + + with context_decorator as cassette: + assert cassette._path == 'test' + + arg_getter.return_value = ('other', {}) + + with context_decorator as cassette: + assert cassette._path == 'other' + + arg_getter.return_value = ('', {'filter_headers': ('header_name',)}) + + @context_decorator + def function(): + pass + + with mock.patch.object(Cassette, 'load') as cassette_load: + function() + cassette_load.assert_called_once_with(arg_getter.return_value[0], + **arg_getter.return_value[1]) + + def test_cassette_not_all_played(): a = Cassette('test') a.append('foo', 'bar') @@ -80,3 +126,58 @@ def test_cassette_all_played(): a.append('foo', 'bar') a.play_response('foo') assert a.all_played + + +def test_before_record_response(): + before_record_response = mock.Mock(return_value='mutated') + cassette = Cassette('test', before_record_response=before_record_response) + cassette.append('req', 'res') + + before_record_response.assert_called_once_with('res') + assert cassette.responses[0] == 'mutated' + + +def assert_get_response_body_is(value): + conn = httplib.HTTPConnection("www.python.org") + conn.request("GET", "/index.html") + assert conn.getresponse().read().decode('utf8') == value + + +@mock.patch('vcr.cassette.requests_match', _mock_requests_match) +@mock.patch('vcr.cassette.Cassette.can_play_response_for', return_value=True) +@mock.patch('vcr.cassette.Cassette._save', return_value=True) +def test_nesting_cassette_context_managers(*args): + first_response = {'body': {'string': b'first_response'}, 'headers': {}, + 'status': {'message': 'm', 'code': 200}} + + second_response = copy.deepcopy(first_response) + second_response['body']['string'] = b'second_response' + + with contextlib2.ExitStack() as exit_stack: + first_cassette = exit_stack.enter_context(Cassette.use('test')) + exit_stack.enter_context(mock.patch.object(first_cassette, 'play_response', + return_value=first_response)) + assert_get_response_body_is('first_response') + + # Make sure a second cassette can supercede the first + with Cassette.use('test') as second_cassette: + with mock.patch.object(second_cassette, 'play_response', return_value=second_response): + assert_get_response_body_is('second_response') + + # Now the first cassette should be back in effect + assert_get_response_body_is('first_response') + + +def test_nesting_context_managers_by_checking_references_of_http_connection(): + original = httplib.HTTPConnection + with Cassette.use('test'): + first_cassette_HTTPConnection = httplib.HTTPConnection + with Cassette.use('test'): + second_cassette_HTTPConnection = httplib.HTTPConnection + assert second_cassette_HTTPConnection is not first_cassette_HTTPConnection + with Cassette.use('test'): + assert httplib.HTTPConnection is not second_cassette_HTTPConnection + with force_reset(): + assert httplib.HTTPConnection is original + assert httplib.HTTPConnection is second_cassette_HTTPConnection + assert httplib.HTTPConnection is first_cassette_HTTPConnection diff --git a/tests/unit/test_vcr.py b/tests/unit/test_vcr.py new file mode 100644 index 00000000..c1dc6f03 --- /dev/null +++ b/tests/unit/test_vcr.py @@ -0,0 +1,28 @@ +import mock + +from vcr import VCR + + +def test_vcr_use_cassette(): + filter_headers = mock.Mock() + test_vcr = VCR(filter_headers=filter_headers) + with mock.patch('vcr.cassette.Cassette.load') as mock_cassette_load: + @test_vcr.use_cassette('test') + def function(): + pass + assert mock_cassette_load.call_count == 0 + function() + assert mock_cassette_load.call_args[1]['filter_headers'] is filter_headers + + # Make sure that calls to function now use cassettes with the + # new filter_header_settings + test_vcr.filter_headers = ('a',) + function() + assert mock_cassette_load.call_args[1]['filter_headers'] == test_vcr.filter_headers + + # Ensure that explicitly provided arguments still supercede + # those on the vcr. + new_filter_headers = mock.Mock() + + with test_vcr.use_cassette('test', filter_headers=new_filter_headers) as cassette: + assert cassette._filter_headers == new_filter_headers diff --git a/tox.ini b/tox.ini index bf7a8400..5797c48b 100644 --- a/tox.ini +++ b/tox.ini @@ -40,220 +40,150 @@ deps = pytest pytest-localserver PyYAML + ipdb [testenv:py26requests1] basepython = python2.6 deps = - mock - pytest - pytest-localserver - PyYAML + {[testenv]deps} requests==1.2.3 [testenv:py27requests1] basepython = python2.7 deps = - mock - pytest - pytest-localserver - PyYAML + {[testenv]deps} requests==1.2.3 [testenv:py33requests1] basepython = python3.3 deps = - mock - pytest - pytest-localserver - PyYAML + {[testenv]deps} requests==1.2.3 [testenv:pypyrequests1] basepython = pypy deps = - mock - pytest - pytest-localserver - PyYAML + {[testenv]deps} requests==1.2.3 [testenv:py26requests24] basepython = python2.6 deps = - mock - pytest - pytest-localserver - PyYAML + {[testenv]deps} requests==2.4.0 [testenv:py27requests24] basepython = python2.7 deps = - mock - pytest - pytest-localserver - PyYAML + {[testenv]deps} requests==2.4.0 [testenv:py33requests24] basepython = python3.4 deps = - mock - pytest - pytest-localserver - PyYAML + {[testenv]deps} requests==2.4.0 [testenv:py34requests24] basepython = python3.4 deps = - mock - pytest - pytest-localserver - PyYAML + {[testenv]deps} requests==2.4.0 [testenv:pypyrequests24] basepython = pypy deps = - mock - pytest - pytest-localserver - PyYAML + {[testenv]deps} requests==2.4.0 [testenv:py26requests23] basepython = python2.6 deps = - mock - pytest - pytest-localserver - PyYAML + {[testenv]deps} requests==2.3.0 [testenv:py27requests23] basepython = python2.7 deps = - mock - pytest - pytest-localserver - PyYAML + {[testenv]deps} requests==2.3.0 [testenv:py33requests23] basepython = python3.4 deps = - mock - pytest - pytest-localserver - PyYAML + {[testenv]deps} requests==2.3.0 [testenv:py34requests23] basepython = python3.4 deps = - mock - pytest - pytest-localserver - PyYAML + {[testenv]deps} requests==2.3.0 [testenv:pypyrequests23] basepython = pypy deps = - mock - pytest - pytest-localserver - PyYAML + {[testenv]deps} requests==2.3.0 [testenv:py26requests22] basepython = python2.6 deps = - mock - pytest - pytest-localserver - PyYAML + {[testenv]deps} requests==2.2.1 [testenv:py27requests22] basepython = python2.7 deps = - mock - pytest - pytest-localserver - PyYAML + {[testenv]deps} requests==2.2.1 [testenv:py33requests22] basepython = python3.4 deps = - mock - pytest - pytest-localserver - PyYAML + {[testenv]deps} requests==2.2.1 [testenv:py34requests22] basepython = python3.4 deps = - mock - pytest - pytest-localserver - PyYAML + {[testenv]deps} requests==2.2.1 + [testenv:pypyrequests22] basepython = pypy deps = - mock - pytest - pytest-localserver - PyYAML + {[testenv]deps} requests==2.2.1 [testenv:py26httplib2] basepython = python2.6 deps = - mock - pytest - pytest-localserver - PyYAML + {[testenv]deps} httplib2 [testenv:py27httplib2] basepython = python2.7 deps = - mock - pytest - pytest-localserver - PyYAML + {[testenv]deps} httplib2 [testenv:py33httplib2] basepython = python3.4 deps = - mock - pytest - pytest-localserver - PyYAML + {[testenv]deps} httplib2 [testenv:py34httplib2] basepython = python3.4 deps = - mock - pytest - pytest-localserver - PyYAML + {[testenv]deps} httplib2 [testenv:pypyhttplib2] basepython = pypy deps = - mock - pytest - pytest-localserver - PyYAML + {[testenv]deps} httplib2 diff --git a/vcr/cassette.py b/vcr/cassette.py index 36deed32..d5c33b8c 100644 --- a/vcr/cassette.py +++ b/vcr/cassette.py @@ -1,14 +1,14 @@ '''The container for recorded requests and responses''' +import logging +import contextlib2 try: from collections import Counter except ImportError: from .compat.counter import Counter -from contextdecorator import ContextDecorator - # Internal imports -from .patch import install, reset +from .patch import CassettePatcherBuilder from .persist import load_cassette, save_cassette from .filters import filter_request from .serializers import yamlserializer @@ -16,7 +16,50 @@ from .errors import UnhandledHTTPRequestError -class Cassette(ContextDecorator): +log = logging.getLogger(__name__) + + +class CassetteContextDecorator(contextlib2.ContextDecorator): + """Context manager/decorator that handles installing the cassette and + removing cassettes. + + This class defers the creation of a new cassette instance until the point at + which it is installed by context manager or decorator. The fact that a new + cassette is used with each application prevents the state of any cassette + from interfering with another. + """ + + @classmethod + def from_args(cls, cassette_class, path, **kwargs): + return cls(cassette_class, lambda: (path, kwargs)) + + def __init__(self, cls, args_getter): + self.cls = cls + self._args_getter = args_getter + self.__finish = None + + def _patch_generator(self, cassette): + with contextlib2.ExitStack() as exit_stack: + for patcher in CassettePatcherBuilder(cassette).build(): + exit_stack.enter_context(patcher) + log.debug('Entered context for cassette at {0}.'.format(cassette._path)) + yield cassette + log.debug('Exiting context for cassette at {0}.'.format(cassette._path)) + # TODO(@IvanMalison): Hmmm. it kind of feels like this should be somewhere else. + cassette._save() + + def __enter__(self): + assert self.__finish is None + path, kwargs = self._args_getter() + self.__finish = self._patch_generator(self.cls.load(path, **kwargs)) + return next(self.__finish) + + def __exit__(self, *args): + next(self.__finish, None) + self.__finish = None + + +class Cassette(object): '''A container for recorded requests and responses''' @classmethod @@ -26,27 +69,29 @@ def load(cls, path, **kwargs): new_cassette._load() return new_cassette - def __init__(self, - path, - serializer=yamlserializer, - record_mode='once', - match_on=(uri, method), - filter_headers=(), - filter_query_parameters=(), - before_record=None, - ignore_hosts=(), - ignore_localhost=() - ): + @classmethod + def use_arg_getter(cls, arg_getter): + return CassetteContextDecorator(cls, arg_getter) + + @classmethod + def use(cls, *args, **kwargs): + return CassetteContextDecorator.from_args(cls, *args, **kwargs) + + def __init__(self, path, serializer=yamlserializer, record_mode='once', + match_on=(uri, method), filter_headers=(), + filter_query_parameters=(), before_record=None, before_record_response=None, + ignore_hosts=(), ignore_localhost=()): self._path = path self._serializer = serializer self._match_on = match_on self._filter_headers = filter_headers self._filter_query_parameters = filter_query_parameters self._before_record = before_record + self._before_record_response = before_record_response self._ignore_hosts = ignore_hosts if ignore_localhost: self._ignore_hosts = list(set( - self._ignore_hosts + ['localhost', '0.0.0.0', '127.0.0.1'] + list(self._ignore_hosts) + ['localhost', '0.0.0.0', '127.0.0.1'] )) # self.data is the list of (req, resp) tuples @@ -94,6 +139,8 @@ def append(self, request, response): request = self._filter_request(request) if not request: return + if self._before_record_response: + response = self._before_record_response(response) self.data.append((request, response)) self.dirty = True @@ -185,12 +232,3 @@ def __contains__(self, request): for response in self._responses(request): return True return False - - def __enter__(self): - '''Patch the fetching libraries we know about''' - install(self) - return self - - def __exit__(self, typ, value, traceback): - self._save() - reset() diff --git a/vcr/config.py b/vcr/config.py index 0a1615a8..8308b48b 100644 --- a/vcr/config.py +++ b/vcr/config.py @@ -1,3 +1,4 @@ +import functools import os from .cassette import Cassette from .serializers import yamlserializer, jsonserializer @@ -9,18 +10,19 @@ def __init__(self, serializer='yaml', cassette_library_dir=None, record_mode="once", - filter_headers=[], - filter_query_parameters=[], + filter_headers=(), + filter_query_parameters=(), before_record=None, - match_on=[ + before_record_response=None, + match_on=( 'method', 'scheme', 'host', 'port', 'path', 'query', - ], - ignore_hosts=[], + ), + ignore_hosts=(), ignore_localhost=False, ): self.serializer = serializer @@ -46,6 +48,7 @@ def __init__(self, self.filter_headers = filter_headers self.filter_query_parameters = filter_query_parameters self.before_record = before_record + self.before_record_response = before_record_response self.ignore_hosts = ignore_hosts self.ignore_localhost = ignore_localhost @@ -72,13 +75,16 @@ def _get_matchers(self, matcher_names): return matchers def use_cassette(self, path, **kwargs): + args_getter = functools.partial(self.get_path_and_merged_config, path, **kwargs) + return Cassette.use_arg_getter(args_getter) + + def get_path_and_merged_config(self, path, **kwargs): serializer_name = kwargs.get('serializer', self.serializer) matcher_names = kwargs.get('match_on', self.match_on) cassette_library_dir = kwargs.get( 'cassette_library_dir', self.cassette_library_dir ) - if cassette_library_dir: path = os.path.join(cassette_library_dir, path) @@ -95,6 +101,9 @@ def use_cassette(self, path, **kwargs): "before_record": kwargs.get( "before_record", self.before_record ), + "before_record_response": kwargs.get( + "before_record_response", self.before_record_response + ), "ignore_hosts": kwargs.get( 'ignore_hosts', self.ignore_hosts ), @@ -102,8 +111,7 @@ def use_cassette(self, path, **kwargs): 'ignore_localhost', self.ignore_localhost ), } - - return Cassette.load(path, **merged_config) + return path, merged_config def register_serializer(self, name, serializer): self.serializers[name] = serializer diff --git a/vcr/patch.py b/vcr/patch.py index 2d40362d..6ab8627c 100644 --- a/vcr/patch.py +++ b/vcr/patch.py @@ -1,4 +1,9 @@ '''Utilities for patching in cassettes''' +import functools +import itertools + +import contextlib2 +import mock from .stubs import VCRHTTPConnection, VCRHTTPSConnection from six.moves import http_client as httplib @@ -8,139 +13,290 @@ _HTTPConnection = httplib.HTTPConnection _HTTPSConnection = httplib.HTTPSConnection + +# Try to save the original types for requests try: - # Try to save the original types for requests import requests.packages.urllib3.connectionpool as cpool +except ImportError: # pragma: no cover + pass +else: _VerifiedHTTPSConnection = cpool.VerifiedHTTPSConnection _cpoolHTTPConnection = cpool.HTTPConnection _cpoolHTTPSConnection = cpool.HTTPSConnection -except ImportError: # pragma: no cover - pass + +# Try to save the original types for urllib3 try: - # Try to save the original types for urllib3 import urllib3 - _VerifiedHTTPSConnection = urllib3.connectionpool.VerifiedHTTPSConnection except ImportError: # pragma: no cover pass +else: + _VerifiedHTTPSConnection = urllib3.connectionpool.VerifiedHTTPSConnection + +# Try to save the original types for httplib2 try: - # Try to save the original types for httplib2 import httplib2 +except ImportError: # pragma: no cover + pass +else: _HTTPConnectionWithTimeout = httplib2.HTTPConnectionWithTimeout _HTTPSConnectionWithTimeout = httplib2.HTTPSConnectionWithTimeout _SCHEME_TO_CONNECTION = httplib2.SCHEME_TO_CONNECTION -except ImportError: # pragma: no cover - pass + +# Try to save the original types for boto try: - # Try to save the original types for boto import boto.https_connection - _CertValidatingHTTPSConnection = \ - boto.https_connection.CertValidatingHTTPSConnection except ImportError: # pragma: no cover pass +else: + _CertValidatingHTTPSConnection = boto.https_connection.CertValidatingHTTPSConnection -def install(cassette): - """ - Patch all the HTTPConnections references we can find! - This replaces the actual HTTPConnection with a VCRHTTPConnection - object which knows how to save to / read from cassettes - """ - httplib.HTTPConnection = VCRHTTPConnection - httplib.HTTPSConnection = VCRHTTPSConnection - httplib.HTTPConnection.cassette = cassette - httplib.HTTPSConnection.cassette = cassette +class CassettePatcherBuilder(object): - # patch requests v1.x - try: - import requests.packages.urllib3.connectionpool as cpool + def _build_patchers_from_mock_triples_decorator(function): + @functools.wraps(function) + def wrapped(self, *args, **kwargs): + return self._build_patchers_from_mock_triples(function(self, *args, **kwargs)) + return wrapped + + def __init__(self, cassette): + self._cassette = cassette + self._class_to_cassette_subclass = {} + + def build(self): + return itertools.chain(self._httplib(), self._requests(), + self._urllib3(), self._httplib2(), + self._boto()) + + def _build_patchers_from_mock_triples(self, mock_triples): + for args in mock_triples: + patcher = self._build_patcher(*args) + if patcher: + yield patcher + + def _build_patcher(self, obj, patched_attribute, replacement_class): + if not hasattr(obj, patched_attribute): + return + + return mock.patch.object(obj, patched_attribute, + self._recursively_apply_get_cassette_subclass( + replacement_class)) + + def _recursively_apply_get_cassette_subclass(self, replacement_dict_or_obj): + if isinstance(replacement_dict_or_obj, dict): + for key, replacement_obj in replacement_dict_or_obj.items(): + replacement_obj = self._recursively_apply_get_cassette_subclass( + replacement_obj) + replacement_dict_or_obj[key] = replacement_obj + return replacement_dict_or_obj + if hasattr(replacement_dict_or_obj, 'cassette'): + replacement_dict_or_obj = self._get_cassette_subclass( + replacement_dict_or_obj) + return replacement_dict_or_obj + + def _get_cassette_subclass(self, klass): + if klass.cassette is not None: + return klass + if klass not in self._class_to_cassette_subclass: + subclass = self._build_cassette_subclass(klass) + self._class_to_cassette_subclass[klass] = subclass + return self._class_to_cassette_subclass[klass] + + def _build_cassette_subclass(self, base_class): + bases = (base_class,) + if not issubclass(base_class, object): # Check for old style class + bases += (object,) + return type('{0}{1}'.format(base_class.__name__, self._cassette._path), + bases, dict(cassette=self._cassette)) + + @_build_patchers_from_mock_triples_decorator + def _httplib(self): + yield httplib, 'HTTPConnection', VCRHTTPConnection + yield httplib, 'HTTPSConnection', VCRHTTPSConnection + + def _requests(self): + try: + import requests.packages.urllib3.connectionpool as cpool + except ImportError: # pragma: no cover + return () from .stubs.requests_stubs import VCRRequestsHTTPConnection, VCRRequestsHTTPSConnection - cpool.VerifiedHTTPSConnection = VCRRequestsHTTPSConnection - cpool.HTTPConnection = VCRRequestsHTTPConnection - cpool.VerifiedHTTPSConnection.cassette = cassette - cpool.HTTPConnection = VCRHTTPConnection - cpool.HTTPConnection.cassette = cassette - # patch requests v2.x - cpool.HTTPConnectionPool.ConnectionCls = VCRRequestsHTTPConnection - cpool.HTTPConnectionPool.cassette = cassette - cpool.HTTPSConnectionPool.ConnectionCls = VCRRequestsHTTPSConnection - cpool.HTTPSConnectionPool.cassette = cassette - except ImportError: # pragma: no cover - pass + http_connection_remover = ConnectionRemover( + self._get_cassette_subclass(VCRRequestsHTTPConnection) + ) + https_connection_remover = ConnectionRemover( + self._get_cassette_subclass(VCRRequestsHTTPSConnection) + ) + mock_triples = ( + (cpool, 'VerifiedHTTPSConnection', VCRRequestsHTTPSConnection), + (cpool, 'VerifiedHTTPSConnection', VCRRequestsHTTPSConnection), + (cpool, 'HTTPConnection', VCRRequestsHTTPConnection), + (cpool, 'HTTPSConnection', VCRRequestsHTTPSConnection), + (cpool.HTTPConnectionPool, 'ConnectionCls', VCRRequestsHTTPConnection), + (cpool.HTTPSConnectionPool, 'ConnectionCls', VCRRequestsHTTPSConnection), + ) + # These handle making sure that sessions only use the + # connections of the appropriate type. + mock_triples += ((cpool.HTTPConnectionPool, '_get_conn', + self._patched_get_conn(cpool.HTTPConnectionPool, + lambda : cpool.HTTPConnection)), + (cpool.HTTPSConnectionPool, '_get_conn', + self._patched_get_conn(cpool.HTTPSConnectionPool, + lambda : cpool.HTTPSConnection)), + (cpool.HTTPConnectionPool, '_new_conn', + self._patched_new_conn(cpool.HTTPConnectionPool, + http_connection_remover)), + (cpool.HTTPSConnectionPool, '_new_conn', + self._patched_new_conn(cpool.HTTPSConnectionPool, + https_connection_remover))) - # patch urllib3 - try: - import urllib3.connectionpool as cpool - from .stubs.urllib3_stubs import VCRVerifiedHTTPSConnection - cpool.VerifiedHTTPSConnection = VCRVerifiedHTTPSConnection - cpool.VerifiedHTTPSConnection.cassette = cassette - cpool.HTTPConnection = VCRHTTPConnection - cpool.HTTPConnection.cassette = cassette - except ImportError: # pragma: no cover - pass + return itertools.chain(self._build_patchers_from_mock_triples(mock_triples), + (http_connection_remover, https_connection_remover)) - # patch httplib2 - try: - import httplib2 as cpool - from .stubs.httplib2_stubs import VCRHTTPConnectionWithTimeout - from .stubs.httplib2_stubs import VCRHTTPSConnectionWithTimeout - cpool.HTTPConnectionWithTimeout = VCRHTTPConnectionWithTimeout - cpool.HTTPSConnectionWithTimeout = VCRHTTPSConnectionWithTimeout - cpool.SCHEME_TO_CONNECTION = { - 'http': VCRHTTPConnectionWithTimeout, - 'https': VCRHTTPSConnectionWithTimeout - } - except ImportError: # pragma: no cover - pass + def _patched_get_conn(self, connection_pool_class, connection_class_getter): + get_conn = connection_pool_class._get_conn + @functools.wraps(get_conn) + def patched_get_conn(pool, timeout=None): + connection = get_conn(pool, timeout) + connection_class = pool.ConnectionCls if hasattr(pool, 'ConnectionCls') \ + else connection_class_getter() + while not isinstance(connection, connection_class): + connection = get_conn(pool, timeout) + return connection + return patched_get_conn - # patch boto - try: - import boto.https_connection as cpool - from .stubs.boto_stubs import VCRCertValidatingHTTPSConnection - cpool.CertValidatingHTTPSConnection = VCRCertValidatingHTTPSConnection - cpool.CertValidatingHTTPSConnection.cassette = cassette - except ImportError: # pragma: no cover - pass + def _patched_new_conn(self, connection_pool_class, connection_remover): + new_conn = connection_pool_class._new_conn + @functools.wraps(new_conn) + def patched_new_conn(pool): + new_connection = new_conn(pool) + connection_remover.add_connection_to_pool_entry(pool, new_connection) + return new_connection + return patched_new_conn + + @_build_patchers_from_mock_triples_decorator + def _urllib3(self): + try: + import urllib3.connectionpool as cpool + except ImportError: # pragma: no cover + pass + else: + from .stubs.urllib3_stubs import VCRVerifiedHTTPSConnection + + yield cpool, 'VerifiedHTTPSConnection', VCRVerifiedHTTPSConnection + yield cpool, 'HTTPConnection', VCRHTTPConnection + + @_build_patchers_from_mock_triples_decorator + def _httplib2(self): + try: + import httplib2 as cpool + except ImportError: # pragma: no cover + pass + else: + from .stubs.httplib2_stubs import VCRHTTPConnectionWithTimeout + from .stubs.httplib2_stubs import VCRHTTPSConnectionWithTimeout + + yield cpool, 'HTTPConnectionWithTimeout', VCRHTTPConnectionWithTimeout + yield cpool, 'HTTPSConnectionWithTimeout', VCRHTTPSConnectionWithTimeout + yield cpool, 'SCHEME_TO_CONNECTION', {'http': VCRHTTPConnectionWithTimeout, + 'https': VCRHTTPSConnectionWithTimeout} + @_build_patchers_from_mock_triples_decorator + def _boto(self): + try: + import boto.https_connection as cpool + except ImportError: # pragma: no cover + pass + else: + from .stubs.boto_stubs import VCRCertValidatingHTTPSConnection + yield cpool, 'CertValidatingHTTPSConnection', VCRCertValidatingHTTPSConnection -def reset(): - '''Undo all the patching''' - httplib.HTTPConnection = _HTTPConnection - httplib.HTTPSConnection = _HTTPSConnection + +class ConnectionRemover(object): + + def __init__(self, connection_class): + self._connection_class = connection_class + self._connection_pool_to_connections = {} + + def add_connection_to_pool_entry(self, pool, connection): + if isinstance(connection, self._connection_class): + self._connection_pool_to_connections.setdefault(pool, set()).add(connection) + + def remove_connection_to_pool_entry(self, pool, connection): + if isinstance(connection, self._connection_class): + self._connection_pool_to_connections[self._connection_class].remove(connection) + + def __enter__(self): + return self + + def __exit__(self, *args): + for pool, connections in self._connection_pool_to_connections.items(): + readd_connections = [] + while not pool.pool.empty() and connections: + connection = pool.pool.get() + if isinstance(connection, self._connection_class): + connections.remove(connection) + else: + readd_connections.append(connection) + for connection in readd_connections: + pool._put_conn(connection) + + +def reset_patchers(): + yield mock.patch.object(httplib, 'HTTPConnection', _HTTPConnection) + yield mock.patch.object(httplib, 'HTTPSConnection', _HTTPSConnection) try: import requests.packages.urllib3.connectionpool as cpool - # unpatch requests v1.x - cpool.VerifiedHTTPSConnection = _VerifiedHTTPSConnection - cpool.HTTPConnection = _cpoolHTTPConnection - # unpatch requests v2.x - cpool.HTTPConnectionPool.ConnectionCls = _cpoolHTTPConnection - cpool.HTTPSConnection = _cpoolHTTPSConnection - cpool.HTTPSConnectionPool.ConnectionCls = _cpoolHTTPSConnection except ImportError: # pragma: no cover pass + else: + # unpatch requests v1.x + yield mock.patch.object(cpool, 'VerifiedHTTPSConnection', _VerifiedHTTPSConnection) + yield mock.patch.object(cpool, 'HTTPConnection', _cpoolHTTPConnection) + # unpatch requests v2.x + if hasattr(cpool.HTTPConnectionPool, 'ConnectionCls'): + yield mock.patch.object(cpool.HTTPConnectionPool, 'ConnectionCls', + _cpoolHTTPConnection) + yield mock.patch.object(cpool.HTTPSConnectionPool, 'ConnectionCls', + _cpoolHTTPSConnection) + + if hasattr(cpool, 'HTTPSConnection'): + yield mock.patch.object(cpool, 'HTTPSConnection', _cpoolHTTPSConnection) try: import urllib3.connectionpool as cpool - cpool.VerifiedHTTPSConnection = _VerifiedHTTPSConnection - cpool.HTTPConnection = _HTTPConnection - cpool.HTTPSConnection = _HTTPSConnection - cpool.HTTPConnectionPool.ConnectionCls = _HTTPConnection - cpool.HTTPSConnectionPool.ConnectionCls = _HTTPSConnection except ImportError: # pragma: no cover pass + else: + yield mock.patch.object(cpool, 'VerifiedHTTPSConnection', _VerifiedHTTPSConnection) + yield mock.patch.object(cpool, 'HTTPConnection', _HTTPConnection) + yield mock.patch.object(cpool, 'HTTPSConnection', _HTTPSConnection) + yield mock.patch.object(cpool.HTTPConnectionPool, 'ConnectionCls', _HTTPConnection) + yield mock.patch.object(cpool.HTTPSConnectionPool, 'ConnectionCls', _HTTPSConnection) try: import httplib2 as cpool - cpool.HTTPConnectionWithTimeout = _HTTPConnectionWithTimeout - cpool.HTTPSConnectionWithTimeout = _HTTPSConnectionWithTimeout - cpool.SCHEME_TO_CONNECTION = _SCHEME_TO_CONNECTION except ImportError: # pragma: no cover pass + else: + yield mock.patch.object(cpool, 'HTTPConnectionWithTimeout', _HTTPConnectionWithTimeout) + yield mock.patch.object(cpool, 'HTTPSConnectionWithTimeout', _HTTPSConnectionWithTimeout) + yield mock.patch.object(cpool, 'SCHEME_TO_CONNECTION', _SCHEME_TO_CONNECTION) try: import boto.https_connection as cpool - cpool.CertValidatingHTTPSConnection = _CertValidatingHTTPSConnection except ImportError: # pragma: no cover pass + else: + yield mock.patch.object(cpool, 'CertValidatingHTTPSConnection', + _CertValidatingHTTPSConnection) + + +@contextlib2.contextmanager +def force_reset(): + with contextlib2.ExitStack() as exit_stack: + for patcher in reset_patchers(): + exit_stack.enter_context(patcher) + yield diff --git a/vcr/stubs/__init__.py b/vcr/stubs/__init__.py index 45af7ea8..df67b1d9 100644 --- a/vcr/stubs/__init__.py +++ b/vcr/stubs/__init__.py @@ -119,7 +119,7 @@ def getheader(self, header, default=None): return default -class VCRConnection: +class VCRConnection(object): # A reference to the cassette that's currently being patched in cassette = None @@ -205,7 +205,7 @@ def endheaders(self, *args, **kwargs): pass def getresponse(self, _=False): - '''Retrieve a the response''' + '''Retrieve the response''' # Check to see if the cassette has a response for this request. If so, # then return it if self.cassette.can_play_response_for(self._vcr_request): @@ -295,10 +295,9 @@ def __init__(self, *args, **kwargs): # need to temporarily reset here because the real connection # inherits from the thing that we are mocking out. Take out # the reset if you want to see what I mean :) - from vcr.patch import install, reset - reset() - self.real_connection = self._baseclass(*args, **kwargs) - install(self.cassette) + from vcr.patch import force_reset + with force_reset(): + self.real_connection = self._baseclass(*args, **kwargs) class VCRHTTPConnection(VCRConnection):