Skip to content

Commit

Permalink
fix broken vph tests (#533)
Browse files Browse the repository at this point in the history
  • Loading branch information
rmudambi committed Nov 12, 2024
1 parent 93c184b commit 6056147
Show file tree
Hide file tree
Showing 3 changed files with 22 additions and 23 deletions.
22 changes: 7 additions & 15 deletions src/vivarium/component.py
Original file line number Diff line number Diff line change
Expand Up @@ -603,6 +603,10 @@ def build_lookup_table(
If the data source is invalid.
"""
data = self.get_data(builder, data_source)
# TODO update this to use vivarium.types.LookupTableData once we drop
# support for Python 3.9
if not isinstance(data, (Number, timedelta, datetime, pd.DataFrame, list, tuple)):
raise ConfigurationError(f"Data '{data}' must be a LookupTableData instance.")

if isinstance(data, list):
return builder.lookup.build_table(data, value_columns=list(value_columns))
Expand Down Expand Up @@ -656,7 +660,7 @@ def get_data(
self,
builder: Builder,
data_source: LookupTableData | str | Callable[[Builder], LookupTableData],
) -> float | pd.DataFrame:
) -> Any:
"""Retrieves data from a data source.
If the data source is a float or a DataFrame, it is treated as the data
Expand All @@ -683,12 +687,7 @@ def get_data(
layered_config_tree.exceptions.ConfigurationError
If the data source is invalid.
"""
# TODO update this to use vivarium.types.LookupTableData once we drop
# support for Python 3.9
valid_data_types = (Number, timedelta, datetime, pd.DataFrame, list, tuple)
if isinstance(data_source, valid_data_types):
data = data_source
elif isinstance(data_source, str):
if isinstance(data_source, str):
if "::" in data_source:
module, method = data_source.split("::")
try:
Expand Down Expand Up @@ -716,15 +715,8 @@ def get_data(
elif isinstance(data_source, Callable):
data = data_source(builder)
else:
raise ConfigurationError(
f"Data source is of type '{type(data_source)}'. It must be a "
"LookupTableData instance, a string corresponding to an "
"artifact key, a callable that returns a LookupTableData "
"instance, or a string defining such a callable."
)
data = data_source

if not isinstance(data, valid_data_types):
raise ConfigurationError(f"Data '{data}' must be a LookupTableData instance.")
return data

def _set_population_view(self, builder: "Builder") -> None:
Expand Down
20 changes: 12 additions & 8 deletions src/vivarium/framework/state_machine.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,10 +27,6 @@
from vivarium.types import ClockTime, LookupTableData


def default_initializer(_builder: Builder) -> LookupTableData:
return 0.0


def _next_state(
index: pd.Index,
event_time: ClockTime,
Expand Down Expand Up @@ -221,7 +217,7 @@ def __init__(
self,
state_id: str,
allow_self_transition: bool = False,
initialization_weights: Callable[[Builder], LookupTableData] = default_initializer,
initialization_weights: Callable[[Builder], LookupTableData] | None = None,
) -> None:
super().__init__()
self.state_id = state_id
Expand Down Expand Up @@ -294,7 +290,10 @@ def allow_self_transitions(self) -> None:
##################

def get_initialization_weights(self, builder: Builder) -> LookupTableData:
return self.initialization_weights(builder)
if self.initialization_weights:
return self.initialization_weights(builder)
else:
return 0.0

def transition_side_effect(self, index: pd.Index, event_time: ClockTime) -> None:
pass
Expand Down Expand Up @@ -516,7 +515,7 @@ def __init__(
self.add_states(states)

states_with_initialization_weights = [
s for s in self.states if s.initialization_weights != default_initializer
state for state in self.states if state.initialization_weights
]

if initial_state is not None:
Expand All @@ -533,7 +532,12 @@ def __init__(

initial_state.initialization_weights = lambda _builder: 1.0

elif not states_with_initialization_weights:
# TODO: [MIC-5403] remove this on_initialize_simulants check once
# VPH's DiseaseModel has a compatible initialization strategy
elif (
type(self).on_initialize_simulants == Machine.on_initialize_simulants
and not states_with_initialization_weights
):
raise ValueError(
"Must specify either an initial state or provide"
" initialization weights to states."
Expand Down
3 changes: 3 additions & 0 deletions src/vivarium/framework/values/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,6 +186,9 @@ def _call(self, *args: Any, skip_post_processor: bool = False, **kwargs: Any) ->
def __repr__(self) -> str:
return f"_Pipeline({self.name})"

def __hash__(self) -> int:
return hash(self.name)

def get_value_modifier(
self, modifier: Callable[..., Any], component: Component | None
) -> ValueModifier:
Expand Down

0 comments on commit 6056147

Please sign in to comment.