Skip to content

Commit

Permalink
add mypy type checking to artifact manager (#530)
Browse files Browse the repository at this point in the history
  • Loading branch information
rmudambi authored Nov 5, 2024
1 parent 6e8693e commit 8c2285a
Show file tree
Hide file tree
Showing 3 changed files with 46 additions and 36 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.rst
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
**3.0.17 - 11/04/24**

- Fix mypy errors in vivarium/framework/configuration.py
- Fix mypy errors in vivarium/framework/artifact/manager.py

**3.0.16 - 10/31/24**

Expand Down
1 change: 0 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,6 @@ exclude = [
'src/vivarium/examples/disease_model/observer.py',
'src/vivarium/examples/disease_model/population.py',
'src/vivarium/examples/disease_model/risk.py',
'src/vivarium/framework/artifact/manager.py',
'src/vivarium/framework/components/manager.py',
'src/vivarium/framework/components/parser.py',
'src/vivarium/framework/engine.py',
Expand Down
80 changes: 45 additions & 35 deletions src/vivarium/framework/artifact/manager.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
# mypy: ignore-errors
"""
====================
The Artifact Manager
Expand All @@ -8,18 +7,22 @@
for handling complex data bound up in a data artifact.
"""
from __future__ import annotations

import re
from collections.abc import Callable, Sequence
from pathlib import Path
from typing import Any, Callable, List, Optional, Sequence, Union
from typing import TYPE_CHECKING, Any

import pandas as pd
from layered_config_tree.main import LayeredConfigTree

from vivarium.framework.artifact import ArtifactException
from vivarium.framework.artifact.artifact import Artifact
from vivarium.manager import Interface, Manager

_Filter = Union[str, int, Sequence[int], Sequence[str]]
if TYPE_CHECKING:
from vivarium.framework.engine import Builder


class ArtifactManager(Manager):
Expand All @@ -33,14 +36,14 @@ class ArtifactManager(Manager):
}
}

def __init__(self):
def __init__(self) -> None:
self._default_value_column = "value"

@property
def name(self):
def name(self) -> str:
return "artifact_manager"

def setup(self, builder):
def setup(self, builder: Builder) -> None:
"""Performs this component's simulation setup."""
self.logger = builder.logging.get_logger(self.name)
# because not all columns are accessible via artifact filter terms, apply config filters separately
Expand All @@ -50,7 +53,7 @@ def setup(self, builder):
self.artifact = self._load_artifact(builder.configuration)
builder.lifecycle.add_constraint(self.load, allow_during=["setup"])

def _load_artifact(self, configuration: LayeredConfigTree) -> Optional[Artifact]:
def _load_artifact(self, configuration: LayeredConfigTree) -> Artifact | None:
"""Loads artifact data.
Looks up the path to the artifact hdf file, builds a default filter,
Expand All @@ -76,7 +79,7 @@ def _load_artifact(self, configuration: LayeredConfigTree) -> Optional[Artifact]
self.logger.info(f"Artifact additional filter terms are {self.config_filter_term}.")
return Artifact(artifact_path, base_filter_terms)

def load(self, entity_key: str, **column_filters: _Filter) -> Any:
def load(self, entity_key: str, **column_filters: int | str | Sequence[int | str]) -> Any:
"""Loads data associated with the given entity key.
Parameters
Expand All @@ -93,19 +96,21 @@ def load(self, entity_key: str, **column_filters: _Filter) -> Any:
The data associated with the given key, filtered down to the
requested subset if the data is a dataframe.
"""
if self.artifact is None:
raise ArtifactException("No artifact defined for simulation.")

data = self.artifact.load(entity_key)
if isinstance(data, pd.DataFrame): # could be metadata dict
data = data.reset_index()
draw_col = [c for c in data if "draw" in c]
draw_col = [c for c in data.columns if "draw" in c]
if draw_col:
data = data.rename(columns={draw_col[0]: self._default_value_column})
return (
filter_data(data, self.config_filter_term, **column_filters)
if isinstance(data, pd.DataFrame)
else data
)

def value_columns(self) -> Callable[[Union[str, pd.DataFrame]], List[str]]:
data = filter_data(data, self.config_filter_term, **column_filters)

return data

def value_columns(self) -> Callable[[str | pd.DataFrame], list[str]]:
"""Returns a function that returns the value columns for the given input.
The function can be called with either a string or a pandas DataFrame.
Expand All @@ -120,17 +125,17 @@ def value_columns(self) -> Callable[[Union[str, pd.DataFrame]], List[str]]:
"""
return lambda _: [self._default_value_column]

def __repr__(self):
def __repr__(self) -> str:
return "ArtifactManager()"


class ArtifactInterface(Interface):
"""The builder interface for accessing a data artifact."""

def __init__(self, manager: ArtifactManager):
def __init__(self, manager: ArtifactManager) -> None:
self._manager = manager

def load(self, entity_key: str, **column_filters: _Filter) -> pd.DataFrame:
def load(self, entity_key: str, **column_filters: int | str | Sequence[int | str]) -> Any:
"""Loads data associated with a formatted entity key.
The provided entity key must be of the form
Expand Down Expand Up @@ -162,7 +167,7 @@ def load(self, entity_key: str, **column_filters: _Filter) -> pd.DataFrame:
"""
return self._manager.load(entity_key, **column_filters)

def value_columns(self) -> Callable[[Union[str, pd.DataFrame]], List[str]]:
def value_columns(self) -> Callable[[str | pd.DataFrame], list[str]]:
"""Returns a function that returns the value columns for the given input.
The function can be called with either a string or a pandas DataFrame.
Expand All @@ -175,12 +180,14 @@ def value_columns(self) -> Callable[[Union[str, pd.DataFrame]], List[str]]:
"""
return self._manager.value_columns()

def __repr__(self):
def __repr__(self) -> str:
return "ArtifactManagerInterface()"


def filter_data(
data: pd.DataFrame, config_filter_term: Optional[str] = None, **column_filters: _Filter
data: pd.DataFrame,
config_filter_term: str | None = None,
**column_filters: int | str | Sequence[int | str],
) -> pd.DataFrame:
"""Uses the provided column filters and age_group conditions to subset the raw data."""
data = _config_filter(data, config_filter_term)
Expand All @@ -189,15 +196,15 @@ def filter_data(
return data


def _config_filter(data, config_filter_term):
def _config_filter(data: pd.DataFrame, config_filter_term: str | None) -> pd.DataFrame:
if config_filter_term:
filter_column = re.split("[<=>]", config_filter_term.split()[0])[0]
if filter_column in data.columns:
data = data.query(config_filter_term)
return data


def validate_filter_term(config_filter_term):
def validate_filter_term(config_filter_term: str | None) -> str | None:
multiple_filter_indicators = [" and ", " or ", "|", "&"]
if config_filter_term is not None and any(
x in config_filter_term for x in multiple_filter_indicators
Expand All @@ -208,7 +215,9 @@ def validate_filter_term(config_filter_term):
return config_filter_term


def _subset_rows(data: pd.DataFrame, **column_filters: _Filter) -> pd.DataFrame:
def _subset_rows(
data: pd.DataFrame, **column_filters: int | str | Sequence[int | str]
) -> pd.DataFrame:
"""Filters out unwanted rows from the data using the provided filters."""
extra_filters = set(column_filters.keys()) - set(data.columns)
if extra_filters:
Expand All @@ -218,26 +227,27 @@ def _subset_rows(data: pd.DataFrame, **column_filters: _Filter) -> pd.DataFrame:
)

for column, condition in column_filters.items():
if column in data.columns:
if not isinstance(condition, (list, tuple)):
condition = [condition]
mask = pd.Series(False, index=data.index)
for c in condition:
mask |= data[f"{column}"] == c
row_indexer = data[mask].index
data = data.loc[row_indexer, :]
if isinstance(condition, (str, int)):
condition = [condition]
mask = pd.Series(False, index=data.index)
for c in condition:
mask |= data[f"{column}"] == c
row_indexer = data[mask].index
data = data.loc[row_indexer, :]

return data


def _subset_columns(data: pd.DataFrame, **column_filters) -> pd.DataFrame:
def _subset_columns(
data: pd.DataFrame, **column_filters: int | str | Sequence[int | str]
) -> pd.DataFrame:
"""Filters out unwanted columns and default columns from the data using provided filters."""
columns_to_remove = set(list(column_filters.keys()) + ["draw"])
columns_to_remove = columns_to_remove.intersection(data.columns)
return data.drop(columns=columns_to_remove)
return data.drop(columns=list(columns_to_remove))


def get_base_filter_terms(configuration: LayeredConfigTree):
def get_base_filter_terms(configuration: LayeredConfigTree) -> list[str]:
"""Parses default filter terms from the artifact configuration."""
base_filter_terms = []

Expand Down

0 comments on commit 8c2285a

Please sign in to comment.