diff --git a/papermill/iorw.py b/papermill/iorw.py index 14a0122c..7feeb0e7 100644 --- a/papermill/iorw.py +++ b/papermill/iorw.py @@ -165,9 +165,24 @@ def get_handler(self, path, extensions=None): class HttpHandler: + @classmethod + def _get_auth_kwargs(cls): + """Gets the Authorization header from PAPERMILL_HTTP_AUTH_HEADER. + A valid example value Basic dW5hbWU6cGFzc3dvcmQK""" + auth_header = os.environ.get('PAPERMILL_HTTP_AUTH_HEADER', None) + if auth_header: + return {'headers': {'Authorization': auth_header}} + return {} + + @classmethod + def _get_read_kwargs(cls): + kwargs = cls._get_auth_kwargs() or {'headers': {}} + kwargs['headers']['Accept'] = os.environ.get('PAPERMILL_HTTP_ACCEPT_HEADER', 'application/json') + return kwargs + @classmethod def read(cls, path): - return requests.get(path, headers={'Accept': 'application/json'}).text + return requests.get(path, **cls._get_read_kwargs()).text @classmethod def listdir(cls, path): @@ -175,7 +190,7 @@ def listdir(cls, path): @classmethod def write(cls, buf, path): - result = requests.put(path, json=json.loads(buf)) + result = requests.put(path, json=json.loads(buf), **cls._get_auth_kwargs()) result.raise_for_status() @classmethod diff --git a/papermill/tests/test_iorw.py b/papermill/tests/test_iorw.py index ab09f01a..7d80caff 100644 --- a/papermill/tests/test_iorw.py +++ b/papermill/tests/test_iorw.py @@ -320,6 +320,38 @@ def test_read(self): self.assertEqual(HttpHandler.read(path), text) mock_get.assert_called_once_with(path, headers={'Accept': 'application/json'}) + def test_read_with_auth(self): + """ + Tests that the `read` function performs a request to the giving path + with authentication from the environment variables and returns the response. + """ + path = 'http://example.com' + text = 'request test response' + auth = 'Basic dW5hbWU6cGFzc3dvcmQK' + + with patch.dict(os.environ, clear=True) as env, patch('papermill.iorw.requests.get') as mock_get: + env['PAPERMILL_HTTP_AUTH_HEADER'] = auth + mock_get.return_value = Mock(text=text) + + self.assertEqual(HttpHandler.read(path), text) + mock_get.assert_called_once_with(path, headers={'Accept': 'application/json', 'Authorization': auth}) + + def test_read_with_accept_header(self): + """ + Tests that the `read` function performs a request to the giving path + with an accept type from env variables and returns the response. + """ + path = 'http://example.com' + text = 'request test response' + accept_type = 'test accept type' + + with patch.dict(os.environ, clear=True) as env, patch('papermill.iorw.requests.get') as mock_get: + env['PAPERMILL_HTTP_ACCEPT_HEADER'] = accept_type + mock_get.return_value = Mock(text=text) + + self.assertEqual(HttpHandler.read(path), text) + mock_get.assert_called_once_with(path, headers={'Accept': accept_type}) + def test_write(self): """ Tests that the `write` function performs a put request to the given @@ -332,6 +364,21 @@ def test_write(self): HttpHandler.write(buf, path) mock_put.assert_called_once_with(path, json=json.loads(buf)) + def test_write_with_auth(self): + """ + Tests that the `write` function performs a put request to the given + path with authentication from env variables. + """ + path = 'http://example.com' + buf = '{"papermill": true}' + auth = 'token' + + with patch.dict(os.environ, clear=True) as env, patch('papermill.iorw.requests.put') as mock_put: + env['PAPERMILL_HTTP_AUTH_HEADER'] = auth + + HttpHandler.write(buf, path) + mock_put.assert_called_once_with(path, json=json.loads(buf), headers={'Authorization': auth}) + def test_write_failure(self): """ Tests that the `write` function raises on failure to put the buffer.