From 50650254738f0b52bf14f636668ff18f04a32f1f Mon Sep 17 00:00:00 2001 From: Pieter Gijsbers Date: Wed, 9 Oct 2024 21:32:05 +0200 Subject: [PATCH] Change/dataset/tag (#202) * Extract method that create errors * Update migration test to make fewer errors persist to database A lot of elastic search errors seem to be occurring and I am not sure why. * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Avoid infinite recursion if obj is empty sequence * Add note of change to tags endpoint * Add type hints * Remove infinite recursion from an indexed str being a str again --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- docs/migration.md | 13 +++++++ src/core/conversions.py | 10 ++++- src/routers/openml/datasets.py | 38 +++++++++++-------- tests/routers/openml/dataset_tag_test.py | 2 +- .../migration/datasets_migration_test.py | 21 ++++++---- 5 files changed, 58 insertions(+), 26 deletions(-) diff --git a/docs/migration.md b/docs/migration.md index b1a2dc0..78f8761 100644 --- a/docs/migration.md +++ b/docs/migration.md @@ -91,6 +91,19 @@ includes datasets which are private. The `limit` and `offset` parameters can now be used independently, you no longer need to provide both if you wish to set only one. +### `POST /datasets/tag` +When successful, the "tag" property in the returned response is now always a list, even if only one tag exists for the entity. +For example, after tagging dataset 21 with the tag `"foo"`: +```diff +{ + data_tag": { + "id": "21", +- "tag": "foo" ++ "tag": ["foo"] + } +} +``` + ## Studies ### `GET /{id_or_alias}` diff --git a/src/core/conversions.py b/src/core/conversions.py index 7c0d7fd..1e1fbe1 100644 --- a/src/core/conversions.py +++ b/src/core/conversions.py @@ -15,18 +15,20 @@ def _str_to_num(string: str) -> int | float | str: def nested_str_to_num(obj: Any) -> Any: """Recursively tries to convert all strings in the object to numbers. For dictionaries, only the values will be converted.""" + if isinstance(obj, str): + return _str_to_num(obj) if isinstance(obj, Mapping): return {key: nested_str_to_num(val) for key, val in obj.items()} if isinstance(obj, Iterable): return [nested_str_to_num(val) for val in obj] - if isinstance(obj, str): - return _str_to_num(obj) return obj def nested_num_to_str(obj: Any) -> Any: """Recursively tries to convert all numbers in the object to strings. For dictionaries, only the values will be converted.""" + if isinstance(obj, str): + return obj if isinstance(obj, Mapping): return {key: nested_num_to_str(val) for key, val in obj.items()} if isinstance(obj, Iterable): @@ -37,6 +39,8 @@ def nested_num_to_str(obj: Any) -> Any: def nested_remove_nones(obj: Any) -> Any: + if isinstance(obj, str): + return obj if isinstance(obj, Mapping): return { key: nested_remove_nones(val) @@ -49,6 +53,8 @@ def nested_remove_nones(obj: Any) -> Any: def nested_remove_single_element_list(obj: Any) -> Any: + if isinstance(obj, str): + return obj if isinstance(obj, Mapping): return {key: nested_remove_single_element_list(val) for key, val in obj.items()} if isinstance(obj, Sequence): diff --git a/src/routers/openml/datasets.py b/src/routers/openml/datasets.py index b22e920..dda2511 100644 --- a/src/routers/openml/datasets.py +++ b/src/routers/openml/datasets.py @@ -37,29 +37,35 @@ def tag_dataset( ) -> dict[str, dict[str, Any]]: tags = database.datasets.get_tags_for(data_id, expdb_db) if tag.casefold() in [t.casefold() for t in tags]: - raise HTTPException( - status_code=HTTPStatus.INTERNAL_SERVER_ERROR, - detail={ - "code": "473", - "message": "Entity already tagged by this tag.", - "additional_information": f"id={data_id}; tag={tag}", - }, - ) + raise create_tag_exists_error(data_id, tag) if user is None: - raise HTTPException( - status_code=HTTPStatus.PRECONDITION_FAILED, - detail={"code": "103", "message": "Authentication failed"}, - ) from None - database.datasets.tag(data_id, tag, user_id=user.user_id, connection=expdb_db) - all_tags = [*tags, tag] - tag_value = all_tags if len(all_tags) > 1 else all_tags[0] + raise create_authentication_failed_error() + database.datasets.tag(data_id, tag, user_id=user.user_id, connection=expdb_db) return { - "data_tag": {"id": str(data_id), "tag": tag_value}, + "data_tag": {"id": str(data_id), "tag": [*tags, tag]}, } +def create_authentication_failed_error() -> HTTPException: + return HTTPException( + status_code=HTTPStatus.PRECONDITION_FAILED, + detail={"code": "103", "message": "Authentication failed"}, + ) + + +def create_tag_exists_error(data_id: int, tag: str) -> HTTPException: + return HTTPException( + status_code=HTTPStatus.INTERNAL_SERVER_ERROR, + detail={ + "code": "473", + "message": "Entity already tagged by this tag.", + "additional_information": f"id={data_id}; tag={tag}", + }, + ) + + class DatasetStatusFilter(StrEnum): ACTIVE = DatasetStatus.ACTIVE DEACTIVATED = DatasetStatus.DEACTIVATED diff --git a/tests/routers/openml/dataset_tag_test.py b/tests/routers/openml/dataset_tag_test.py index 8d4e1da..c23aa3a 100644 --- a/tests/routers/openml/dataset_tag_test.py +++ b/tests/routers/openml/dataset_tag_test.py @@ -36,7 +36,7 @@ def test_dataset_tag(key: ApiKey, expdb_test: Connection, py_api: TestClient) -> json={"data_id": dataset_id, "tag": tag}, ) assert response.status_code == HTTPStatus.OK - assert response.json() == {"data_tag": {"id": str(dataset_id), "tag": tag}} + assert response.json() == {"data_tag": {"id": str(dataset_id), "tag": [tag]}} tags = get_tags_for(id_=dataset_id, connection=expdb_test) assert tag in tags diff --git a/tests/routers/openml/migration/datasets_migration_test.py b/tests/routers/openml/migration/datasets_migration_test.py index bf5224f..5a67105 100644 --- a/tests/routers/openml/migration/datasets_migration_test.py +++ b/tests/routers/openml/migration/datasets_migration_test.py @@ -6,6 +6,7 @@ import pytest from starlette.testclient import TestClient +from core.conversions import nested_remove_single_element_list from tests.conftest import ApiKey @@ -137,7 +138,7 @@ def test_private_dataset_admin_access(py_api: TestClient) -> None: @pytest.mark.parametrize( "dataset_id", - [*range(1, 10), 101], + [*range(1, 10), 101, 131], ) @pytest.mark.parametrize( "api_key", @@ -160,17 +161,22 @@ def test_dataset_tag_response_is_identical( "/data/tag", data={"api_key": api_key, "tag": tag, "data_id": dataset_id}, ) - if ( - original.status_code == HTTPStatus.PRECONDITION_FAILED - and original.json()["error"]["message"] == "An Elastic Search Exception occurred." - ): - pytest.skip("Encountered Elastic Search error.") - if original.status_code == HTTPStatus.OK: + already_tagged = ( + original.status_code == HTTPStatus.INTERNAL_SERVER_ERROR + and "already tagged" in original.json()["error"]["message"] + ) + if not already_tagged: # undo the tag, because we don't want to persist this change to the database + # Sometimes a change is already committed to the database even if an error occurs. php_api.post( "/data/untag", data={"api_key": api_key, "tag": tag, "data_id": dataset_id}, ) + if ( + original.status_code != HTTPStatus.OK + and original.json()["error"]["message"] == "An Elastic Search Exception occured." + ): + pytest.skip("Encountered Elastic Search error.") new = py_api.post( f"/datasets/tag?api_key={api_key}", json={"data_id": dataset_id, "tag": tag}, @@ -183,6 +189,7 @@ def test_dataset_tag_response_is_identical( original = original.json() new = new.json() + new = nested_remove_single_element_list(new) assert original == new