From 50650254738f0b52bf14f636668ff18f04a32f1f Mon Sep 17 00:00:00 2001
From: Pieter Gijsbers
Date: Wed, 9 Oct 2024 21:32:05 +0200
Subject: [PATCH] Change/dataset/tag (#202)
* Extract method that create errors
* Update migration test to make fewer errors persist to database
A lot of elastic search errors seem to be occurring and I am not
sure why.
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
* Avoid infinite recursion if obj is empty sequence
* Add note of change to tags endpoint
* Add type hints
* Remove infinite recursion from an indexed str being a str again
---------
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
---
docs/migration.md | 13 +++++++
src/core/conversions.py | 10 ++++-
src/routers/openml/datasets.py | 38 +++++++++++--------
tests/routers/openml/dataset_tag_test.py | 2 +-
.../migration/datasets_migration_test.py | 21 ++++++----
5 files changed, 58 insertions(+), 26 deletions(-)
diff --git a/docs/migration.md b/docs/migration.md
index b1a2dc0..78f8761 100644
--- a/docs/migration.md
+++ b/docs/migration.md
@@ -91,6 +91,19 @@ includes datasets which are private.
The `limit` and `offset` parameters can now be used independently, you no longer need
to provide both if you wish to set only one.
+### `POST /datasets/tag`
+When successful, the "tag" property in the returned response is now always a list, even if only one tag exists for the entity.
+For example, after tagging dataset 21 with the tag `"foo"`:
+```diff
+{
+ data_tag": {
+ "id": "21",
+- "tag": "foo"
++ "tag": ["foo"]
+ }
+}
+```
+
## Studies
### `GET /{id_or_alias}`
diff --git a/src/core/conversions.py b/src/core/conversions.py
index 7c0d7fd..1e1fbe1 100644
--- a/src/core/conversions.py
+++ b/src/core/conversions.py
@@ -15,18 +15,20 @@ def _str_to_num(string: str) -> int | float | str:
def nested_str_to_num(obj: Any) -> Any:
"""Recursively tries to convert all strings in the object to numbers.
For dictionaries, only the values will be converted."""
+ if isinstance(obj, str):
+ return _str_to_num(obj)
if isinstance(obj, Mapping):
return {key: nested_str_to_num(val) for key, val in obj.items()}
if isinstance(obj, Iterable):
return [nested_str_to_num(val) for val in obj]
- if isinstance(obj, str):
- return _str_to_num(obj)
return obj
def nested_num_to_str(obj: Any) -> Any:
"""Recursively tries to convert all numbers in the object to strings.
For dictionaries, only the values will be converted."""
+ if isinstance(obj, str):
+ return obj
if isinstance(obj, Mapping):
return {key: nested_num_to_str(val) for key, val in obj.items()}
if isinstance(obj, Iterable):
@@ -37,6 +39,8 @@ def nested_num_to_str(obj: Any) -> Any:
def nested_remove_nones(obj: Any) -> Any:
+ if isinstance(obj, str):
+ return obj
if isinstance(obj, Mapping):
return {
key: nested_remove_nones(val)
@@ -49,6 +53,8 @@ def nested_remove_nones(obj: Any) -> Any:
def nested_remove_single_element_list(obj: Any) -> Any:
+ if isinstance(obj, str):
+ return obj
if isinstance(obj, Mapping):
return {key: nested_remove_single_element_list(val) for key, val in obj.items()}
if isinstance(obj, Sequence):
diff --git a/src/routers/openml/datasets.py b/src/routers/openml/datasets.py
index b22e920..dda2511 100644
--- a/src/routers/openml/datasets.py
+++ b/src/routers/openml/datasets.py
@@ -37,29 +37,35 @@ def tag_dataset(
) -> dict[str, dict[str, Any]]:
tags = database.datasets.get_tags_for(data_id, expdb_db)
if tag.casefold() in [t.casefold() for t in tags]:
- raise HTTPException(
- status_code=HTTPStatus.INTERNAL_SERVER_ERROR,
- detail={
- "code": "473",
- "message": "Entity already tagged by this tag.",
- "additional_information": f"id={data_id}; tag={tag}",
- },
- )
+ raise create_tag_exists_error(data_id, tag)
if user is None:
- raise HTTPException(
- status_code=HTTPStatus.PRECONDITION_FAILED,
- detail={"code": "103", "message": "Authentication failed"},
- ) from None
- database.datasets.tag(data_id, tag, user_id=user.user_id, connection=expdb_db)
- all_tags = [*tags, tag]
- tag_value = all_tags if len(all_tags) > 1 else all_tags[0]
+ raise create_authentication_failed_error()
+ database.datasets.tag(data_id, tag, user_id=user.user_id, connection=expdb_db)
return {
- "data_tag": {"id": str(data_id), "tag": tag_value},
+ "data_tag": {"id": str(data_id), "tag": [*tags, tag]},
}
+def create_authentication_failed_error() -> HTTPException:
+ return HTTPException(
+ status_code=HTTPStatus.PRECONDITION_FAILED,
+ detail={"code": "103", "message": "Authentication failed"},
+ )
+
+
+def create_tag_exists_error(data_id: int, tag: str) -> HTTPException:
+ return HTTPException(
+ status_code=HTTPStatus.INTERNAL_SERVER_ERROR,
+ detail={
+ "code": "473",
+ "message": "Entity already tagged by this tag.",
+ "additional_information": f"id={data_id}; tag={tag}",
+ },
+ )
+
+
class DatasetStatusFilter(StrEnum):
ACTIVE = DatasetStatus.ACTIVE
DEACTIVATED = DatasetStatus.DEACTIVATED
diff --git a/tests/routers/openml/dataset_tag_test.py b/tests/routers/openml/dataset_tag_test.py
index 8d4e1da..c23aa3a 100644
--- a/tests/routers/openml/dataset_tag_test.py
+++ b/tests/routers/openml/dataset_tag_test.py
@@ -36,7 +36,7 @@ def test_dataset_tag(key: ApiKey, expdb_test: Connection, py_api: TestClient) ->
json={"data_id": dataset_id, "tag": tag},
)
assert response.status_code == HTTPStatus.OK
- assert response.json() == {"data_tag": {"id": str(dataset_id), "tag": tag}}
+ assert response.json() == {"data_tag": {"id": str(dataset_id), "tag": [tag]}}
tags = get_tags_for(id_=dataset_id, connection=expdb_test)
assert tag in tags
diff --git a/tests/routers/openml/migration/datasets_migration_test.py b/tests/routers/openml/migration/datasets_migration_test.py
index bf5224f..5a67105 100644
--- a/tests/routers/openml/migration/datasets_migration_test.py
+++ b/tests/routers/openml/migration/datasets_migration_test.py
@@ -6,6 +6,7 @@
import pytest
from starlette.testclient import TestClient
+from core.conversions import nested_remove_single_element_list
from tests.conftest import ApiKey
@@ -137,7 +138,7 @@ def test_private_dataset_admin_access(py_api: TestClient) -> None:
@pytest.mark.parametrize(
"dataset_id",
- [*range(1, 10), 101],
+ [*range(1, 10), 101, 131],
)
@pytest.mark.parametrize(
"api_key",
@@ -160,17 +161,22 @@ def test_dataset_tag_response_is_identical(
"/data/tag",
data={"api_key": api_key, "tag": tag, "data_id": dataset_id},
)
- if (
- original.status_code == HTTPStatus.PRECONDITION_FAILED
- and original.json()["error"]["message"] == "An Elastic Search Exception occurred."
- ):
- pytest.skip("Encountered Elastic Search error.")
- if original.status_code == HTTPStatus.OK:
+ already_tagged = (
+ original.status_code == HTTPStatus.INTERNAL_SERVER_ERROR
+ and "already tagged" in original.json()["error"]["message"]
+ )
+ if not already_tagged:
# undo the tag, because we don't want to persist this change to the database
+ # Sometimes a change is already committed to the database even if an error occurs.
php_api.post(
"/data/untag",
data={"api_key": api_key, "tag": tag, "data_id": dataset_id},
)
+ if (
+ original.status_code != HTTPStatus.OK
+ and original.json()["error"]["message"] == "An Elastic Search Exception occured."
+ ):
+ pytest.skip("Encountered Elastic Search error.")
new = py_api.post(
f"/datasets/tag?api_key={api_key}",
json={"data_id": dataset_id, "tag": tag},
@@ -183,6 +189,7 @@ def test_dataset_tag_response_is_identical(
original = original.json()
new = new.json()
+ new = nested_remove_single_element_list(new)
assert original == new