From f015a565c4b66228acadb1af34538e61ee0afa7d Mon Sep 17 00:00:00 2001 From: Steve Bachmeier Date: Tue, 12 Nov 2024 16:28:25 -0800 Subject: [PATCH] Fix mypy errors in framework/results/context.py --- src/vivarium/framework/results/context.py | 23 ++++++++++--------- src/vivarium/framework/results/manager.py | 5 ++-- src/vivarium/framework/results/observation.py | 13 +++++++---- .../framework/results/stratification.py | 7 +++--- src/vivarium/types.py | 4 ++++ 5 files changed, 32 insertions(+), 20 deletions(-) diff --git a/src/vivarium/framework/results/context.py b/src/vivarium/framework/results/context.py index cabe7cc08..cee07d481 100644 --- a/src/vivarium/framework/results/context.py +++ b/src/vivarium/framework/results/context.py @@ -5,9 +5,11 @@ """ +from __future__ import annotations + from collections import defaultdict from collections.abc import Callable, Generator -from typing import Type +from typing import Any, Type import pandas as pd from pandas.core.groupby.generic import DataFrameGroupBy @@ -21,7 +23,7 @@ get_mapped_col_name, get_original_col_name, ) -from vivarium.types import ScalarValue +from vivarium.types import ScalarMapper, VectorMapper class ResultsContext: @@ -55,7 +57,9 @@ def __init__(self) -> None: self.default_stratifications: list[str] = [] self.stratifications: list[Stratification] = [] self.excluded_categories: dict[str, list[str]] = {} - self.observations: defaultdict = defaultdict(lambda: defaultdict(list)) + self.observations: defaultdict[ + str, defaultdict[tuple[str, tuple[str, ...] | None], list[BaseObservation]] + ] = defaultdict(lambda: defaultdict(list)) @property def name(self) -> str: @@ -99,11 +103,7 @@ def add_stratification( sources: list[str], categories: list[str], excluded_categories: list[str] | None, - mapper: ( - Callable[[pd.Series | pd.DataFrame], pd.Series] - | Callable[[ScalarValue], str] - | None - ), + mapper: VectorMapper | ScalarMapper | None, is_vectorized: bool, ) -> None: """Add a stratification to the results context. @@ -190,7 +190,7 @@ def register_observation( name: str, pop_filter: str, when: str, - **kwargs, + **kwargs: Any, ) -> None: """Add an observation to the results context. @@ -301,6 +301,7 @@ def gather_results( if filtered_pop.empty: yield None, None, None else: + pop: pd.DataFrame | DataFrameGroupBy[tuple[str, ...] | str] if stratification_names is None: pop = filtered_pop else: @@ -334,7 +335,7 @@ def _filter_population( @staticmethod def _get_groups( stratifications: tuple[str, ...], filtered_pop: pd.DataFrame - ) -> DataFrameGroupBy: + ) -> DataFrameGroupBy[tuple[str, ...] | str]: """Group the population by stratification. Notes @@ -355,7 +356,7 @@ def _get_groups( ) else: pop_groups = filtered_pop.groupby(lambda _: "all") - return pop_groups + return pop_groups # type: ignore[return-value] def _rename_stratification_columns(self, results: pd.DataFrame) -> None: """Convert the temporary stratified mapped index names back to their original names.""" diff --git a/src/vivarium/framework/results/manager.py b/src/vivarium/framework/results/manager.py index 6c18d279d..b9db068c9 100644 --- a/src/vivarium/framework/results/manager.py +++ b/src/vivarium/framework/results/manager.py @@ -8,12 +8,13 @@ from collections import defaultdict from enum import Enum -from typing import TYPE_CHECKING, Callable, Dict, List, Optional, Tuple, Union +from typing import TYPE_CHECKING, Callable, Dict, List, Optional, Tuple, Type, Union import pandas as pd from vivarium.framework.event import Event from vivarium.framework.results.context import ResultsContext +from vivarium.framework.results.observation import BaseObservation from vivarium.framework.values import Pipeline from vivarium.manager import Manager from vivarium.types import ScalarValue @@ -301,7 +302,7 @@ def _bin_data(data: Union[pd.Series, pd.DataFrame]) -> pd.Series: def register_observation( self, - observation_type, + observation_type: Type[BaseObservation], is_stratified: bool, name: str, pop_filter: str, diff --git a/src/vivarium/framework/results/observation.py b/src/vivarium/framework/results/observation.py index f400900d4..79a99e483 100644 --- a/src/vivarium/framework/results/observation.py +++ b/src/vivarium/framework/results/observation.py @@ -24,7 +24,7 @@ from abc import ABC from collections.abc import Callable from dataclasses import dataclass -from typing import Any +from typing import TYPE_CHECKING import pandas as pd from pandas.api.types import CategoricalDtype @@ -35,6 +35,9 @@ VALUE_COLUMN = "value" +if TYPE_CHECKING: + _PandasGroup = pd.DataFrame | DataFrameGroupBy[tuple[str, ...] | str] + @dataclass class BaseObservation(ABC): @@ -60,7 +63,8 @@ class BaseObservation(ABC): DataFrame or one with a complete set of stratifications as the index and all values set to 0.0.""" results_gatherer: Callable[ - [pd.DataFrame | DataFrameGroupBy[str], tuple[str, ...] | None], pd.DataFrame + [_PandasGroup, tuple[str, ...] | None], + pd.DataFrame, ] """Method or function that gathers the new observation results.""" results_updater: Callable[[pd.DataFrame, pd.DataFrame], pd.DataFrame] @@ -76,7 +80,7 @@ class BaseObservation(ABC): def observe( self, event: Event, - df: pd.DataFrame | DataFrameGroupBy[str], + df: _PandasGroup, stratifications: tuple[str, ...] | None, ) -> pd.DataFrame | None: """Determine whether to observe the given event, and if so, gather the results. @@ -139,7 +143,8 @@ def __init__( to_observe: Callable[[Event], bool] = lambda event: True, ): def _wrap_results_gatherer( - df: pd.DataFrame | DataFrameGroupBy[str], _: tuple[str, ...] | None + df: _PandasGroup, + _: tuple[str, ...] | None, ) -> pd.DataFrame: if isinstance(df, DataFrameGroupBy): raise TypeError( diff --git a/src/vivarium/framework/results/stratification.py b/src/vivarium/framework/results/stratification.py index 7be52813b..e0d4d1f7a 100644 --- a/src/vivarium/framework/results/stratification.py +++ b/src/vivarium/framework/results/stratification.py @@ -4,19 +4,20 @@ =============== """ + from __future__ import annotations from dataclasses import dataclass -from typing import Any, Callable +from typing import Any import pandas as pd from pandas.api.types import CategoricalDtype +from vivarium.types import ScalarMapper, VectorMapper + STRATIFICATION_COLUMN_SUFFIX: str = "mapped_values" # TODO: Parameterizing pandas objects fails below python 3.12 -VectorMapper = Callable[[pd.DataFrame], pd.Series] # type: ignore [type-arg] -ScalarMapper = Callable[[pd.Series], str] # type: ignore [type-arg] @dataclass diff --git a/src/vivarium/types.py b/src/vivarium/types.py index 5d813e312..29f8af0ca 100644 --- a/src/vivarium/types.py +++ b/src/vivarium/types.py @@ -1,3 +1,4 @@ +from collections.abc import Callable from datetime import datetime, timedelta from numbers import Number from typing import Union @@ -24,3 +25,6 @@ Timedelta = Union[pd.Timedelta, timedelta] ClockTime = Union[Time, int] ClockStepSize = Union[Timedelta, int] + +VectorMapper = Callable[[pd.DataFrame], pd.Series] # type: ignore [type-arg] +ScalarMapper = Callable[[pd.Series], str] # type: ignore [type-arg]