Skip to content

Commit

Permalink
Maint/tests (#204)
Browse files Browse the repository at this point in the history
* Remove need for specific line of pytest warning

* Add documentation on usage of fixtures and mocks

* Add verified user replacements

* Explicitly assert user is returned for type checker

* Extract users to separate module

* Make indentation consistent

* Assert data is returned

* Use constant instead of fixture
  • Loading branch information
PGijsbers authored Oct 14, 2024
1 parent dd9682c commit 23c5df2
Show file tree
Hide file tree
Showing 10 changed files with 176 additions and 37 deletions.
98 changes: 98 additions & 0 deletions docs/contributing/tests.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
# Writing Tests

tl;dr:
- Setting up the `py_api` fixture to test directly against a REST API endpoint is really slow, only use it for migration/integration tests.
- Getting a database fixture and doing a database call is slow, consider mocking if appropriate.

## Overhead from Fixtures
Sometimes, you want to interact with the REST API through the `py_api` fixture,
or want access to a database with `user_test` or `expdb_test` fixtures.
Be warned that these come with considerable relative overhead, which adds up when running thousands of tests.

```python
@pytest.mark.parametrize('execution_number', range(5000))
def test_private_dataset_owner_access(
execution_number,
expdb_test: Connection,
user_test: Connection,
py_api: TestClient,
) -> None:
fetch_user(ApiKey.REGULAR_USER, user_test) # accesses only the user db
get_estimation_procedures(expdb_test) # accesses only the experiment db
py_api.get("/does/not/exist") # only queries the api
pass
```

When individually adding/removing components, we measure (for 5000 repeats, n=1):

| expdb | user | api | exp call | user call | api get | time (s) |
|-------|------|-----|----------|-----------|---------|----------:|
||||||| 1.78 |
||||||| 3.45 |
||||||| 3.22 |
||||||| 298.48 |
||||||| 4.44 |
||||||| 285.69 |
||||||| 4.91 |
||||||| 5.81 |
||||||| 307.91 |

Adding a fixture that just returns some value adds only minimal overhead (1.91s),
so the burden comes from establishing the database connection itself.

We make the following observations:

- Adding a database fixture adds the same overhead as instantiating an entirely new test.
- Overhead of adding multiple database fixtures is not free
- The `py_api` fixture adds two orders of magnitude more overhead

We want our tests to be fast, so we want to avoid using these fixtures when we reasonably can.
We restrict usage of `py_api` fixtures to integration/migration tests, since it is very slow.
These only run on CI before merges.
For database fixtures

We will write some fixtures that can be used to e.g., get a `User` without accessing the database.
The validity of these users will be tested against the database in only a single test.

### Mocking
Mocking can help us reduce the reliance on database connections in tests.
A mocked function can prevent accessing the database, and instead return a predefined value instead.

It has a few upsides:
- It's faster than using a database fixture (see below).
- The test is not dependent on the database: you can run the test without a database.

But it also has downsides:
- Behavior changes in the database, such as schema changes, are not automatically reflected in the tests.
- The database layer (e.g., queries) are not actually tested.

Basically, the mocked behavior may not match real behavior when executed on a database.
For this reason, for each mocked entity, we should add a test that verifies that if the database layer
is invoked with the database, it returns the expected output that matches the mock.
This is additional overhead in development, but hopefully it pays back in more granular test feedback and faster tests.

On the speed of mocks, consider these two tests:

```diff
@pytest.mark.parametrize('execution_number', range(5000))
def test_private_dataset_owner_access(
execution_number,
admin,
+ mocker,
- expdb_test: Connection,
) -> None:
+ mock = mocker.patch('database.datasets.get')
+ class Dataset(NamedTuple):
+ uploader: int
+ visibility: Visibility
+ mock.return_value = Dataset(uploader=1, visibility=Visibility.PRIVATE)

_get_dataset_raise_otherwise(
dataset_id=1,
user=admin,
- expdb=expdb_test,
+ expdb=None,
)
```
There is only a single database call in the test. It fetches a record on an indexed field and does not require any joins.
Despite the database call being very light, the database-included test is ~50% slower than the mocked version (3.50s vs 5.04s).
1 change: 1 addition & 0 deletions mkdocs.yml
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ nav:
- Contributing:
- contributing/index.md
- Development: contributing/contributing.md
- Tests: contributing/tests.md
- Documentation: contributing/documentation.md
- Project Overview: contributing/project_overview.md
- Changes: migration.md
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -89,5 +89,5 @@ markers = [
"mut: executes a mutation on the database (in a transaction which is rolled back)",
]
filterwarnings = [
'ignore:A private pytest class or function was used.:DeprecationWarning:tests.conftest:119',
'ignore:A private pytest class or function was used.:DeprecationWarning:tests.conftest:',
]
8 changes: 0 additions & 8 deletions tests/conftest.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import contextlib
import json
from collections.abc import Iterator
from enum import StrEnum
from pathlib import Path
from typing import Any, NamedTuple

Expand All @@ -18,13 +17,6 @@
from routers.dependencies import expdb_connection, userdb_connection


class ApiKey(StrEnum):
ADMIN: str = "AD000000000000000000000000000000"
REGULAR_USER: str = "00000000000000000000000000000000"
OWNER_USER: str = "DA1A0000000000000000000000000000"
INVALID: str = "11111111111111111111111111111111"


@contextlib.contextmanager
def automatic_rollback(engine: Engine) -> Iterator[Connection]:
with engine.connect() as connection:
Expand Down
2 changes: 1 addition & 1 deletion tests/routers/openml/dataset_tag_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

from database.datasets import get_tags_for
from tests import constants
from tests.conftest import ApiKey
from tests.users import ApiKey


@pytest.mark.parametrize(
Expand Down
2 changes: 1 addition & 1 deletion tests/routers/openml/datasets_list_datasets_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from starlette.testclient import TestClient

from tests import constants
from tests.conftest import ApiKey
from tests.users import ApiKey


def _assert_empty_result(
Expand Down
56 changes: 31 additions & 25 deletions tests/routers/openml/datasets_test.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,14 @@
from http import HTTPStatus
from typing import Any

import pytest
from fastapi import HTTPException
from sqlalchemy import Connection
from starlette.testclient import TestClient

from schemas.datasets.openml import DatasetStatus
from tests.conftest import ApiKey
from database.users import User
from routers.openml.datasets import get_dataset
from schemas.datasets.openml import DatasetMetadata, DatasetStatus
from tests.users import NO_USER, OWNER_USER, SOME_USER, ApiKey


@pytest.mark.parametrize(
Expand Down Expand Up @@ -66,32 +69,35 @@ def test_get_dataset(py_api: TestClient) -> None:


@pytest.mark.parametrize(
("api_key", "response_code"),
"user",
[
(None, HTTPStatus.FORBIDDEN),
("a" * 32, HTTPStatus.FORBIDDEN),
NO_USER,
SOME_USER,
],
)
def test_private_dataset_no_user_no_access(
py_api: TestClient,
api_key: str | None,
response_code: int,
def test_private_dataset_no_owner_no_access(
user: User | None,
expdb_test: Connection,
) -> None:
query = f"?api_key={api_key}" if api_key else ""
response = py_api.get(f"/datasets/130{query}")

assert response.status_code == response_code
assert response.json()["detail"] == {"code": "112", "message": "No access granted"}


@pytest.mark.skip("Not sure how to include apikey in test yet.")
def test_private_dataset_owner_access(
py_api: TestClient,
dataset_130: dict[str, Any],
) -> None:
response = py_api.get("/v2/datasets/130?api_key=...")
assert response.status_code == HTTPStatus.OK
assert dataset_130 == response.json()
with pytest.raises(HTTPException) as e:
get_dataset(
dataset_id=130,
user=user,
user_db=None,
expdb_db=expdb_test,
)
assert e.value.status_code == HTTPStatus.FORBIDDEN
assert e.value.detail == {"code": "112", "message": "No access granted"} # type: ignore[comparison-overlap]


def test_private_dataset_owner_access(expdb_test: Connection, user_test: Connection) -> None:
dataset = get_dataset(
dataset_id=130,
user=OWNER_USER,
user_db=user_test,
expdb_db=expdb_test,
)
assert isinstance(dataset, DatasetMetadata)


@pytest.mark.skip("Not sure how to include apikey in test yet.")
Expand Down
2 changes: 1 addition & 1 deletion tests/routers/openml/migration/datasets_migration_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from starlette.testclient import TestClient

from core.conversions import nested_remove_single_element_list
from tests.conftest import ApiKey
from tests.users import ApiKey


@pytest.mark.parametrize(
Expand Down
27 changes: 27 additions & 0 deletions tests/routers/openml/users_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
import pytest
from sqlalchemy import Connection

from database.users import User
from routers.dependencies import fetch_user
from tests.users import ADMIN_USER, OWNER_USER, SOME_USER, ApiKey


@pytest.mark.parametrize(
("api_key", "user"),
[
(ApiKey.ADMIN, ADMIN_USER),
(ApiKey.OWNER_USER, OWNER_USER),
(ApiKey.REGULAR_USER, SOME_USER),
],
)
def test_fetch_user(api_key: str, user: User, user_test: Connection) -> None:
db_user = fetch_user(api_key, user_data=user_test)
assert db_user is not None
assert user.user_id == db_user.user_id
assert user.groups == db_user.groups


def test_fetch_user_invalid_key_returns_none(user_test: Connection) -> None:
assert fetch_user(api_key=None, user_data=user_test) is None
invalid_key = "f" * 32
assert fetch_user(api_key=invalid_key, user_data=user_test) is None
15 changes: 15 additions & 0 deletions tests/users.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
from enum import StrEnum

from database.users import User, UserGroup

NO_USER = None
SOME_USER = User(user_id=2, _database=None, _groups=[UserGroup.READ_WRITE])
OWNER_USER = User(user_id=16, _database=None, _groups=[UserGroup.READ_WRITE])
ADMIN_USER = User(user_id=1, _database=None, _groups=[UserGroup.ADMIN, UserGroup.READ_WRITE])


class ApiKey(StrEnum):
ADMIN: str = "AD000000000000000000000000000000"
REGULAR_USER: str = "00000000000000000000000000000000"
OWNER_USER: str = "DA1A0000000000000000000000000000"
INVALID: str = "11111111111111111111111111111111"

0 comments on commit 23c5df2

Please sign in to comment.