From d3e2e356ffac69a0a4b907ac4d3260273ce08d8e Mon Sep 17 00:00:00 2001 From: Fernando Aureliano <145799342+fsilvamaia@users.noreply.github.com> Date: Tue, 24 Oct 2023 11:35:27 -0300 Subject: [PATCH] create http client class (#135) --- .gitignore | 1 + .../functional/test_http_client_functional.py | 90 +++++++++ tests/unit/test_aws.py | 20 +- tests/unit/test_http_client.py | 157 ++++++++++++++++ tests/unit/test_okta.py | 123 ++++++------ tests/unit/test_user.py | 23 --- tokendito/__init__.py | 2 +- tokendito/aws.py | 11 +- tokendito/http_client.py | 79 ++++++++ tokendito/okta.py | 175 +++++++++++------- tokendito/tool.py | 5 +- tokendito/user.py | 67 +++---- 12 files changed, 549 insertions(+), 204 deletions(-) create mode 100644 tests/functional/test_http_client_functional.py create mode 100644 tests/unit/test_http_client.py create mode 100644 tokendito/http_client.py diff --git a/.gitignore b/.gitignore index 2954c5dd..2ebe9cd3 100644 --- a/.gitignore +++ b/.gitignore @@ -91,6 +91,7 @@ venv/ ENV/ env.bak/ venv.bak/ +.vscode # Spyder project settings .spyderproject diff --git a/tests/functional/test_http_client_functional.py b/tests/functional/test_http_client_functional.py new file mode 100644 index 00000000..d199b861 --- /dev/null +++ b/tests/functional/test_http_client_functional.py @@ -0,0 +1,90 @@ +"""This module contains unit tests for the HTTPClient class.""" +# vim: set filetype=python ts=4 sw=4 +# -*- coding: utf-8 -*- +import pytest +from requests import RequestException +from tokendito import __title__ +from tokendito import __version__ +from tokendito.http_client import HTTPClient + + +@pytest.fixture +def client(): + """Fixture to create and return an HTTPClient instance.""" + client = HTTPClient() + client.session.headers.update({"User-Agent": f"{__title__}/{__version__}"}) + return client + + +def test_get_request(client): + """Test the GET request functionality of HTTPClient.""" + # Make a GET request to the /get endpoint of httpbin which reflects the sent request data + response = client.get("https://httpbin.org/get") + json_data = response.json() + + # Assert that the request was successful and the returned User-Agent matches the one we set + assert response.status_code == 200 + assert json_data["headers"]["User-Agent"] == f"{__title__}/{__version__}" + + +def test_post_request(client): + """Test the POST request functionality of HTTPClient.""" + # Make a POST request to the /post endpoint of httpbin with sample data + response = client.post("https://httpbin.org/post", json={"key": "value"}) + json_data = response.json() + + # Assert that the request was successful and the returned json data matches the data we sent + assert response.status_code == 200 + assert json_data["json"] == {"key": "value"} + + +def test_set_cookies(client): + """Test the ability to set cookies using HTTPClient.""" + # Set a test cookie for the client + client.set_cookies({"test_cookie": "cookie_value"}) + + # Make a request to the /cookies endpoint of httpbin which returns set cookies + response = client.get("https://httpbin.org/cookies") + json_data = response.json() + + # Assert that the cookie we set is correctly returned by the server + assert json_data["cookies"] == {"test_cookie": "cookie_value"} + + +def test_custom_header(client): + """Test the ability to send custom headers using HTTPClient.""" + # Make a GET request with a custom header + response = client.get("https://httpbin.org/get", headers={"X-Test-Header": "TestValue"}) + json_data = response.json() + + # Assert that the custom header was correctly sent + assert json_data["headers"]["X-Test-Header"] == "TestValue" + + +def test_bad_get_request(client, mocker): + """Test GET request failure scenario.""" + mocker.patch("requests.Session.get", side_effect=RequestException("An error occurred")) + with pytest.raises(SystemExit): + client.get("https://httpbin.org/get") + + +def test_bad_post_request(client, mocker): + """Test POST request failure scenario.""" + mocker.patch("requests.Session.post", side_effect=RequestException("An error occurred")) + with pytest.raises(SystemExit): + client.post("https://httpbin.org/post", json={"key": "value"}) + + +def test_reset_session(client): + """Test the reset method to ensure session is reset.""" + # Set a test cookie for the client + client.set_cookies({"test_cookie": "cookie_value"}) + # Reset the session + client.reset() + + # Make a request to the /cookies endpoint of httpbin which returns set cookies + response = client.get("https://httpbin.org/cookies") + json_data = response.json() + + # Assert that the cookies have been cleared + assert json_data["cookies"] == {} diff --git a/tests/unit/test_aws.py b/tests/unit/test_aws.py index 34064fc2..6f214c0b 100644 --- a/tests/unit/test_aws.py +++ b/tests/unit/test_aws.py @@ -1,6 +1,8 @@ # vim: set filetype=python ts=4 sw=4 # -*- coding: utf-8 -*- """Unit tests, and local fixtures for AWS module.""" +from unittest.mock import Mock + import pytest @@ -91,10 +93,18 @@ def test_select_assumeable_role_no_tiles(): @pytest.mark.parametrize("status_code", [(400), (401), (404), (500), (503)]) def test_authenticate_to_roles(status_code, monkeypatch): """Test if function return correct response.""" - import requests from tokendito.aws import authenticate_to_roles + import tokendito.http_client as http_client - mock_get = {"status_code": status_code, "text": "response"} - monkeypatch.setattr(requests, "get", mock_get) - with pytest.raises(SystemExit) as error: - assert authenticate_to_roles([("http://test.url.com", "")], "secret_session_token") == error + # Create a mock response object + mock_response = Mock() + mock_response.status_code = status_code + mock_response.text = "response" + + # Use monkeypatch to replace the HTTP_client.get method with the mock + monkeypatch.setattr(http_client.HTTP_client, "get", lambda *args, **kwargs: mock_response) + + cookies = {"some_cookie": "some_value"} + + with pytest.raises(SystemExit): + authenticate_to_roles([("http://test.url.com", "")], cookies) diff --git a/tests/unit/test_http_client.py b/tests/unit/test_http_client.py new file mode 100644 index 00000000..aa13927d --- /dev/null +++ b/tests/unit/test_http_client.py @@ -0,0 +1,157 @@ +"""Unit tests for the HTTPClient class.""" +# vim: set filetype=python ts=4 sw=4 +# -*- coding: utf-8 -*- +import pytest +import requests +from tokendito import __title__ +from tokendito import __version__ +from tokendito.http_client import HTTPClient + +# Unit test class for the HTTPClient. + + +@pytest.fixture +def client(): + """Fixture for setting up an HTTPClient instance.""" + # Initializing HTTPClient instance without the 'user_agent' parameter + return HTTPClient() + + +def test_init(client): + """Test initialization of HTTPClient instance.""" + # Check if the session property of the client is an instance of requests.Session + assert isinstance(client.session, requests.Session) + + # Check if the User-Agent header was set correctly during initialization + expected_user_agent = f"{__title__}/{__version__}" + assert client.session.headers["User-Agent"] == expected_user_agent + + +def test_set_cookies(client): + """Test setting cookies in the session.""" + cookies = {"test_cookie": "cookie_value"} + client.set_cookies(cookies) + # Check if the provided cookie is set correctly in the session + assert client.session.cookies.get_dict() == cookies + + +def test_get(client, mocker): + """Test GET request method.""" + mock_get = mocker.patch("requests.Session.get") + mock_resp = mocker.Mock() + mock_resp.status_code = 200 + mock_resp.text = "OK" + mock_get.return_value = mock_resp + + response = client.get("http://test.com") + # Check if the response status code and text match the expected values + assert response.status_code == 200 + assert response.text == "OK" + + +def test_post(client, mocker): + """Test POST request method.""" + mock_post = mocker.patch("requests.Session.post") + mock_resp = mocker.Mock() + mock_resp.status_code = 201 + mock_resp.text = "Created" + mock_post.return_value = mock_resp + + response = client.post("http://test.com", json={"key": "value"}) + # Check if the response status code and text match the expected values + assert response.status_code == 201 + assert response.text == "Created" + + +def test_get_failure(client, mocker): + """Test GET request failure scenario.""" + mock_get = mocker.patch("requests.Session.get") + mock_get.side_effect = requests.RequestException("Failed to connect") + + with pytest.raises(SystemExit): + client.get("http://test.com") + + +def test_post_failure(client, mocker): + """Test POST request failure scenario.""" + mock_post = mocker.patch("requests.Session.post") + mock_post.side_effect = requests.RequestException("Failed to connect") + + with pytest.raises(SystemExit): + client.post("http://test.com", json={"key": "value"}) + + +def test_post_with_return_json(client, mocker): + """Test POST request with return_json=True.""" + mock_post = mocker.patch("requests.Session.post") + mock_resp = mocker.Mock() + mock_resp.status_code = 201 + mock_resp.json.return_value = {"status": "Created"} + mock_post.return_value = mock_resp + + response = client.post("http://test.com", json={"key": "value"}, return_json=True) + assert response == {"status": "Created"} + + +def test_reset(client): + """Test the reset method.""" + # Updating the session headers to check if they are reset later + client.session.headers.update({"Test-Header": "Test-Value"}) + + client.reset() + + expected_user_agent = f"{__title__}/{__version__}" + assert "Test-Header" not in client.session.headers + assert client.session.headers["User-Agent"] == expected_user_agent + + +def test_get_generic_exception(client, mocker): + """Test GET request with generic exception.""" + mock_get = mocker.patch("requests.Session.get") + mock_get.side_effect = Exception("Some Exception") + + with pytest.raises(SystemExit): + client.get("http://test.com") + + +def test_post_generic_exception(client, mocker): + """Test POST request with generic exception.""" + mock_post = mocker.patch("requests.Session.post") + mock_post.side_effect = Exception("Some Exception") + + with pytest.raises(SystemExit): + client.post("http://test.com", json={"key": "value"}) + + +def test_post_json_exception(client, mocker): + """Test POST request when json() method raises an exception.""" + mock_post = mocker.patch("requests.Session.post") + mock_resp = mocker.Mock() + mock_resp.status_code = 201 + mock_resp.json.side_effect = Exception("JSON Exception") + mock_post.return_value = mock_resp + + with pytest.raises(SystemExit): + client.post("http://test.com", json={"key": "value"}, return_json=True) + + +def test_get_logging_on_exception(client, mocker): + """Test if logging occurs during exception in GET request.""" + mock_get = mocker.patch("requests.Session.get") + mock_get.side_effect = requests.RequestException("Failed to connect") + mock_logger = mocker.patch("logging.Logger.error") + + with pytest.raises(SystemExit): + client.get("http://test.com") + mock_logger.assert_called() + + +def test_post_logging_on_exception(client, mocker): + """Test if logging occurs during exception in POST request.""" + mock_post = mocker.patch("requests.Session.post") + mock_post.side_effect = requests.RequestException("Failed to connect") + mock_logger = mocker.patch("logging.Logger.error") + + with pytest.raises(SystemExit): + client.post("http://test.com", json={"key": "value"}) + mock_logger.assert_called() diff --git a/tests/unit/test_okta.py b/tests/unit/test_okta.py index 85f69cf4..b2769b22 100644 --- a/tests/unit/test_okta.py +++ b/tests/unit/test_okta.py @@ -4,7 +4,8 @@ from unittest.mock import Mock import pytest -import requests_mock +from tokendito.config import Config +from tokendito.http_client import HTTP_client @pytest.fixture @@ -64,7 +65,12 @@ def test_bad_session_token(mocker, sample_json_response, sample_headers): "mfa_provider, session_token, selected_factor, expected", [ ("DUO", 123, {"_embedded": {}}, 123), - ("OKTA", 345, {"_embedded": {"factor": {"factorType": "push"}}}, 345), + ( + "OKTA", + 345, + {"_embedded": {"factor": {"factorType": "push"}}}, + 345, + ), # Changed expected value to 2 ("GOOGLE", 456, {"_embedded": {"factor": {"factorType": "sms"}}}, 456), ], ) @@ -77,9 +83,14 @@ def test_mfa_provider_type( sample_headers, ): """Test whether function return key on specific MFA provider.""" - from tokendito.config import Config + from tokendito.http_client import HTTP_client from tokendito.okta import mfa_provider_type + mock_response = {"sessionToken": session_token} + mocker.patch.object(HTTP_client, "post", return_value=mock_response) + + mocker.patch("tokendito.duo.duo_api_post", return_value=None) + payload = {"x": "y", "t": "z"} callback_url = "https://www.acme.org" selected_mfa_option = 1 @@ -87,15 +98,13 @@ def test_mfa_provider_type( primary_auth = 1 pytest_config = Config() - mfa_verify = {"sessionToken": session_token} mocker.patch( "tokendito.duo.authenticate_duo", return_value=(payload, sample_headers, callback_url), ) - mocker.patch("tokendito.okta.api_wrapper", return_value=mfa_verify) - mocker.patch("tokendito.okta.push_approval", return_value=mfa_verify) - mocker.patch("tokendito.okta.totp_approval", return_value=mfa_verify) - mocker.patch("tokendito.duo.duo_api_post") + mocker.patch("tokendito.okta.push_approval", return_value={"sessionToken": session_token}) + mocker.patch("tokendito.okta.totp_approval", return_value={"sessionToken": session_token}) + assert ( mfa_provider_type( pytest_config, @@ -114,6 +123,7 @@ def test_mfa_provider_type( def test_bad_mfa_provider_type(mocker, sample_headers): """Test whether function return key on specific MFA provider.""" from tokendito.config import Config + from tokendito.http_client import HTTP_client from tokendito.okta import mfa_provider_type pytest_config = Config() @@ -126,11 +136,15 @@ def test_bad_mfa_provider_type(mocker, sample_headers): mfa_verify = {"sessionToken": "pytest_session_token"} mfa_bad_provider = "bad_provider" + + mock_response = Mock() + mock_response.json.return_value = mfa_verify + mocker.patch( "tokendito.duo.authenticate_duo", return_value=(payload, sample_headers, callback_url), ) - mocker.patch("tokendito.okta.api_wrapper", return_value=mfa_verify) + mocker.patch.object(HTTP_client, "post", return_value=mock_response) mocker.patch("tokendito.okta.totp_approval", return_value=mfa_verify) with pytest.raises(SystemExit) as error: @@ -149,37 +163,6 @@ def test_bad_mfa_provider_type(mocker, sample_headers): ) -def test_api_wrapper(): - """Test whether verify_api_method returns the correct data.""" - from tokendito.okta import api_wrapper - - url = "https://acme.org" - with requests_mock.Mocker() as m: - data = {"response": "ok"} - m.post(url, json=data, status_code=200) - assert api_wrapper(url, data) == data - - with pytest.raises(SystemExit) as error, requests_mock.Mocker() as m: - data = None - m.post(url, json=data, status_code=200) - assert api_wrapper(url, data) == error - - with pytest.raises(SystemExit) as error, requests_mock.Mocker() as m: - data = {"response": "ok", "errorCode": "0xdeadbeef"} - m.post(url, json=data, status_code=200) - assert api_wrapper(url, data) == error - - with pytest.raises(SystemExit) as error, requests_mock.Mocker() as m: - data = "pytest_bad_datatype" - m.post(url, text=data, status_code=403) - assert api_wrapper(url, data) == error - - with pytest.raises(SystemExit) as error, requests_mock.Mocker() as m: - data = {"response": "incorrect", "errorCode": "0xdeadbeef"} - m.post(url, json=data, status_code=403) - assert api_wrapper("http://acme.org", data) == error - - def test_api_error_code_parser(): """Test whether message on specific status equal.""" from tokendito.okta import _status_dict @@ -219,6 +202,7 @@ def test_mfa_index(preset_mfa, output, mocker, sample_json_response): def test_mfa_options(sample_headers, sample_json_response, mocker): """Test handling of MFA approval.""" from tokendito.config import Config + from tokendito.http_client import HTTP_client from tokendito.okta import totp_approval selected_mfa_option = {"factorType": "push"} @@ -227,15 +211,17 @@ def test_mfa_options(sample_headers, sample_json_response, mocker): mfa_challenge_url = "https://pytest" pytest_config = Config(okta={"mfa_response": None}) - # Test that selecting software token returns a session token + mocker.patch("tokendito.user.get_input", return_value="012345") + + mocker.patch.object(HTTP_client, "post", return_value={"sessionToken": "pytest"}) selected_mfa_option = {"factorType": "token:software:totp"} primary_auth["stateToken"] = "pytest" mfa_verify = {"sessionToken": "pytest"} - mocker.patch("tokendito.user.get_input", return_value="012345") - mocker.patch("tokendito.okta.api_wrapper", return_value=mfa_verify) + ret = totp_approval( pytest_config, selected_mfa_option, sample_headers, mfa_challenge_url, payload, primary_auth ) + assert ret == mfa_verify @@ -295,24 +281,25 @@ def test_mfa_challenge_with_no_mfas(sample_headers, sample_json_response): ), ], ) -def test_push_approval(mocker, sample_headers, return_value, side_effect, expected): +def test_push_approval(mocker, return_value, side_effect, expected): """Test push approval.""" from tokendito import okta challenge_url = "https://pytest/api/v1/authn/factors/factorid/verify" + payload = {"some_key": "some_value"} - mocker.patch("tokendito.okta.api_wrapper", return_value=return_value, side_effect=side_effect) - mocker.patch("time.sleep", return_value=0) + mocker.patch.object(HTTP_client, "post", return_value=return_value, side_effect=side_effect) + mocker.patch("time.sleep", return_value=None) if "status" in return_value and return_value["status"] == "SUCCESS": - ret = okta.push_approval(sample_headers, challenge_url, None) + ret = okta.push_approval(challenge_url, payload) assert ret["status"] == "SUCCESS" elif "factorResult" in return_value and return_value["factorResult"] == "WAITING": - ret = okta.push_approval(sample_headers, challenge_url, None) + ret = okta.push_approval(challenge_url, payload) assert ret["status"] == "SUCCESS" else: with pytest.raises(SystemExit) as err: - okta.push_approval(sample_headers, challenge_url, None) + okta.push_approval(challenge_url, payload) assert err.value.code == expected @@ -456,16 +443,19 @@ def test_extract_saml_relaystate(html, expected): def test_get_saml_request(mocker): """Test getting SAML request.""" from tokendito import okta + from tokendito.http_client import HTTP_client - request_wrapper_response = Mock() - auth_properties = {"id": "id", "metadata": "metadata"} - request_wrapper_response.text = ( + mock_response = Mock() + mock_response.text = ( "
" "" ) - mocker.patch("tokendito.user.request_wrapper", return_value=request_wrapper_response) + + mocker.patch.object(HTTP_client, "get", return_value=mock_response) + + auth_properties = {"id": "id", "metadata": "metadata"} assert okta.get_saml_request(auth_properties) == { "base_url": "https://acme.okta.com", @@ -479,8 +469,8 @@ def test_send_saml_request(mocker): """Test sending SAML request.""" from tokendito import okta - request_wrapper_response = Mock() - request_wrapper_response.text = ( + mock_response = Mock() + mock_response.text = ( "" @@ -490,7 +480,7 @@ def test_send_saml_request(mocker): saml_request = {"relay_state": "relay_state", "request": "request", "post_url": "post_url"} cookie = {"sid": "pytestcookie"} - mocker.patch("tokendito.user.request_wrapper", return_value=request_wrapper_response) + mocker.patch("tokendito.http_client.HTTP_client.get", return_value=mock_response) assert okta.send_saml_request(saml_request, cookie) == { "response": "PD94bWwgdmVyc2lvbj0iMS4wIiBlbmNvZGluZz0iVVRGLTgiPz4=", @@ -502,9 +492,10 @@ def test_send_saml_request(mocker): def test_send_saml_response(mocker): """Test sending SAML response.""" from tokendito import okta + from tokendito.http_client import HTTP_client - request_wrapper_response = Mock() - request_wrapper_response.cookies = {"sid": "pytestcookie"} + mock_response = Mock() + mock_response.cookies = {"sid": "pytestcookie"} saml_response = { "response": "pytestresponse", @@ -512,8 +503,9 @@ def test_send_saml_response(mocker): "post_url": "https://acme.okta.com/app/okta_org2org/akjlkjlksjx0xmdd/sso/saml", } - mocker.patch("tokendito.user.request_wrapper", return_value=request_wrapper_response) - assert okta.send_saml_response(saml_response) == request_wrapper_response.cookies + mocker.patch.object(HTTP_client, "post", return_value=mock_response) + + assert okta.send_saml_response(saml_response) == mock_response.cookies def test_authenticate(mocker): @@ -552,7 +544,15 @@ def test_local_auth(mocker): """Test local auth method.""" from tokendito import okta from tokendito.config import Config + from tokendito.http_client import HTTP_client + + # Create a fake HTTP response using Mock + mock_response_data = {"status": "SUCCESS", "sessionToken": "pytesttoken"} + + # Patch HTTP_client.post to return the mock response + mocker.patch.object(HTTP_client, "post", return_value=mock_response_data) + # Initialize the configuration pytest_config = Config( okta={ "username": "pytest", @@ -560,10 +560,7 @@ def test_local_auth(mocker): "org": "https://acme.okta.org/", } ) - api_wrapper_response = {"status": "SUCCESS", "sessionToken": "pytesttoken"} - mocker.patch("tokendito.okta.api_wrapper", return_value=api_wrapper_response) - mocker.patch("tokendito.okta.get_session_token", return_value="pytesttoken") assert okta.local_auth(pytest_config) == "pytesttoken" diff --git a/tests/unit/test_user.py b/tests/unit/test_user.py index 295572cd..1e80835f 100644 --- a/tests/unit/test_user.py +++ b/tests/unit/test_user.py @@ -6,7 +6,6 @@ import sys import pytest -import requests_mock @pytest.mark.xfail( @@ -936,25 +935,3 @@ def test_extract_arns(saml, expected): from tokendito import user assert user.extract_arns(saml) == expected - - -def test_request_wrapper(): - """Test whether request_wrapper returns the correct data.""" - from tokendito.user import request_wrapper - - url = "https://acme.org" - - with requests_mock.Mocker() as m: - data = {"response": "ok"} - m.get(url, json=data, status_code=200) - assert request_wrapper("GET", url).json() == data - - with requests_mock.Mocker() as m: - data = {"response": "ok"} - m.post(url, json=data, status_code=200) - assert request_wrapper("POST", url).json() == data - - with pytest.raises(SystemExit) as error, requests_mock.Mocker() as m: - m.get(url, json=data, status_code=500) - request_wrapper("GET", url) - assert error.value.code == 1 diff --git a/tokendito/__init__.py b/tokendito/__init__.py index 20898530..3c8cbe92 100644 --- a/tokendito/__init__.py +++ b/tokendito/__init__.py @@ -1,7 +1,7 @@ # vim: set filetype=python ts=4 sw=4 # -*- coding: utf-8 -*- """Tokendito module initialization.""" -__version__ = "2.1.2" +__version__ = "2.1.3" __title__ = "tokendito" __description__ = "Get AWS STS tokens from Okta SSO" __long_description_content_type__ = "text/markdown" diff --git a/tokendito/aws.py b/tokendito/aws.py index e5f81ccd..ef5634be 100644 --- a/tokendito/aws.py +++ b/tokendito/aws.py @@ -18,6 +18,7 @@ import botocore.session from tokendito import okta from tokendito import user +from tokendito.http_client import HTTP_client logger = logging.getLogger(__name__) @@ -51,19 +52,21 @@ def authenticate_to_roles(urls, cookies=None): :param urls: list of tuples or tuple, with tiles info :param cookies: html cookies + :param user_agent: optional user agent string :return: response text """ + if cookies: + HTTP_client.set_cookies(cookies) # Set cookies if provided + url_list = [urls] if isinstance(urls, tuple) else urls responses = [] tile_count = len(url_list) - plural = "" - if tile_count > 1: - plural = "s" + plural = "s" if tile_count > 1 else "" logger.info(f"Discovering roles in {tile_count} tile{plural}.") for url, label in url_list: - response = user.request_wrapper("GET", url, cookies=cookies) + response = HTTP_client.get(url) # Use the HTTPClient's get method saml_response_string = response.text saml_xml = okta.extract_saml_response(saml_response_string) diff --git a/tokendito/http_client.py b/tokendito/http_client.py new file mode 100644 index 00000000..cb67acc1 --- /dev/null +++ b/tokendito/http_client.py @@ -0,0 +1,79 @@ +# vim: set filetype=python ts=4 sw=4 +# -*- coding: utf-8 -*- +"""This module handles HTTP client operations.""" + +import logging +import sys + +import requests +from tokendito import __title__ +from tokendito import __version__ + +logger = logging.getLogger(__name__) + + +class HTTPClient: + """Handles HTTP client operations.""" + + def __init__(self): + """Initialize the HTTPClient with a session object.""" + user_agent = f"{__title__}/{__version__}" + self.session = requests.Session() + self.session.headers.update({"User-Agent": user_agent}) + + def set_cookies(self, cookies): + """Update session with additional cookies.""" + self.session.cookies.update(cookies) + + def get(self, url, params=None, headers=None): + """Perform a GET request.""" + response = None + try: + logger.debug(f"Sending cookies: {self.session.cookies}") + logger.debug(f"Sending headers: {self.session.headers}") + response = self.session.get(url, params=params, headers=headers) + response.raise_for_status() + logger.debug(f"Received response from {url}: {response.text}") + return response + except requests.RequestException as e: + logger.error(f"Error during GET request to {url}. Error: {e}") + if response: + logger.debug(f"Response Headers: {response.headers}") + logger.debug(f"Response Content: {response.content}") + else: + logger.debug("No response received") + sys.exit(1) + + except Exception as err: + logger.error(f"The get request to {url} failed with {err}") + sys.exit(1) + + def post(self, url, data=None, json=None, headers=None, return_json=False): + """Perform a POST request.""" + try: + response = self.session.post(url, data=data, json=json, headers=headers) + response.raise_for_status() + if return_json is True: + try: + return response.json() + except Exception as err: + logger.error(f"Problem with json response {err}") + sys.exit(1) + else: + return response + except requests.RequestException as e: + logger.error(f"Error during POST request to {url}. Error: {e}") + sys.exit(1) + except Exception as err: + logger.error(f"The post request to {url} failed with {err}") + sys.exit(1) + + def reset(self): + """Reset the session object to its initial state.""" + user_agent = f"{__title__}/{__version__}" + self.session.cookies.clear() + self.session.headers = requests.utils.default_headers() + self.session.headers.update({"User-Agent": user_agent}) + + +HTTP_client = HTTPClient() diff --git a/tokendito/okta.py b/tokendito/okta.py index 238b051f..e23e5981 100644 --- a/tokendito/okta.py +++ b/tokendito/okta.py @@ -17,9 +17,9 @@ import bs4 from bs4 import BeautifulSoup -import requests from tokendito import duo from tokendito import user +from tokendito.http_client import HTTP_client logger = logging.getLogger(__name__) @@ -31,42 +31,6 @@ ) -def api_wrapper(url, payload, headers=None): - """Okta MFA authentication. - - :param url: url to call - :param payload: JSON Payload - :param headers: Headers of the request - :return: Dictionary with authentication response - """ - logger.debug(f"Calling {url} with {headers}") - try: - response = requests.request("POST", url, data=json.dumps(payload), headers=headers) - response.raise_for_status() - except Exception as err: - logger.error(f"There was an error with the call to {url}: {err}") - sys.exit(1) - - logger.debug(f"{response.url} responded with status code {response.status_code}") - - try: - ret = response.json() - except ValueError as e: - logger.error( - f"{type(e).__name__} - Failed to parse response\n" - f"URL: {url}\n" - f"Status: {response.status_code}\n" - f"Content: {response.content}\n" - ) - sys.exit(1) - - if "errorCode" in ret: - api_error_code_parser(ret["errorCode"]) - sys.exit(1) - - return ret - - def api_error_code_parser(status=None): """Status code parsing. @@ -83,19 +47,23 @@ def api_error_code_parser(status=None): def get_auth_properties(userid=None, url=None): - """Make a call to the webfinger endpoint. + """Make a call to the Okta webfinger endpoint to retrieve authentication properties. - :param userid: User for which we are requesting an auth endpoint. - :param url: Site where we are looking up the user. - :returns: dictionary with authentication properties. + :param userid: User's ID for which we are requesting an auth endpoint. + :param url: Okta organization URL where we are looking up the user. + :returns: Dictionary containing authentication properties. """ + # Prepare the payload for the webfinger endpoint request. payload = {"resource": f"okta:acct:{userid}", "rel": "okta:idp"} headers = {"accept": "application/jrd+json"} url = f"{url}/.well-known/webfinger" logger.debug(f"Looking up auth endpoint for {userid} in {url}") - response = user.request_wrapper("GET", url, headers=headers, params=payload) + # Make a GET request to the webfinger endpoint. + response = HTTP_client.get(url, params=payload, headers=headers) + + # Extract properties from the response. try: ret = response.json()["links"][0]["properties"] except (KeyError, ValueError) as e: @@ -103,8 +71,8 @@ def get_auth_properties(userid=None, url=None): logger.debug(f"Response: {response.text}") sys.exit(1) - # Try to get metadata, type, and ID if available, but ensure - # that a dictionary with the correct keys is returned. + # Extract specific authentication properties if available. + # Return a dictionary with 'metadata', 'type', and 'id' keys. properties = {} properties["metadata"] = ret.get("okta:idp:metadata", None) properties["type"] = ret.get("okta:idp:type", None) @@ -121,69 +89,120 @@ def get_saml_request(auth_properties): :param auth_properties: dict with the IdP ID and type. :returns: dict with post_url, relay_state, and base64 encoded saml request. """ + # Prepare the headers for the request to retrieve the SAML request. headers = {"accept": "text/html,application/xhtml+xml,application/xml"} + + # Build the URL based on the metadata and ID provided in the auth properties. base_url = user.get_base_url(auth_properties["metadata"]) url = f"{base_url}/sso/idps/{auth_properties['id']}" logger.debug(f"Getting SAML request from {url}") - response = user.request_wrapper("GET", url, headers=headers) + + # Make a GET request using the HTTP client to retrieve the SAML request. + response = HTTP_client.get(url, headers=headers) + + # Extract the required parameters from the SAML request. saml_request = { "base_url": user.get_base_url(extract_form_post_url(response.text)), "post_url": extract_form_post_url(response.text), "relay_state": extract_saml_relaystate(response.text), "request": extract_saml_request(response.text, raw=True), } + + # Mask sensitive data in the logs for security. user.add_sensitive_value_to_be_masked(saml_request["request"]) + logger.debug(f"SAML request is {saml_request}") return saml_request def send_saml_request(saml_request, cookies): - """Submit SAML request to IdP, and get the response back. + """ + Submit SAML request to IdP, and get the response back. - :param cookies: session cookies with `sid` :param saml_request: dict with IdP post_url, relay_state, and saml_request + :param cookies: session cookies with `sid` :returns: dict with with SP post_url, relay_state, and saml_response """ + HTTP_client.set_cookies(cookies) + + # Define the payload and headers for the request payload = { "relayState": saml_request["relay_state"], "SAMLRequest": saml_request["request"], } - headers = {"accept": "text/html,application/xhtml+xml,application/xml"} + + headers = { + "accept": "text/html,application/xhtml+xml,application/xml", + "Content-Type": "application/json", + } + + # Construct the URL from the provided saml_request url = saml_request["post_url"] + + # Log the SAML request details logger.debug(f"Sending SAML request to {url}") - response = user.request_wrapper("GET", url, headers=headers, data=payload, cookies=cookies) + # Use the HTTP client to make a GET request + response = HTTP_client.get(url, params=payload, headers=headers) + + # Extract relevant information from the response to form the saml_response dictionary saml_response = { "response": extract_saml_response(response.text, raw=True), "relay_state": extract_saml_relaystate(response.text), "post_url": extract_form_post_url(response.text), } + + # Mask sensitive values for logging purposes user.add_sensitive_value_to_be_masked(saml_response["response"]) + + # Log the formed SAML response logger.debug(f"SAML response is {saml_response}") + + # Return the formed SAML response return saml_response def send_saml_response(saml_response): - """Submit SAML response to the SP. + """ + Submit SAML response to the SP. - :param saml_response: dict with with SP post_url, relay_state, and saml_response + :param saml_response: dict with SP post_url, relay_state, and saml_response :returns: `sid` session cookie """ + # Define the payload and headers for the request. payload = { "SAMLResponse": saml_response["response"], "RelayState": saml_response["relay_state"], } - headers = {"accept": "text/html,application/xhtml+xml,application/xml"} + headers = { + "accept": "text/html,application/xhtml+xml,application/xml", + "Content-Type": "application/x-www-form-urlencoded", + } + + # Construct the URL from the provided saml_response. url = saml_response["post_url"] + # Log the SAML response details. logger.debug(f"Sending SAML response back to {url}") - response = user.request_wrapper("POST", url, data=payload, headers=headers) + + # Use the HTTP client to make a POST request. + response = HTTP_client.post(url, data=payload, headers=headers) + + # Extract cookies from the response. session_cookies = response.cookies + + # Get the 'sid' value from the cookies. sid = session_cookies.get("sid") + + # If 'sid' is present, mask its value for logging purposes. if sid is not None: user.add_sensitive_value_to_be_masked(sid) + + # Log the session cookies. logger.debug(f"Have session cookies: {session_cookies}") + + # Return the session cookies. return session_cookies @@ -204,6 +223,7 @@ def get_session_token(config, primary_auth, headers): if status == "SUCCESS" and "sessionToken" in primary_auth: session_token = primary_auth.get("sessionToken") elif status == "MFA_REQUIRED": + # Note: mfa_challenge should also be modified to accept and use http_client session_token = mfa_challenge(config, headers, primary_auth) else: logger.debug(f"Error parsing response: {json.dumps(primary_auth)}") @@ -278,11 +298,18 @@ def local_auth(config): logger.debug(f"Authenticate user to {config.okta['org']}") logger.debug(f"Sending {headers}, {payload} to {config.okta['org']}") - primary_auth = api_wrapper(f"{config.okta['org']}/api/v1/authn", payload, headers) + + primary_auth = HTTP_client.post( + f"{config.okta['org']}/api/v1/authn", json=payload, headers=headers, return_json=True + ) + + if "errorCode" in primary_auth: + api_error_code_parser(primary_auth["errorCode"]) + sys.exit(1) while session_token is None: session_token = get_session_token(config, primary_auth, headers) - logger.info(f"User has been succesfully authenticated to {config.okta['org']}.") + logger.info(f"User has been successfully authenticated to {config.okta['org']}.") return session_token @@ -301,6 +328,7 @@ def saml2_auth(config, auth_properties): saml2_config = deepcopy(config) saml2_config.okta["org"] = saml_request["base_url"] logger.info(f"Authentication is being redirected to {saml2_config.okta['org']}.") + # Try to authenticate using the new configuration. This could cause # recursive calls, which allows for IdP chaining. session_cookies = authenticate(saml2_config) @@ -438,9 +466,12 @@ def mfa_provider_type( if mfa_provider == "DUO": payload, headers, callback_url = duo.authenticate_duo(selected_factor) duo.duo_api_post(callback_url, payload=payload) - mfa_verify = api_wrapper(mfa_challenge_url, payload, headers) + mfa_verify = HTTP_client.post( + mfa_challenge_url, json=payload, headers=headers, return_json=True + ) + elif mfa_provider == "OKTA" and factor_type == "push": - mfa_verify = push_approval(headers, mfa_challenge_url, payload) + mfa_verify = push_approval(mfa_challenge_url, payload) elif mfa_provider in ["OKTA", "GOOGLE"] and factor_type in ["token:software:totp", "sms"]: mfa_verify = totp_approval( config, selected_mfa_option, headers, mfa_challenge_url, payload, primary_auth @@ -505,13 +536,9 @@ def mfa_challenge(config, headers, primary_auth): preset_mfa = config.okta["mfa"] - # This creates a list where each elements looks like provider_factor_id. - # For example, OKTA_push_9yi4bKJNH2WEWQ0x8, GOOGLE_token:software:totp_9yi4bKJNH2WEWQ available_mfas = [f"{d['provider']}_{d['factorType']}_{d['id']}" for d in mfa_options] - index = mfa_index(preset_mfa, available_mfas, mfa_options) - # time to challenge the mfa option selected_mfa_option = mfa_options[index] logger.debug(f"Selected MFA is [{selected_mfa_option}]") @@ -523,10 +550,14 @@ def mfa_challenge(config, headers, primary_auth): "provider": selected_mfa_option["provider"], "profile": selected_mfa_option["profile"], } - selected_factor = api_wrapper(mfa_challenge_url, payload, headers) + + selected_factor = HTTP_client.post( + mfa_challenge_url, json=payload, headers=headers, return_json=True + ) mfa_provider = selected_factor["_embedded"]["factor"]["provider"] logger.debug(f"MFA Challenge URL: [{mfa_challenge_url}] headers: {headers}") + mfa_session_token = mfa_provider_type( config, mfa_provider, @@ -564,8 +595,12 @@ def totp_approval(config, selected_mfa_option, headers, mfa_challenge_url, paylo "stateToken": primary_auth["stateToken"], "passCode": config.okta["mfa_response"], } - # FIXME: This call needs to catch a 403 coming from a bad token - mfa_verify = api_wrapper(mfa_challenge_url, payload, headers) + + # Using the http_client to make the POST request + mfa_verify = HTTP_client.post( + mfa_challenge_url, json=payload, headers=headers, return_json=True + ) + if "sessionToken" in mfa_verify: user.add_sensitive_value_to_be_masked(mfa_verify["sessionToken"]) logger.debug(f"mfa_verify [{json.dumps(mfa_verify)}]") @@ -573,16 +608,15 @@ def totp_approval(config, selected_mfa_option, headers, mfa_challenge_url, paylo return mfa_verify -def push_approval(headers, mfa_challenge_url, payload): +def push_approval(mfa_challenge_url, payload): """Handle push approval from the user. - :param headers: HTTP headers sent to API call :param mfa_challenge_url: MFA challenge url :param payload: payload which needs to be sent :return: Session Token if succeeded or terminates if user wait goes 5 min """ - logger.debug(f"Push approval with headers:{headers} challenge_url:{mfa_challenge_url}") + logger.debug(f"Push approval with challenge_url:{mfa_challenge_url}") user.print("Waiting for an approval from the device...") status = "MFA_CHALLENGE" @@ -590,8 +624,13 @@ def push_approval(headers, mfa_challenge_url, payload): response = {} challenge_displayed = False + headers = {"content-type": "application/json", "accept": "application/json"} + while status == "MFA_CHALLENGE" and result == "WAITING": - response = api_wrapper(mfa_challenge_url, payload, headers) + response = HTTP_client.post( + mfa_challenge_url, json=payload, headers=headers, return_json=True + ) + if "sessionToken" in response: user.add_sensitive_value_to_be_masked(response["sessionToken"]) diff --git a/tokendito/tool.py b/tokendito/tool.py index 35c4334e..c3477a8a 100644 --- a/tokendito/tool.py +++ b/tokendito/tool.py @@ -8,6 +8,7 @@ from tokendito import okta from tokendito import user from tokendito.config import config +from tokendito.http_client import HTTP_client logger = logging.getLogger(__name__) @@ -40,11 +41,13 @@ def cli(args): # Authenticate to okta session_cookies = okta.authenticate(config) + HTTP_client.set_cookies(session_cookies) + if config.okta["tile"]: tile_label = "" config.okta["tile"] = (config.okta["tile"], tile_label) else: - config.okta["tile"] = user.discover_tiles(config.okta["org"], session_cookies) + config.okta["tile"] = user.discover_tiles(config.okta["org"]) # Authenticate to AWS roles auth_tiles = aws.authenticate_to_roles(config.okta["tile"], cookies=session_cookies) diff --git a/tokendito/user.py b/tokendito/user.py index 36d91b7c..955964c6 100644 --- a/tokendito/user.py +++ b/tokendito/user.py @@ -25,6 +25,7 @@ from tokendito import aws from tokendito.config import Config from tokendito.config import config +from tokendito.http_client import HTTP_client # Unfortunately, readline is only available in non-Windows systems. There is no substitution. try: @@ -538,7 +539,7 @@ def get_account_aliases(saml_xml, saml_response_string): encoded_xml = codecs.encode(saml_xml.encode("utf-8"), "base64") aws_response = None try: - aws_response = requests.Session().post(url, data={"SAMLResponse": encoded_xml}) + aws_response = HTTP_client.post(url, data={"SAMLResponse": encoded_xml}) except Exception as request_error: logger.error(f"There was an error retrieving the AWS SAML page: \n{request_error}") logger.debug(json.dumps(aws_response)) @@ -1205,21 +1206,37 @@ def request_cookies(url, session_token): :param session_token: session token, str :returns: cookies object """ + # Construct the URL from the base URL provided. url = f"{url}/api/v1/sessions" - data = json.dumps({"sessionToken": f"{session_token}"}) - response_with_cookie = request_wrapper(method="POST", url=url, data=data) - sess_id = response_with_cookie.json()["id"] + # Define the payload and headers for the request. + data = {"sessionToken": session_token} + headers = {"Content-Type": "application/json", "accept": "application/json"} + + # Log the request details. + logger.debug(f"Requesting session cookies from {url}") + + # Use the HTTP client to make a POST request. + response_json = HTTP_client.post(url, json=data, headers=headers, return_json=True) + + if "id" not in response_json: + logger.error(f"'id' not found in response. Full response: {response_json}") + sys.exit(1) + + sess_id = response_json["id"] add_sensitive_value_to_be_masked(sess_id) - cookies = response_with_cookie.cookies - cookies.update({"sid": f"{sess_id}"}) - logger.debug(f"Session cookies: {cookies}") + # create cookies with sid 'sid'. + cookies = requests.cookies.RequestsCookieJar() + cookies.set("sid", sess_id, domain=urlparse(url).netloc, path="/") + + # Log the session cookies. + logger.debug(f"Received session cookies: {cookies}") return cookies -def discover_tiles(url, cookies): +def discover_tiles(url): """ Discover aws tile url on user's okta dashboard. @@ -1233,7 +1250,9 @@ def discover_tiles(url, cookies): "expand": ["items", "items.resource"], } logger.debug(f"Performing auto-discovery on {url}.") - response_with_tabs = request_wrapper(method="GET", url=url, cookies=cookies, params=params) + + response_with_tabs = HTTP_client.get(url, params=params) + tabs = response_with_tabs.json() aws_tiles = [] @@ -1254,33 +1273,3 @@ def discover_tiles(url, cookies): logger.debug(f"Discovered {len(tile)} URLs.") return tile - - -def request_wrapper(method, url, headers=None, **kwargs): - """ - Wrap 'requests.request' and perform response checks. - - :param method: request method - :param url: request URL - :param headers: request headers - :param kwargs: additional parameters passed to request - :returns: response object - """ - if headers is None: - headers = {"content-type": "application/json", "accept": "application/json"} - - logger.debug(f"Issuing {method} request to {url} with {headers} and {kwargs}") - try: - response = requests.request(method=method, url=url, headers=headers, **kwargs) - response.raise_for_status() - except requests.exceptions.HTTPError as err: - logger.error( - f"The {method} request to {url} failed ({err.response.status_code}): " - f"{err.response.text}" - ) - sys.exit(1) - except Exception as err: - logger.error(f"The {method} request to {url} failed with {err}") - sys.exit(1) - - return response