Skip to content

Commit

Permalink
Fix mypy errors in framework/results/context.py
Browse files Browse the repository at this point in the history
  • Loading branch information
stevebachmeier committed Nov 13, 2024
1 parent 630b06d commit f015a56
Show file tree
Hide file tree
Showing 5 changed files with 32 additions and 20 deletions.
23 changes: 12 additions & 11 deletions src/vivarium/framework/results/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,11 @@
"""

from __future__ import annotations

from collections import defaultdict
from collections.abc import Callable, Generator
from typing import Type
from typing import Any, Type

import pandas as pd
from pandas.core.groupby.generic import DataFrameGroupBy
Expand All @@ -21,7 +23,7 @@
get_mapped_col_name,
get_original_col_name,
)
from vivarium.types import ScalarValue
from vivarium.types import ScalarMapper, VectorMapper


class ResultsContext:
Expand Down Expand Up @@ -55,7 +57,9 @@ def __init__(self) -> None:
self.default_stratifications: list[str] = []
self.stratifications: list[Stratification] = []
self.excluded_categories: dict[str, list[str]] = {}
self.observations: defaultdict = defaultdict(lambda: defaultdict(list))
self.observations: defaultdict[
str, defaultdict[tuple[str, tuple[str, ...] | None], list[BaseObservation]]
] = defaultdict(lambda: defaultdict(list))

@property
def name(self) -> str:
Expand Down Expand Up @@ -99,11 +103,7 @@ def add_stratification(
sources: list[str],
categories: list[str],
excluded_categories: list[str] | None,
mapper: (
Callable[[pd.Series | pd.DataFrame], pd.Series]
| Callable[[ScalarValue], str]
| None
),
mapper: VectorMapper | ScalarMapper | None,
is_vectorized: bool,
) -> None:
"""Add a stratification to the results context.
Expand Down Expand Up @@ -190,7 +190,7 @@ def register_observation(
name: str,
pop_filter: str,
when: str,
**kwargs,
**kwargs: Any,
) -> None:
"""Add an observation to the results context.
Expand Down Expand Up @@ -301,6 +301,7 @@ def gather_results(
if filtered_pop.empty:
yield None, None, None
else:
pop: pd.DataFrame | DataFrameGroupBy[tuple[str, ...] | str]
if stratification_names is None:
pop = filtered_pop
else:
Expand Down Expand Up @@ -334,7 +335,7 @@ def _filter_population(
@staticmethod
def _get_groups(
stratifications: tuple[str, ...], filtered_pop: pd.DataFrame
) -> DataFrameGroupBy:
) -> DataFrameGroupBy[tuple[str, ...] | str]:
"""Group the population by stratification.
Notes
Expand All @@ -355,7 +356,7 @@ def _get_groups(
)
else:
pop_groups = filtered_pop.groupby(lambda _: "all")
return pop_groups
return pop_groups # type: ignore[return-value]

def _rename_stratification_columns(self, results: pd.DataFrame) -> None:
"""Convert the temporary stratified mapped index names back to their original names."""
Expand Down
5 changes: 3 additions & 2 deletions src/vivarium/framework/results/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,13 @@

from collections import defaultdict
from enum import Enum
from typing import TYPE_CHECKING, Callable, Dict, List, Optional, Tuple, Union
from typing import TYPE_CHECKING, Callable, Dict, List, Optional, Tuple, Type, Union

import pandas as pd

from vivarium.framework.event import Event
from vivarium.framework.results.context import ResultsContext
from vivarium.framework.results.observation import BaseObservation
from vivarium.framework.values import Pipeline
from vivarium.manager import Manager
from vivarium.types import ScalarValue
Expand Down Expand Up @@ -301,7 +302,7 @@ def _bin_data(data: Union[pd.Series, pd.DataFrame]) -> pd.Series:

def register_observation(
self,
observation_type,
observation_type: Type[BaseObservation],
is_stratified: bool,
name: str,
pop_filter: str,
Expand Down
13 changes: 9 additions & 4 deletions src/vivarium/framework/results/observation.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
from abc import ABC
from collections.abc import Callable
from dataclasses import dataclass
from typing import Any
from typing import TYPE_CHECKING

import pandas as pd
from pandas.api.types import CategoricalDtype
Expand All @@ -35,6 +35,9 @@

VALUE_COLUMN = "value"

if TYPE_CHECKING:
_PandasGroup = pd.DataFrame | DataFrameGroupBy[tuple[str, ...] | str]


@dataclass
class BaseObservation(ABC):
Expand All @@ -60,7 +63,8 @@ class BaseObservation(ABC):
DataFrame or one with a complete set of stratifications as the index and
all values set to 0.0."""
results_gatherer: Callable[
[pd.DataFrame | DataFrameGroupBy[str], tuple[str, ...] | None], pd.DataFrame
[_PandasGroup, tuple[str, ...] | None],
pd.DataFrame,
]
"""Method or function that gathers the new observation results."""
results_updater: Callable[[pd.DataFrame, pd.DataFrame], pd.DataFrame]
Expand All @@ -76,7 +80,7 @@ class BaseObservation(ABC):
def observe(
self,
event: Event,
df: pd.DataFrame | DataFrameGroupBy[str],
df: _PandasGroup,
stratifications: tuple[str, ...] | None,
) -> pd.DataFrame | None:
"""Determine whether to observe the given event, and if so, gather the results.
Expand Down Expand Up @@ -139,7 +143,8 @@ def __init__(
to_observe: Callable[[Event], bool] = lambda event: True,
):
def _wrap_results_gatherer(
df: pd.DataFrame | DataFrameGroupBy[str], _: tuple[str, ...] | None
df: _PandasGroup,
_: tuple[str, ...] | None,
) -> pd.DataFrame:
if isinstance(df, DataFrameGroupBy):
raise TypeError(
Expand Down
7 changes: 4 additions & 3 deletions src/vivarium/framework/results/stratification.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,19 +4,20 @@
===============
"""

from __future__ import annotations

from dataclasses import dataclass
from typing import Any, Callable
from typing import Any

import pandas as pd
from pandas.api.types import CategoricalDtype

from vivarium.types import ScalarMapper, VectorMapper

STRATIFICATION_COLUMN_SUFFIX: str = "mapped_values"

# TODO: Parameterizing pandas objects fails below python 3.12
VectorMapper = Callable[[pd.DataFrame], pd.Series] # type: ignore [type-arg]
ScalarMapper = Callable[[pd.Series], str] # type: ignore [type-arg]


@dataclass
Expand Down
4 changes: 4 additions & 0 deletions src/vivarium/types.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from collections.abc import Callable
from datetime import datetime, timedelta
from numbers import Number
from typing import Union
Expand All @@ -24,3 +25,6 @@
Timedelta = Union[pd.Timedelta, timedelta]
ClockTime = Union[Time, int]
ClockStepSize = Union[Timedelta, int]

VectorMapper = Callable[[pd.DataFrame], pd.Series] # type: ignore [type-arg]
ScalarMapper = Callable[[pd.Series], str] # type: ignore [type-arg]

0 comments on commit f015a56

Please sign in to comment.