diff --git a/tests/unit_test.py b/tests/unit_test.py index 2e3c5375..57b077c2 100644 --- a/tests/unit_test.py +++ b/tests/unit_test.py @@ -1780,3 +1780,39 @@ def test_assume_role(mocker): with mocker.patch("tokendito.aws.handle_assume_role", return_value={}): with pytest.raises(SystemExit) as error: assert aws.assume_role(pytest_config, role_arn, session_name) == error + + +@pytest.mark.parametrize( + "saml, expected", + [ + ("pytest", {}), + ("pytest,pytest", {}), + ( + 'xsi:type="xs:string">arn:aws:iam::000000000000:saml/name,' + "arn:aws:iam::000000000000:role/name", + {"arn:aws:iam::000000000000:role/name": "arn:aws:iam::000000000000:saml/name"}, + ), + ], +) +def test_extract_arns(saml, expected): + """Test extracting Provider/Role ARN pairs from a SAML document.""" + from tokendito import user + + assert user.extract_arns(saml) == expected + + +def test_select_assumeable_role_no_tiles(): + """Test exiting when there are no assumable roles.""" + from tokendito import aws + + tiles = [ + ( + "https://acme.okta.org/home/amazon_aws/0123456789abcdef0123/456", + "saml_response", + "arn:aws:iam::000000000000:saml/name,arn:aws:iam::000000000000:role/name", + "Tile Label", + ) + ] + with pytest.raises(SystemExit) as err: + aws.select_assumeable_role(tiles) + assert err.value.code == 1 diff --git a/tokendito/__init__.py b/tokendito/__init__.py index 5f7cb01b..167f273b 100644 --- a/tokendito/__init__.py +++ b/tokendito/__init__.py @@ -8,7 +8,7 @@ from platformdirs import user_config_dir -__version__ = "2.1.0" +__version__ = "2.1.1" __title__ = "tokendito" __description__ = "Get AWS STS tokens from Okta SSO" __long_description_content_type__ = "text/markdown" diff --git a/tokendito/aws.py b/tokendito/aws.py index 9b03464a..81b017b3 100644 --- a/tokendito/aws.py +++ b/tokendito/aws.py @@ -204,6 +204,9 @@ def select_assumeable_role(tiles): authenticated_tiles = {} for url, saml_response, saml, label in tiles: roles_and_providers = user.extract_arns(saml) + if not roles_and_providers: + logger.warning(f"Skipping {url}, no valid roles or tile is misconfigured") + continue authenticated_tiles[url] = { "roles": list(roles_and_providers.keys()), "saml": saml, @@ -212,6 +215,10 @@ def select_assumeable_role(tiles): "label": label, } + if not authenticated_tiles: + logger.error("No roles found. Please check with your Okta admin.") + sys.exit(1) + role_arn, _id = user.select_role_arn(authenticated_tiles) role_name = role_arn.split("/")[-1] diff --git a/tokendito/user.py b/tokendito/user.py index 31a2c909..55d1a7d8 100644 --- a/tokendito/user.py +++ b/tokendito/user.py @@ -455,20 +455,18 @@ def extract_arns(saml): """ logger.debug("Decode response string as a SAML decoded value.") + roles_and_providers = {} arn_regex = ">(arn:aws:iam::.*?,arn:aws:iam::.*?)<" # find all provider and role pairs. arns = re.findall(arn_regex, saml) - - if len(arns) == 0: - logger.error("No IAM roles found in SAML response.") - logger.debug(arns) - sys.exit(2) + logger.debug(f"found ARNs: {arns}") # stuff into dict, role is dict key. - roles_and_providers = {i.split(",")[1]: i.split(",")[0] for i in arns} + if arns: + roles_and_providers = {i.split(",")[1]: i.split(",")[0] for i in arns} - logger.debug(f"Collected ARNs: {json.dumps(roles_and_providers)}") + logger.debug(f"Collected ARNs: {roles_and_providers}") return roles_and_providers