Skip to content

Commit

Permalink
Change/dataset/tag (#202)
Browse files Browse the repository at this point in the history
* Extract method that create errors

* Update migration test to make fewer errors persist to database

A lot of elastic search errors seem to be occurring and I am not
sure why.

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Avoid infinite recursion if obj is empty sequence

* Add note of change to tags endpoint

* Add type hints

* Remove infinite recursion from an indexed str being a str again

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
PGijsbers and pre-commit-ci[bot] authored Oct 9, 2024
1 parent e23f74b commit 5065025
Show file tree
Hide file tree
Showing 5 changed files with 58 additions and 26 deletions.
13 changes: 13 additions & 0 deletions docs/migration.md
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,19 @@ includes datasets which are private.
The `limit` and `offset` parameters can now be used independently, you no longer need
to provide both if you wish to set only one.

### `POST /datasets/tag`
When successful, the "tag" property in the returned response is now always a list, even if only one tag exists for the entity.
For example, after tagging dataset 21 with the tag `"foo"`:
```diff
{
data_tag": {
"id": "21",
- "tag": "foo"
+ "tag": ["foo"]
}
}
```

## Studies

### `GET /{id_or_alias}`
Expand Down
10 changes: 8 additions & 2 deletions src/core/conversions.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,18 +15,20 @@ def _str_to_num(string: str) -> int | float | str:
def nested_str_to_num(obj: Any) -> Any:
"""Recursively tries to convert all strings in the object to numbers.
For dictionaries, only the values will be converted."""
if isinstance(obj, str):
return _str_to_num(obj)
if isinstance(obj, Mapping):
return {key: nested_str_to_num(val) for key, val in obj.items()}
if isinstance(obj, Iterable):
return [nested_str_to_num(val) for val in obj]
if isinstance(obj, str):
return _str_to_num(obj)
return obj


def nested_num_to_str(obj: Any) -> Any:
"""Recursively tries to convert all numbers in the object to strings.
For dictionaries, only the values will be converted."""
if isinstance(obj, str):
return obj
if isinstance(obj, Mapping):
return {key: nested_num_to_str(val) for key, val in obj.items()}
if isinstance(obj, Iterable):
Expand All @@ -37,6 +39,8 @@ def nested_num_to_str(obj: Any) -> Any:


def nested_remove_nones(obj: Any) -> Any:
if isinstance(obj, str):
return obj
if isinstance(obj, Mapping):
return {
key: nested_remove_nones(val)
Expand All @@ -49,6 +53,8 @@ def nested_remove_nones(obj: Any) -> Any:


def nested_remove_single_element_list(obj: Any) -> Any:
if isinstance(obj, str):
return obj
if isinstance(obj, Mapping):
return {key: nested_remove_single_element_list(val) for key, val in obj.items()}
if isinstance(obj, Sequence):
Expand Down
38 changes: 22 additions & 16 deletions src/routers/openml/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,29 +37,35 @@ def tag_dataset(
) -> dict[str, dict[str, Any]]:
tags = database.datasets.get_tags_for(data_id, expdb_db)
if tag.casefold() in [t.casefold() for t in tags]:
raise HTTPException(
status_code=HTTPStatus.INTERNAL_SERVER_ERROR,
detail={
"code": "473",
"message": "Entity already tagged by this tag.",
"additional_information": f"id={data_id}; tag={tag}",
},
)
raise create_tag_exists_error(data_id, tag)

if user is None:
raise HTTPException(
status_code=HTTPStatus.PRECONDITION_FAILED,
detail={"code": "103", "message": "Authentication failed"},
) from None
database.datasets.tag(data_id, tag, user_id=user.user_id, connection=expdb_db)
all_tags = [*tags, tag]
tag_value = all_tags if len(all_tags) > 1 else all_tags[0]
raise create_authentication_failed_error()

database.datasets.tag(data_id, tag, user_id=user.user_id, connection=expdb_db)
return {
"data_tag": {"id": str(data_id), "tag": tag_value},
"data_tag": {"id": str(data_id), "tag": [*tags, tag]},
}


def create_authentication_failed_error() -> HTTPException:
return HTTPException(
status_code=HTTPStatus.PRECONDITION_FAILED,
detail={"code": "103", "message": "Authentication failed"},
)


def create_tag_exists_error(data_id: int, tag: str) -> HTTPException:
return HTTPException(
status_code=HTTPStatus.INTERNAL_SERVER_ERROR,
detail={
"code": "473",
"message": "Entity already tagged by this tag.",
"additional_information": f"id={data_id}; tag={tag}",
},
)


class DatasetStatusFilter(StrEnum):
ACTIVE = DatasetStatus.ACTIVE
DEACTIVATED = DatasetStatus.DEACTIVATED
Expand Down
2 changes: 1 addition & 1 deletion tests/routers/openml/dataset_tag_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ def test_dataset_tag(key: ApiKey, expdb_test: Connection, py_api: TestClient) ->
json={"data_id": dataset_id, "tag": tag},
)
assert response.status_code == HTTPStatus.OK
assert response.json() == {"data_tag": {"id": str(dataset_id), "tag": tag}}
assert response.json() == {"data_tag": {"id": str(dataset_id), "tag": [tag]}}

tags = get_tags_for(id_=dataset_id, connection=expdb_test)
assert tag in tags
Expand Down
21 changes: 14 additions & 7 deletions tests/routers/openml/migration/datasets_migration_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import pytest
from starlette.testclient import TestClient

from core.conversions import nested_remove_single_element_list
from tests.conftest import ApiKey


Expand Down Expand Up @@ -137,7 +138,7 @@ def test_private_dataset_admin_access(py_api: TestClient) -> None:

@pytest.mark.parametrize(
"dataset_id",
[*range(1, 10), 101],
[*range(1, 10), 101, 131],
)
@pytest.mark.parametrize(
"api_key",
Expand All @@ -160,17 +161,22 @@ def test_dataset_tag_response_is_identical(
"/data/tag",
data={"api_key": api_key, "tag": tag, "data_id": dataset_id},
)
if (
original.status_code == HTTPStatus.PRECONDITION_FAILED
and original.json()["error"]["message"] == "An Elastic Search Exception occurred."
):
pytest.skip("Encountered Elastic Search error.")
if original.status_code == HTTPStatus.OK:
already_tagged = (
original.status_code == HTTPStatus.INTERNAL_SERVER_ERROR
and "already tagged" in original.json()["error"]["message"]
)
if not already_tagged:
# undo the tag, because we don't want to persist this change to the database
# Sometimes a change is already committed to the database even if an error occurs.
php_api.post(
"/data/untag",
data={"api_key": api_key, "tag": tag, "data_id": dataset_id},
)
if (
original.status_code != HTTPStatus.OK
and original.json()["error"]["message"] == "An Elastic Search Exception occured."
):
pytest.skip("Encountered Elastic Search error.")
new = py_api.post(
f"/datasets/tag?api_key={api_key}",
json={"data_id": dataset_id, "tag": tag},
Expand All @@ -183,6 +189,7 @@ def test_dataset_tag_response_is_identical(

original = original.json()
new = new.json()
new = nested_remove_single_element_list(new)
assert original == new


Expand Down

0 comments on commit 5065025

Please sign in to comment.