Skip to content

Commit

Permalink
Sbachmei/mic 5549/mypy results context (#538)
Browse files Browse the repository at this point in the history
  • Loading branch information
stevebachmeier authored Nov 13, 2024
1 parent 2a81a36 commit 881d6b8
Show file tree
Hide file tree
Showing 9 changed files with 65 additions and 55 deletions.
4 changes: 4 additions & 0 deletions CHANGELOG.rst
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
**3.2.1 - TBD/TBD/TBD**

- Fix mypy errors in vivarium/framework/results/context.py

**3.2.0 - 11/12/24**

- Feature: Supports passing callables directly when building lookup tables
Expand Down
26 changes: 13 additions & 13 deletions docs/source/concepts/results.rst
Original file line number Diff line number Diff line change
Expand Up @@ -303,7 +303,7 @@ A couple other more specific and commonly used observations are provided as well
that gathers new results and concatenates them to any existing results.

Ideally, all concrete classes should inherit from the
:class:`BaseObservation <vivarium.framework.results.observation.BaseObservation>`
:class:`Observation <vivarium.framework.results.observation.Observation>`
abstract base class, which contains the common attributes between observation types:

.. list-table:: **Common Observation Attributes**
Expand All @@ -312,40 +312,40 @@ abstract base class, which contains the common attributes between observation ty

* - Attribute
- Description
* - | :attr:`name <vivarium.framework.results.observation.BaseObservation.name>`
* - | :attr:`name <vivarium.framework.results.observation.Observation.name>`
- | Name of the observation. It will also be the name of the output results file
| for this particular observation.
* - | :attr:`pop_filter <vivarium.framework.results.observation.BaseObservation.pop_filter>`
* - | :attr:`pop_filter <vivarium.framework.results.observation.Observation.pop_filter>`
- | A Pandas query filter string to filter the population down to the simulants
| who should be considered for the observation.
* - | :attr:`when <vivarium.framework.results.observation.BaseObservation.when>`
* - | :attr:`when <vivarium.framework.results.observation.Observation.when>`
- | Name of the lifecycle phase the observation should happen. Valid values are:
| "time_step__prepare", "time_step", "time_step__cleanup", or "collect_metrics".
* - | :attr:`results_initializer <vivarium.framework.results.observation.BaseObservation.results_initializer>`
* - | :attr:`results_initializer <vivarium.framework.results.observation.Observation.results_initializer>`
- | Method or function that initializes the raw observation results
| prior to starting the simulation. This could return, for example, an empty
| DataFrame or one with a complete set of stratifications as the index and
| all values set to 0.0.
* - | :attr:`results_gatherer <vivarium.framework.results.observation.BaseObservation.results_gatherer>`
* - | :attr:`results_gatherer <vivarium.framework.results.observation.Observation.results_gatherer>`
- | Method or function that gathers the new observation results.
* - | :attr:`results_updater <vivarium.framework.results.observation.BaseObservation.results_updater>`
* - | :attr:`results_updater <vivarium.framework.results.observation.Observation.results_updater>`
- | Method or function that updates existing raw observation results with newly
| gathered results.
* - | :attr:`results_formatter <vivarium.framework.results.observation.BaseObservation.results_formatter>`
* - | :attr:`results_formatter <vivarium.framework.results.observation.Observation.results_formatter>`
- | Method or function that formats the raw observation results.
* - | :attr:`stratifications <vivarium.framework.results.observation.BaseObservation.stratifications>`
* - | :attr:`stratifications <vivarium.framework.results.observation.Observation.stratifications>`
- | Optional tuple of column names for the observation to stratify by.
* - | :attr:`to_observe <vivarium.framework.results.observation.BaseObservation.to_observe>`
* - | :attr:`to_observe <vivarium.framework.results.observation.Observation.to_observe>`
- | Method or function that determines whether to perform an observation on this Event.

The **BaseObservation** also contains the
:meth:`observe <vivarium.framework.results.observation.BaseObservation.observe>`
The **Observation** also contains the
:meth:`observe <vivarium.framework.results.observation.Observation.observe>`
method which is called at each :ref:`event <event_concept>` and :ref:`time step <time_concept>`
to determine whether or not the observation should be recorded, and if so, gathers
the results and stores them in the results system.

.. note::
All four observation types discussed above inherit from the **BaseObservation**
All four observation types discussed above inherit from the **Observation**
abstract base class. What differentiates them are the assigned attributes
(e.g. defining the **results_updater** to be an adding method for the
**AddingObservation**) or adding other attributes as necessary (e.g.
Expand Down
1 change: 0 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,6 @@ exclude = [
'src/vivarium/framework/lookup/manager.py',
'src/vivarium/framework/population/manager.py',
'src/vivarium/framework/population/population_view.py',
'src/vivarium/framework/results/context.py',
'src/vivarium/framework/results/interface.py',
'src/vivarium/framework/results/manager.py',
'src/vivarium/framework/results/observer.py',
Expand Down
54 changes: 27 additions & 27 deletions src/vivarium/framework/results/context.py
Original file line number Diff line number Diff line change
@@ -1,27 +1,29 @@
# mypy: ignore-errors
"""
===============
Results Context
===============
"""

from __future__ import annotations

from collections import defaultdict
from typing import Callable, Generator, List, Optional, Tuple, Type, Union
from collections.abc import Callable, Generator
from typing import Any

import pandas as pd
from pandas.core.groupby.generic import DataFrameGroupBy

from vivarium.framework.engine import Builder
from vivarium.framework.event import Event
from vivarium.framework.results.exceptions import ResultsConfigurationError
from vivarium.framework.results.observation import BaseObservation
from vivarium.framework.results.observation import Observation
from vivarium.framework.results.stratification import (
Stratification,
get_mapped_col_name,
get_original_col_name,
)
from vivarium.types import ScalarValue
from vivarium.types import ScalarMapper, VectorMapper


class ResultsContext:
Expand Down Expand Up @@ -52,10 +54,12 @@ class ResultsContext:
"""

def __init__(self) -> None:
self.default_stratifications: List[str] = []
self.stratifications: List[Stratification] = []
self.default_stratifications: list[str] = []
self.stratifications: list[Stratification] = []
self.excluded_categories: dict[str, list[str]] = {}
self.observations: defaultdict = defaultdict(lambda: defaultdict(list))
self.observations: defaultdict[
str, defaultdict[tuple[str, tuple[str, ...] | None], list[Observation]]
] = defaultdict(lambda: defaultdict(list))

@property
def name(self) -> str:
Expand All @@ -73,7 +77,7 @@ def setup(self, builder: Builder) -> None:
)

# noinspection PyAttributeOutsideInit
def set_default_stratifications(self, default_grouping_columns: List[str]) -> None:
def set_default_stratifications(self, default_grouping_columns: list[str]) -> None:
"""Set the default stratifications to be used by stratified observations.
Parameters
Expand All @@ -96,15 +100,10 @@ def set_default_stratifications(self, default_grouping_columns: List[str]) -> No
def add_stratification(
self,
name: str,
sources: List[str],
categories: List[str],
excluded_categories: Optional[List[str]],
mapper: Optional[
Union[
Callable[[Union[pd.Series, pd.DataFrame]], pd.Series],
Callable[[ScalarValue], str],
]
],
sources: list[str],
categories: list[str],
excluded_categories: list[str] | None,
mapper: VectorMapper | ScalarMapper | None,
is_vectorized: bool,
) -> None:
"""Add a stratification to the results context.
Expand Down Expand Up @@ -187,11 +186,11 @@ def add_stratification(

def register_observation(
self,
observation_type: Type[BaseObservation],
observation_type: type[Observation],
name: str,
pop_filter: str,
when: str,
**kwargs,
**kwargs: Any,
) -> None:
"""Add an observation to the results context.
Expand Down Expand Up @@ -242,10 +241,10 @@ def register_observation(
def gather_results(
self, population: pd.DataFrame, lifecycle_phase: str, event: Event
) -> Generator[
Tuple[
Optional[pd.DataFrame],
Optional[str],
Optional[Callable[[pd.DataFrame, pd.DataFrame], pd.DataFrame]],
tuple[
pd.DataFrame | None,
str | None,
Callable[[pd.DataFrame, pd.DataFrame], pd.DataFrame] | None,
],
None,
None,
Expand Down Expand Up @@ -302,6 +301,7 @@ def gather_results(
if filtered_pop.empty:
yield None, None, None
else:
pop: pd.DataFrame | DataFrameGroupBy[tuple[str, ...] | str]
if stratification_names is None:
pop = filtered_pop
else:
Expand All @@ -317,7 +317,7 @@ def _filter_population(
self,
population: pd.DataFrame,
pop_filter: str,
stratification_names: Optional[tuple[str, ...]],
stratification_names: tuple[str, ...] | None,
) -> pd.DataFrame:
"""Filter out simulants not to observe."""
pop = population.query(pop_filter) if pop_filter else population.copy()
Expand All @@ -334,8 +334,8 @@ def _filter_population(

@staticmethod
def _get_groups(
stratifications: Tuple[str, ...], filtered_pop: pd.DataFrame
) -> DataFrameGroupBy:
stratifications: tuple[str, ...], filtered_pop: pd.DataFrame
) -> DataFrameGroupBy[tuple[str, ...] | str]:
"""Group the population by stratification.
Notes
Expand All @@ -356,7 +356,7 @@ def _get_groups(
)
else:
pop_groups = filtered_pop.groupby(lambda _: "all")
return pop_groups
return pop_groups # type: ignore[return-value]

def _rename_stratification_columns(self, results: pd.DataFrame) -> None:
"""Convert the temporary stratified mapped index names back to their original names."""
Expand Down
5 changes: 3 additions & 2 deletions src/vivarium/framework/results/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,13 @@

from collections import defaultdict
from enum import Enum
from typing import TYPE_CHECKING, Callable, Dict, List, Optional, Tuple, Union
from typing import TYPE_CHECKING, Callable, Dict, List, Optional, Tuple, Type, Union

import pandas as pd

from vivarium.framework.event import Event
from vivarium.framework.results.context import ResultsContext
from vivarium.framework.results.observation import Observation
from vivarium.framework.values import Pipeline
from vivarium.manager import Manager
from vivarium.types import ScalarValue
Expand Down Expand Up @@ -301,7 +302,7 @@ def _bin_data(data: Union[pd.Series, pd.DataFrame]) -> pd.Series:

def register_observation(
self,
observation_type,
observation_type: Type[Observation],
is_stratified: bool,
name: str,
pop_filter: str,
Expand Down
17 changes: 9 additions & 8 deletions src/vivarium/framework/results/observation.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
An observation is a class object that records simulation results; they are responsible
for initializing, gathering, updating, and formatting results.
The provided :class:`BaseObservation` class is an abstract base class that should
The provided :class:`Observation` class is an abstract base class that should
be subclassed by concrete observations. While there are no required abstract methods
to define when subclassing, the class does provide common attributes as well
as an `observe` method that determines whether to observe results for a given event.
Expand All @@ -24,7 +24,6 @@
from abc import ABC
from collections.abc import Callable
from dataclasses import dataclass
from typing import Any

import pandas as pd
from pandas.api.types import CategoricalDtype
Expand All @@ -37,7 +36,7 @@


@dataclass
class BaseObservation(ABC):
class Observation(ABC):
"""An abstract base dataclass to be inherited by concrete observations.
This class includes an :meth:`observe <observe>` method that determines whether
Expand All @@ -60,7 +59,8 @@ class BaseObservation(ABC):
DataFrame or one with a complete set of stratifications as the index and
all values set to 0.0."""
results_gatherer: Callable[
[pd.DataFrame | DataFrameGroupBy[str], tuple[str, ...] | None], pd.DataFrame
[pd.DataFrame | DataFrameGroupBy[tuple[str, ...] | str], tuple[str, ...] | None],
pd.DataFrame,
]
"""Method or function that gathers the new observation results."""
results_updater: Callable[[pd.DataFrame, pd.DataFrame], pd.DataFrame]
Expand All @@ -76,7 +76,7 @@ class BaseObservation(ABC):
def observe(
self,
event: Event,
df: pd.DataFrame | DataFrameGroupBy[str],
df: pd.DataFrame | DataFrameGroupBy[tuple[str, ...] | str],
stratifications: tuple[str, ...] | None,
) -> pd.DataFrame | None:
"""Determine whether to observe the given event, and if so, gather the results.
Expand All @@ -100,7 +100,7 @@ def observe(
return self.results_gatherer(df, stratifications)


class UnstratifiedObservation(BaseObservation):
class UnstratifiedObservation(Observation):
"""Concrete class for observing results that are not stratified.
The parent class `stratifications` are set to None and the `results_initializer`
Expand Down Expand Up @@ -139,7 +139,8 @@ def __init__(
to_observe: Callable[[Event], bool] = lambda event: True,
):
def _wrap_results_gatherer(
df: pd.DataFrame | DataFrameGroupBy[str], _: tuple[str, ...] | None
df: pd.DataFrame | DataFrameGroupBy[tuple[str, ...] | str],
_: tuple[str, ...] | None,
) -> pd.DataFrame:
if isinstance(df, DataFrameGroupBy):
raise TypeError(
Expand Down Expand Up @@ -181,7 +182,7 @@ def create_empty_df(
return pd.DataFrame()


class StratifiedObservation(BaseObservation):
class StratifiedObservation(Observation):
"""Concrete class for observing stratified results.
The parent class `results_initializer` and `results_gatherer` methods are
Expand Down
2 changes: 1 addition & 1 deletion src/vivarium/framework/results/observer.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
=========
An observer is a component that is responsible for registering
:class:`observations <vivarium.framework.results.observation.BaseObservation>`
:class:`observations <vivarium.framework.results.observation.Observation>`
to the simulation.
The provided :class:`Observer` class is an abstract base class that should be subclassed
Expand Down
7 changes: 4 additions & 3 deletions src/vivarium/framework/results/stratification.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,19 +4,20 @@
===============
"""

from __future__ import annotations

from dataclasses import dataclass
from typing import Any, Callable
from typing import Any

import pandas as pd
from pandas.api.types import CategoricalDtype

from vivarium.types import ScalarMapper, VectorMapper

STRATIFICATION_COLUMN_SUFFIX: str = "mapped_values"

# TODO: Parameterizing pandas objects fails below python 3.12
VectorMapper = Callable[[pd.DataFrame], pd.Series] # type: ignore [type-arg]
ScalarMapper = Callable[[pd.Series], str] # type: ignore [type-arg]


@dataclass
Expand Down
4 changes: 4 additions & 0 deletions src/vivarium/types.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from collections.abc import Callable
from datetime import datetime, timedelta
from numbers import Number
from typing import Union
Expand Down Expand Up @@ -25,3 +26,6 @@
float,
int,
]

VectorMapper = Callable[[pd.DataFrame], pd.Series] # type: ignore [type-arg]
ScalarMapper = Callable[[pd.Series], str] # type: ignore [type-arg]

0 comments on commit 881d6b8

Please sign in to comment.