From 40cd8741f55ea6dade6972146a547f40079640a9 Mon Sep 17 00:00:00 2001 From: Rajan Mudambi <11376379+rmudambi@users.noreply.github.com> Date: Mon, 14 Oct 2024 12:02:31 -0700 Subject: [PATCH 01/22] support callables in get_data (#496) --- CHANGELOG.rst | 4 ++ docs/nitpick-exceptions | 5 ++- src/vivarium/component.py | 84 ++++++++++++++++++++++++--------------- src/vivarium/types.py | 13 +++--- 4 files changed, 65 insertions(+), 41 deletions(-) diff --git a/CHANGELOG.rst b/CHANGELOG.rst index 78d869b6..88afe1b2 100644 --- a/CHANGELOG.rst +++ b/CHANGELOG.rst @@ -1,3 +1,7 @@ +**3.2.0 - TBD** + + - Support passing callables directly when building lookup tables + **3.1.0 - 11/07/24** - Drop support for python 3.9 diff --git a/docs/nitpick-exceptions b/docs/nitpick-exceptions index 22bd2118..ec57a7ef 100644 --- a/docs/nitpick-exceptions +++ b/docs/nitpick-exceptions @@ -46,6 +46,7 @@ py:class layered_config_tree.main.LayeredConfigTree py:class LayeredConfigTree py:exc layered_config_tree.exceptions.ConfigurationError -# TODO: Need to revisit this. Nitpicking here to avoid failing builds on 3.9 in state_machine and testing_utils +# TODO: Need to revisit this. Nitpicking here to avoid failing builds on 3.9 py:class Logger -py:class Path \ No newline at end of file +py:class Path +py:class LookupTableData \ No newline at end of file diff --git a/src/vivarium/component.py b/src/vivarium/component.py index f1038d13..444e89b3 100644 --- a/src/vivarium/component.py +++ b/src/vivarium/component.py @@ -13,8 +13,10 @@ import re from abc import ABC from collections.abc import Sequence +from datetime import datetime, timedelta from importlib import import_module from inspect import signature +from numbers import Number from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Union import pandas as pd @@ -25,6 +27,7 @@ from vivarium.framework.event import Event from vivarium.framework.lookup import LookupTable from vivarium.framework.population import PopulationError, PopulationView +from vivarium.types import LookupTableData if TYPE_CHECKING: from vivarium.framework.engine import Builder @@ -576,10 +579,9 @@ def build_all_lookup_tables(self, builder: "Builder") -> None: def build_lookup_table( self, - builder: "Builder", - # todo: replace with LookupTableData - data_source: Union[str, float, int, list, pd.DataFrame], - value_columns: Optional[Sequence[str]] = None, + builder: Builder, + data_source: LookupTableData | str | Callable[[Builder], LookupTableData], + value_columns: Sequence[str] | None = None, ) -> LookupTable: """Builds a LookupTable from a data source. @@ -658,11 +660,10 @@ def _get_columns( return value_columns, parameter_columns, key_columns def get_data( - # TODO: replace with LookupTableData self, - builder: "Builder", - data_source: Union[str, float, pd.DataFrame], - ) -> Union[float, pd.DataFrame]: + builder: Builder, + data_source: LookupTableData | str | Callable[[Builder], LookupTableData], + ) -> float | pd.DataFrame: """Retrieves data from a data source. If the data source is a float or a DataFrame, it is treated as the data @@ -689,32 +690,49 @@ def get_data( layered_config_tree.exceptions.ConfigurationError If the data source is invalid. """ - if isinstance(data_source, (float, int, list, pd.DataFrame)): - return data_source - - if "::" in data_source: - module, method = data_source.split("::") - try: - if module == "self": - data_getter = getattr(self, method) - else: - data_getter = getattr(import_module(module), method) - except ModuleNotFoundError: - raise ConfigurationError(f"Unable to find module '{module}'.") - except AttributeError: - module_string = ( - f"component {self.name}." if module == "self" else f"module '{module}'." - ) - raise ConfigurationError( - f"There is no method '{method}' for the {module_string}." - ) - - return data_getter(builder) + # TODO update this to use vivarium.types.LookupTableData once we drop + # support for Python 3.9 + valid_data_types = (Number, timedelta, datetime, pd.DataFrame, list, tuple) + if isinstance(data_source, valid_data_types): + data = data_source + elif isinstance(data_source, str): + if "::" in data_source: + module, method = data_source.split("::") + try: + if module == "self": + data_source = getattr(self, method) + else: + data_source = getattr(import_module(module), method) + except ModuleNotFoundError: + raise ConfigurationError(f"Unable to find module '{module}'.") + except AttributeError: + module_string = ( + f"component {self.name}" if module == "self" else f"module '{module}'" + ) + raise ConfigurationError( + f"There is no method '{method}' for the {module_string}." + ) + data = data_source(builder) + else: + try: + data = builder.data.load(data_source) + except ArtifactException: + raise ConfigurationError( + f"Failed to find key '{data_source}' in artifact." + ) + elif isinstance(data_source, Callable): + data = data_source(builder) + else: + raise TypeError( + f"Data source is of type '{type(data_source)}'. It must be a " + "LookupTableData instance, a string corresponding to an " + "artifact key, a callable that returns a LookupTableData " + "instance, or a string defining such a callable." + ) - try: - return builder.data.load(data_source) - except ArtifactException: - raise ConfigurationError(f"Failed to find key '{data_source}' in artifact.") + if not isinstance(data, valid_data_types): + raise ConfigurationError(f"Data '{data}' must be a LookupTableData instance.") + return data def _set_population_view(self, builder: "Builder") -> None: """Creates the PopulationView for this component if it needs access to diff --git a/src/vivarium/types.py b/src/vivarium/types.py index 5d813e31..89da4389 100644 --- a/src/vivarium/types.py +++ b/src/vivarium/types.py @@ -8,7 +8,13 @@ NumericArray = npt.NDArray[np.number[npt.NBitBase]] -ScalarValue = Union[Number, timedelta, datetime] +# todo need to use TypeVars here +Time = Union[pd.Timestamp, datetime] +Timedelta = Union[pd.Timedelta, timedelta] +ClockTime = Union[Time, int] +ClockStepSize = Union[Timedelta, int] + +ScalarValue = Union[Number, Timedelta, Time] LookupTableData = Union[ScalarValue, pd.DataFrame, list[ScalarValue], tuple[ScalarValue]] # TODO: For some of the uses of NumberLike, we probably want a TypeVar here instead. NumberLike = Union[ @@ -19,8 +25,3 @@ float, int, ] -# TODO: [MIC-5481] need to use TypeVars here -Time = Union[pd.Timestamp, datetime] -Timedelta = Union[pd.Timedelta, timedelta] -ClockTime = Union[Time, int] -ClockStepSize = Union[Timedelta, int] From a20552756dd76c3644b253c09be4db2c9e8b4185 Mon Sep 17 00:00:00 2001 From: Rajan Mudambi <11376379+rmudambi@users.noreply.github.com> Date: Mon, 14 Oct 2024 12:10:55 -0700 Subject: [PATCH 02/22] clean up imported types (#497) --- src/vivarium/framework/state_machine.py | 54 +++++++++++++------------ 1 file changed, 28 insertions(+), 26 deletions(-) diff --git a/src/vivarium/framework/state_machine.py b/src/vivarium/framework/state_machine.py index 916b78a6..fdd52614 100644 --- a/src/vivarium/framework/state_machine.py +++ b/src/vivarium/framework/state_machine.py @@ -7,6 +7,7 @@ A state machine implementation for use in ``vivarium`` simulations. """ +from __future__ import annotations from enum import Enum from typing import TYPE_CHECKING, Any, Callable, Dict, Iterable, List, Optional, Tuple @@ -24,9 +25,9 @@ def _next_state( index: pd.Index, - event_time: "ClockTime", - transition_set: "TransitionSet", - population_view: "PopulationView", + event_time: ClockTime, + transition_set: TransitionSet, + population_view: PopulationView, ) -> None: """Moves a population between different states using information from a `TransitionSet`. @@ -107,20 +108,7 @@ def _process_trigger(trigger): class Transition(Component): - """A process by which an entity might change into a particular state. - - Attributes - ---------- - input_state - The start state of the entity that undergoes the transition. - output_state - The end state of the entity that undergoes the transition. - probability_func - A method or function that describing the probability of this - transition occurring. - triggered - A flag indicating whether this transition is triggered by some event. - """ + """A process by which an entity might change into a particular state.""" ##################### # Lifecycle methods # @@ -128,13 +116,27 @@ class Transition(Component): def __init__( self, - input_state: "State", - output_state: "State", + input_state: State, + output_state: State, probability_func: Callable[[pd.Index], pd.Series] = lambda index: pd.Series( 1.0, index=index ), triggered=Trigger.NOT_TRIGGERED, ): + """Initializes a transition between two states. + + Parameters + ---------- + input_state + The start state of the entity that undergoes the transition. + output_state + The end state of the entity that undergoes the transition. + probability_func + A method or function that describing the probability of this + transition occurring. + triggered + A flag indicating whether this transition is triggered by some event. + """ super().__init__() self.input_state = input_state self.output_state = output_state @@ -215,7 +217,7 @@ def set_model(self, model_name: str) -> None: self._model = model_name def next_state( - self, index: pd.Index, event_time: "ClockTime", population_view: "PopulationView" + self, index: pd.Index, event_time: ClockTime, population_view: PopulationView ) -> None: """Moves a population between different states. @@ -231,7 +233,7 @@ def next_state( return _next_state(index, event_time, self.transition_set, population_view) def transition_effect( - self, index: pd.Index, event_time: "ClockTime", population_view: "PopulationView" + self, index: pd.Index, event_time: ClockTime, population_view: PopulationView ) -> None: """Updates the simulation state and triggers any side-effects associated with entering this state. @@ -247,7 +249,7 @@ def transition_effect( population_view.update(pd.Series(self.state_id, index=index)) self.transition_side_effect(index, event_time) - def cleanup_effect(self, index: pd.Index, event_time: "ClockTime") -> None: + def cleanup_effect(self, index: pd.Index, event_time: ClockTime) -> None: pass def add_transition(self, transition: Transition) -> None: @@ -267,7 +269,7 @@ def allow_self_transitions(self) -> None: # Helper methods # ################## - def transition_side_effect(self, index: pd.Index, event_time: "ClockTime") -> None: + def transition_side_effect(self, index: pd.Index, event_time: ClockTime) -> None: pass @@ -321,7 +323,7 @@ def __init__( self.extend(transitions) - def setup(self, builder: "Builder") -> None: + def setup(self, builder: Builder) -> None: """Performs this component's simulation setup and return sub-components. Parameters @@ -484,7 +486,7 @@ def add_states(self, states: Iterable[State]) -> None: self.states.append(state) state.set_model(self.state_column) - def transition(self, index: pd.Index, event_time: "ClockTime") -> None: + def transition(self, index: pd.Index, event_time: ClockTime) -> None: """Finds the population in each state and moves them to the next state. Parameters @@ -502,7 +504,7 @@ def transition(self, index: pd.Index, event_time: "ClockTime") -> None: self.population_view.subview(self.state_column), ) - def cleanup(self, index: pd.Index, event_time: "ClockTime") -> None: + def cleanup(self, index: pd.Index, event_time: ClockTime) -> None: for state, affected in self._get_state_pops(index): if not affected.empty: state.cleanup_effect(affected.index, event_time) From 606fb0270663fd3a04c937fb9825642b59c1380a Mon Sep 17 00:00:00 2001 From: Rajan Mudambi <11376379+rmudambi@users.noreply.github.com> Date: Mon, 14 Oct 2024 12:17:36 -0700 Subject: [PATCH 03/22] fix bugs preventing Machine from functioning (#498) --- CHANGELOG.rst | 1 + docs/source/tutorials/exploration.rst | 8 + src/vivarium/component.py | 10 +- .../examples/disease_model/disease.py | 36 +-- src/vivarium/framework/state_machine.py | 95 ++++++- tests/framework/components/test_component.py | 9 +- tests/framework/test_state_machine.py | 259 ++++++++++-------- tests/helpers.py | 10 +- 8 files changed, 269 insertions(+), 159 deletions(-) diff --git a/CHANGELOG.rst b/CHANGELOG.rst index 88afe1b2..4fd9f1b8 100644 --- a/CHANGELOG.rst +++ b/CHANGELOG.rst @@ -1,5 +1,6 @@ **3.2.0 - TBD** + - Enable Machine to be used directly to model a state machine - Support passing callables directly when building lookup tables **3.1.0 - 11/07/24** diff --git a/docs/source/tutorials/exploration.rst b/docs/source/tutorials/exploration.rst index db0cc97f..0304a581 100644 --- a/docs/source/tutorials/exploration.rst +++ b/docs/source/tutorials/exploration.rst @@ -174,6 +174,14 @@ configuration by simply printing it. component_configs: [] include: component_configs: [] + disease_state.susceptible_to_lower_respiratory_infections: + data_sources: + initialization_weights: + component_configs: + disease_state.infected_with_lower_respiratory_infections: + data_sources: + initialization_weights: + component_configs: What do we see here? The configuration is *hierarchical*. There are a set of diff --git a/src/vivarium/component.py b/src/vivarium/component.py index 444e89b3..1831f516 100644 --- a/src/vivarium/component.py +++ b/src/vivarium/component.py @@ -723,11 +723,11 @@ def get_data( elif isinstance(data_source, Callable): data = data_source(builder) else: - raise TypeError( - f"Data source is of type '{type(data_source)}'. It must be a " - "LookupTableData instance, a string corresponding to an " - "artifact key, a callable that returns a LookupTableData " - "instance, or a string defining such a callable." + raise ConfigurationError( + f"Data source '{data_source}' is not a valid data source. It " + f"must be a LookupTableData instance, a string corresponding to " + f"an artifact key, a callable that returns a LookupTableData " + f"instance, or a string defining such a callable." ) if not isinstance(data, valid_data_types): diff --git a/src/vivarium/examples/disease_model/disease.py b/src/vivarium/examples/disease_model/disease.py index 9bdbaf0f..3e62bdff 100644 --- a/src/vivarium/examples/disease_model/disease.py +++ b/src/vivarium/examples/disease_model/disease.py @@ -1,12 +1,10 @@ # mypy: ignore-errors -from typing import List, Optional +from __future__ import annotations import pandas as pd from vivarium import Component from vivarium.framework.engine import Builder -from vivarium.framework.event import Event -from vivarium.framework.population import SimulantData from vivarium.framework.state_machine import Machine, State, Transition from vivarium.framework.utilities import rate_to_probability from vivarium.framework.values import list_combiner, union_post_processor @@ -72,11 +70,11 @@ class DiseaseState(State): ############## @property - def columns_required(self) -> Optional[List[str]]: + def columns_required(self) -> list[str] | None: return [self.model, "alive"] @property - def population_view_query(self) -> Optional[str]: + def population_view_query(self) -> str | None: return f"alive == 'alive' and {self.model} == '{self.state_id}'" ##################### @@ -150,26 +148,11 @@ def add_in_excess_mortality( class DiseaseModel(Machine): - ############## - # Properties # - ############## - - @property - def columns_created(self) -> List[str]: - return [self.state_column] - - @property - def columns_required(self) -> Optional[List[str]]: - return ["age", "sex"] ##################### # Lifecycle methods # ##################### - def __init__(self, disease: str, initial_state: DiseaseState, **kwargs): - super().__init__(disease, **kwargs) - self.initial_state = initial_state.state_id - # noinspection PyAttributeOutsideInit def setup(self, builder: Builder) -> None: super().setup(builder) @@ -188,19 +171,6 @@ def setup(self, builder: Builder) -> None: "mortality_rate", modifier=self.delete_cause_specific_mortality ) - ######################## - # Event-driven methods # - ######################## - - def on_initialize_simulants(self, pop_data: SimulantData) -> None: - condition_column = pd.Series( - self.initial_state, index=pop_data.index, name=self.state_column - ) - self.population_view.update(condition_column) - - def on_time_step(self, event: Event) -> None: - self.transition(event.index, event.time) - ################################## # Pipeline sources and modifiers # ################################## diff --git a/src/vivarium/framework/state_machine.py b/src/vivarium/framework/state_machine.py index fdd52614..f18371aa 100644 --- a/src/vivarium/framework/state_machine.py +++ b/src/vivarium/framework/state_machine.py @@ -9,18 +9,24 @@ """ from __future__ import annotations +from collections.abc import Callable from enum import Enum -from typing import TYPE_CHECKING, Any, Callable, Dict, Iterable, List, Optional, Tuple +from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Tuple import numpy as np import pandas as pd from vivarium import Component +from vivarium.framework.event import Event if TYPE_CHECKING: from vivarium.framework.engine import Builder - from vivarium.framework.population import PopulationView - from vivarium.types import ClockTime + from vivarium.framework.population import PopulationView, SimulantData + from vivarium.types import ClockTime, LookupTableData + + +def default_initializer(_builder: Builder) -> LookupTableData: + return 0.0 def _next_state( @@ -191,6 +197,16 @@ class State(Component): # Properties # ############## + @property + def configuration_defaults(self) -> dict[str, Any]: + return { + f"{self.name}": { + "data_sources": { + "initialization_weights": self.get_initialization_weights, + }, + }, + } + @property def model(self) -> str: return self._model @@ -199,12 +215,18 @@ def model(self) -> str: # Lifecycle methods # ##################### - def __init__(self, state_id: str, allow_self_transition: bool = False): + def __init__( + self, + state_id: str, + allow_self_transition: bool = False, + initialization_weights: Callable[[Builder], LookupTableData] = default_initializer, + ) -> None: super().__init__() self.state_id = state_id self.transition_set = TransitionSet( self.state_id, allow_self_transition=allow_self_transition ) + self.initialization_weights = initialization_weights self._model = None self._sub_components = [self.transition_set] @@ -269,6 +291,9 @@ def allow_self_transitions(self) -> None: # Helper methods # ################## + def get_initialization_weights(self, builder: Builder) -> LookupTableData: + return self.initialization_weights(builder) + def transition_side_effect(self, index: pd.Index, event_time: ClockTime) -> None: pass @@ -463,20 +488,78 @@ def sub_components(self): return self.states @property - def columns_required(self) -> Optional[List[str]]: + def columns_created(self) -> List[str]: return [self.state_column] + @property + def initialization_requirements(self) -> Dict[str, List[str]]: + return { + "requires_columns": [], + "requires_values": [], + "requires_streams": [self.randomness.key], + } + ##################### # Lifecycle methods # ##################### - def __init__(self, state_column: str, states: Iterable[State] = ()): + def __init__( + self, + state_column: str, + states: Iterable[State] = (), + initial_state: State | None = None, + ) -> None: super().__init__() self.states = [] self.state_column = state_column if states: self.add_states(states) + states_with_initialization_weights = [ + s for s in self.states if s.initialization_weights != default_initializer + ] + + if initial_state is not None: + if initial_state not in self.states: + raise ValueError( + f"Initial state '{initial_state}' must be one of the" + f" states: {self.states}." + ) + if states_with_initialization_weights: + raise ValueError( + "Cannot specify both an initial state and provide" + " initialization weights to states." + ) + + initial_state.initialization_weights = lambda _builder: 1.0 + + elif not states_with_initialization_weights: + raise ValueError( + "Must specify either an initial state or provide" + " initialization weights to states." + ) + + def setup(self, builder: Builder) -> None: + self.randomness = builder.randomness.get_stream(self.name) + + def on_initialize_simulants(self, pop_data: SimulantData) -> None: + state_ids = [s.state_id for s in self.states] + state_weights = pd.concat( + [ + state.lookup_tables["initialization_weights"](pop_data.index) + for state in self.states + ], + axis=1, + ).to_numpy() + + initial_states = self.randomness.choice( + pop_data.index, state_ids, state_weights, "initialization" + ).rename(self.state_column) + self.population_view.update(initial_states) + + def on_time_step(self, event: Event) -> None: + self.transition(event.index, event.time) + ################## # Public methods # ################## diff --git a/tests/framework/components/test_component.py b/tests/framework/components/test_component.py index 5c79d83b..9aa85a33 100644 --- a/tests/framework/components/test_component.py +++ b/tests/framework/components/test_component.py @@ -167,12 +167,13 @@ def test_component_initializer_is_not_registered_if_not_defined(): def test_component_initializer_is_registered_and_called_if_defined(): + pop_size = 1000 component = ColumnCreator() - simulation = InteractiveContext(components=[component]) + expected_pop_view = component.get_initial_state(pd.RangeIndex(pop_size)) + + config = {"population": {"population_size": pop_size}} + simulation = InteractiveContext(components=[component], configuration=config) population = simulation.get_population() - expected_pop_view = pd.DataFrame( - {column: 9 for column in component.columns_created}, index=population.index - ) # Assert that simulant initializer has been registered assert component.on_initialize_simulants in simulation._resource diff --git a/tests/framework/test_state_machine.py b/tests/framework/test_state_machine.py index 77f69f49..eec4fa36 100644 --- a/tests/framework/test_state_machine.py +++ b/tests/framework/test_state_machine.py @@ -1,31 +1,19 @@ -from typing import List, Optional +from __future__ import annotations import numpy as np import pandas as pd +import pytest +from layered_config_tree import LayeredConfigTree -from vivarium import Component, InteractiveContext +from tests.helpers import ColumnCreator +from vivarium import InteractiveContext +from vivarium.framework.configuration import build_simulation_configuration from vivarium.framework.population import SimulantData from vivarium.framework.state_machine import Machine, State, Transition from vivarium.types import ClockTime -def _population_fixture(column, initial_value): - class PopFixture(Component): - @property - def name(self) -> str: - return f"test_pop_fixture_{column}_{initial_value}" - - @property - def columns_created(self) -> List[str]: - return [column] - - def on_initialize_simulants(self, pop_data: SimulantData) -> None: - self.population_view.update(pd.Series(initial_value, index=pop_data.index)) - - return PopFixture() - - -def test_initialize_allowing_self_transition(): +def test_initialize_allowing_self_transition() -> None: self_transitions = State("self-transitions", allow_self_transition=True) no_self_transitions = State("no-self-transitions", allow_self_transition=False) undefined_self_transitions = State("self-transitions") @@ -35,38 +23,120 @@ def test_initialize_allowing_self_transition(): assert not undefined_self_transitions.transition_set.allow_null_transition -def test_transition(): - done_state = State("done") +def test_initialize_with_initial_state() -> None: start_state = State("start") - start_state.add_transition(Transition(start_state, done_state)) - machine = Machine("state", states=[start_state, done_state]) + other_state = State("other") + machine = Machine("state", states=[start_state, other_state], initial_state=start_state) + simulation = InteractiveContext(components=[machine]) + assert simulation.get_population()["state"].unique() == ["start"] + +def test_initialize_with_scalar_initialization_weights( + base_config: LayeredConfigTree, +) -> None: + base_config.update( + {"population": {"population_size": 10000}, "randomness": {"key_columns": []}} + ) + state_a = State("a", initialization_weights=lambda _: 0.2) + state_b = State("b", initialization_weights=lambda _: 0.8) + machine = Machine("state", states=[state_a, state_b]) + simulation = InteractiveContext(components=[machine], configuration=base_config) + + state = simulation.get_population()["state"] + assert np.all(simulation.get_population().state != "start") + assert round((state == "a").mean(), 1) == 0.2 + assert round((state == "b").mean(), 1) == 0.8 + + +@pytest.mark.parametrize( + "use_artifact", [True, False], ids=["with_artifact", "without_artifact"] +) +def test_initialize_with_array_initialization_weights(use_artifact) -> None: + state_weights = { + "state_a.weights": pd.DataFrame( + {"test_column_1": [0, 1, 2], "value": [0.2, 0.7, 0.4]} + ), + "state_b.weights": pd.DataFrame( + {"test_column_1": [0, 1, 2], "value": [0.8, 0.3, 0.6]} + ), + } + + def mock_load(key: str) -> pd.DataFrame: + return state_weights.get(key) + + config = build_simulation_configuration() + config.update( + {"population": {"population_size": 10000}, "randomness ": {"key_columns": []}} + ) + + class TestMachine(Machine): + @property + def initialization_requirements(self) -> dict[str, list[str]]: + # FIXME - MIC-5408: We shouldn't need to specify the columns in the + # lookup tables here, since the component can't know what will be + # specified by the states or the configuration. + return { + "requires_columns": ["test_column_1"], + "requires_values": [], + "requires_streams": [], + } + + def initialization_weights(key: str): + if use_artifact: + return lambda builder: builder.data.load(key) + else: + return lambda _: state_weights[key] + + state_a = State("a", initialization_weights=initialization_weights("state_a.weights")) + state_b = State("b", initialization_weights=initialization_weights("state_b.weights")) + machine = TestMachine("state", states=[state_a, state_b]) simulation = InteractiveContext( - components=[machine, _population_fixture("state", "start")] + components=[machine, ColumnCreator()], configuration=config, setup=False ) - event_time = simulation._clock.time + simulation._clock.step_size - machine.transition(simulation.get_population().index, event_time) - assert np.all(simulation.get_population().state == "done") + simulation._builder.data.load = mock_load + simulation.setup() + + pop = simulation.get_population()[["state", "test_column_1"]] + state_a_weights = state_weights["state_a.weights"] + state_b_weights = state_weights["state_b.weights"] + for i in range(3): + pop_i_state = pop.loc[pop["test_column_1"] == i, "state"] + assert round((pop_i_state == "a").mean(), 1) == state_a_weights.loc[i, "value"] + assert round((pop_i_state == "b").mean(), 1) == state_b_weights.loc[i, "value"] + + +def test_error_if_initialize_with_both_initial_state_and_initialization_weights() -> None: + start_state = State("start") + other_state = State("other", initialization_weights=lambda _: 0.8) + with pytest.raises(ValueError, match="Cannot specify both"): + Machine("state", states=[start_state, other_state], initial_state=start_state) -def test_single_transition(base_config): +def test_error_if_initialize_with_neither_initial_state_nor_initialization_weights() -> None: + with pytest.raises(ValueError, match="Must specify either"): + Machine("state", states=[State("a"), State("b")]) + + +@pytest.mark.parametrize("population_size", [1, 100]) +def test_transition(base_config: LayeredConfigTree, population_size: int) -> None: base_config.update( - {"population": {"population_size": 1}, "randomness": {"key_columns": []}} + { + "population": {"population_size": population_size}, + "randomness": {"key_columns": []}, + } ) done_state = State("done") start_state = State("start") start_state.add_transition(Transition(start_state, done_state)) - machine = Machine("state", states=[start_state, done_state]) + machine = Machine("state", states=[start_state, done_state], initial_state=start_state) - simulation = InteractiveContext( - components=[machine, _population_fixture("state", "start")], configuration=base_config - ) - event_time = simulation._clock.time + simulation._clock.step_size - machine.transition(simulation.get_population().index, event_time) + simulation = InteractiveContext(components=[machine], configuration=base_config) + assert np.all(simulation.get_population().state == "start") + simulation.step() assert np.all(simulation.get_population().state == "done") -def test_choice(base_config): +def test_no_null_transition(base_config: LayeredConfigTree) -> None: base_config.update( {"population": {"population_size": 10000}, "randomness": {"key_columns": []}} ) @@ -75,108 +145,81 @@ def test_choice(base_config): start_state = State("start") start_state.add_transition( Transition( - start_state, a_state, probability_func=lambda agents: np.full(len(agents), 0.5) + start_state, a_state, probability_func=lambda index: pd.Series(0.4, index=index) ) ) start_state.add_transition( Transition( - start_state, b_state, probability_func=lambda agents: np.full(len(agents), 0.5) + start_state, b_state, probability_func=lambda index: pd.Series(0.6, index=index) ) ) - machine = Machine("state", states=[start_state, a_state, b_state]) - - simulation = InteractiveContext( - components=[machine, _population_fixture("state", "start")], configuration=base_config + machine = Machine( + "state", states=[start_state, a_state, b_state], initial_state=start_state ) - event_time = simulation._clock.time + simulation._clock.step_size - machine.transition(simulation.get_population().index, event_time) - a_count = (simulation.get_population().state == "a").sum() - assert round(a_count / len(simulation.get_population()), 1) == 0.5 + simulation = InteractiveContext(components=[machine], configuration=base_config) + assert np.all(simulation.get_population().state == "start") -def test_null_transition(base_config): - base_config.update( - {"population": {"population_size": 10000}, "randomness": {"key_columns": []}} - ) - a_state = State("a") - start_state = State("start") - start_state.add_transition( - Transition( - start_state, a_state, probability_func=lambda agents: np.full(len(agents), 0.5) - ) - ) - start_state.allow_self_transitions() - - machine = Machine("state", states=[start_state, a_state]) + simulation.step() - simulation = InteractiveContext( - components=[machine, _population_fixture("state", "start")], configuration=base_config - ) - event_time = simulation._clock.time + simulation._clock.step_size - machine.transition(simulation.get_population().index, event_time) - a_count = (simulation.get_population().state == "a").sum() - assert round(a_count / len(simulation.get_population()), 1) == 0.5 + state = simulation.get_population()["state"] + assert np.all(simulation.get_population().state != "start") + assert round((state == "a").mean(), 1) == 0.4 + assert round((state == "b").mean(), 1) == 0.6 -def test_no_null_transition(base_config): +def test_null_transition(base_config: LayeredConfigTree) -> None: base_config.update( {"population": {"population_size": 10000}, "randomness": {"key_columns": []}} ) a_state = State("a") - b_state = State("b") - start_state = State("start") - a_transition = Transition( - start_state, a_state, probability_func=lambda index: pd.Series(0.5, index=index) - ) - b_transition = Transition( - start_state, b_state, probability_func=lambda index: pd.Series(0.5, index=index) + start_state = State("start", allow_self_transition=True) + start_state.add_transition( + Transition( + start_state, a_state, probability_func=lambda index: pd.Series(0.4, index=index) + ) ) - start_state.transition_set.allow_null_transition = False - start_state.transition_set.extend((a_transition, b_transition)) - machine = Machine("state") - machine.states.extend([start_state, a_state, b_state]) - simulation = InteractiveContext( - components=[machine, _population_fixture("state", "start")], configuration=base_config - ) - event_time = simulation._clock.time + simulation._clock.step_size - machine.transition(simulation.get_population().index, event_time) - a_count = (simulation.get_population().state == "a").sum() - assert round(a_count / len(simulation.get_population()), 1) == 0.5 + machine = Machine("state", states=[start_state, a_state], initial_state=start_state) + simulation = InteractiveContext(components=[machine], configuration=base_config) + simulation.step() + state = simulation.get_population()["state"] + assert round((state == "a").mean(), 1) == 0.4 -def test_side_effects(): - class DoneState(State): - @property - def name(self) -> str: - return "test_done_state" +def test_side_effects() -> None: + class CountingState(State): @property - def columns_required(self) -> Optional[List[str]]: + def columns_created(self) -> list[str]: return ["count"] - def transition_side_effect(self, index: pd.Index, _: ClockTime) -> None: + def on_initialize_simulants(self, pop_data: SimulantData) -> None: + self.population_view.update(pd.Series(0, index=pop_data.index, name="count")) + + def transition_side_effect(self, index: pd.Index[int], _: ClockTime) -> None: pop = self.population_view.get(index) self.population_view.update(pop["count"] + 1) - done_state = DoneState("done") + counting_state = CountingState("counting") start_state = State("start") - start_state.add_transition(Transition(start_state, done_state)) - done_state.add_transition(Transition(done_state, start_state)) + start_state.add_transition(Transition(start_state, counting_state)) + counting_state.add_transition(Transition(counting_state, start_state)) - machine = Machine("state", states=[start_state, done_state]) - - simulation = InteractiveContext( - components=[ - machine, - _population_fixture("state", "start"), - _population_fixture("count", 0), - ] + machine = Machine( + "state", states=[start_state, counting_state], initial_state=start_state ) - event_time = simulation._clock.time + simulation._clock.step_size - machine.transition(simulation.get_population().index, event_time) + simulation = InteractiveContext(components=[machine]) + assert np.all(simulation.get_population()["count"] == 0) + + # transitioning to counting state + simulation.step() assert np.all(simulation.get_population()["count"] == 1) - machine.transition(simulation.get_population().index, event_time) + + # transitioning back to start state + simulation.step() assert np.all(simulation.get_population()["count"] == 1) - machine.transition(simulation.get_population().index, event_time) + + # transitioning to counting state again + simulation.step() assert np.all(simulation.get_population()["count"] == 2) diff --git a/tests/helpers.py b/tests/helpers.py index d2cfc0d6..7fafa88a 100644 --- a/tests/helpers.py +++ b/tests/helpers.py @@ -1,3 +1,5 @@ +from __future__ import annotations + from typing import Any, Dict, List, Optional import pandas as pd @@ -146,10 +148,12 @@ def setup(self, builder: Builder) -> None: builder.randomness.get_stream("stream_1") def on_initialize_simulants(self, pop_data: SimulantData) -> None: - initialization_data = pd.DataFrame( - {column: 9 for column in self.columns_created}, index=pop_data.index + self.population_view.update(self.get_initial_state(pop_data.index)) + + def get_initial_state(self, index: pd.Index[int]) -> pd.DataFrame: + return pd.DataFrame( + {column: [i % 3 for i in index] for column in self.columns_created}, index=index ) - self.population_view.update(initialization_data) class LookupCreator(ColumnCreator): From 61d696868942129563bcf0fc2dcf0c24b3079a10 Mon Sep 17 00:00:00 2001 From: Rajan Mudambi <11376379+rmudambi@users.noreply.github.com> Date: Mon, 14 Oct 2024 12:25:38 -0700 Subject: [PATCH 04/22] create simple tests for disease model example (#499) --- pyproject.toml | 1 + setup.py | 1 + tests/conftest.py | 6 +++ tests/examples/__init__.py | 0 tests/examples/test_boids_model.py | 7 +++ tests/examples/test_disease_model.py | 81 ++++++++++++++++++++++++++++ 6 files changed, 96 insertions(+) create mode 100644 tests/examples/__init__.py create mode 100644 tests/examples/test_boids_model.py create mode 100644 tests/examples/test_disease_model.py diff --git a/pyproject.toml b/pyproject.toml index 768b9878..c82f4f91 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -61,6 +61,7 @@ exclude = [ 'src/vivarium/interface/interactive.py', 'src/vivarium/testing_utilities.py', 'tests/conftest.py', + 'tests/examples/test_disease_model.py', 'tests/framework/artifact/test_artifact.py', 'tests/framework/artifact/test_hdf.py', 'tests/framework/artifact/test_manager.py', diff --git a/setup.py b/setup.py index 3a379732..58f5ebf0 100644 --- a/setup.py +++ b/setup.py @@ -60,6 +60,7 @@ "pytest", "pytest-cov", "pytest-mock", + "vivarium_testing_utils", ] lint_requirements = [ diff --git a/tests/conftest.py b/tests/conftest.py index 86c24602..7100cacb 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -5,6 +5,7 @@ import yaml from _pytest.logging import LogCaptureFixture from loguru import logger +from vivarium_testing_utils import FuzzyChecker from vivarium.framework.configuration import ( build_model_specification, @@ -31,6 +32,11 @@ def pytest_collection_modifyitems(config, items): item.add_marker(skip_slow) +@pytest.fixture(scope="session") +def fuzzy_checker(): + return FuzzyChecker() + + @pytest.fixture def caplog(caplog: LogCaptureFixture): handler_id = logger.add(caplog.handler, format="{message}") diff --git a/tests/examples/__init__.py b/tests/examples/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/examples/test_boids_model.py b/tests/examples/test_boids_model.py new file mode 100644 index 00000000..947829b8 --- /dev/null +++ b/tests/examples/test_boids_model.py @@ -0,0 +1,7 @@ +# todo add tests of boids model +import pytest + + +@pytest.mark.skip("TODO - MIC-5411: Add tests for boids model") +def test_boids_model() -> None: + pass \ No newline at end of file diff --git a/tests/examples/test_disease_model.py b/tests/examples/test_disease_model.py new file mode 100644 index 00000000..9a373c0b --- /dev/null +++ b/tests/examples/test_disease_model.py @@ -0,0 +1,81 @@ +from datetime import datetime, timedelta + +import numpy as np +from layered_config_tree import LayeredConfigTree +from vivarium_testing_utils import FuzzyChecker + +from vivarium import InteractiveContext +from vivarium.framework.utilities import from_yearly + + +def test_disease_model(fuzzy_checker: FuzzyChecker, disease_model_spec) -> None: + config = LayeredConfigTree(disease_model_spec, layers=["base", "override"]) + config.update( + { + "configuration": { + "mortality": { + "mortality_rate": 20.0, + }, + "lower_respiratory_infections": { + "incidence_rate": 25.0, + "remission_rate": 50.0, + "excess_mortality_rate": 0.01, + }, + }, + } + ) + + simulation = InteractiveContext(config) + + pop = simulation.get_population() + expected_columns = { + "tracked", + "alive", + "age", + "sex", + "entrance_time", + "lower_respiratory_infections", + "child_wasting_propensity", + } + assert set(pop.columns) == expected_columns + assert len(pop) == 100_000 + assert np.all(pop["tracked"] == True) + assert np.all(pop["alive"] == "alive") + assert np.all((pop["age"] >= 0) & (pop["age"] <= 5)) + assert np.all(pop["entrance_time"] == datetime(2021, 12, 31, 12)) + + for sex in ["Female", "Male"]: + fuzzy_checker.fuzzy_assert_proportion( + observed_numerator=(pop["sex"] == sex).sum(), + observed_denominator=len(pop), + target_proportion=0.5, + # todo: remove this parameter when MIC-5412 is resolved + name=f"{sex}_proportion", + ) + + assert np.all(pop["lower_respiratory_infections"] == "susceptible_to_lower_respiratory_infections") + assert np.all((pop["child_wasting_propensity"] >= 0) & (pop["child_wasting_propensity"] <= 1)) + + simulation.step() + pop = simulation.get_population() + is_alive = pop["alive"] == "alive" + + fuzzy_checker.fuzzy_assert_proportion( + observed_numerator=(len(pop[~is_alive])), + observed_denominator=len(pop), + target_proportion=from_yearly(20, timedelta(days=0.5)), + # todo: remove this parameter when MIC-5412 is resolved + name="alive_proportion", + ) + + has_lri = pop["lower_respiratory_infections"] == "infected_with_lower_respiratory_infections" + fuzzy_checker.fuzzy_assert_proportion( + observed_numerator=(len(pop[is_alive & has_lri])), + observed_denominator=len(pop[is_alive]), + target_proportion=from_yearly(25, timedelta(days=0.5)), + # todo: remove this parameter when MIC-5412 is resolved + name="lri_proportion", + ) + + # todo test remission and excess mortality + # todo test risk factor and intervention From 5a44f040178deeb4b8cc969d80c3f540b9925a02 Mon Sep 17 00:00:00 2001 From: Rajan Mudambi <11376379+rmudambi@users.noreply.github.com> Date: Wed, 16 Oct 2024 12:49:58 -0700 Subject: [PATCH 05/22] create resource package (#503) --- .../api_reference/framework/resource.rst | 1 - .../framework/resource/exceptions.rst | 1 + .../framework/resource/group.rst | 1 + .../framework/resource/index.rst | 11 ++ .../framework/resource/manager.rst | 1 + src/vivarium/framework/engine.py | 113 ++++++++---------- src/vivarium/framework/resource/__init__.py | 23 ++++ src/vivarium/framework/resource/exceptions.py | 7 ++ src/vivarium/framework/resource/group.py | 63 ++++++++++ .../{resource.py => resource/manager.py} | 86 +------------ tests/framework/test_resource.py | 12 +- 11 files changed, 165 insertions(+), 154 deletions(-) delete mode 100644 docs/source/api_reference/framework/resource.rst create mode 100644 docs/source/api_reference/framework/resource/exceptions.rst create mode 100644 docs/source/api_reference/framework/resource/group.rst create mode 100644 docs/source/api_reference/framework/resource/index.rst create mode 100644 docs/source/api_reference/framework/resource/manager.rst create mode 100644 src/vivarium/framework/resource/__init__.py create mode 100644 src/vivarium/framework/resource/exceptions.py create mode 100644 src/vivarium/framework/resource/group.py rename src/vivarium/framework/{resource.py => resource/manager.py} (78%) diff --git a/docs/source/api_reference/framework/resource.rst b/docs/source/api_reference/framework/resource.rst deleted file mode 100644 index 4bca0c2d..00000000 --- a/docs/source/api_reference/framework/resource.rst +++ /dev/null @@ -1 +0,0 @@ -.. automodule:: vivarium.framework.resource diff --git a/docs/source/api_reference/framework/resource/exceptions.rst b/docs/source/api_reference/framework/resource/exceptions.rst new file mode 100644 index 00000000..34323072 --- /dev/null +++ b/docs/source/api_reference/framework/resource/exceptions.rst @@ -0,0 +1 @@ +.. automodule:: vivarium.framework.resource.exceptions \ No newline at end of file diff --git a/docs/source/api_reference/framework/resource/group.rst b/docs/source/api_reference/framework/resource/group.rst new file mode 100644 index 00000000..a495628a --- /dev/null +++ b/docs/source/api_reference/framework/resource/group.rst @@ -0,0 +1 @@ +.. automodule:: vivarium.framework.resource.group \ No newline at end of file diff --git a/docs/source/api_reference/framework/resource/index.rst b/docs/source/api_reference/framework/resource/index.rst new file mode 100644 index 00000000..dafcfb70 --- /dev/null +++ b/docs/source/api_reference/framework/resource/index.rst @@ -0,0 +1,11 @@ +=================== +Resource Management +=================== + +.. automodule:: vivarium.framework.resource + +.. toctree:: + :maxdepth: 1 + :glob: + + * \ No newline at end of file diff --git a/docs/source/api_reference/framework/resource/manager.rst b/docs/source/api_reference/framework/resource/manager.rst new file mode 100644 index 00000000..fc440150 --- /dev/null +++ b/docs/source/api_reference/framework/resource/manager.rst @@ -0,0 +1 @@ +.. automodule:: vivarium.framework.resource.manager \ No newline at end of file diff --git a/src/vivarium/framework/engine.py b/src/vivarium/framework/engine.py index 0bf49ccf..90acf77f 100644 --- a/src/vivarium/framework/engine.py +++ b/src/vivarium/framework/engine.py @@ -20,7 +20,6 @@ """ -import math from pathlib import Path from pprint import pformat from time import time @@ -360,44 +359,6 @@ class Builder: This is the access point for components through which they are able to utilize a variety of interfaces to interact with the simulation framework. - Attributes - ---------- - configuration : ``LayeredConfigTree`` - Provides access to the :ref:`configuration` - logging : LoggingInterface - Provides access to the :ref:`logging` system. - lookup : LookupTableInterface - Provides access to simulant-specific data via the - :ref:`lookup table` abstraction. - value : ValuesInterface - Provides access to computed simulant attribute values via the - :ref:`value pipeline` system. - event : EventInterface - Provides access to event listeners utilized in the - :ref:`event` system. - population : PopulationInterface - Provides access to simulant state table via the - :ref:`population` system. - resources : ResourceInterface - Provides access to the :ref:`resource` system, - which manages dependencies between components. - results : ResultsInterface - Provides access to the :ref:`results` system. - randomness : RandomnessInterface - Provides access to the :ref:`randomness` system. - time : TimeInterface - Provides access to the simulation's :ref:`clock`. - components : ComponentInterface - Provides access to the :ref:`component management` - system, which maintains a reference to all managers and components in - the simulation. - lifecycle : LifeCycleInterface - Provides access to the :ref:`life-cycle` system, - which manages the simulation's execution life-cycle. - data : ArtifactInterface - Provides access to the simulation's input data housed in the - :ref:`data artifact`. - Notes ----- A `Builder` should never be created directly. It will automatically be @@ -405,37 +366,61 @@ class Builder: """ - def __init__(self, configuration, plugin_manager): + def __init__(self, configuration: LayeredConfigTree, plugin_manager): self.configuration = configuration + """Provides access to the :ref:`configuration`""" + + self.logging: LoggingInterface = plugin_manager.get_plugin_interface("logging") + """Provides access to the :ref:`logging` system.""" + + self.lookup: LookupTableInterface = plugin_manager.get_plugin_interface("lookup") + """Provides access to simulant-specific data via the + :ref:`lookup table` abstraction.""" + + self.value: ValuesInterface = plugin_manager.get_plugin_interface("value") + """Provides access to computed simulant attribute values via the + :ref:`value pipeline` system.""" - self.logging = plugin_manager.get_plugin_interface( - "logging" - ) # type: LoggingInterface - self.lookup = plugin_manager.get_plugin_interface( - "lookup" - ) # type: LookupTableInterface - self.value = plugin_manager.get_plugin_interface("value") # type: ValuesInterface - self.event = plugin_manager.get_plugin_interface("event") # type: EventInterface - self.population = plugin_manager.get_plugin_interface( + self.event: EventInterface = plugin_manager.get_plugin_interface("event") + """Provides access to event listeners utilized in the + :ref:`event` system.""" + + self.population: PopulationInterface = plugin_manager.get_plugin_interface( "population" - ) # type: PopulationInterface - self.resources = plugin_manager.get_plugin_interface( - "resource" - ) # type: ResourceInterface - self.results = plugin_manager.get_plugin_interface( - "results" - ) # type: ResultsInterface - self.randomness = plugin_manager.get_plugin_interface( + ) + """Provides access to simulant state table via the + :ref:`population` system.""" + + self.resources: ResourceInterface = plugin_manager.get_plugin_interface("resource") + """Provides access to the :ref:`resource` system, + which manages dependencies between components. + """ + + self.results: ResultsInterface = plugin_manager.get_plugin_interface("results") + """Provides access to the :ref:`results` system.""" + + self.randomness: RandomnessInterface = plugin_manager.get_plugin_interface( "randomness" - ) # type: RandomnessInterface - self.time = plugin_manager.get_plugin_interface("clock") # type: TimeInterface - self.components = plugin_manager.get_plugin_interface( + ) + """Provides access to the :ref:`randomness` system.""" + + self.time: TimeInterface = plugin_manager.get_plugin_interface("clock") + """Provides access to the simulation's :ref:`clock`.""" + + self.components: ComponentInterface = plugin_manager.get_plugin_interface( "component_manager" - ) # type: ComponentInterface - self.lifecycle = plugin_manager.get_plugin_interface( - "lifecycle" - ) # type: LifeCycleInterface + ) + """Provides access to the :ref:`component management` + system, which maintains a reference to all managers and components in + the simulation.""" + + self.lifecycle: LifeCycleInterface = plugin_manager.get_plugin_interface("lifecycle") + """Provides access to the :ref:`life-cycle` system, + which manages the simulation's execution life-cycle.""" + self.data = plugin_manager.get_plugin_interface("data") # type: ArtifactInterface + """Provides access to the simulation's input data housed in the + :ref:`data artifact`.""" for name, interface in plugin_manager.get_optional_interfaces().items(): setattr(self, name, interface) diff --git a/src/vivarium/framework/resource/__init__.py b/src/vivarium/framework/resource/__init__.py new file mode 100644 index 00000000..d15753b5 --- /dev/null +++ b/src/vivarium/framework/resource/__init__.py @@ -0,0 +1,23 @@ +""" +=================== +Resource Management +=================== + +This module provides a tool to manage dependencies on resources within a +:mod:`vivarium` simulation. These resources take the form of things that can +be created and utilized by components, for example columns in the +:mod:`state table ` +or :mod:`named value pipelines `. + +Because these resources need to be created before they can be used, they are +sensitive to ordering. The intent behind this tool is to provide an interface +that allows other managers to register resources with the resource manager +and in turn ask for ordered sequences of these resources according to their +dependencies or raise exceptions if this is not possible. + +For more information, see the Resource Management +:ref:`concept note`. + +""" + +from vivarium.framework.resource.manager import ResourceInterface, ResourceManager diff --git a/src/vivarium/framework/resource/exceptions.py b/src/vivarium/framework/resource/exceptions.py new file mode 100644 index 00000000..81947cac --- /dev/null +++ b/src/vivarium/framework/resource/exceptions.py @@ -0,0 +1,7 @@ +from vivarium.exceptions import VivariumError + + +class ResourceError(VivariumError): + """Error raised when a dependency requirement is violated.""" + + pass diff --git a/src/vivarium/framework/resource/group.py b/src/vivarium/framework/resource/group.py new file mode 100644 index 00000000..5e0218c2 --- /dev/null +++ b/src/vivarium/framework/resource/group.py @@ -0,0 +1,63 @@ +from __future__ import annotations + +from collections.abc import Iterator +from typing import Any + + +class ResourceGroup: + """Resource groups are the nodes in the resource dependency graph. + + A resource group represents the pool of resources produced by a single + callable and all the dependencies necessary to produce that resource. + When thinking of the dependency graph, this represents a vertex and + all in-edges. This is a local-information representation that can be + used to construct the entire dependency graph once all resources are + specified. + + """ + + def __init__( + self, + resource_type: str, + resource_names: list[str], + producer: Any, + dependencies: list[str], + ): + self._resource_type = resource_type + self._resource_names = resource_names + self._producer = producer + self._dependencies = dependencies + + @property + def type(self) -> str: + """The type of resource produced by this resource group's producer. + + Must be one of `RESOURCE_TYPES`. + """ + return self._resource_type + + @property + def names(self) -> list[str]: + """The long names (including type) of all resources in this group.""" + return [f"{self._resource_type}.{name}" for name in self._resource_names] + + @property + def producer(self) -> Any: + """The method or object that produces this group of resources.""" + return self._producer + + @property + def dependencies(self) -> list[str]: + """The long names (including type) of dependencies for this group.""" + return self._dependencies + + def __iter__(self) -> Iterator[str]: + return iter(self.names) + + def __repr__(self) -> str: + resources = ", ".join(self) + return f"ResourceProducer({resources})" + + def __str__(self) -> str: + resources = ", ".join(self) + return f"({resources})" diff --git a/src/vivarium/framework/resource.py b/src/vivarium/framework/resource/manager.py similarity index 78% rename from src/vivarium/framework/resource.py rename to src/vivarium/framework/resource/manager.py index 641de223..d7d6f0c5 100644 --- a/src/vivarium/framework/resource.py +++ b/src/vivarium/framework/resource/manager.py @@ -1,19 +1,7 @@ """ -=================== -Resource Management -=================== - -This module provides a tool to manage dependencies on resources within a -:mod:`vivarium` simulation. These resources take the form of things that can -be created and utilized by components, for example columns in the -:mod:`state table ` -or :mod:`named value pipelines `. - -Because these resources need to be created before they can be used, they are -sensitive to ordering. The intent behind this tool is to provide an interface -that allows other managers to register resources with the resource manager -and in turn ask for ordered sequences of these resources according to their -dependencies or raise exceptions if this is not possible. +================ +Resource Manager +================ """ @@ -24,19 +12,14 @@ import networkx as nx -from vivarium.exceptions import VivariumError +from vivarium.framework.resource.exceptions import ResourceError +from vivarium.framework.resource.group import ResourceGroup from vivarium.manager import Interface, Manager if TYPE_CHECKING: from vivarium.framework.engine import Builder -class ResourceError(VivariumError): - """Error raised when a dependency requirement is violated.""" - - pass - - RESOURCE_TYPES = { "value", "value_source", @@ -48,65 +31,6 @@ class ResourceError(VivariumError): NULL_RESOURCE_TYPE = "null" -class ResourceGroup: - """Resource groups are the nodes in the resource dependency graph. - - A resource group represents the pool of resources produced by a single - callable and all the dependencies necessary to produce that resource. - When thinking of the dependency graph, this represents a vertex and - all in-edges. This is a local-information representation that can be - used to construct the entire dependency graph once all resources are - specified. - - """ - - def __init__( - self, - resource_type: str, - resource_names: list[str], - producer: Any, - dependencies: list[str], - ): - self._resource_type = resource_type - self._resource_names = resource_names - self._producer = producer - self._dependencies = dependencies - - @property - def type(self) -> str: - """The type of resource produced by this resource group's producer. - - Must be one of `RESOURCE_TYPES`. - """ - return self._resource_type - - @property - def names(self) -> list[str]: - """The long names (including type) of all resources in this group.""" - return [f"{self._resource_type}.{name}" for name in self._resource_names] - - @property - def producer(self) -> Any: - """The method or object that produces this group of resources.""" - return self._producer - - @property - def dependencies(self) -> list[str]: - """The long names (including type) of dependencies for this group.""" - return self._dependencies - - def __iter__(self) -> Iterator[str]: - return iter(self.names) - - def __repr__(self) -> str: - resources = ", ".join(self) - return f"ResourceProducer({resources})" - - def __str__(self) -> str: - resources = ", ".join(self) - return f"({resources})" - - class ResourceManager(Manager): """Manages all the resources needed for population initialization.""" diff --git a/tests/framework/test_resource.py b/tests/framework/test_resource.py index e1c3134e..673abc15 100644 --- a/tests/framework/test_resource.py +++ b/tests/framework/test_resource.py @@ -1,14 +1,10 @@ -import networkx as nx import pytest from vivarium import Component -from vivarium.framework.resource import ( - NULL_RESOURCE_TYPE, - RESOURCE_TYPES, - ResourceError, - ResourceGroup, - ResourceManager, -) +from vivarium.framework.resource import ResourceManager +from vivarium.framework.resource.exceptions import ResourceError +from vivarium.framework.resource.group import ResourceGroup +from vivarium.framework.resource.manager import NULL_RESOURCE_TYPE, RESOURCE_TYPES class ResourceProducer(Component): From b39e89e356064b0b78da6e01bd159fcff642831f Mon Sep 17 00:00:00 2001 From: Rajan Mudambi <11376379+rmudambi@users.noreply.github.com> Date: Wed, 16 Oct 2024 12:57:04 -0700 Subject: [PATCH 06/22] create get_population_initializers method (#504) --- src/vivarium/framework/population/manager.py | 2 +- src/vivarium/framework/resource/manager.py | 21 ++++++++------------ tests/framework/components/test_component.py | 10 ++++++++-- 3 files changed, 17 insertions(+), 16 deletions(-) diff --git a/src/vivarium/framework/population/manager.py b/src/vivarium/framework/population/manager.py index fac40265..a580b1dd 100644 --- a/src/vivarium/framework/population/manager.py +++ b/src/vivarium/framework/population/manager.py @@ -341,7 +341,7 @@ def _create_simulants( index = new_population.index.difference(self._population.index) self._population = new_population self.adding_simulants = True - for initializer in self.resources: + for initializer in self.resources.get_population_initializers(): initializer( SimulantData(index, population_configuration, self.clock(), self.step_size()) ) diff --git a/src/vivarium/framework/resource/manager.py b/src/vivarium/framework/resource/manager.py index d7d6f0c5..02ce11d0 100644 --- a/src/vivarium/framework/resource/manager.py +++ b/src/vivarium/framework/resource/manager.py @@ -7,7 +7,6 @@ from __future__ import annotations -from collections.abc import Iterator from typing import TYPE_CHECKING, Any import networkx as nx @@ -189,20 +188,16 @@ def _to_graph(self) -> nx.DiGraph: return resource_graph - def __iter__(self) -> Iterator[Any]: - """Returns a dependency-sorted iterable of population initializers. + def get_population_initializers(self) -> list[Any]: + """Returns a dependency-sorted list of population initializers. We exclude all non-initializer dependencies. They were necessary in graph construction, but we only need the column producers at population creation time. """ - return iter( - [ - r.producer - for r in self.sorted_nodes - if r.type in {"column", NULL_RESOURCE_TYPE} - ] - ) + return [ + r.producer for r in self.sorted_nodes if r.type in {"column", NULL_RESOURCE_TYPE} + ] def __repr__(self) -> str: out = {} @@ -264,11 +259,11 @@ def add_resources( """ self._manager.add_resources(resource_type, resource_names, producer, dependencies) - def __iter__(self) -> Iterator[Any]: - """Returns a dependency-sorted iterable of population initializers. + def get_population_initializers(self) -> list[Any]: + """Returns a dependency-sorted list of population initializers. We exclude all non-initializer dependencies. They were necessary in graph construction, but we only need the column producers at population creation time. """ - return iter(self._manager) + return self._manager.get_population_initializers() diff --git a/tests/framework/components/test_component.py b/tests/framework/components/test_component.py index 9aa85a33..86ce4df5 100644 --- a/tests/framework/components/test_component.py +++ b/tests/framework/components/test_component.py @@ -163,7 +163,10 @@ def test_component_initializer_is_not_registered_if_not_defined(): simulation = InteractiveContext(components=[component]) # Assert that simulant initializer has been registered - assert component.on_initialize_simulants not in simulation._resource + assert ( + component.on_initialize_simulants + not in simulation._resource.get_population_initializers() + ) def test_component_initializer_is_registered_and_called_if_defined(): @@ -176,7 +179,10 @@ def test_component_initializer_is_registered_and_called_if_defined(): population = simulation.get_population() # Assert that simulant initializer has been registered - assert component.on_initialize_simulants in simulation._resource + assert ( + component.on_initialize_simulants + in simulation._resource.get_population_initializers() + ) # and that created columns are correctly initialized pd.testing.assert_frame_equal(population[component.columns_created], expected_pop_view) From a1f7f3c61ff678a372ea0c5863a4a06aba840156 Mon Sep 17 00:00:00 2001 From: Rajan Mudambi <11376379+rmudambi@users.noreply.github.com> Date: Wed, 16 Oct 2024 12:57:48 -0700 Subject: [PATCH 07/22] add name to Pipeline __init__ (#505) --- src/vivarium/framework/values.py | 96 ++++++++++++++++++-------------- tests/framework/test_values.py | 2 +- 2 files changed, 54 insertions(+), 44 deletions(-) diff --git a/src/vivarium/framework/values.py b/src/vivarium/framework/values.py index b2e32401..42e7235c 100644 --- a/src/vivarium/framework/values.py +++ b/src/vivarium/framework/values.py @@ -14,9 +14,9 @@ """ from __future__ import annotations -from collections import defaultdict +from collections.abc import Callable, Iterable from datetime import timedelta -from typing import TYPE_CHECKING, Any, Callable, Iterable, Protocol, TypeVar +from typing import TYPE_CHECKING, Any, Protocol, TypeVar import pandas as pd @@ -118,9 +118,8 @@ def rescale_post_processor(value: NumberLike, manager: ValuesManager) -> NumberL Annual rates, either as a number or something we can broadcast multiplication over like a :mod:`numpy` array or :mod:`pandas` data frame. - time_step - A pandas time delta representing the size of the upcoming time - step. + manager + The ValuesManager for this simulation. Returns ------- @@ -203,24 +202,18 @@ class Pipeline: values that won't be used in the particular simulation. """ - def __init__(self) -> None: - """ - Parameters - ---------- - name - The name of the value represented by this pipeline. - mutators - A list of callables that directly modify the pipeline source or - contribute portions of the value. - post_processor - An optional final transformation to perform on the combined output of - the source and mutators. - """ - self.name: str | None = None + def __init__(self, name: str) -> None: + self.name: str = name + """The name of the value represented by this pipeline.""" self.source: Callable[..., Any] | None = None + """The callable source of the value represented by the pipeline.""" self.mutators: list[Callable[..., Any]] = [] + """A list of callables that directly modify the pipeline source or + contribute portions of the value.""" self._combiner: ValueCombiner | None = None self.post_processor: PostProcessor | None = None + """An optional final transformation to perform on the combined output of + the source and mutators.""" self._manager: ValuesManager | None = None def _get_attr_error(self, attribute: str) -> str: @@ -241,32 +234,17 @@ def _get_property(self, property: T | None, property_name: str) -> T: raise DynamicValueError(self._get_attr_error(property_name)) return property - def _set_property(self, property_name: str, new_value: Any) -> None: - private_name = f"_{property_name}" - old_value = getattr(self, private_name) - if old_value is not None: - raise DynamicValueError(self._set_attr_error(property_name, new_value)) - setattr(self, private_name, new_value) - @property def combiner(self) -> ValueCombiner: """A strategy for combining the source and mutator values into the final value represented by the pipeline.""" return self._get_property(self._combiner, "combiner") - @combiner.setter - def combiner(self, combiner: ValueCombiner) -> None: - self._set_property("combiner", combiner) - @property def manager(self) -> ValuesManager: """A reference to the simulation values manager.""" return self._get_property(self._manager, "manager") - @manager.setter - def manager(self, manager: ValuesManager) -> None: - self._set_property("manager", manager) - def __call__(self, *args: Any, skip_post_processor: bool = False, **kwargs: Any) -> Any: """Generates the value represented by this pipeline. @@ -311,13 +289,45 @@ def _call(self, *args: Any, skip_post_processor: bool = False, **kwargs: Any) -> def __repr__(self) -> str: return f"_Pipeline({self.name})" + @classmethod + def setup_pipeline( + cls, + pipeline: Pipeline, + source: Callable[..., Any], + combiner: ValueCombiner, + post_processor: PostProcessor | None, + manager: ValuesManager, + ) -> None: + """ + Add a source, combiner, and post-processor to a pipeline. + + Parameters + ---------- + pipeline + The pipeline to configure. + source + The callable source of the value represented by the pipeline. + combiner + A strategy for combining the source and mutator values into the + final value represented by the pipeline. + post_processor + An optional final transformation to perform on the combined output + of the source and mutators. + manager + The simulation values manager. + """ + pipeline.source = source + pipeline._combiner = combiner + pipeline.post_processor = post_processor + pipeline._manager = manager + class ValuesManager(Manager): """Manager for the dynamic value system.""" def __init__(self) -> None: # Pipelines are lazily initialized by _register_value_producer - self._pipelines: dict[str, Pipeline] = defaultdict(Pipeline) + self._pipelines: dict[str, Pipeline] = {} @property def name(self) -> str: @@ -403,12 +413,10 @@ def _register_value_producer( ) -> Pipeline: """Configure the named value pipeline with a source, combiner, and post-processor.""" self.logger.debug(f"Registering value pipeline {value_name}") - pipeline = self._pipelines[value_name] - pipeline.name = value_name - pipeline.source = source - pipeline.combiner = preferred_combiner - pipeline.post_processor = preferred_post_processor - pipeline.manager = self + pipeline = self.get_value(value_name) + Pipeline.setup_pipeline( + pipeline, source, preferred_combiner, preferred_post_processor, self + ) return pipeline def register_value_modifier( @@ -446,7 +454,7 @@ def register_value_modifier( """ modifier_name = self._get_modifier_name(modifier) - pipeline = self._pipelines[value_name] # May create a pipeline + pipeline = self.get_value(value_name) pipeline.mutators.append(modifier) name = f"{value_name}.{len(pipeline.mutators)}.{modifier_name}" @@ -471,7 +479,9 @@ def get_value(self, name: str) -> Pipeline: (frequently just a :class:`pandas.Index` representing the simulants). """ - return self._pipelines[name] # May create a pipeline. + pipeline = self._pipelines.get(name) or Pipeline(name) + self._pipelines[name] = pipeline + return pipeline @staticmethod def _convert_dependencies( diff --git a/tests/framework/test_values.py b/tests/framework/test_values.py index 78faef54..f4276b0a 100644 --- a/tests/framework/test_values.py +++ b/tests/framework/test_values.py @@ -123,7 +123,7 @@ def test_rescale_post_processor_variable(manager_with_step_size): def test_unsourced_pipeline(): - pipeline = Pipeline() + pipeline = Pipeline("some_name") assert pipeline.source is None with pytest.raises( DynamicValueError, From 53d1e319b852baadfe84d23f59427616a5c13f35 Mon Sep 17 00:00:00 2001 From: Rajan Mudambi <11376379+rmudambi@users.noreply.github.com> Date: Wed, 16 Oct 2024 13:05:11 -0700 Subject: [PATCH 08/22] enable registration of resources using list of entities (#506) --- src/vivarium/component.py | 52 ++++++----- src/vivarium/examples/disease_model/risk.py | 18 ++-- src/vivarium/framework/population/manager.py | 54 +++++++++++- src/vivarium/framework/state_machine.py | 14 +-- src/vivarium/framework/values.py | 92 ++++++++++++++++---- tests/framework/test_state_machine.py | 8 +- tests/helpers.py | 15 ++-- 7 files changed, 181 insertions(+), 72 deletions(-) diff --git a/src/vivarium/component.py b/src/vivarium/component.py index 1831f516..a773986d 100644 --- a/src/vivarium/component.py +++ b/src/vivarium/component.py @@ -11,6 +11,7 @@ from __future__ import annotations import re +import warnings from abc import ABC from collections.abc import Sequence from datetime import datetime, timedelta @@ -24,14 +25,16 @@ from loguru._logger import Logger from vivarium.framework.artifact import ArtifactException -from vivarium.framework.event import Event -from vivarium.framework.lookup import LookupTable -from vivarium.framework.population import PopulationError, PopulationView -from vivarium.types import LookupTableData +from vivarium.framework.population import PopulationError if TYPE_CHECKING: from vivarium.framework.engine import Builder - from vivarium.framework.population import SimulantData + from vivarium.framework.event import Event + from vivarium.framework.lookup import LookupTable + from vivarium.framework.population import PopulationView, SimulantData + from vivarium.framework.randomness import RandomnessStream + from vivarium.framework.values import Pipeline + from vivarium.types import LookupTableData DEFAULT_EVENT_PRIORITY = 5 """The default priority at which events will be triggered.""" @@ -234,22 +237,12 @@ def columns_required(self) -> Optional[List[str]]: return None @property - def initialization_requirements(self) -> Dict[str, List[str]]: - """Provides the names of all values required by this component during - simulant initialization. - - Returns - ------- - A dictionary containing the additional requirements of this - component during simulant initialization. An omitted key or an empty - list for a key implies no requirements for that key during - initialization. - """ - return { - "requires_columns": [], - "requires_values": [], - "requires_streams": [], - } + def initialization_requirements( + self, + ) -> list[str | Pipeline | RandomnessStream]: + """A list containing the columns, pipelines, and randomness streams + required by this component's simulant initializer.""" + return [] @property def population_view_query(self) -> Optional[str]: @@ -785,7 +778,7 @@ def _register_post_setup_listener(self, builder: "Builder") -> None: self.post_setup_priority, ) - def _register_simulant_initializer(self, builder: "Builder") -> None: + def _register_simulant_initializer(self, builder: Builder) -> None: """Registers a simulant initializer if this component has defined one. This method allows the component to initialize simulants if it has its @@ -798,11 +791,24 @@ def _register_simulant_initializer(self, builder: "Builder") -> None: builder The builder with which to register the initializer. """ + if isinstance(self.initialization_requirements, list): + initialization_requirements = { + "required_resources": self.initialization_requirements + } + else: + initialization_requirements = self.initialization_requirements + warnings.warn( + "The dict format for initialization_requirements is deprecated." + " You should use provide a list of the required resources.", + DeprecationWarning, + stacklevel=2, + ) + if type(self).on_initialize_simulants != Component.on_initialize_simulants: builder.population.initializes_simulants( self.on_initialize_simulants, creates_columns=self.columns_created, - **self.initialization_requirements, + **initialization_requirements, ) def _register_time_step_prepare_listener(self, builder: "Builder") -> None: diff --git a/src/vivarium/examples/disease_model/risk.py b/src/vivarium/examples/disease_model/risk.py index a8ed931b..835e77df 100644 --- a/src/vivarium/examples/disease_model/risk.py +++ b/src/vivarium/examples/disease_model/risk.py @@ -1,10 +1,16 @@ # mypy: ignore-errors -from typing import Any, Dict, List +from __future__ import annotations + +from typing import TYPE_CHECKING, Any, Dict, List import pandas as pd from vivarium import Component -from vivarium.framework.engine import Builder + +if TYPE_CHECKING: + from vivarium.framework.engine import Builder + from vivarium.framework.randomness import RandomnessStream + from vivarium.framework.values import Pipeline class Risk(Component): @@ -27,12 +33,8 @@ def columns_created(self) -> List[str]: return [f"{self.risk}_propensity"] @property - def initialization_requirements(self) -> Dict[str, List[str]]: - return { - "requires_columns": [], - "requires_values": [], - "requires_streams": [self.risk], - } + def initialization_requirements(self) -> list[str | Pipeline | RandomnessStream]: + return [self.randomness] ##################### # Lifecycle methods # diff --git a/src/vivarium/framework/population/manager.py b/src/vivarium/framework/population/manager.py index a580b1dd..7b69b459 100644 --- a/src/vivarium/framework/population/manager.py +++ b/src/vivarium/framework/population/manager.py @@ -18,6 +18,8 @@ from vivarium.framework.population.exceptions import PopulationError from vivarium.framework.population.population_view import PopulationView +from vivarium.framework.randomness import RandomnessStream +from vivarium.framework.values import Pipeline from vivarium.manager import Interface, Manager from vivarium.types import ClockStepSize, ClockTime @@ -263,6 +265,7 @@ def register_simulant_initializer( requires_columns: str | Sequence[str] = (), requires_values: str | Sequence[str] = (), requires_streams: str | Sequence[str] = (), + required_resources: Sequence[str | Pipeline | RandomnessStream] = (), ) -> None: """Marks a source of initial state information for new simulants. @@ -284,9 +287,21 @@ def register_simulant_initializer( requires_streams The randomness streams necessary to initialize the simulant attributes. + required_resources + The resources that the initializer requires to run. Strings are + interpreted as column names, and Pipelines and RandomnessStreams + are interpreted as value pipelines and randomness streams, + respectively. """ - if isinstance(creates_columns, str): - creates_columns = [creates_columns] + + has_individual_requires = requires_columns or requires_values or requires_streams + + if has_individual_requires and required_resources: + raise ValueError( + "If requires_columns, requires_values, or requires_streams are provided, " + "requirements must be empty." + ) + if isinstance(requires_columns, str): requires_columns = [requires_columns] if isinstance(requires_values, str): @@ -294,12 +309,33 @@ def register_simulant_initializer( if isinstance(requires_streams, str): requires_streams = [requires_streams] - self._initializer_components.add(initializer, list(creates_columns)) + if required_resources: + requires_columns = [] + requires_values = [] + requires_streams = [] + for required_resource in required_resources: + if isinstance(required_resource, str): + requires_columns.append(required_resource) + elif isinstance(required_resource, Pipeline): + requires_values.append(required_resource.name) + elif isinstance(required_resource, RandomnessStream): + requires_streams.append(required_resource.key) + else: + raise TypeError( + "requirements must be a sequence of strings, Pipelines," + f" and RandomnessStreams. Provided: '{type(required_resource)}'." + ) + dependencies = ( [f"column.{name}" for name in requires_columns] + [f"value.{name}" for name in requires_values] + [f"stream.{name}" for name in requires_streams] ) + + if isinstance(creates_columns, str): + creates_columns = [creates_columns] + self._initializer_components.add(initializer, list(creates_columns)) + if "tracked" not in creates_columns: # The population view itself uses the tracked column, so include # to be safe. @@ -455,6 +491,7 @@ def initializes_simulants( requires_columns: str | Sequence[str] = (), requires_values: str | Sequence[str] = (), requires_streams: str | Sequence[str] = (), + required_resources: Sequence[str | Pipeline | RandomnessStream] = (), ) -> None: """Marks a source of initial state information for new simulants. @@ -476,7 +513,16 @@ def initializes_simulants( requires_streams The randomness streams necessary to initialize the simulant attributes. + required_resources + The resources that the initializer requires to run. Strings are + interpreted as column names, and Pipelines and RandomnessStreams + are interpreted as value pipelines and randomness streams, """ self._manager.register_simulant_initializer( - initializer, creates_columns, requires_columns, requires_values, requires_streams + initializer, + creates_columns, + requires_columns, + requires_values, + requires_streams, + required_resources, ) diff --git a/src/vivarium/framework/state_machine.py b/src/vivarium/framework/state_machine.py index f18371aa..e91b59f7 100644 --- a/src/vivarium/framework/state_machine.py +++ b/src/vivarium/framework/state_machine.py @@ -17,11 +17,13 @@ import pandas as pd from vivarium import Component -from vivarium.framework.event import Event if TYPE_CHECKING: from vivarium.framework.engine import Builder + from vivarium.framework.event import Event from vivarium.framework.population import PopulationView, SimulantData + from vivarium.framework.randomness import RandomnessStream + from vivarium.framework.values import Pipeline from vivarium.types import ClockTime, LookupTableData @@ -492,12 +494,10 @@ def columns_created(self) -> List[str]: return [self.state_column] @property - def initialization_requirements(self) -> Dict[str, List[str]]: - return { - "requires_columns": [], - "requires_values": [], - "requires_streams": [self.randomness.key], - } + def initialization_requirements( + self, + ) -> list[str | Pipeline | RandomnessStream]: + return [self.randomness] ##################### # Lifecycle methods # diff --git a/src/vivarium/framework/values.py b/src/vivarium/framework/values.py index 42e7235c..56c985e9 100644 --- a/src/vivarium/framework/values.py +++ b/src/vivarium/framework/values.py @@ -5,7 +5,7 @@ The value pipeline system is a vital part of the :mod:`vivarium` infrastructure. It allows for values that determine the behavior of individual -:term:`simulants ` to be constructed across across multiple +:term:`simulants ` to be constructed across multiple :ref:`components `. For more information about when and how you should use pipelines in your @@ -14,7 +14,8 @@ """ from __future__ import annotations -from collections.abc import Callable, Iterable +import warnings +from collections.abc import Callable, Iterable, Sequence from datetime import timedelta from typing import TYPE_CHECKING, Any, Protocol, TypeVar @@ -22,6 +23,7 @@ from vivarium.exceptions import VivariumError from vivarium.framework.event import Event +from vivarium.framework.randomness import RandomnessStream from vivarium.framework.utilities import from_yearly from vivarium.manager import Interface, Manager from vivarium.types import NumberLike @@ -377,6 +379,7 @@ def register_value_producer( requires_columns: Iterable[str] = (), requires_values: Iterable[str] = (), requires_streams: Iterable[str] = (), + required_resources: Sequence[str | Pipeline | RandomnessStream] = (), preferred_combiner: ValueCombiner = replace_combiner, preferred_post_processor: PostProcessor | None = None, ) -> Pipeline: @@ -395,7 +398,7 @@ def register_value_producer( # declare that resource at post-setup once all sources and modifiers # are registered. dependencies = self._convert_dependencies( - source, requires_columns, requires_values, requires_streams + source, requires_columns, requires_values, requires_streams, required_resources ) self.resources.add_resources("value_source", [value_name], source, dependencies) self.add_constraint( @@ -426,6 +429,7 @@ def register_value_modifier( requires_columns: Iterable[str] = (), requires_values: Iterable[str] = (), requires_streams: Iterable[str] = (), + required_resources: Sequence[str | Pipeline | RandomnessStream] = (), ) -> None: """Marks a ``Callable`` as the modifier of a named value. @@ -451,6 +455,10 @@ def register_value_modifier( requires_streams A list of the randomness streams that need to be properly sourced before the pipeline modifier is called. + required_resources + A list of resources that need to be properly sourced before the + pipeline modifier is called. This is a list of strings, pipeline + names, or randomness streams. """ modifier_name = self._get_modifier_name(modifier) @@ -460,7 +468,7 @@ def register_value_modifier( name = f"{value_name}.{len(pipeline.mutators)}.{modifier_name}" self.logger.debug(f"Registering {name} as modifier to {value_name}") dependencies = self._convert_dependencies( - modifier, requires_columns, requires_values, requires_streams + modifier, requires_columns, requires_values, requires_streams, required_resources ) self.resources.add_resources("value_modifier", [name], modifier, dependencies) @@ -483,26 +491,54 @@ def get_value(self, name: str) -> Pipeline: self._pipelines[name] = pipeline return pipeline - @staticmethod def _convert_dependencies( + self, func: Callable[..., Any], requires_columns: Iterable[str], requires_values: Iterable[str], requires_streams: Iterable[str], + required_resources: Iterable[str | Pipeline | RandomnessStream], ) -> list[str]: - # If declaring a pipeline as a value source or modifier, columns and - # streams are optional since the pipeline itself will have all the - # appropriate dependencies. In any situation, make sure we don't have - # provide the pipeline function to source/modifier as well as - # explicitly stating the pipeline name in 'requires_values'. if isinstance(func, Pipeline): - dependencies = [f"value.{func.name}"] - else: - dependencies = ( - [f"column.{name}" for name in requires_columns] - + [f"value.{name}" for name in requires_values] - + [f"stream.{name}" for name in requires_streams] + # The dependencies of the pipeline itself will have been declared + # when the pipeline was registered. + return [f"value.{func.name}"] + + if requires_columns or requires_values or requires_streams: + warnings.warn( + "Specifying requirements individually is deprecated. You should " + "specify them using the 'required_resources' argument instead.", + DeprecationWarning, + stacklevel=2, ) + if required_resources: + raise ValueError( + "If requires_columns, requires_values, or requires_streams" + " are provided, requirements must be empty." + ) + + if required_resources: + requires_columns = [] + requires_values = [] + requires_streams = [] + for required_resource in required_resources: + if isinstance(required_resource, str): + requires_columns.append(required_resource) + elif isinstance(required_resource, Pipeline): + requires_values.append(required_resource.name) + elif isinstance(required_resource, RandomnessStream): + requires_streams.append(required_resource.key) + else: + raise TypeError( + "requirements must be a sequence of strings, Pipelines," + f" and RandomnessStreams. Provided: '{type(required_resource)}'." + ) + + dependencies = ( + [f"column.{name}" for name in requires_columns] + + [f"value.{name}" for name in requires_values] + + [f"stream.{name}" for name in requires_streams] + ) return dependencies @staticmethod @@ -565,6 +601,7 @@ def register_value_producer( requires_columns: Iterable[str] = (), requires_values: Iterable[str] = (), requires_streams: Iterable[str] = (), + required_resources: Sequence[str | Pipeline | RandomnessStream] = (), preferred_combiner: ValueCombiner = replace_combiner, preferred_post_processor: PostProcessor | None = None, ) -> Pipeline: @@ -586,6 +623,10 @@ def register_value_producer( requires_streams A list of the randomness streams that need to be properly sourced before the pipeline source is called. + required_resources + A list of resources that need to be properly sourced before the + pipeline source is called. This is a list of strings, pipeline + names, or randomness streams. preferred_combiner A strategy for combining the source and the results of any calls to mutators in the pipeline. ``vivarium`` provides the strategies @@ -609,6 +650,7 @@ def register_value_producer( requires_columns, requires_values, requires_streams, + required_resources, preferred_combiner, preferred_post_processor, ) @@ -620,6 +662,7 @@ def register_rate_producer( requires_columns: Iterable[str] = (), requires_values: Iterable[str] = (), requires_streams: Iterable[str] = (), + required_resources: Sequence[str | Pipeline | RandomnessStream] = (), ) -> Pipeline: """Marks a ``Callable`` as the producer of a named rate. @@ -646,6 +689,10 @@ def register_rate_producer( requires_streams A list of the randomness streams that need to be properly sourced before the pipeline source is called. + required_resources + A list of resources that need to be properly sourced before the + pipeline source is called. This is a list of strings, pipeline + names, or randomness streams. Returns ------- @@ -657,6 +704,7 @@ def register_rate_producer( requires_columns, requires_values, requires_streams, + required_resources, preferred_post_processor=rescale_post_processor, ) @@ -667,6 +715,7 @@ def register_value_modifier( requires_columns: Iterable[str] = (), requires_values: Iterable[str] = (), requires_streams: Iterable[str] = (), + required_resources: Sequence[str | Pipeline | RandomnessStream] = (), ) -> None: """Marks a ``Callable`` as the modifier of a named value. @@ -692,9 +741,18 @@ def register_value_modifier( requires_streams A list of the randomness streams that need to be properly sourced before the pipeline modifier is called. + required_resources + A list of resources that need to be properly sourced before the + pipeline modifier is called. This is a list of strings, pipeline + names, or randomness streams. """ self._manager.register_value_modifier( - value_name, modifier, requires_columns, requires_values, requires_streams + value_name, + modifier, + requires_columns, + requires_values, + requires_streams, + required_resources, ) def get_value(self, name: str) -> Pipeline: diff --git a/tests/framework/test_state_machine.py b/tests/framework/test_state_machine.py index eec4fa36..bd9906f3 100644 --- a/tests/framework/test_state_machine.py +++ b/tests/framework/test_state_machine.py @@ -71,15 +71,11 @@ def mock_load(key: str) -> pd.DataFrame: class TestMachine(Machine): @property - def initialization_requirements(self) -> dict[str, list[str]]: + def initialization_requirements(self) -> list[str | Pipeline | RandomnessStream]: # FIXME - MIC-5408: We shouldn't need to specify the columns in the # lookup tables here, since the component can't know what will be # specified by the states or the configuration. - return { - "requires_columns": ["test_column_1"], - "requires_values": [], - "requires_streams": [], - } + return ["test_column_1"] def initialization_weights(key: str): if use_artifact: diff --git a/tests/helpers.py b/tests/helpers.py index 7fafa88a..b79bf622 100644 --- a/tests/helpers.py +++ b/tests/helpers.py @@ -8,6 +8,8 @@ from vivarium.framework.engine import Builder from vivarium.framework.event import Event from vivarium.framework.population import SimulantData +from vivarium.framework.randomness import RandomnessStream +from vivarium.framework.values import Pipeline class MockComponentA(Observer): @@ -145,7 +147,6 @@ def columns_created(self) -> List[str]: def setup(self, builder: Builder) -> None: builder.value.register_value_producer("pipeline_1", lambda x: x) - builder.randomness.get_stream("stream_1") def on_initialize_simulants(self, pop_data: SimulantData) -> None: self.population_view.update(self.get_initial_state(pop_data.index)) @@ -245,12 +246,12 @@ def columns_created(self) -> List[str]: return ["test_column_4"] @property - def initialization_requirements(self) -> Dict[str, List[str]]: - return { - "requires_columns": ["test_column_2"], - "requires_values": ["pipeline_1"], - "requires_streams": ["stream_1"], - } + def initialization_requirements(self) -> list[str | Pipeline | RandomnessStream]: + return ["test_column_2", self.pipeline, self.randomness] + + def setup(self, builder: Builder) -> None: + self.pipeline = builder.value.get_value("pipeline_1") + self.randomness = builder.randomness.get_stream("stream_1") def on_initialize_simulants(self, pop_data: SimulantData) -> None: initialization_data = pd.DataFrame({"test_column_4": 8}, index=pop_data.index) From c8b9201202b067e9440aee779dbdc50eb32a7712 Mon Sep 17 00:00:00 2001 From: Rajan Mudambi <11376379+rmudambi@users.noreply.github.com> Date: Thu, 17 Oct 2024 10:35:04 -0700 Subject: [PATCH 09/22] specify required resources in example models (#507) --- src/vivarium/examples/boids/forces.py | 3 +- src/vivarium/examples/boids/neighbors.py | 2 +- .../examples/disease_model/disease.py | 33 +++++++++++-------- .../examples/disease_model/intervention.py | 8 +++-- src/vivarium/examples/disease_model/risk.py | 25 +++++++++----- 5 files changed, 45 insertions(+), 26 deletions(-) diff --git a/src/vivarium/examples/boids/forces.py b/src/vivarium/examples/boids/forces.py index 45b48a23..cc71e1f0 100644 --- a/src/vivarium/examples/boids/forces.py +++ b/src/vivarium/examples/boids/forces.py @@ -21,7 +21,7 @@ def configuration_defaults(self) -> Dict[str, Any]: }, } - columns_required = ["x", "y", "vx", "vy"] + columns_required = [] ##################### # Lifecycle methods # @@ -36,6 +36,7 @@ def setup(self, builder: Builder) -> None: builder.value.register_value_modifier( "acceleration", modifier=self.apply_force, + required_resources=self.columns_required + [self.neighbors], ) ################################## diff --git a/src/vivarium/examples/boids/neighbors.py b/src/vivarium/examples/boids/neighbors.py index e79bee17..9004cc11 100644 --- a/src/vivarium/examples/boids/neighbors.py +++ b/src/vivarium/examples/boids/neighbors.py @@ -26,7 +26,7 @@ def setup(self, builder: Builder) -> None: self.neighbors_calculated = False self._neighbors = pd.Series() self.neighbors = builder.value.register_value_producer( - "neighbors", source=self.get_neighbors + "neighbors", source=self.get_neighbors, required_resources=self.columns_required ) ######################## diff --git a/src/vivarium/examples/disease_model/disease.py b/src/vivarium/examples/disease_model/disease.py index 3e62bdff..770619e4 100644 --- a/src/vivarium/examples/disease_model/disease.py +++ b/src/vivarium/examples/disease_model/disease.py @@ -36,15 +36,17 @@ def setup(self, builder: Builder) -> None: rate = builder.configuration[self.cause_key][self.measure] self.base_rate = lambda index: pd.Series(rate, index=index) - self.transition_rate = builder.value.register_rate_producer( - self.rate_name, source=self._risk_deleted_rate - ) self.joint_population_attributable_fraction = builder.value.register_value_producer( f"{self.rate_name}.population_attributable_fraction", source=lambda index: [pd.Series(0.0, index=index)], preferred_combiner=list_combiner, preferred_post_processor=union_post_processor, ) + self.transition_rate = builder.value.register_rate_producer( + self.rate_name, + source=self._risk_deleted_rate, + required_resources=[self.joint_population_attributable_fraction], + ) ################################## # Pipeline sources and modifiers # @@ -104,11 +106,6 @@ def setup(self, builder: Builder): self._excess_mortality_rate = 0 self.clock = builder.time.clock() - - self.excess_mortality_rate = builder.value.register_rate_producer( - f"{self.state_id}.excess_mortality_rate", - source=self.risk_deleted_excess_mortality_rate, - ) self.excess_mortality_rate_paf = builder.value.register_value_producer( f"{self.state_id}.excess_mortality_rate.population_attributable_fraction", source=lambda index: [pd.Series(0.0, index=index)], @@ -116,7 +113,17 @@ def setup(self, builder: Builder): preferred_post_processor=union_post_processor, ) - builder.value.register_value_modifier("mortality_rate", self.add_in_excess_mortality) + self.excess_mortality_rate = builder.value.register_rate_producer( + f"{self.state_id}.excess_mortality_rate", + source=self.risk_deleted_excess_mortality_rate, + required_resources=[self.excess_mortality_rate_paf], + ) + + builder.value.register_value_modifier( + "mortality_rate", + self.add_in_excess_mortality, + required_resources=[self.excess_mortality_rate] + ) ################## # Public methods # @@ -141,9 +148,7 @@ def risk_deleted_excess_mortality_rate(self, index: pd.Index) -> pd.Series: def add_in_excess_mortality( self, index: pd.Index, mortality_rates: pd.Series ) -> pd.Series: - affected = self.population_view.get(index) - mortality_rates.loc[affected.index] += self.excess_mortality_rate(affected.index) - + mortality_rates.loc[index] += self.excess_mortality_rate(index) return mortality_rates @@ -168,7 +173,9 @@ def setup(self, builder: Builder) -> None: source=lambda index: pd.Series(cause_specific_mortality_rate, index=index), ) builder.value.register_value_modifier( - "mortality_rate", modifier=self.delete_cause_specific_mortality + "mortality_rate", + modifier=self.delete_cause_specific_mortality, + required_resources=[self.cause_specific_mortality_rate], ) ################################## diff --git a/src/vivarium/examples/disease_model/intervention.py b/src/vivarium/examples/disease_model/intervention.py index 4f7a7cc1..f09cc593 100644 --- a/src/vivarium/examples/disease_model/intervention.py +++ b/src/vivarium/examples/disease_model/intervention.py @@ -34,13 +34,15 @@ def __init__(self, intervention: str, affected_value: str): # noinspection PyAttributeOutsideInit def setup(self, builder: Builder) -> None: effect_size = builder.configuration[self.intervention].effect_size - builder.value.register_value_modifier( - self.affected_value, modifier=self.intervention_effect - ) self.effect_size = builder.value.register_value_producer( f"{self.intervention}.effect_size", source=lambda index: pd.Series(effect_size, index=index), ) + builder.value.register_value_modifier( + self.affected_value, + modifier=self.intervention_effect, + required_resources=[self.effect_size], + ) ################################## # Pipeline sources and modifiers # diff --git a/src/vivarium/examples/disease_model/risk.py b/src/vivarium/examples/disease_model/risk.py index 835e77df..419efc6e 100644 --- a/src/vivarium/examples/disease_model/risk.py +++ b/src/vivarium/examples/disease_model/risk.py @@ -30,7 +30,7 @@ def configuration_defaults(self) -> Dict[str, Any]: @property def columns_created(self) -> List[str]: - return [f"{self.risk}_propensity"] + return [self.propensity_column] @property def initialization_requirements(self) -> list[str | Pipeline | RandomnessStream]: @@ -43,6 +43,7 @@ def initialization_requirements(self) -> list[str | Pipeline | RandomnessStream] def __init__(self, risk: str): super().__init__() self.risk = risk + self.propensity_column = f"{risk}_propensity" # noinspection PyAttributeOutsideInit def setup(self, builder: Builder) -> None: @@ -56,7 +57,9 @@ def setup(self, builder: Builder) -> None: ) self.exposure = builder.value.register_value_producer( - f"{self.risk}.exposure", source=self._exposure + f"{self.risk}.exposure", + source=self._exposure, + required_resources=[self.propensity_column, self.exposure_threshold], ) self.randomness = builder.randomness.get_stream(self.risk) @@ -66,14 +69,14 @@ def setup(self, builder: Builder) -> None: def on_initialize_simulants(self, pop_data): draw = self.randomness.get_draw(pop_data.index) - self.population_view.update(pd.Series(draw, name=f"{self.risk}_propensity")) + self.population_view.update(pd.Series(draw, name=self.propensity_column)) ################################## # Pipeline sources and modifiers # ################################## def _exposure(self, index): - propensity = self.population_view.get(index)[f"{self.risk}_propensity"] + propensity = self.population_view.get(index)[self.propensity_column] return self.exposure_threshold(index) > propensity @@ -104,6 +107,11 @@ def __init__(self, risk_name: str, disease_rate: str): # noinspection PyAttributeOutsideInit def setup(self, builder: Builder) -> None: + self.base_risk_exposure = builder.value.get_value( + f"{self.risk_name}.base_proportion_exposed" + ) + self.actual_risk_exposure = builder.value.get_value(f"{self.risk_name}.exposure") + relative_risk = builder.configuration[self.risk].relative_risk self.relative_risk = builder.value.register_value_producer( f"{self.risk}.relative_risk", @@ -113,12 +121,13 @@ def setup(self, builder: Builder) -> None: builder.value.register_value_modifier( f"{self.disease_rate}.population_attributable_fraction", self.population_attributable_fraction, + required_resources=[self.base_risk_exposure, self.relative_risk], ) - builder.value.register_value_modifier(f"{self.disease_rate}", self.rate_adjustment) - self.base_risk_exposure = builder.value.get_value( - f"{self.risk_name}.base_proportion_exposed" + builder.value.register_value_modifier( + f"{self.disease_rate}", + self.rate_adjustment, + required_resources=[self.actual_risk_exposure, self.relative_risk], ) - self.actual_risk_exposure = builder.value.get_value(f"{self.risk_name}.exposure") ################################## # Pipeline sources and modifiers # From 4cae343f6a220af03f100c138d167950a333850f Mon Sep 17 00:00:00 2001 From: Rajan Mudambi <11376379+rmudambi@users.noreply.github.com> Date: Fri, 18 Oct 2024 15:05:29 -0700 Subject: [PATCH 10/22] refactor the dependencies argument to add_resources (#510) --- .../framework/resource/resource.rst | 1 + pyproject.toml | 1 - src/vivarium/framework/population/manager.py | 50 ++--- src/vivarium/framework/randomness/manager.py | 5 +- src/vivarium/framework/resource/__init__.py | 1 + src/vivarium/framework/resource/group.py | 40 +++- src/vivarium/framework/resource/manager.py | 35 ++-- src/vivarium/framework/resource/resource.py | 31 +++ src/vivarium/framework/values.py | 43 ++--- tests/framework/resource/__init__.py | 0 tests/framework/resource/test_manager.py | 182 ++++++++++++++++++ tests/framework/resource/test_resource.py | 6 + .../framework/resource/test_resource_group.py | 60 ++++++ tests/framework/test_resource.py | 162 ---------------- 14 files changed, 367 insertions(+), 250 deletions(-) create mode 100644 docs/source/api_reference/framework/resource/resource.rst create mode 100644 src/vivarium/framework/resource/resource.py create mode 100644 tests/framework/resource/__init__.py create mode 100644 tests/framework/resource/test_manager.py create mode 100644 tests/framework/resource/test_resource.py create mode 100644 tests/framework/resource/test_resource_group.py delete mode 100644 tests/framework/test_resource.py diff --git a/docs/source/api_reference/framework/resource/resource.rst b/docs/source/api_reference/framework/resource/resource.rst new file mode 100644 index 00000000..2a066b8a --- /dev/null +++ b/docs/source/api_reference/framework/resource/resource.rst @@ -0,0 +1 @@ +.. automodule:: vivarium.framework.resource.resource \ No newline at end of file diff --git a/pyproject.toml b/pyproject.toml index c82f4f91..3f89b0c9 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -91,7 +91,6 @@ exclude = [ 'tests/framework/test_event.py', 'tests/framework/test_lifecycle.py', 'tests/framework/test_plugins.py', - 'tests/framework/test_resource.py', 'tests/framework/test_state_machine.py', 'tests/framework/test_time.py', 'tests/framework/test_utilities.py', diff --git a/src/vivarium/framework/population/manager.py b/src/vivarium/framework/population/manager.py index 7b69b459..3a0b345b 100644 --- a/src/vivarium/framework/population/manager.py +++ b/src/vivarium/framework/population/manager.py @@ -9,7 +9,7 @@ """ from __future__ import annotations -from collections.abc import Callable, Sequence +from collections.abc import Callable, Iterable, Sequence from dataclasses import dataclass from types import MethodType from typing import TYPE_CHECKING, Any @@ -19,6 +19,7 @@ from vivarium.framework.population.exceptions import PopulationError from vivarium.framework.population.population_view import PopulationView from vivarium.framework.randomness import RandomnessStream +from vivarium.framework.resource import Resource from vivarium.framework.values import Pipeline from vivarium.manager import Interface, Manager from vivarium.types import ClockStepSize, ClockTime @@ -265,7 +266,7 @@ def register_simulant_initializer( requires_columns: str | Sequence[str] = (), requires_values: str | Sequence[str] = (), requires_streams: str | Sequence[str] = (), - required_resources: Sequence[str | Pipeline | RandomnessStream] = (), + required_resources: Iterable[str | Pipeline | RandomnessStream] = (), ) -> None: """Marks a source of initial state information for new simulants. @@ -302,6 +303,8 @@ def register_simulant_initializer( "requirements must be empty." ) + if isinstance(creates_columns, str): + creates_columns = [creates_columns] if isinstance(requires_columns, str): requires_columns = [requires_columns] if isinstance(requires_values, str): @@ -309,39 +312,26 @@ def register_simulant_initializer( if isinstance(requires_streams, str): requires_streams = [requires_streams] - if required_resources: - requires_columns = [] - requires_values = [] - requires_streams = [] - for required_resource in required_resources: - if isinstance(required_resource, str): - requires_columns.append(required_resource) - elif isinstance(required_resource, Pipeline): - requires_values.append(required_resource.name) - elif isinstance(required_resource, RandomnessStream): - requires_streams.append(required_resource.key) - else: - raise TypeError( - "requirements must be a sequence of strings, Pipelines," - f" and RandomnessStreams. Provided: '{type(required_resource)}'." - ) - - dependencies = ( - [f"column.{name}" for name in requires_columns] - + [f"value.{name}" for name in requires_values] - + [f"stream.{name}" for name in requires_streams] - ) - - if isinstance(creates_columns, str): - creates_columns = [creates_columns] - self._initializer_components.add(initializer, list(creates_columns)) + declared_dependencies: Iterable[str | Pipeline | RandomnessStream | Resource] + if has_individual_requires: + declared_dependencies = ( + list(requires_columns) + + [Resource("value", name) for name in requires_values] + + [Resource("stream", name) for name in requires_streams] + ) + else: + declared_dependencies = list(required_resources) if "tracked" not in creates_columns: # The population view itself uses the tracked column, so include # to be safe. - dependencies += ["column.tracked"] + all_dependencies = list(declared_dependencies) + ["tracked"] + else: + all_dependencies = list(declared_dependencies) + + self._initializer_components.add(initializer, list(creates_columns)) self.resources.add_resources( - "column", list(creates_columns), initializer, dependencies + "column", list(creates_columns), initializer, all_dependencies ) def get_simulant_creator(self) -> Callable[[int, dict[str, Any] | None], pd.Index[int]]: diff --git a/src/vivarium/framework/randomness/manager.py b/src/vivarium/framework/randomness/manager.py index c4fe4916..6fdcbb2e 100644 --- a/src/vivarium/framework/randomness/manager.py +++ b/src/vivarium/framework/randomness/manager.py @@ -119,10 +119,7 @@ def get_randomness_stream( if not initializes_crn_attributes: # We need the key columns to be created before this stream can be called. self.resources.add_resources( - "stream", - [decision_point], - stream, - [f"column.{name}" for name in self._key_columns], + "stream", [decision_point], stream, self._key_columns ) self._add_constraint( stream.get_draw, restrict_during=["initialization", "setup", "post_setup"] diff --git a/src/vivarium/framework/resource/__init__.py b/src/vivarium/framework/resource/__init__.py index d15753b5..8e57e20a 100644 --- a/src/vivarium/framework/resource/__init__.py +++ b/src/vivarium/framework/resource/__init__.py @@ -21,3 +21,4 @@ """ from vivarium.framework.resource.manager import ResourceInterface, ResourceManager +from vivarium.framework.resource.resource import Resource diff --git a/src/vivarium/framework/resource/group.py b/src/vivarium/framework/resource/group.py index 5e0218c2..6ab66d3d 100644 --- a/src/vivarium/framework/resource/group.py +++ b/src/vivarium/framework/resource/group.py @@ -1,7 +1,14 @@ from __future__ import annotations -from collections.abc import Iterator -from typing import Any +from collections.abc import Iterable, Iterator +from typing import TYPE_CHECKING, Any + +from vivarium.framework.resource.exceptions import ResourceError +from vivarium.framework.resource.resource import Resource + +if TYPE_CHECKING: + from vivarium.framework.randomness import RandomnessStream + from vivarium.framework.values import Pipeline class ResourceGroup: @@ -21,12 +28,14 @@ def __init__( resource_type: str, resource_names: list[str], producer: Any, - dependencies: list[str], + dependencies: Iterable[str | Pipeline | RandomnessStream | Resource], ): self._resource_type = resource_type self._resource_names = resource_names self._producer = producer - self._dependencies = dependencies + self._dependency_keys = [ + self._get_dependency_key(dependency) for dependency in dependencies + ] @property def type(self) -> str: @@ -49,7 +58,7 @@ def producer(self) -> Any: @property def dependencies(self) -> list[str]: """The long names (including type) of dependencies for this group.""" - return self._dependencies + return self._dependency_keys def __iter__(self) -> Iterator[str]: return iter(self.names) @@ -61,3 +70,24 @@ def __repr__(self) -> str: def __str__(self) -> str: resources = ", ".join(self) return f"({resources})" + + @staticmethod + def _get_dependency_key(dependency: str | Pipeline | RandomnessStream | Resource) -> str: + # local import to avoid circular dependency + from vivarium.framework.randomness import RandomnessStream + from vivarium.framework.values import Pipeline + + if isinstance(dependency, str): + return f"column.{dependency}" + elif isinstance(dependency, Pipeline): + return f"value.{dependency.name}" + elif isinstance(dependency, RandomnessStream): + return f"stream.{dependency.key}" + elif isinstance(dependency, Resource): + return str(dependency) + else: + raise ResourceError( + f"Dependency '{dependency}' of unknown type: {type(dependency)}." + " Dependencies must be strings, Pipelines, RandomnessStreams, or" + " Resources." + ) diff --git a/src/vivarium/framework/resource/manager.py b/src/vivarium/framework/resource/manager.py index 02ce11d0..270c6e0f 100644 --- a/src/vivarium/framework/resource/manager.py +++ b/src/vivarium/framework/resource/manager.py @@ -7,26 +7,21 @@ from __future__ import annotations +from collections.abc import Iterable from typing import TYPE_CHECKING, Any import networkx as nx from vivarium.framework.resource.exceptions import ResourceError from vivarium.framework.resource.group import ResourceGroup +from vivarium.framework.resource.resource import RESOURCE_TYPES, Resource from vivarium.manager import Interface, Manager if TYPE_CHECKING: from vivarium.framework.engine import Builder + from vivarium.framework.randomness import RandomnessStream + from vivarium.framework.values import Pipeline - -RESOURCE_TYPES = { - "value", - "value_source", - "missing_value_source", - "value_modifier", - "column", - "stream", -} NULL_RESOURCE_TYPE = "null" @@ -88,7 +83,7 @@ def add_resources( resource_type: str, resource_names: list[str], producer: Any, - dependencies: list[str], + dependencies: Iterable[str | Pipeline | RandomnessStream | Resource], ) -> None: """Adds managed resources to the resource pool. @@ -102,15 +97,15 @@ def add_resources( producer A method or object that will produce the resources. dependencies - A list of resource names formatted as - ``resource_type.resource_name`` that the producer requires. + A list of resources that the producer requires. Raises ------ ResourceError If either the resource type is invalid, a component has multiple - resource producers for the ``column`` resource type, or - there are multiple producers of the same resource. + resource producers for the ``column`` resource type, + there are multiple producers of the same resource, or . + the dependencies are of an invalid type. """ if resource_type not in RESOURCE_TYPES: raise ResourceError( @@ -136,7 +131,7 @@ def _get_resource_group( resource_type: str, resource_names: list[str], producer: Any, - dependencies: list[str], + dependencies: Iterable[str | Pipeline | RandomnessStream | Resource], ) -> ResourceGroup: """Packages resource information into a resource group. @@ -233,7 +228,7 @@ def add_resources( resource_type: str, resource_names: list[str], producer: Any, - dependencies: list[str], + dependencies: Iterable[str | Pipeline | RandomnessStream | Resource], ) -> None: """Adds managed resources to the resource pool. @@ -247,15 +242,15 @@ def add_resources( producer A method or object that will produce the resources. dependencies - A list of resource names formatted as - ``resource_type.resource_name`` that the producer requires. + A list of resources that the producer requires. Raises ------ ResourceError If either the resource type is invalid, a component has multiple - resource producers for the ``column`` resource type, or - there are multiple producers of the same resource. + resource producers for the ``column`` resource type, + there are multiple producers of the same resource, or . + the dependencies are of an invalid type. """ self._manager.add_resources(resource_type, resource_names, producer, dependencies) diff --git a/src/vivarium/framework/resource/resource.py b/src/vivarium/framework/resource/resource.py new file mode 100644 index 00000000..5358202a --- /dev/null +++ b/src/vivarium/framework/resource/resource.py @@ -0,0 +1,31 @@ +from __future__ import annotations + +from vivarium.framework.resource.exceptions import ResourceError + +RESOURCE_TYPES = { + "value", + "value_source", + "missing_value_source", + "value_modifier", + "column", + "stream", +} + + +class Resource: + """A generic resource. + + These resources may be required to build up the dependency graph. + """ + + def __init__(self, type: str, name: str): + if type not in RESOURCE_TYPES: + raise ResourceError(f"Unknown resource type: {type}") + + self.type = type + """The type of the resource.""" + self.name = name + """The name of the resource.""" + + def __str__(self) -> str: + return f"{self.type}.{self.name}" diff --git a/src/vivarium/framework/values.py b/src/vivarium/framework/values.py index 56c985e9..38730c9d 100644 --- a/src/vivarium/framework/values.py +++ b/src/vivarium/framework/values.py @@ -24,6 +24,7 @@ from vivarium.exceptions import VivariumError from vivarium.framework.event import Event from vivarium.framework.randomness import RandomnessStream +from vivarium.framework.resource import Resource from vivarium.framework.utilities import from_yearly from vivarium.manager import Interface, Manager from vivarium.types import NumberLike @@ -364,12 +365,14 @@ def on_post_setup(self, _event: Event) -> None: for name, pipe in self._pipelines.items(): dependencies = [] if pipe.source is not None: - dependencies += [f"value_source.{name}"] + dependencies.append(Resource("value_source", name)) else: - dependencies += [f"missing_value_source.{name}"] + dependencies.append(Resource("missing_value_source", name)) for i, m in enumerate(pipe.mutators): mutator_name = self._get_modifier_name(m) - dependencies.append(f"value_modifier.{name}.{i+1}.{mutator_name}") + dependencies.append( + Resource("value_modifier", f"{name}.{i+1}.{mutator_name}") + ) self.resources.add_resources("value", [name], pipe._call, dependencies) def register_value_producer( @@ -498,11 +501,11 @@ def _convert_dependencies( requires_values: Iterable[str], requires_streams: Iterable[str], required_resources: Iterable[str | Pipeline | RandomnessStream], - ) -> list[str]: + ) -> Iterable[str | Pipeline | RandomnessStream | Resource]: if isinstance(func, Pipeline): # The dependencies of the pipeline itself will have been declared # when the pipeline was registered. - return [f"value.{func.name}"] + return [Resource("value", func.name)] if requires_columns or requires_values or requires_streams: warnings.warn( @@ -517,29 +520,13 @@ def _convert_dependencies( " are provided, requirements must be empty." ) - if required_resources: - requires_columns = [] - requires_values = [] - requires_streams = [] - for required_resource in required_resources: - if isinstance(required_resource, str): - requires_columns.append(required_resource) - elif isinstance(required_resource, Pipeline): - requires_values.append(required_resource.name) - elif isinstance(required_resource, RandomnessStream): - requires_streams.append(required_resource.key) - else: - raise TypeError( - "requirements must be a sequence of strings, Pipelines," - f" and RandomnessStreams. Provided: '{type(required_resource)}'." - ) - - dependencies = ( - [f"column.{name}" for name in requires_columns] - + [f"value.{name}" for name in requires_values] - + [f"stream.{name}" for name in requires_streams] - ) - return dependencies + return ( + list(requires_columns) + + [Resource("value", name) for name in requires_values] + + [Resource("stream", name) for name in requires_streams] + ) + else: + return required_resources @staticmethod def _get_modifier_name(modifier: Callable[..., Any]) -> str: diff --git a/tests/framework/resource/__init__.py b/tests/framework/resource/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/framework/resource/test_manager.py b/tests/framework/resource/test_manager.py new file mode 100644 index 00000000..da909101 --- /dev/null +++ b/tests/framework/resource/test_manager.py @@ -0,0 +1,182 @@ +from __future__ import annotations + +from collections.abc import Callable, Mapping + +import pytest +import pytest_mock + +from vivarium import Component +from vivarium.framework.population import SimulantData +from vivarium.framework.randomness import RandomnessStream +from vivarium.framework.randomness.index_map import IndexMap +from vivarium.framework.resource import Resource, ResourceManager +from vivarium.framework.resource.exceptions import ResourceError +from vivarium.framework.resource.manager import NULL_RESOURCE_TYPE, RESOURCE_TYPES +from vivarium.framework.values import Pipeline + + +@pytest.fixture +def manager(mocker: pytest_mock.MockFixture) -> ResourceManager: + manager = ResourceManager() + manager.logger = mocker.Mock() + return manager + + +@pytest.fixture +def randomness_stream() -> RandomnessStream: + return RandomnessStream("stream.1", lambda x: x, 1, IndexMap()) + + +class ResourceProducer(Component): + @property + def name(self) -> str: + return self._name + + def __init__(self, name: str): + super().__init__() + self._name = name + + def producer(self, _simulant_data: SimulantData) -> None: + pass + + +@pytest.mark.parametrize("r_type", RESOURCE_TYPES, ids=lambda x: f"r_type_{x}") +def test_resource_manager_get_resource_group(r_type: str, manager: ResourceManager) -> None: + component = ResourceProducer("base") + r_names = ["foo"] + r_producer = component.producer + r_dependencies: list[str | Resource] = [] + + group = manager._get_resource_group(r_type, r_names, r_producer, r_dependencies) + + assert group.type == r_type + assert group.names == [f"{r_type}.foo"] + assert group.producer == component.producer + assert not group.dependencies + + +def test_resource_manager_get_resource_group_null(manager: ResourceManager) -> None: + component = ResourceProducer("base") + r_names: list[str] = [] + r_producer = component.producer + r_dependencies: list[str | Resource] = [] + + group_1 = manager._get_resource_group("column", r_names, r_producer, r_dependencies) + group_2 = manager._get_resource_group("column", r_names, r_producer, r_dependencies) + + assert group_1.type == NULL_RESOURCE_TYPE + assert group_1.names == [f"{NULL_RESOURCE_TYPE}.0"] + assert group_1.producer == component.producer + assert not group_1.dependencies + + assert group_2.type == NULL_RESOURCE_TYPE + assert group_2.names == [f"{NULL_RESOURCE_TYPE}.1"] + assert group_2.producer == component.producer + assert not group_2.dependencies + + +def test_resource_manager_add_resources_bad_type(manager: ResourceManager) -> None: + c = ResourceProducer("base") + r_type = "unknown" + r_names = [str(i) for i in range(5)] + r_producer = c.producer + r_dependencies: list[str | Resource] = [] + + with pytest.raises(ResourceError, match="Unknown resource type"): + manager.add_resources(r_type, r_names, r_producer, r_dependencies) + + +def test_resource_manager_add_resources_multiple_producers(manager: ResourceManager) -> None: + c1 = ResourceProducer("1") + c2 = ResourceProducer("2") + r_type = "column" + r1_names = [str(i) for i in range(5)] + r2_names = [str(i) for i in range(5, 10)] + ["1"] + r1_producer = c1.producer + r2_producer = c2.producer + r_dependencies: list[str | Resource] = [] + + manager.add_resources(r_type, r1_names, r1_producer, r_dependencies) + with pytest.raises(ResourceError, match="producers for column.1"): + manager.add_resources(r_type, r2_names, r2_producer, r_dependencies) + + +def test_resource_manager_sorted_nodes_two_node_cycle( + manager: ResourceManager, randomness_stream: RandomnessStream +) -> None: + c = ResourceProducer("test") + + manager.add_resources("column", ["c_1"], c.producer, [randomness_stream]) + manager.add_resources("stream", [randomness_stream.key], c.producer, ["c_1"]) + + with pytest.raises(ResourceError, match="cycle"): + _ = manager.sorted_nodes + + +def test_resource_manager_sorted_nodes_three_node_cycle( + manager: ResourceManager, randomness_stream: RandomnessStream +) -> None: + c = ResourceProducer("test") + pipeline = Pipeline("some_pipeline") + + manager.add_resources("column", ["c_1"], c.producer, [randomness_stream]) + manager.add_resources("value", [pipeline.name], c.producer, ["c_1"]) + manager.add_resources("stream", [randomness_stream.key], c.producer, [pipeline]) + + with pytest.raises(ResourceError, match="cycle"): + _ = manager.sorted_nodes + + +def test_resource_manager_sorted_nodes_large_cycle(manager: ResourceManager) -> None: + c = ResourceProducer("test") + + for i in range(10): + manager.add_resources("column", [f"c_{i}"], c.producer, [f"c_{i%10}"]) + + with pytest.raises(ResourceError, match="cycle"): + _ = manager.sorted_nodes + + +def test_resource_manager_sorted_nodes_acyclic(manager: ResourceManager) -> None: + _add_resources(manager) + + n = [str(node) for node in manager.sorted_nodes] + + assert n.index("(column.A)") < n.index("(stream.B)") + assert n.index("(column.A)") < n.index("(value.C)") + assert n.index("(column.A)") < n.index("(column.D)") + + assert n.index("(stream.B)") < n.index("(column.D)") + assert n.index("(value.C)") < n.index("(column.D)") + + assert n.index("(stream.B)") < n.index(f"({NULL_RESOURCE_TYPE}.0)") + + +def test_get_population_initializers(manager: ResourceManager) -> None: + producers = _add_resources(manager) + initializers = manager.get_population_initializers() + + assert len(initializers) == 3 + assert initializers[0] == producers[0] + assert producers[3] in initializers + assert producers[4] in initializers + + +#################### +# Helper functions # +#################### + + +def _add_resources(manager: ResourceManager) -> Mapping[int, Callable[[SimulantData], None]]: + producers = {i: ResourceProducer(f"test_{i}").producer for i in range(5)} + + stream = RandomnessStream("B", lambda x: x, 1, IndexMap()) + pipeline = Pipeline("C") + + manager.add_resources("column", ["D"], producers[3], [stream, pipeline]) + manager.add_resources("stream", ["B"], producers[1], ["A"]) + manager.add_resources("value", ["C"], producers[2], ["A"]) + manager.add_resources("column", ["A"], producers[0], []) + manager.add_resources("column", [], producers[4], [stream]) + + return producers diff --git a/tests/framework/resource/test_resource.py b/tests/framework/resource/test_resource.py new file mode 100644 index 00000000..c4c625bc --- /dev/null +++ b/tests/framework/resource/test_resource.py @@ -0,0 +1,6 @@ +from vivarium.framework.resource import Resource + + +def test_to_string() -> None: + resource = Resource("value_source", "test") + assert str(resource) == "value_source.test" diff --git a/tests/framework/resource/test_resource_group.py b/tests/framework/resource/test_resource_group.py new file mode 100644 index 00000000..ef45e1b6 --- /dev/null +++ b/tests/framework/resource/test_resource_group.py @@ -0,0 +1,60 @@ +from __future__ import annotations + +import pytest + +from vivarium.framework.randomness import RandomnessStream +from vivarium.framework.randomness.index_map import IndexMap +from vivarium.framework.resource import Resource +from vivarium.framework.resource.exceptions import ResourceError +from vivarium.framework.resource.group import ResourceGroup +from vivarium.framework.values import Pipeline + + +def dummy_producer() -> str: + return "resources!" + + +def test_resource_group() -> None: + r_type = "column" + r_names = [str(i) for i in range(5)] + r_producer = dummy_producer + r_dependencies: list[str | Pipeline | RandomnessStream | Resource] = [ + "an_interesting_column", + Pipeline("baz"), + RandomnessStream("bar", lambda x: x, 1, IndexMap()), + Resource("value_source", "foo"), + ] + + rg = ResourceGroup(r_type, r_names, r_producer, r_dependencies) + + assert rg.type == r_type + assert rg.names == [f"{r_type}.{name}" for name in r_names] + assert rg.producer == dummy_producer + assert rg.dependencies == [ + "column.an_interesting_column", + "value.baz", + "stream.bar", + "value_source.foo", + ] + assert list(rg) == rg.names + + +@pytest.mark.parametrize( + "dependency, expected_key", + [ + ("an_interesting_column", "column.an_interesting_column"), + (Pipeline("baz"), "value.baz"), + (RandomnessStream("bar", lambda x: x, 1, IndexMap()), "stream.bar"), + (Resource("value_source", "foo"), "value_source.foo"), + ], +) +def test__get_dependency_key( + dependency: str | Pipeline | RandomnessStream | Resource, expected_key: str +) -> None: + key = ResourceGroup._get_dependency_key(dependency) + assert key == expected_key + + +def test__get_dependency_key_unknown_type() -> None: + with pytest.raises(ResourceError, match="unknown type"): + ResourceGroup._get_dependency_key(1) # type: ignore [arg-type] diff --git a/tests/framework/test_resource.py b/tests/framework/test_resource.py deleted file mode 100644 index 673abc15..00000000 --- a/tests/framework/test_resource.py +++ /dev/null @@ -1,162 +0,0 @@ -import pytest - -from vivarium import Component -from vivarium.framework.resource import ResourceManager -from vivarium.framework.resource.exceptions import ResourceError -from vivarium.framework.resource.group import ResourceGroup -from vivarium.framework.resource.manager import NULL_RESOURCE_TYPE, RESOURCE_TYPES - - -class ResourceProducer(Component): - @property - def name(self) -> str: - return self._name - - def __init__(self, name: str): - super().__init__() - self._name = name - - def producer(self): - return "resources!" - - -def test_resource_group(): - c = ResourceProducer("base") - r_type = "column" - r_names = [str(i) for i in range(5)] - r_producer = c.producer - r_dependencies = [] - - rg = ResourceGroup(r_type, r_names, r_producer, r_dependencies) - - assert rg.type == r_type - assert rg.names == [f"{r_type}.{name}" for name in r_names] - assert rg.producer == c.producer - assert not rg.dependencies - assert list(rg) == rg.names - - -def test_resource_manager_get_resource_group(): - rm = ResourceManager() - c = ResourceProducer("base") - r_type = "column" - r_names = [str(i) for i in range(5)] - r_producer = c.producer - r_dependencies = [] - - rg = rm._get_resource_group(r_type, r_names, r_producer, r_dependencies) - - assert rg.type == r_type - assert rg.names == [f"{r_type}.{name}" for name in r_names] - assert rg.producer == c.producer - assert not rg.dependencies - assert list(rg) == rg.names - - -def test_resource_manager_get_resource_group_null(): - rm = ResourceManager() - c = ResourceProducer("base") - r_type = "column" - r_names = [] - r_producer = c.producer - r_dependencies = [] - - rg = rm._get_resource_group(r_type, r_names, r_producer, r_dependencies) - - assert rg.type == NULL_RESOURCE_TYPE - assert rg.names == [f"{NULL_RESOURCE_TYPE}.0"] - assert rg.producer == c.producer - assert not rg.dependencies - assert list(rg) == rg.names - - -def test_resource_manager_add_resources_bad_type(): - rm = ResourceManager() - c = ResourceProducer("base") - r_type = "unknown" - r_names = [str(i) for i in range(5)] - r_producer = c.producer - r_dependencies = [] - - with pytest.raises(ResourceError, match="Unknown resource type"): - rm.add_resources(r_type, r_names, r_producer, r_dependencies) - - -def test_resource_manager_add_resources_multiple_producers(): - rm = ResourceManager() - c1 = ResourceProducer("1") - c2 = ResourceProducer("2") - r_type = "column" - r1_names = [str(i) for i in range(5)] - r2_names = [str(i) for i in range(5, 10)] + ["1"] - r1_producer = c1.producer - r2_producer = c2.producer - r_dependencies = [] - - rm.add_resources(r_type, r1_names, r1_producer, r_dependencies) - with pytest.raises(ResourceError, match="producers for column.1"): - rm.add_resources(r_type, r2_names, r2_producer, r_dependencies) - - -def test_resource_manager_add_resources(): - rm = ResourceManager() - for r_type in RESOURCE_TYPES: - old_names = [] - for i in range(5): - c = ResourceProducer(f"r_type_{i}") - names = [f"r_type_{i}_{j}" for j in range(5)] - rm.add_resources(r_type, names, c.producer, old_names) - old_names = names - - -def test_resource_manager_sorted_nodes_two_node_cycle(): - rm = ResourceManager() - c = ResourceProducer("test") - - rm.add_resources("column", ["1"], c.producer, ["stream.2"]) - rm.add_resources("stream", ["2"], c.producer, ["column.1"]) - - with pytest.raises(ResourceError, match="cycle"): - _ = rm.sorted_nodes - - -def test_resource_manager_sorted_nodes_three_node_cycle(): - rm = ResourceManager() - c = ResourceProducer("test") - - rm.add_resources("column", ["1"], c.producer, ["stream.3"]) - rm.add_resources("stream", ["2"], c.producer, ["column.1"]) - rm.add_resources("stream", ["3"], c.producer, ["stream.2"]) - - with pytest.raises(ResourceError, match="cycle"): - _ = rm.sorted_nodes - - -def test_resource_manager_sorted_nodes_large_cycle(): - rm = ResourceManager() - c = ResourceProducer("test") - - for i in range(10): - rm.add_resources("column", [f"{i}"], c.producer, [f"column.{i%10}"]) - - with pytest.raises(ResourceError, match="cycle"): - _ = rm.sorted_nodes - - -def test_resource_manager_sorted_nodes_diamond(): - rm = ResourceManager() - c = ResourceProducer("test") - - rm.add_resources("column", ["1"], c.producer, []) - rm.add_resources("column", ["2"], c.producer, ["column.1"]) - rm.add_resources("column", ["3"], c.producer, ["column.1"]) - rm.add_resources("column", ["4"], c.producer, ["column.2", "column.3"]) - - n = [str(node) for node in rm.sorted_nodes] - - assert n.index("(column.1)") < n.index("(column.2)") - assert n.index("(column.1)") < n.index("(column.3)") - assert n.index("(column.1)") < n.index("(column.4)") - - assert n.index("(column.2)") < n.index("(column.4)") - assert n.index("(column.3)") < n.index("(column.4)") From 83391a599511363f4e4de84c37a559e824806c29 Mon Sep 17 00:00:00 2001 From: Rajan Mudambi <11376379+rmudambi@users.noreply.github.com> Date: Mon, 21 Oct 2024 13:59:47 -0700 Subject: [PATCH 11/22] create resource types (#511) --- src/vivarium/framework/population/manager.py | 47 ++++++++----------- src/vivarium/framework/randomness/stream.py | 26 ++++------ src/vivarium/framework/resource/group.py | 38 +++------------ src/vivarium/framework/resource/manager.py | 29 ++++++------ src/vivarium/framework/resource/resource.py | 36 +++++++++----- src/vivarium/framework/values.py | 45 ++++++++++++------ tests/framework/resource/test_manager.py | 16 ++++++- tests/framework/resource/test_resource.py | 4 +- .../framework/resource/test_resource_group.py | 36 +++----------- 9 files changed, 125 insertions(+), 152 deletions(-) diff --git a/src/vivarium/framework/population/manager.py b/src/vivarium/framework/population/manager.py index 3a0b345b..c738ee0b 100644 --- a/src/vivarium/framework/population/manager.py +++ b/src/vivarium/framework/population/manager.py @@ -266,7 +266,7 @@ def register_simulant_initializer( requires_columns: str | Sequence[str] = (), requires_values: str | Sequence[str] = (), requires_streams: str | Sequence[str] = (), - required_resources: Iterable[str | Pipeline | RandomnessStream] = (), + required_resources: Iterable[str | Resource] = (), ) -> None: """Marks a source of initial state information for new simulants. @@ -290,44 +290,37 @@ def register_simulant_initializer( attributes. required_resources The resources that the initializer requires to run. Strings are - interpreted as column names, and Pipelines and RandomnessStreams - are interpreted as value pipelines and randomness streams, - respectively. + interpreted as column names. """ + if requires_columns or requires_values or requires_streams: + if required_resources: + raise ValueError( + "If requires_columns, requires_values, or requires_streams are provided, " + "requirements must be empty." + ) - has_individual_requires = requires_columns or requires_values or requires_streams - - if has_individual_requires and required_resources: - raise ValueError( - "If requires_columns, requires_values, or requires_streams are provided, " - "requirements must be empty." - ) + if isinstance(requires_columns, str): + requires_columns = [requires_columns] + if isinstance(requires_values, str): + requires_values = [requires_values] + if isinstance(requires_streams, str): + requires_streams = [requires_streams] - if isinstance(creates_columns, str): - creates_columns = [creates_columns] - if isinstance(requires_columns, str): - requires_columns = [requires_columns] - if isinstance(requires_values, str): - requires_values = [requires_values] - if isinstance(requires_streams, str): - requires_streams = [requires_streams] - - declared_dependencies: Iterable[str | Pipeline | RandomnessStream | Resource] - if has_individual_requires: - declared_dependencies = ( + required_resources = ( list(requires_columns) + [Resource("value", name) for name in requires_values] + [Resource("stream", name) for name in requires_streams] ) - else: - declared_dependencies = list(required_resources) + + if isinstance(creates_columns, str): + creates_columns = [creates_columns] if "tracked" not in creates_columns: # The population view itself uses the tracked column, so include # to be safe. - all_dependencies = list(declared_dependencies) + ["tracked"] + all_dependencies = list(required_resources) + ["tracked"] else: - all_dependencies = list(declared_dependencies) + all_dependencies = list(required_resources) self._initializer_components.add(initializer, list(creates_columns)) self.resources.add_resources( diff --git a/src/vivarium/framework/randomness/stream.py b/src/vivarium/framework/randomness/stream.py index d8d2bac5..87236a9d 100644 --- a/src/vivarium/framework/randomness/stream.py +++ b/src/vivarium/framework/randomness/stream.py @@ -37,6 +37,7 @@ from vivarium.framework.randomness.exceptions import RandomnessError from vivarium.framework.randomness.index_map import IndexMap +from vivarium.framework.resource import Resource from vivarium.framework.utilities import rate_to_probability from vivarium.types import ClockTime, NumericArray @@ -62,7 +63,7 @@ def get_hash(key: str) -> int: return int(hashlib.sha1(key.encode("utf8")).hexdigest(), 16) % max_allowable_numpy_seed -class RandomnessStream: +class RandomnessStream(Resource): """A stream for producing common random numbers. `RandomnessStream` objects provide an interface to Vivarium's @@ -70,19 +71,6 @@ class RandomnessStream: for doing common simulation tasks that require random numbers like making decisions among a number of choices. - Attributes - ---------- - key - The name of the randomness stream. - clock - A way to get the current simulation time. - seed - An extra number used to seed the random number generation. - index_map - A key-index mapping with a fectorized hash and vectorized lookups. - initializes_crn_attributes - A boolean indicating whether the stram is used to initialize CRN attributes. - Notes ----- Should not be constructed by client code. @@ -105,15 +93,17 @@ def __init__( index_map: IndexMap, initializes_crn_attributes: bool = False, ): + super().__init__("stream", key) self.key = key + """The name of the randomness stream.""" self.clock = clock + """A way to get the current simulation time.""" self.seed = seed + """An extra number used to seed the random number generation.""" self.index_map = index_map + """A key-index mapping with a vectorized hash and vectorized lookups.""" self.initializes_crn_attributes = initializes_crn_attributes - - @property - def name(self) -> str: - return f"randomness_stream_{self.key}" + """A boolean indicating whether the stream is used to initialize CRN attributes.""" def _key(self, additional_key: Any = None) -> str: """Construct a hashable key from this object's state. diff --git a/src/vivarium/framework/resource/group.py b/src/vivarium/framework/resource/group.py index 6ab66d3d..c3a96d8a 100644 --- a/src/vivarium/framework/resource/group.py +++ b/src/vivarium/framework/resource/group.py @@ -1,15 +1,10 @@ from __future__ import annotations from collections.abc import Iterable, Iterator -from typing import TYPE_CHECKING, Any +from typing import Any -from vivarium.framework.resource.exceptions import ResourceError from vivarium.framework.resource.resource import Resource -if TYPE_CHECKING: - from vivarium.framework.randomness import RandomnessStream - from vivarium.framework.values import Pipeline - class ResourceGroup: """Resource groups are the nodes in the resource dependency graph. @@ -28,14 +23,14 @@ def __init__( resource_type: str, resource_names: list[str], producer: Any, - dependencies: Iterable[str | Pipeline | RandomnessStream | Resource], + dependencies: Iterable[Resource] = (), ): self._resource_type = resource_type self._resource_names = resource_names self._producer = producer - self._dependency_keys = [ - self._get_dependency_key(dependency) for dependency in dependencies - ] + """The method or object that produces this group of resources.""" + self._dependencies = dependencies + """The resources this resource group's producer depends on.""" @property def type(self) -> str: @@ -58,7 +53,7 @@ def producer(self) -> Any: @property def dependencies(self) -> list[str]: """The long names (including type) of dependencies for this group.""" - return self._dependency_keys + return [dependency.resource_id for dependency in self._dependencies] def __iter__(self) -> Iterator[str]: return iter(self.names) @@ -70,24 +65,3 @@ def __repr__(self) -> str: def __str__(self) -> str: resources = ", ".join(self) return f"({resources})" - - @staticmethod - def _get_dependency_key(dependency: str | Pipeline | RandomnessStream | Resource) -> str: - # local import to avoid circular dependency - from vivarium.framework.randomness import RandomnessStream - from vivarium.framework.values import Pipeline - - if isinstance(dependency, str): - return f"column.{dependency}" - elif isinstance(dependency, Pipeline): - return f"value.{dependency.name}" - elif isinstance(dependency, RandomnessStream): - return f"stream.{dependency.key}" - elif isinstance(dependency, Resource): - return str(dependency) - else: - raise ResourceError( - f"Dependency '{dependency}' of unknown type: {type(dependency)}." - " Dependencies must be strings, Pipelines, RandomnessStreams, or" - " Resources." - ) diff --git a/src/vivarium/framework/resource/manager.py b/src/vivarium/framework/resource/manager.py index 270c6e0f..53a03622 100644 --- a/src/vivarium/framework/resource/manager.py +++ b/src/vivarium/framework/resource/manager.py @@ -14,13 +14,11 @@ from vivarium.framework.resource.exceptions import ResourceError from vivarium.framework.resource.group import ResourceGroup -from vivarium.framework.resource.resource import RESOURCE_TYPES, Resource +from vivarium.framework.resource.resource import RESOURCE_TYPES, Column, Resource from vivarium.manager import Interface, Manager if TYPE_CHECKING: from vivarium.framework.engine import Builder - from vivarium.framework.randomness import RandomnessStream - from vivarium.framework.values import Pipeline NULL_RESOURCE_TYPE = "null" @@ -83,7 +81,7 @@ def add_resources( resource_type: str, resource_names: list[str], producer: Any, - dependencies: Iterable[str | Pipeline | RandomnessStream | Resource], + dependencies: Iterable[str | Resource], ) -> None: """Adds managed resources to the resource pool. @@ -97,15 +95,15 @@ def add_resources( producer A method or object that will produce the resources. dependencies - A list of resources that the producer requires. + A list of resources that the producer requires. A string represents + a column resource. Raises ------ ResourceError If either the resource type is invalid, a component has multiple - resource producers for the ``column`` resource type, - there are multiple producers of the same resource, or . - the dependencies are of an invalid type. + resource producers for the ``column`` resource type, or + there are multiple producers of the same resource. """ if resource_type not in RESOURCE_TYPES: raise ResourceError( @@ -131,7 +129,7 @@ def _get_resource_group( resource_type: str, resource_names: list[str], producer: Any, - dependencies: Iterable[str | Pipeline | RandomnessStream | Resource], + dependencies: Iterable[str | Resource], ) -> ResourceGroup: """Packages resource information into a resource group. @@ -139,6 +137,7 @@ def _get_resource_group( -------- :class:`ResourceGroup` """ + dependencies_ = [Column(d) if isinstance(d, str) else d for d in dependencies] if not resource_names: # We have a "producer" that doesn't produce anything, but # does have dependencies. This is necessary for components that @@ -147,7 +146,7 @@ def _get_resource_group( resource_names = [str(self._null_producer_count)] self._null_producer_count += 1 - return ResourceGroup(resource_type, resource_names, producer, dependencies) + return ResourceGroup(resource_type, resource_names, producer, dependencies_) def _to_graph(self) -> nx.DiGraph: """Constructs the full resource graph from information in the groups. @@ -228,7 +227,7 @@ def add_resources( resource_type: str, resource_names: list[str], producer: Any, - dependencies: Iterable[str | Pipeline | RandomnessStream | Resource], + dependencies: Iterable[str | Resource], ) -> None: """Adds managed resources to the resource pool. @@ -242,15 +241,15 @@ def add_resources( producer A method or object that will produce the resources. dependencies - A list of resources that the producer requires. + A list of resources that the producer requires. A string represents + a column resource. Raises ------ ResourceError If either the resource type is invalid, a component has multiple - resource producers for the ``column`` resource type, - there are multiple producers of the same resource, or . - the dependencies are of an invalid type. + resource producers for the ``column`` resource type, or + there are multiple producers of the same resource. """ self._manager.add_resources(resource_type, resource_names, producer, dependencies) diff --git a/src/vivarium/framework/resource/resource.py b/src/vivarium/framework/resource/resource.py index 5358202a..b8ca3fe5 100644 --- a/src/vivarium/framework/resource/resource.py +++ b/src/vivarium/framework/resource/resource.py @@ -1,6 +1,6 @@ from __future__ import annotations -from vivarium.framework.resource.exceptions import ResourceError +from dataclasses import dataclass RESOURCE_TYPES = { "value", @@ -12,20 +12,30 @@ } +@dataclass class Resource: - """A generic resource. + """A generic resource representing a node in the dependency graph.""" - These resources may be required to build up the dependency graph. - """ + resource_type: str + """The type of the resource.""" + name: str + """The name of the resource.""" - def __init__(self, type: str, name: str): - if type not in RESOURCE_TYPES: - raise ResourceError(f"Unknown resource type: {type}") + @property + def resource_id(self) -> str: + """The long name of the resource, including the type.""" + return f"{self.resource_type}.{self.name}" - self.type = type - """The type of the resource.""" - self.name = name - """The name of the resource.""" - def __str__(self) -> str: - return f"{self.type}.{self.name}" +class NullResource(Resource): + """A node in the dependency graph that does not produce any resources.""" + + def __init__(self, index: int): + super().__init__("null", f"{index}") + + +class Column(Resource): + """A resource representing a column in the state table.""" + + def __init__(self, name: str): + super().__init__("column", name) diff --git a/src/vivarium/framework/values.py b/src/vivarium/framework/values.py index 38730c9d..c441f163 100644 --- a/src/vivarium/framework/values.py +++ b/src/vivarium/framework/values.py @@ -191,7 +191,28 @@ def union_post_processor(values: list[NumberLike], _: Any) -> NumberLike: return joint_value -class Pipeline: +class ValueSource(Resource): + """A resource representing the source of a value pipeline.""" + + def __init__(self, name: str) -> None: + super().__init__("value_source", name) + + +class MissingValueSource(Resource): + """A resource representing an undefined source of a value pipeline.""" + + def __init__(self, name: str) -> None: + super().__init__("missing_value_source", name) + + +class ValueModifier(Resource): + """A resource representing a modifier of a value pipeline.""" + + def __init__(self, name: str) -> None: + super().__init__("value_modifier", name) + + +class Pipeline(Resource): """A tool for building up values across several components. Pipelines are lazily initialized so that we don't have to put constraints @@ -206,8 +227,8 @@ class Pipeline: """ def __init__(self, name: str) -> None: - self.name: str = name - """The name of the value represented by this pipeline.""" + super().__init__("value", name) + self.source: Callable[..., Any] | None = None """The callable source of the value represented by the pipeline.""" self.mutators: list[Callable[..., Any]] = [] @@ -363,16 +384,12 @@ def on_post_setup(self, _event: Event) -> None: # we say the pipeline value depends on its source and all its # modifiers. for name, pipe in self._pipelines.items(): - dependencies = [] - if pipe.source is not None: - dependencies.append(Resource("value_source", name)) - else: - dependencies.append(Resource("missing_value_source", name)) + dependencies: list[Resource] = [ + ValueSource(name) if pipe.source else MissingValueSource(name) + ] for i, m in enumerate(pipe.mutators): mutator_name = self._get_modifier_name(m) - dependencies.append( - Resource("value_modifier", f"{name}.{i+1}.{mutator_name}") - ) + dependencies.append(ValueModifier(f"{name}.{i + 1}.{mutator_name}")) self.resources.add_resources("value", [name], pipe._call, dependencies) def register_value_producer( @@ -500,12 +517,12 @@ def _convert_dependencies( requires_columns: Iterable[str], requires_values: Iterable[str], requires_streams: Iterable[str], - required_resources: Iterable[str | Pipeline | RandomnessStream], - ) -> Iterable[str | Pipeline | RandomnessStream | Resource]: + required_resources: Iterable[str | Resource], + ) -> Iterable[str | Resource]: if isinstance(func, Pipeline): # The dependencies of the pipeline itself will have been declared # when the pipeline was registered. - return [Resource("value", func.name)] + return [func] if requires_columns or requires_values or requires_streams: warnings.warn( diff --git a/tests/framework/resource/test_manager.py b/tests/framework/resource/test_manager.py index da909101..ff335252 100644 --- a/tests/framework/resource/test_manager.py +++ b/tests/framework/resource/test_manager.py @@ -1,6 +1,7 @@ from __future__ import annotations from collections.abc import Callable, Mapping +from datetime import datetime import pytest import pytest_mock @@ -24,7 +25,7 @@ def manager(mocker: pytest_mock.MockFixture) -> ResourceManager: @pytest.fixture def randomness_stream() -> RandomnessStream: - return RandomnessStream("stream.1", lambda x: x, 1, IndexMap()) + return RandomnessStream("stream.1", lambda: datetime.now(), 1, IndexMap()) class ResourceProducer(Component): @@ -137,6 +138,17 @@ def test_resource_manager_sorted_nodes_large_cycle(manager: ResourceManager) -> _ = manager.sorted_nodes +def test_large_dependency_chain(manager: ResourceManager) -> None: + for i in range(9, 0, -1): + manager.add_resources( + "column", [f"c_{i}"], ResourceProducer(f"p_{i}").producer, [f"c_{i - 1}"] + ) + manager.add_resources("column", ["c_0"], ResourceProducer("producer_0").producer, []) + + for i, resource in enumerate(manager.sorted_nodes): + assert str(resource) == f"(column.c_{i})" + + def test_resource_manager_sorted_nodes_acyclic(manager: ResourceManager) -> None: _add_resources(manager) @@ -170,7 +182,7 @@ def test_get_population_initializers(manager: ResourceManager) -> None: def _add_resources(manager: ResourceManager) -> Mapping[int, Callable[[SimulantData], None]]: producers = {i: ResourceProducer(f"test_{i}").producer for i in range(5)} - stream = RandomnessStream("B", lambda x: x, 1, IndexMap()) + stream = RandomnessStream("B", lambda: datetime.now(), 1, IndexMap()) pipeline = Pipeline("C") manager.add_resources("column", ["D"], producers[3], [stream, pipeline]) diff --git a/tests/framework/resource/test_resource.py b/tests/framework/resource/test_resource.py index c4c625bc..ee23641e 100644 --- a/tests/framework/resource/test_resource.py +++ b/tests/framework/resource/test_resource.py @@ -1,6 +1,6 @@ from vivarium.framework.resource import Resource -def test_to_string() -> None: +def test_resource_id() -> None: resource = Resource("value_source", "test") - assert str(resource) == "value_source.test" + assert resource.resource_id == "value_source.test" diff --git a/tests/framework/resource/test_resource_group.py b/tests/framework/resource/test_resource_group.py index ef45e1b6..cc9c344f 100644 --- a/tests/framework/resource/test_resource_group.py +++ b/tests/framework/resource/test_resource_group.py @@ -1,13 +1,12 @@ from __future__ import annotations -import pytest +from datetime import datetime from vivarium.framework.randomness import RandomnessStream from vivarium.framework.randomness.index_map import IndexMap -from vivarium.framework.resource import Resource -from vivarium.framework.resource.exceptions import ResourceError from vivarium.framework.resource.group import ResourceGroup -from vivarium.framework.values import Pipeline +from vivarium.framework.resource.resource import Column +from vivarium.framework.values import Pipeline, ValueSource def dummy_producer() -> str: @@ -18,11 +17,11 @@ def test_resource_group() -> None: r_type = "column" r_names = [str(i) for i in range(5)] r_producer = dummy_producer - r_dependencies: list[str | Pipeline | RandomnessStream | Resource] = [ - "an_interesting_column", + r_dependencies = [ + Column("an_interesting_column"), Pipeline("baz"), - RandomnessStream("bar", lambda x: x, 1, IndexMap()), - Resource("value_source", "foo"), + RandomnessStream("bar", lambda: datetime.now(), 1, IndexMap()), + ValueSource("foo"), ] rg = ResourceGroup(r_type, r_names, r_producer, r_dependencies) @@ -37,24 +36,3 @@ def test_resource_group() -> None: "value_source.foo", ] assert list(rg) == rg.names - - -@pytest.mark.parametrize( - "dependency, expected_key", - [ - ("an_interesting_column", "column.an_interesting_column"), - (Pipeline("baz"), "value.baz"), - (RandomnessStream("bar", lambda x: x, 1, IndexMap()), "stream.bar"), - (Resource("value_source", "foo"), "value_source.foo"), - ], -) -def test__get_dependency_key( - dependency: str | Pipeline | RandomnessStream | Resource, expected_key: str -) -> None: - key = ResourceGroup._get_dependency_key(dependency) - assert key == expected_key - - -def test__get_dependency_key_unknown_type() -> None: - with pytest.raises(ResourceError, match="unknown type"): - ResourceGroup._get_dependency_key(1) # type: ignore [arg-type] From 16478ab33c76043d6275881d722ceaa9a40780a5 Mon Sep 17 00:00:00 2001 From: Rajan Mudambi <11376379+rmudambi@users.noreply.github.com> Date: Mon, 21 Oct 2024 14:31:10 -0700 Subject: [PATCH 12/22] refactor resource arguments of add_resources (#512) --- src/vivarium/framework/population/manager.py | 4 +- src/vivarium/framework/randomness/manager.py | 4 +- src/vivarium/framework/resource/group.py | 23 ++-- src/vivarium/framework/resource/manager.py | 57 +++------ src/vivarium/framework/resource/resource.py | 9 -- src/vivarium/framework/values.py | 8 +- tests/framework/resource/test_manager.py | 119 ++++++++---------- .../framework/resource/test_resource_group.py | 27 ++-- 8 files changed, 108 insertions(+), 143 deletions(-) diff --git a/src/vivarium/framework/population/manager.py b/src/vivarium/framework/population/manager.py index c738ee0b..3b8e8ecd 100644 --- a/src/vivarium/framework/population/manager.py +++ b/src/vivarium/framework/population/manager.py @@ -323,9 +323,7 @@ def register_simulant_initializer( all_dependencies = list(required_resources) self._initializer_components.add(initializer, list(creates_columns)) - self.resources.add_resources( - "column", list(creates_columns), initializer, all_dependencies - ) + self.resources.add_resources(creates_columns, initializer, all_dependencies) def get_simulant_creator(self) -> Callable[[int, dict[str, Any] | None], pd.Index[int]]: """Gets a function that can generate new simulants. diff --git a/src/vivarium/framework/randomness/manager.py b/src/vivarium/framework/randomness/manager.py index 6fdcbb2e..8742633c 100644 --- a/src/vivarium/framework/randomness/manager.py +++ b/src/vivarium/framework/randomness/manager.py @@ -118,9 +118,7 @@ def get_randomness_stream( stream = self._get_randomness_stream(decision_point, initializes_crn_attributes) if not initializes_crn_attributes: # We need the key columns to be created before this stream can be called. - self.resources.add_resources( - "stream", [decision_point], stream, self._key_columns - ) + self.resources.add_resources([stream], stream, self._key_columns) self._add_constraint( stream.get_draw, restrict_during=["initialization", "setup", "post_setup"] ) diff --git a/src/vivarium/framework/resource/group.py b/src/vivarium/framework/resource/group.py index c3a96d8a..7c036473 100644 --- a/src/vivarium/framework/resource/group.py +++ b/src/vivarium/framework/resource/group.py @@ -3,6 +3,7 @@ from collections.abc import Iterable, Iterator from typing import Any +from vivarium.framework.resource.exceptions import ResourceError from vivarium.framework.resource.resource import Resource @@ -20,13 +21,18 @@ class ResourceGroup: def __init__( self, - resource_type: str, - resource_names: list[str], + produced_resources: Iterable[Resource], producer: Any, dependencies: Iterable[Resource] = (), ): - self._resource_type = resource_type - self._resource_names = resource_names + if not produced_resources: + raise ResourceError("Resource groups must have at least one resource.") + + if len(set(r.resource_type for r in produced_resources)) != 1: + raise ResourceError("All produced resources must be of the same type.") + + self._resources = list(produced_resources) + """The resources produced by this resource group's producer.""" self._producer = producer """The method or object that produces this group of resources.""" self._dependencies = dependencies @@ -34,16 +40,13 @@ def __init__( @property def type(self) -> str: - """The type of resource produced by this resource group's producer. - - Must be one of `RESOURCE_TYPES`. - """ - return self._resource_type + """The type of resource produced by this resource group's producer.""" + return self._resources[0].resource_type @property def names(self) -> list[str]: """The long names (including type) of all resources in this group.""" - return [f"{self._resource_type}.{name}" for name in self._resource_names] + return [resource.resource_id for resource in self._resources] @property def producer(self) -> Any: diff --git a/src/vivarium/framework/resource/manager.py b/src/vivarium/framework/resource/manager.py index 53a03622..bcd8629e 100644 --- a/src/vivarium/framework/resource/manager.py +++ b/src/vivarium/framework/resource/manager.py @@ -14,14 +14,12 @@ from vivarium.framework.resource.exceptions import ResourceError from vivarium.framework.resource.group import ResourceGroup -from vivarium.framework.resource.resource import RESOURCE_TYPES, Column, Resource +from vivarium.framework.resource.resource import Column, NullResource, Resource from vivarium.manager import Interface, Manager if TYPE_CHECKING: from vivarium.framework.engine import Builder -NULL_RESOURCE_TYPE = "null" - class ResourceManager(Manager): """Manages all the resources needed for population initialization.""" @@ -78,8 +76,7 @@ def setup(self, builder: Builder) -> None: # TODO [MIC-5380]: Refactor add_resources for better type hinting def add_resources( self, - resource_type: str, - resource_names: list[str], + resources: Iterable[str | Resource], producer: Any, dependencies: Iterable[str | Resource], ) -> None: @@ -87,11 +84,8 @@ def add_resources( Parameters ---------- - resource_type - The type of the resources being added. Must be one of - `RESOURCE_TYPES`. - resource_names - A list of names of the resources being added. + resources + The resources being added. A string represents a column resource. producer A method or object that will produce the resources. dependencies @@ -101,19 +95,10 @@ def add_resources( Raises ------ ResourceError - If either the resource type is invalid, a component has multiple - resource producers for the ``column`` resource type, or - there are multiple producers of the same resource. + If a component has multiple resource producers for the ``column`` + resource type or there are multiple producers of the same resource. """ - if resource_type not in RESOURCE_TYPES: - raise ResourceError( - f"Unknown resource type {resource_type}. " - f"Permitted types are {RESOURCE_TYPES}." - ) - - resource_group = self._get_resource_group( - resource_type, resource_names, producer, dependencies - ) + resource_group = self._get_resource_group(resources, producer, dependencies) for resource in resource_group: if resource in self._resource_group_map: @@ -126,8 +111,7 @@ def add_resources( def _get_resource_group( self, - resource_type: str, - resource_names: list[str], + resources: Iterable[str | Resource], producer: Any, dependencies: Iterable[str | Resource], ) -> ResourceGroup: @@ -137,16 +121,17 @@ def _get_resource_group( -------- :class:`ResourceGroup` """ + resources_ = [Column(r) if isinstance(r, str) else r for r in resources] dependencies_ = [Column(d) if isinstance(d, str) else d for d in dependencies] - if not resource_names: + + if not resources_: # We have a "producer" that doesn't produce anything, but # does have dependencies. This is necessary for components that # want to track private state information. - resource_type = NULL_RESOURCE_TYPE - resource_names = [str(self._null_producer_count)] + resources_ = [NullResource(self._null_producer_count)] self._null_producer_count += 1 - return ResourceGroup(resource_type, resource_names, producer, dependencies_) + return ResourceGroup(resources_, producer, dependencies_) def _to_graph(self) -> nx.DiGraph: """Constructs the full resource graph from information in the groups. @@ -189,9 +174,7 @@ def get_population_initializers(self) -> list[Any]: graph construction, but we only need the column producers at population creation time. """ - return [ - r.producer for r in self.sorted_nodes if r.type in {"column", NULL_RESOURCE_TYPE} - ] + return [r.producer for r in self.sorted_nodes if r.type in {"column", "null"}] def __repr__(self) -> str: out = {} @@ -224,8 +207,7 @@ def __init__(self, manager: ResourceManager): def add_resources( self, - resource_type: str, - resource_names: list[str], + resources: Iterable[str | Resource], producer: Any, dependencies: Iterable[str | Resource], ) -> None: @@ -233,11 +215,8 @@ def add_resources( Parameters ---------- - resource_type - The type of the resources being added. Must be one of - `RESOURCE_TYPES`. - resource_names - A list of names of the resources being added. + resources + The resources being added. A string represents a column resource. producer A method or object that will produce the resources. dependencies @@ -251,7 +230,7 @@ def add_resources( resource producers for the ``column`` resource type, or there are multiple producers of the same resource. """ - self._manager.add_resources(resource_type, resource_names, producer, dependencies) + self._manager.add_resources(resources, producer, dependencies) def get_population_initializers(self) -> list[Any]: """Returns a dependency-sorted list of population initializers. diff --git a/src/vivarium/framework/resource/resource.py b/src/vivarium/framework/resource/resource.py index b8ca3fe5..530d5b91 100644 --- a/src/vivarium/framework/resource/resource.py +++ b/src/vivarium/framework/resource/resource.py @@ -2,15 +2,6 @@ from dataclasses import dataclass -RESOURCE_TYPES = { - "value", - "value_source", - "missing_value_source", - "value_modifier", - "column", - "stream", -} - @dataclass class Resource: diff --git a/src/vivarium/framework/values.py b/src/vivarium/framework/values.py index c441f163..fdd21ef7 100644 --- a/src/vivarium/framework/values.py +++ b/src/vivarium/framework/values.py @@ -390,7 +390,7 @@ def on_post_setup(self, _event: Event) -> None: for i, m in enumerate(pipe.mutators): mutator_name = self._get_modifier_name(m) dependencies.append(ValueModifier(f"{name}.{i + 1}.{mutator_name}")) - self.resources.add_resources("value", [name], pipe._call, dependencies) + self.resources.add_resources([pipe], pipe._call, dependencies) def register_value_producer( self, @@ -420,7 +420,7 @@ def register_value_producer( dependencies = self._convert_dependencies( source, requires_columns, requires_values, requires_streams, required_resources ) - self.resources.add_resources("value_source", [value_name], source, dependencies) + self.resources.add_resources([ValueSource(value_name)], source, dependencies) self.add_constraint( pipeline._call, restrict_during=["initialization", "setup", "post_setup"] ) @@ -490,7 +490,7 @@ def register_value_modifier( dependencies = self._convert_dependencies( modifier, requires_columns, requires_values, requires_streams, required_resources ) - self.resources.add_resources("value_modifier", [name], modifier, dependencies) + self.resources.add_resources([ValueModifier(name)], modifier, dependencies) def get_value(self, name: str) -> Pipeline: """Retrieve the pipeline representing the named value. @@ -511,8 +511,8 @@ def get_value(self, name: str) -> Pipeline: self._pipelines[name] = pipeline return pipeline + @staticmethod def _convert_dependencies( - self, func: Callable[..., Any], requires_columns: Iterable[str], requires_values: Iterable[str], diff --git a/tests/framework/resource/test_manager.py b/tests/framework/resource/test_manager.py index ff335252..e4158fc8 100644 --- a/tests/framework/resource/test_manager.py +++ b/tests/framework/resource/test_manager.py @@ -10,10 +10,10 @@ from vivarium.framework.population import SimulantData from vivarium.framework.randomness import RandomnessStream from vivarium.framework.randomness.index_map import IndexMap -from vivarium.framework.resource import Resource, ResourceManager +from vivarium.framework.resource import ResourceManager from vivarium.framework.resource.exceptions import ResourceError -from vivarium.framework.resource.manager import NULL_RESOURCE_TYPE, RESOURCE_TYPES -from vivarium.framework.values import Pipeline +from vivarium.framework.resource.resource import Column, NullResource +from vivarium.framework.values import MissingValueSource, Pipeline, ValueModifier, ValueSource @pytest.fixture @@ -41,74 +41,62 @@ def producer(self, _simulant_data: SimulantData) -> None: pass -@pytest.mark.parametrize("r_type", RESOURCE_TYPES, ids=lambda x: f"r_type_{x}") -def test_resource_manager_get_resource_group(r_type: str, manager: ResourceManager) -> None: - component = ResourceProducer("base") - r_names = ["foo"] - r_producer = component.producer - r_dependencies: list[str | Resource] = [] +@pytest.mark.parametrize( + "resource_class, type_string", + [ + (Pipeline, "value"), + (ValueSource, "value_source"), + (MissingValueSource, "missing_value_source"), + (ValueModifier, "value_modifier"), + (Column, "column"), + (NullResource, "null"), + ], + ids=lambda x: {x.__name__ if isinstance(x, type) else x}, +) +def test_resource_manager_get_resource_group( + resource_class: type, type_string: str, manager: ResourceManager +) -> None: + producer = ResourceProducer("base").producer - group = manager._get_resource_group(r_type, r_names, r_producer, r_dependencies) + group = manager._get_resource_group([resource_class("foo")], producer, []) - assert group.type == r_type - assert group.names == [f"{r_type}.foo"] - assert group.producer == component.producer + assert group.type == type_string + assert group.names == [f"{type_string}.foo"] + assert group.producer == producer assert not group.dependencies def test_resource_manager_get_resource_group_null(manager: ResourceManager) -> None: - component = ResourceProducer("base") - r_names: list[str] = [] - r_producer = component.producer - r_dependencies: list[str | Resource] = [] + producer = ResourceProducer("base").producer - group_1 = manager._get_resource_group("column", r_names, r_producer, r_dependencies) - group_2 = manager._get_resource_group("column", r_names, r_producer, r_dependencies) + group_1 = manager._get_resource_group([], producer, []) + group_2 = manager._get_resource_group([], producer, []) - assert group_1.type == NULL_RESOURCE_TYPE - assert group_1.names == [f"{NULL_RESOURCE_TYPE}.0"] - assert group_1.producer == component.producer + assert group_1.type == "null" + assert group_1.names == ["null.0"] + assert group_1.producer == producer assert not group_1.dependencies - assert group_2.type == NULL_RESOURCE_TYPE - assert group_2.names == [f"{NULL_RESOURCE_TYPE}.1"] - assert group_2.producer == component.producer + assert group_2.type == "null" + assert group_2.names == ["null.1"] + assert group_2.producer == producer assert not group_2.dependencies -def test_resource_manager_add_resources_bad_type(manager: ResourceManager) -> None: - c = ResourceProducer("base") - r_type = "unknown" - r_names = [str(i) for i in range(5)] - r_producer = c.producer - r_dependencies: list[str | Resource] = [] - - with pytest.raises(ResourceError, match="Unknown resource type"): - manager.add_resources(r_type, r_names, r_producer, r_dependencies) - - def test_resource_manager_add_resources_multiple_producers(manager: ResourceManager) -> None: - c1 = ResourceProducer("1") - c2 = ResourceProducer("2") - r_type = "column" - r1_names = [str(i) for i in range(5)] - r2_names = [str(i) for i in range(5, 10)] + ["1"] - r1_producer = c1.producer - r2_producer = c2.producer - r_dependencies: list[str | Resource] = [] - - manager.add_resources(r_type, r1_names, r1_producer, r_dependencies) + r1 = [str(i) for i in range(5)] + r2 = [str(i) for i in range(5, 10)] + ["1"] + + manager.add_resources(r1, ResourceProducer("1").producer, []) with pytest.raises(ResourceError, match="producers for column.1"): - manager.add_resources(r_type, r2_names, r2_producer, r_dependencies) + manager.add_resources(r2, ResourceProducer("2").producer, []) def test_resource_manager_sorted_nodes_two_node_cycle( manager: ResourceManager, randomness_stream: RandomnessStream ) -> None: - c = ResourceProducer("test") - - manager.add_resources("column", ["c_1"], c.producer, [randomness_stream]) - manager.add_resources("stream", [randomness_stream.key], c.producer, ["c_1"]) + manager.add_resources(["c_1"], ResourceProducer("1").producer, [randomness_stream]) + manager.add_resources([randomness_stream], ResourceProducer("2").producer, ["c_1"]) with pytest.raises(ResourceError, match="cycle"): _ = manager.sorted_nodes @@ -117,22 +105,19 @@ def test_resource_manager_sorted_nodes_two_node_cycle( def test_resource_manager_sorted_nodes_three_node_cycle( manager: ResourceManager, randomness_stream: RandomnessStream ) -> None: - c = ResourceProducer("test") pipeline = Pipeline("some_pipeline") - manager.add_resources("column", ["c_1"], c.producer, [randomness_stream]) - manager.add_resources("value", [pipeline.name], c.producer, ["c_1"]) - manager.add_resources("stream", [randomness_stream.key], c.producer, [pipeline]) + manager.add_resources(["c_1"], ResourceProducer("1").producer, [randomness_stream]) + manager.add_resources([pipeline], ResourceProducer("2").producer, ["c_1"]) + manager.add_resources([randomness_stream], ResourceProducer("3").producer, [pipeline]) with pytest.raises(ResourceError, match="cycle"): _ = manager.sorted_nodes def test_resource_manager_sorted_nodes_large_cycle(manager: ResourceManager) -> None: - c = ResourceProducer("test") - for i in range(10): - manager.add_resources("column", [f"c_{i}"], c.producer, [f"c_{i%10}"]) + manager.add_resources([f"c_{i}"], ResourceProducer("1").producer, [f"c_{i % 10}"]) with pytest.raises(ResourceError, match="cycle"): _ = manager.sorted_nodes @@ -140,10 +125,8 @@ def test_resource_manager_sorted_nodes_large_cycle(manager: ResourceManager) -> def test_large_dependency_chain(manager: ResourceManager) -> None: for i in range(9, 0, -1): - manager.add_resources( - "column", [f"c_{i}"], ResourceProducer(f"p_{i}").producer, [f"c_{i - 1}"] - ) - manager.add_resources("column", ["c_0"], ResourceProducer("producer_0").producer, []) + manager.add_resources([f"c_{i}"], ResourceProducer(f"p_{i}").producer, [f"c_{i - 1}"]) + manager.add_resources(["c_0"], ResourceProducer("producer_0").producer, []) for i, resource in enumerate(manager.sorted_nodes): assert str(resource) == f"(column.c_{i})" @@ -161,7 +144,7 @@ def test_resource_manager_sorted_nodes_acyclic(manager: ResourceManager) -> None assert n.index("(stream.B)") < n.index("(column.D)") assert n.index("(value.C)") < n.index("(column.D)") - assert n.index("(stream.B)") < n.index(f"({NULL_RESOURCE_TYPE}.0)") + assert n.index("(stream.B)") < n.index(f"(null.0)") def test_get_population_initializers(manager: ResourceManager) -> None: @@ -185,10 +168,10 @@ def _add_resources(manager: ResourceManager) -> Mapping[int, Callable[[SimulantD stream = RandomnessStream("B", lambda: datetime.now(), 1, IndexMap()) pipeline = Pipeline("C") - manager.add_resources("column", ["D"], producers[3], [stream, pipeline]) - manager.add_resources("stream", ["B"], producers[1], ["A"]) - manager.add_resources("value", ["C"], producers[2], ["A"]) - manager.add_resources("column", ["A"], producers[0], []) - manager.add_resources("column", [], producers[4], [stream]) + manager.add_resources(["D"], producers[3], [stream, pipeline]) + manager.add_resources([stream], producers[1], ["A"]) + manager.add_resources([pipeline], producers[2], ["A"]) + manager.add_resources(["A"], producers[0], []) + manager.add_resources([], producers[4], [stream]) return producers diff --git a/tests/framework/resource/test_resource_group.py b/tests/framework/resource/test_resource_group.py index cc9c344f..4a5352d1 100644 --- a/tests/framework/resource/test_resource_group.py +++ b/tests/framework/resource/test_resource_group.py @@ -2,11 +2,14 @@ from datetime import datetime +import pytest + from vivarium.framework.randomness import RandomnessStream from vivarium.framework.randomness.index_map import IndexMap +from vivarium.framework.resource.exceptions import ResourceError from vivarium.framework.resource.group import ResourceGroup from vivarium.framework.resource.resource import Column -from vivarium.framework.values import Pipeline, ValueSource +from vivarium.framework.values import Pipeline, ValueModifier, ValueSource def dummy_producer() -> str: @@ -14,9 +17,7 @@ def dummy_producer() -> str: def test_resource_group() -> None: - r_type = "column" - r_names = [str(i) for i in range(5)] - r_producer = dummy_producer + resources = [ValueModifier(str(i)) for i in range(5)] r_dependencies = [ Column("an_interesting_column"), Pipeline("baz"), @@ -24,10 +25,10 @@ def test_resource_group() -> None: ValueSource("foo"), ] - rg = ResourceGroup(r_type, r_names, r_producer, r_dependencies) + rg = ResourceGroup(resources, dummy_producer, r_dependencies) - assert rg.type == r_type - assert rg.names == [f"{r_type}.{name}" for name in r_names] + assert rg.type == "value_modifier" + assert rg.names == [f"value_modifier.{i}" for i in range(5)] assert rg.producer == dummy_producer assert rg.dependencies == [ "column.an_interesting_column", @@ -36,3 +37,15 @@ def test_resource_group() -> None: "value_source.foo", ] assert list(rg) == rg.names + + +def test_resource_group_with_no_resources() -> None: + with pytest.raises(ResourceError, match="must have at least one resource"): + _ = ResourceGroup([], dummy_producer, [Column("foo")]) + + +def test_resource_group_with_multiple_resource_types() -> None: + resources = [ValueModifier("foo"), ValueSource("bar")] + + with pytest.raises(ResourceError, match="resources must be of the same type"): + _ = ResourceGroup(resources, dummy_producer) From 8cce3a3135b3b6d80012d7ceed2f7042da7c688d Mon Sep 17 00:00:00 2001 From: Rajan Mudambi <11376379+rmudambi@users.noreply.github.com> Date: Mon, 21 Oct 2024 14:44:35 -0700 Subject: [PATCH 13/22] improve distinction between initializers and non-initializers (#513) --- src/vivarium/framework/population/manager.py | 2 +- src/vivarium/framework/randomness/manager.py | 2 +- src/vivarium/framework/resource/group.py | 64 ++++++++++---- src/vivarium/framework/resource/manager.py | 47 +++++----- src/vivarium/framework/resource/resource.py | 17 ++++ src/vivarium/framework/values.py | 6 +- tests/framework/components/test_component.py | 13 +-- tests/framework/resource/test_manager.py | 88 +++++++++++-------- .../framework/resource/test_resource_group.py | 27 ++++-- 9 files changed, 175 insertions(+), 91 deletions(-) diff --git a/src/vivarium/framework/population/manager.py b/src/vivarium/framework/population/manager.py index 3b8e8ecd..a1c0da9a 100644 --- a/src/vivarium/framework/population/manager.py +++ b/src/vivarium/framework/population/manager.py @@ -323,7 +323,7 @@ def register_simulant_initializer( all_dependencies = list(required_resources) self._initializer_components.add(initializer, list(creates_columns)) - self.resources.add_resources(creates_columns, initializer, all_dependencies) + self.resources.add_resources(creates_columns, all_dependencies, initializer) def get_simulant_creator(self) -> Callable[[int, dict[str, Any] | None], pd.Index[int]]: """Gets a function that can generate new simulants. diff --git a/src/vivarium/framework/randomness/manager.py b/src/vivarium/framework/randomness/manager.py index 8742633c..eae1921f 100644 --- a/src/vivarium/framework/randomness/manager.py +++ b/src/vivarium/framework/randomness/manager.py @@ -118,7 +118,7 @@ def get_randomness_stream( stream = self._get_randomness_stream(decision_point, initializes_crn_attributes) if not initializes_crn_attributes: # We need the key columns to be created before this stream can be called. - self.resources.add_resources([stream], stream, self._key_columns) + self.resources.add_resources([stream], self._key_columns) self._add_constraint( stream.get_draw, restrict_during=["initialization", "setup", "post_setup"] ) diff --git a/src/vivarium/framework/resource/group.py b/src/vivarium/framework/resource/group.py index 7c036473..5981a636 100644 --- a/src/vivarium/framework/resource/group.py +++ b/src/vivarium/framework/resource/group.py @@ -1,11 +1,14 @@ from __future__ import annotations -from collections.abc import Iterable, Iterator -from typing import Any +from collections.abc import Callable, Iterable, Iterator +from typing import TYPE_CHECKING from vivarium.framework.resource.exceptions import ResourceError from vivarium.framework.resource.resource import Resource +if TYPE_CHECKING: + from vivarium.framework.population import SimulantData + class ResourceGroup: """Resource groups are the nodes in the resource dependency graph. @@ -22,42 +25,65 @@ class ResourceGroup: def __init__( self, produced_resources: Iterable[Resource], - producer: Any, - dependencies: Iterable[Resource] = (), + dependencies: Iterable[Resource], + initializer: Callable[[SimulantData], None] | None = None, ): + """Create a new resource group. + + Parameters + ---------- + produced_resources + The resources produced by this resource group's producer. + dependencies + The resources this resource group's producer depends on. + initializer + The method that initializes this group of resources. If this is + None, the resources don't need to be initialized. + + Raises + ------ + ResourceError + If the resource group is not well-formed. + """ if not produced_resources: raise ResourceError("Resource groups must have at least one resource.") if len(set(r.resource_type for r in produced_resources)) != 1: raise ResourceError("All produced resources must be of the same type.") - self._resources = list(produced_resources) - """The resources produced by this resource group's producer.""" - self._producer = producer - """The method or object that produces this group of resources.""" - self._dependencies = dependencies - """The resources this resource group's producer depends on.""" + if list(produced_resources)[0].is_initialized != (initializer is not None): + raise ResourceError( + "Resource groups with an initializer must have initialized resources." + ) - @property - def type(self) -> str: + self.type = list(produced_resources)[0].resource_type """The type of resource produced by this resource group's producer.""" - return self._resources[0].resource_type + self._resources = {resource.resource_id: resource for resource in produced_resources} + self._initializer = initializer + self._dependencies = dependencies @property def names(self) -> list[str]: """The long names (including type) of all resources in this group.""" - return [resource.resource_id for resource in self._resources] + return list(self._resources) @property - def producer(self) -> Any: - """The method or object that produces this group of resources.""" - return self._producer + def initializer(self) -> Callable[[SimulantData], None]: + """The method that initializes this group of resources.""" + if self._initializer is None: + raise ResourceError("This resource group does not have an initializer.") + return self._initializer @property def dependencies(self) -> list[str]: """The long names (including type) of dependencies for this group.""" return [dependency.resource_id for dependency in self._dependencies] + @property + def is_initializer(self) -> bool: + """Return True if this resource group's producer is an initializer.""" + return self._initializer is not None + def __iter__(self) -> Iterator[str]: return iter(self.names) @@ -68,3 +94,7 @@ def __repr__(self) -> str: def __str__(self) -> str: resources = ", ".join(self) return f"({resources})" + + def get_resource(self, resource_id: str) -> Resource: + """Get a resource by its resource_id.""" + return self._resources[resource_id] diff --git a/src/vivarium/framework/resource/manager.py b/src/vivarium/framework/resource/manager.py index bcd8629e..09165d8e 100644 --- a/src/vivarium/framework/resource/manager.py +++ b/src/vivarium/framework/resource/manager.py @@ -7,7 +7,7 @@ from __future__ import annotations -from collections.abc import Iterable +from collections.abc import Callable, Iterable from typing import TYPE_CHECKING, Any import networkx as nx @@ -19,6 +19,7 @@ if TYPE_CHECKING: from vivarium.framework.engine import Builder + from vivarium.framework.population import SimulantData class ResourceManager(Manager): @@ -73,12 +74,11 @@ def sorted_nodes(self) -> list[ResourceGroup]: def setup(self, builder: Builder) -> None: self.logger = builder.logging.get_logger(self.name) - # TODO [MIC-5380]: Refactor add_resources for better type hinting def add_resources( self, resources: Iterable[str | Resource], - producer: Any, dependencies: Iterable[str | Resource], + initializer: Callable[[SimulantData], None] | None, ) -> None: """Adds managed resources to the resource pool. @@ -86,11 +86,12 @@ def add_resources( ---------- resources The resources being added. A string represents a column resource. - producer - A method or object that will produce the resources. dependencies A list of resources that the producer requires. A string represents a column resource. + initializer + A method that will be called to initialize the resources. This is + called during population initialization. Raises ------ @@ -98,22 +99,27 @@ def add_resources( If a component has multiple resource producers for the ``column`` resource type or there are multiple producers of the same resource. """ - resource_group = self._get_resource_group(resources, producer, dependencies) - - for resource in resource_group: - if resource in self._resource_group_map: - other_producer = self._resource_group_map[resource].producer + resource_group = self._get_resource_group(resources, dependencies, initializer) + + for resource_id in resource_group: + if resource_id in self._resource_group_map: + other_resource = self._resource_group_map[resource_id] + if resource_group.is_initializer: + raise ResourceError( + f"Both {initializer} and {other_resource.initializer}" + f" are registered as initializers for {resource_id}." + ) + resource = resource_group.get_resource(resource_id) raise ResourceError( - f"Both {producer} and {other_producer} are registered as " - f"producers for {resource}." + f"Resource {resource} is not allowed to be registered more" " than once." ) - self._resource_group_map[resource] = resource_group + self._resource_group_map[resource_id] = resource_group def _get_resource_group( self, resources: Iterable[str | Resource], - producer: Any, dependencies: Iterable[str | Resource], + initializer: Callable[[SimulantData], None] | None, ) -> ResourceGroup: """Packages resource information into a resource group. @@ -131,7 +137,7 @@ def _get_resource_group( resources_ = [NullResource(self._null_producer_count)] self._null_producer_count += 1 - return ResourceGroup(resources_, producer, dependencies_) + return ResourceGroup(resources_, dependencies_, initializer) def _to_graph(self) -> nx.DiGraph: """Constructs the full resource graph from information in the groups. @@ -174,7 +180,7 @@ def get_population_initializers(self) -> list[Any]: graph construction, but we only need the column producers at population creation time. """ - return [r.producer for r in self.sorted_nodes if r.type in {"column", "null"}] + return [r.initializer for r in self.sorted_nodes if r.is_initializer] def __repr__(self) -> str: out = {} @@ -208,8 +214,8 @@ def __init__(self, manager: ResourceManager): def add_resources( self, resources: Iterable[str | Resource], - producer: Any, dependencies: Iterable[str | Resource], + initializer: Callable[[SimulantData], None] | None = None, ) -> None: """Adds managed resources to the resource pool. @@ -217,11 +223,12 @@ def add_resources( ---------- resources The resources being added. A string represents a column resource. - producer - A method or object that will produce the resources. dependencies A list of resources that the producer requires. A string represents a column resource. + initializer + A method that will be called to initialize the resources. This is + called during population initialization. Raises ------ @@ -230,7 +237,7 @@ def add_resources( resource producers for the ``column`` resource type, or there are multiple producers of the same resource. """ - self._manager.add_resources(resources, producer, dependencies) + self._manager.add_resources(resources, dependencies, initializer) def get_population_initializers(self) -> list[Any]: """Returns a dependency-sorted list of population initializers. diff --git a/src/vivarium/framework/resource/resource.py b/src/vivarium/framework/resource/resource.py index 530d5b91..bfcda9e0 100644 --- a/src/vivarium/framework/resource/resource.py +++ b/src/vivarium/framework/resource/resource.py @@ -17,6 +17,13 @@ def resource_id(self) -> str: """The long name of the resource, including the type.""" return f"{self.resource_type}.{self.name}" + # TODO [MIC-5452]: make this an abstract method when support for old + # requirements specification is dropped + @property + def is_initialized(self) -> bool: + """Return True if the resource needs to be initialized.""" + return False + class NullResource(Resource): """A node in the dependency graph that does not produce any resources.""" @@ -24,9 +31,19 @@ class NullResource(Resource): def __init__(self, index: int): super().__init__("null", f"{index}") + @property + def is_initialized(self) -> bool: + """Return True if the resource needs to be initialized.""" + return True + class Column(Resource): """A resource representing a column in the state table.""" def __init__(self, name: str): super().__init__("column", name) + + @property + def is_initialized(self) -> bool: + """Return True if the resource needs to be initialized.""" + return True diff --git a/src/vivarium/framework/values.py b/src/vivarium/framework/values.py index fdd21ef7..13436dd1 100644 --- a/src/vivarium/framework/values.py +++ b/src/vivarium/framework/values.py @@ -390,7 +390,7 @@ def on_post_setup(self, _event: Event) -> None: for i, m in enumerate(pipe.mutators): mutator_name = self._get_modifier_name(m) dependencies.append(ValueModifier(f"{name}.{i + 1}.{mutator_name}")) - self.resources.add_resources([pipe], pipe._call, dependencies) + self.resources.add_resources([pipe], dependencies) def register_value_producer( self, @@ -420,7 +420,7 @@ def register_value_producer( dependencies = self._convert_dependencies( source, requires_columns, requires_values, requires_streams, required_resources ) - self.resources.add_resources([ValueSource(value_name)], source, dependencies) + self.resources.add_resources([ValueSource(value_name)], dependencies) self.add_constraint( pipeline._call, restrict_during=["initialization", "setup", "post_setup"] ) @@ -490,7 +490,7 @@ def register_value_modifier( dependencies = self._convert_dependencies( modifier, requires_columns, requires_values, requires_streams, required_resources ) - self.resources.add_resources([ValueModifier(name)], modifier, dependencies) + self.resources.add_resources([ValueModifier(name)], dependencies) def get_value(self, name: str) -> Pipeline: """Retrieve the pipeline representing the named value. diff --git a/tests/framework/components/test_component.py b/tests/framework/components/test_component.py index 86ce4df5..89a9dce4 100644 --- a/tests/framework/components/test_component.py +++ b/tests/framework/components/test_component.py @@ -103,18 +103,21 @@ def test_component_that_creates_and_requires_columns_population_view(): def test_component_with_initialization_requirements(): - component = ColumnCreatorAndRequirer() - simulation = InteractiveContext(components=[ColumnCreator(), component]) + simulation = InteractiveContext( + components=[ColumnCreator(), ColumnCreatorAndRequirer()], + ) # Assert required resources have been recorded by the ResourceManager component_dependencies_list = [ r.dependencies # get all resources in the dependency graph for r in simulation._resource.sorted_nodes - # if the producer is an instance method - if hasattr(r.producer, "__self__") + # if the resource is an initializer + if r.is_initializer + # its initializer is an instance method + and hasattr(r.initializer, "__self__") # and is a method of ColumnCreatorAndRequirer - and isinstance(r.producer.__self__, ColumnCreatorAndRequirer) + and isinstance(r.initializer.__self__, ColumnCreatorAndRequirer) ] assert len(component_dependencies_list) == 1 component_dependencies = component_dependencies_list[0] diff --git a/tests/framework/resource/test_manager.py b/tests/framework/resource/test_manager.py index e4158fc8..e8a9a6e3 100644 --- a/tests/framework/resource/test_manager.py +++ b/tests/framework/resource/test_manager.py @@ -37,66 +37,82 @@ def __init__(self, name: str): super().__init__() self._name = name - def producer(self, _simulant_data: SimulantData) -> None: + def initializer(self, _simulant_data: SimulantData) -> None: pass @pytest.mark.parametrize( - "resource_class, type_string", + "resource_class, type_string, is_initializer", [ - (Pipeline, "value"), - (ValueSource, "value_source"), - (MissingValueSource, "missing_value_source"), - (ValueModifier, "value_modifier"), - (Column, "column"), - (NullResource, "null"), + (Pipeline, "value", False), + (ValueSource, "value_source", False), + (MissingValueSource, "missing_value_source", False), + (ValueModifier, "value_modifier", False), + (Column, "column", True), + (NullResource, "null", True), ], ids=lambda x: {x.__name__ if isinstance(x, type) else x}, ) def test_resource_manager_get_resource_group( - resource_class: type, type_string: str, manager: ResourceManager + resource_class: type, type_string: str, is_initializer: bool, manager: ResourceManager ) -> None: - producer = ResourceProducer("base").producer + initializer = ResourceProducer("base").initializer - group = manager._get_resource_group([resource_class("foo")], producer, []) + group = manager._get_resource_group( + [resource_class("foo")], [], initializer if is_initializer else None + ) assert group.type == type_string assert group.names == [f"{type_string}.foo"] - assert group.producer == producer assert not group.dependencies + assert group.is_initializer == is_initializer + if is_initializer: + assert group.initializer == initializer + else: + with pytest.raises(ResourceError, match="does not have an initializer"): + _ = group.initializer def test_resource_manager_get_resource_group_null(manager: ResourceManager) -> None: - producer = ResourceProducer("base").producer + initializer = ResourceProducer("base").initializer - group_1 = manager._get_resource_group([], producer, []) - group_2 = manager._get_resource_group([], producer, []) + group_1 = manager._get_resource_group([], [], initializer) + group_2 = manager._get_resource_group([], [], initializer) assert group_1.type == "null" assert group_1.names == ["null.0"] - assert group_1.producer == producer + assert group_1.initializer == initializer assert not group_1.dependencies assert group_2.type == "null" assert group_2.names == ["null.1"] - assert group_2.producer == producer + assert group_2.initializer == initializer assert not group_2.dependencies -def test_resource_manager_add_resources_multiple_producers(manager: ResourceManager) -> None: +def test_resource_manager_add_same_column_twice(manager: ResourceManager) -> None: r1 = [str(i) for i in range(5)] r2 = [str(i) for i in range(5, 10)] + ["1"] - manager.add_resources(r1, ResourceProducer("1").producer, []) - with pytest.raises(ResourceError, match="producers for column.1"): - manager.add_resources(r2, ResourceProducer("2").producer, []) + manager.add_resources(r1, [], ResourceProducer("1").initializer) + with pytest.raises(ResourceError, match="initializers for column.1"): + manager.add_resources(r2, [], ResourceProducer("2").initializer) + + +def test_resource_manager_add_same_pipeline_twice(manager: ResourceManager) -> None: + r1 = [Pipeline(str(i)) for i in range(5)] + r2 = [Pipeline(str(i)) for i in range(5, 10)] + [Pipeline("1")] + + manager.add_resources(r1, [], None) + with pytest.raises(ResourceError, match="registered more than once"): + manager.add_resources(r2, [], None) def test_resource_manager_sorted_nodes_two_node_cycle( manager: ResourceManager, randomness_stream: RandomnessStream ) -> None: - manager.add_resources(["c_1"], ResourceProducer("1").producer, [randomness_stream]) - manager.add_resources([randomness_stream], ResourceProducer("2").producer, ["c_1"]) + manager.add_resources(["c_1"], [randomness_stream], ResourceProducer("1").initializer) + manager.add_resources([randomness_stream], ["c_1"], None) with pytest.raises(ResourceError, match="cycle"): _ = manager.sorted_nodes @@ -107,9 +123,9 @@ def test_resource_manager_sorted_nodes_three_node_cycle( ) -> None: pipeline = Pipeline("some_pipeline") - manager.add_resources(["c_1"], ResourceProducer("1").producer, [randomness_stream]) - manager.add_resources([pipeline], ResourceProducer("2").producer, ["c_1"]) - manager.add_resources([randomness_stream], ResourceProducer("3").producer, [pipeline]) + manager.add_resources(["c_1"], [randomness_stream], ResourceProducer("1").initializer) + manager.add_resources([pipeline], ["c_1"], None) + manager.add_resources([randomness_stream], [pipeline], None) with pytest.raises(ResourceError, match="cycle"): _ = manager.sorted_nodes @@ -117,7 +133,7 @@ def test_resource_manager_sorted_nodes_three_node_cycle( def test_resource_manager_sorted_nodes_large_cycle(manager: ResourceManager) -> None: for i in range(10): - manager.add_resources([f"c_{i}"], ResourceProducer("1").producer, [f"c_{i % 10}"]) + manager.add_resources([f"c_{i}"], [f"c_{i % 10}"], ResourceProducer("1").initializer) with pytest.raises(ResourceError, match="cycle"): _ = manager.sorted_nodes @@ -125,8 +141,10 @@ def test_resource_manager_sorted_nodes_large_cycle(manager: ResourceManager) -> def test_large_dependency_chain(manager: ResourceManager) -> None: for i in range(9, 0, -1): - manager.add_resources([f"c_{i}"], ResourceProducer(f"p_{i}").producer, [f"c_{i - 1}"]) - manager.add_resources(["c_0"], ResourceProducer("producer_0").producer, []) + manager.add_resources( + [f"c_{i}"], [f"c_{i - 1}"], ResourceProducer(f"p_{i}").initializer + ) + manager.add_resources(["c_0"], [], ResourceProducer("producer_0").initializer) for i, resource in enumerate(manager.sorted_nodes): assert str(resource) == f"(column.c_{i})" @@ -163,15 +181,15 @@ def test_get_population_initializers(manager: ResourceManager) -> None: def _add_resources(manager: ResourceManager) -> Mapping[int, Callable[[SimulantData], None]]: - producers = {i: ResourceProducer(f"test_{i}").producer for i in range(5)} + producers = {i: ResourceProducer(f"test_{i}").initializer for i in range(5)} stream = RandomnessStream("B", lambda: datetime.now(), 1, IndexMap()) pipeline = Pipeline("C") - manager.add_resources(["D"], producers[3], [stream, pipeline]) - manager.add_resources([stream], producers[1], ["A"]) - manager.add_resources([pipeline], producers[2], ["A"]) - manager.add_resources(["A"], producers[0], []) - manager.add_resources([], producers[4], [stream]) + manager.add_resources(["D"], [stream, pipeline], producers[3]) + manager.add_resources([stream], ["A"], None) + manager.add_resources([pipeline], ["A"], None) + manager.add_resources(["A"], [], producers[0]) + manager.add_resources([], [stream], producers[4]) return producers diff --git a/tests/framework/resource/test_resource_group.py b/tests/framework/resource/test_resource_group.py index 4a5352d1..d3c2bd84 100644 --- a/tests/framework/resource/test_resource_group.py +++ b/tests/framework/resource/test_resource_group.py @@ -4,6 +4,7 @@ import pytest +from vivarium.framework.population import SimulantData from vivarium.framework.randomness import RandomnessStream from vivarium.framework.randomness.index_map import IndexMap from vivarium.framework.resource.exceptions import ResourceError @@ -12,12 +13,12 @@ from vivarium.framework.values import Pipeline, ValueModifier, ValueSource -def dummy_producer() -> str: - return "resources!" +def dummy_initializer(_simulant_data: SimulantData) -> None: + pass def test_resource_group() -> None: - resources = [ValueModifier(str(i)) for i in range(5)] + resources = [Column(str(i)) for i in range(5)] r_dependencies = [ Column("an_interesting_column"), Pipeline("baz"), @@ -25,11 +26,11 @@ def test_resource_group() -> None: ValueSource("foo"), ] - rg = ResourceGroup(resources, dummy_producer, r_dependencies) + rg = ResourceGroup(resources, r_dependencies, dummy_initializer) - assert rg.type == "value_modifier" - assert rg.names == [f"value_modifier.{i}" for i in range(5)] - assert rg.producer == dummy_producer + assert rg.type == "column" + assert rg.names == [f"column.{i}" for i in range(5)] + assert rg.initializer == dummy_initializer assert rg.dependencies == [ "column.an_interesting_column", "value.baz", @@ -39,13 +40,21 @@ def test_resource_group() -> None: assert list(rg) == rg.names +def test_resource_group_is_initializer() -> None: + resources = [ValueModifier("foo")] + rg = ResourceGroup(resources, [Column("bar")]) + + with pytest.raises(ResourceError, match="does not have an initializer"): + _ = rg.initializer + + def test_resource_group_with_no_resources() -> None: with pytest.raises(ResourceError, match="must have at least one resource"): - _ = ResourceGroup([], dummy_producer, [Column("foo")]) + _ = ResourceGroup([], [Column("foo")]) def test_resource_group_with_multiple_resource_types() -> None: resources = [ValueModifier("foo"), ValueSource("bar")] with pytest.raises(ResourceError, match="resources must be of the same type"): - _ = ResourceGroup(resources, dummy_producer) + _ = ResourceGroup(resources, []) From 5449d06b038c1130e733634afa1ac1967f852258 Mon Sep 17 00:00:00 2001 From: Rajan Mudambi <11376379+rmudambi@users.noreply.github.com> Date: Tue, 29 Oct 2024 12:31:32 -0700 Subject: [PATCH 14/22] refactor to create values package (#514) --- docs/nitpick-exceptions | 1 + .../source/api_reference/framework/values.rst | 1 - .../framework/values/combiners.rst | 1 + .../framework/values/exceptions.rst | 1 + .../api_reference/framework/values/index.rst | 11 + .../framework/values/manager.rst | 1 + .../framework/values/pipeline.rst | 1 + .../framework/values/post_processors.rst | 1 + docs/source/concepts/builder.rst | 2 +- docs/source/concepts/values.rst | 20 +- docs/source/tutorials/boids.rst | 4 +- docs/source/tutorials/disease_model.rst | 2 +- src/vivarium/framework/results/interface.py | 2 +- src/vivarium/framework/values/__init__.py | 28 ++ src/vivarium/framework/values/combiners.py | 64 ++++ src/vivarium/framework/values/exceptions.py | 7 + .../{values.py => values/manager.py} | 341 +----------------- src/vivarium/framework/values/pipeline.py | 171 +++++++++ .../framework/values/post_processors.py | 101 ++++++ 19 files changed, 412 insertions(+), 348 deletions(-) delete mode 100644 docs/source/api_reference/framework/values.rst create mode 100644 docs/source/api_reference/framework/values/combiners.rst create mode 100644 docs/source/api_reference/framework/values/exceptions.rst create mode 100644 docs/source/api_reference/framework/values/index.rst create mode 100644 docs/source/api_reference/framework/values/manager.rst create mode 100644 docs/source/api_reference/framework/values/pipeline.rst create mode 100644 docs/source/api_reference/framework/values/post_processors.rst create mode 100644 src/vivarium/framework/values/__init__.py create mode 100644 src/vivarium/framework/values/combiners.py create mode 100644 src/vivarium/framework/values/exceptions.py rename src/vivarium/framework/{values.py => values/manager.py} (61%) create mode 100644 src/vivarium/framework/values/pipeline.py create mode 100644 src/vivarium/framework/values/post_processors.py diff --git a/docs/nitpick-exceptions b/docs/nitpick-exceptions index ec57a7ef..82e46d47 100644 --- a/docs/nitpick-exceptions +++ b/docs/nitpick-exceptions @@ -29,6 +29,7 @@ py:class loguru.Logger # elsewhere. Works fine for static type checker though. I think this # is because sphinx does runtime checks. py:class ScalarValue +py:class NumberLike py:class NumericArray py:class ClockTime py:class Time diff --git a/docs/source/api_reference/framework/values.rst b/docs/source/api_reference/framework/values.rst deleted file mode 100644 index 14068579..00000000 --- a/docs/source/api_reference/framework/values.rst +++ /dev/null @@ -1 +0,0 @@ -.. automodule:: vivarium.framework.values diff --git a/docs/source/api_reference/framework/values/combiners.rst b/docs/source/api_reference/framework/values/combiners.rst new file mode 100644 index 00000000..660486b1 --- /dev/null +++ b/docs/source/api_reference/framework/values/combiners.rst @@ -0,0 +1 @@ +.. automodule:: vivarium.framework.values.combiners \ No newline at end of file diff --git a/docs/source/api_reference/framework/values/exceptions.rst b/docs/source/api_reference/framework/values/exceptions.rst new file mode 100644 index 00000000..ab08cfb1 --- /dev/null +++ b/docs/source/api_reference/framework/values/exceptions.rst @@ -0,0 +1 @@ +.. automodule:: vivarium.framework.values.exceptions \ No newline at end of file diff --git a/docs/source/api_reference/framework/values/index.rst b/docs/source/api_reference/framework/values/index.rst new file mode 100644 index 00000000..b8aa71a5 --- /dev/null +++ b/docs/source/api_reference/framework/values/index.rst @@ -0,0 +1,11 @@ +================ +Value Management +================ + +.. automodule:: vivarium.framework.values + +.. toctree:: + :maxdepth: 1 + :glob: + + * \ No newline at end of file diff --git a/docs/source/api_reference/framework/values/manager.rst b/docs/source/api_reference/framework/values/manager.rst new file mode 100644 index 00000000..93c391e7 --- /dev/null +++ b/docs/source/api_reference/framework/values/manager.rst @@ -0,0 +1 @@ +.. automodule:: vivarium.framework.values.manager \ No newline at end of file diff --git a/docs/source/api_reference/framework/values/pipeline.rst b/docs/source/api_reference/framework/values/pipeline.rst new file mode 100644 index 00000000..f2f2bc18 --- /dev/null +++ b/docs/source/api_reference/framework/values/pipeline.rst @@ -0,0 +1 @@ +.. automodule:: vivarium.framework.values.pipeline \ No newline at end of file diff --git a/docs/source/api_reference/framework/values/post_processors.rst b/docs/source/api_reference/framework/values/post_processors.rst new file mode 100644 index 00000000..a68b4f4f --- /dev/null +++ b/docs/source/api_reference/framework/values/post_processors.rst @@ -0,0 +1 @@ +.. automodule:: vivarium.framework.values.post_processors \ No newline at end of file diff --git a/docs/source/concepts/builder.rst b/docs/source/concepts/builder.rst index 2b20ab27..8e1ed65b 100644 --- a/docs/source/concepts/builder.rst +++ b/docs/source/concepts/builder.rst @@ -24,7 +24,7 @@ they register for services and provide information about their structure. For ex a component needing to leverage the simulation clock and step size to determine a numerical effect to apply on each time step, will get the simulation clock and step size though the Builder and will register -method(s) to apply the effect (e.g., via :meth:`vivarium.framework.values.ValuesInterface.register_value_modifier`). +method(s) to apply the effect (e.g., via :meth:`vivarium.framework.values.manager.ValuesInterface.register_value_modifier`). Another component, needing to initialize state for simulants at before the simulation begin, might call :meth:`vivarium.framework.population.manager.PopulationInterface.initializes_simulants` in its setup method to register method(s) that setup the additional state. diff --git a/docs/source/concepts/values.rst b/docs/source/concepts/values.rst index 31c81439..9def2fa3 100644 --- a/docs/source/concepts/values.rst +++ b/docs/source/concepts/values.rst @@ -6,7 +6,7 @@ The Values System The values system provides an interface to an alternative representation of :term:`state ` in the simulation: pipelines. -:class:`Pipelines ` are dynamically +:class:`Pipelines ` are dynamically calculated values that can be constructed across multiple :ref:`components `. This ability for multiple components to together compose a single value is the biggest advantage @@ -49,13 +49,13 @@ three options for combiners, detailed in the following table. * - Combiner - Description - Modifier Signature - * - | :func:`Replace ` + * - | :func:`Replace ` - | Replaces the output of the source or modifier with the output of the | next modifier. This is the default combiner if none is specified on | pipeline registration. - | Arguments for the modifiers should be the same as the source with an | additional last argument of the results of the previous modifier. - * - | :func:`List ` + * - | :func:`List ` - | The output of the source should be a list to which the results of the | modifiers are appended. - | Modifiers should have the same signature as the source. @@ -70,11 +70,11 @@ combiner to do some postprocessing. * - Post-processor - Description - * - | :func:`Rescale ` + * - | :func:`Rescale ` - | Used for pipelines that produce rates. Rescales the rates to the | size of the time step. Rates provided by source and modifiers are | presumed to be annual. - * - | :func:`Union ` + * - | :func:`Union ` - | Used for pipelines that produce independent proportions or | probabilities. Combines values in a way that is consistent with a | union of the underlying sample space @@ -99,19 +99,19 @@ The values system provides four interface methods, available off the * - Method - Description - * - | :meth:`register_value_producer ` + * - | :meth:`register_value_producer ` - | Register a new pipeline with the values system. Provide a name for the | pipeline and a source. Optionally provide a combiner (defaults to | the replace combiner) and a postprocessor. Provide dependencies (see note). - * - | :meth:`register_rate_producer ` - - | A special case of :meth:`register_value_producer ` + * - | :meth:`register_rate_producer ` + - | A special case of :meth:`register_value_producer ` | for rates specifically. | Provide a name for the pipeline and a source and the values system will | automatically use the rescale postprocessor. Provide dependencies (see note). - * - | :meth:`register_value_modifier ` + * - | :meth:`register_value_modifier ` - | Register a modifier to a pipeline. Provide a name for the pipeline to | modify and a modifier callable. Provide dependencies (see note). - * - | :meth:`get_value ` + * - | :meth:`get_value ` - | Retrieve a reference to the pipeline with the given name. .. note:: diff --git a/docs/source/tutorials/boids.rst b/docs/source/tutorials/boids.rst index 4bb2dc83..cbb12635 100644 --- a/docs/source/tutorials/boids.rst +++ b/docs/source/tutorials/boids.rst @@ -256,7 +256,7 @@ You can find an overview of the values system :ref:`here `. The Builder class exposes an additional property for working with value pipelines: :meth:`vivarium.framework.engine.Builder.value`. -We call the :meth:`vivarium.framework.values.ValuesInterface.register_value_producer` +We call the :meth:`vivarium.framework.values.manager.ValuesInterface.register_value_producer` method to register a new pipeline. .. literalinclude:: ../../../src/vivarium/examples/boids/movement.py @@ -474,7 +474,7 @@ To access the value pipeline we created in the Neighbors component, we use pipeline, we simply call that pipeline as a function inside ``on_time_step`` to retrieve its values for a specified index. The major new Vivarium feature seen here is that of the **value modifier**, -which we register with :meth:`vivarium.framework.values.ValuesInterface.register_value_modifier`. +which we register with :meth:`vivarium.framework.values.manager.ValuesInterface.register_value_modifier`. As the name suggests, this allows us to modify the values in a pipeline, in this case adding the effect of a force to the values in the ``acceleration`` pipeline. We register that the ``apply_force`` method will modify the acceleration values like so: diff --git a/docs/source/tutorials/disease_model.rst b/docs/source/tutorials/disease_model.rst index 5ba5f3de..822f76ce 100644 --- a/docs/source/tutorials/disease_model.rst +++ b/docs/source/tutorials/disease_model.rst @@ -599,7 +599,7 @@ configuration and the mortality randomness stream (which is used to answer the question "which simulants died at this time step?"). The main feature of note is the introduction of the -:class:`values system `. +:class:`values system `. The values system provides a way of distributing the computation of a value over multiple components. This can be a bit difficult to grasp, but is vital to the way we think about components in Vivarium. The best diff --git a/src/vivarium/framework/results/interface.py b/src/vivarium/framework/results/interface.py index fac5baff..16d4fb92 100644 --- a/src/vivarium/framework/results/interface.py +++ b/src/vivarium/framework/results/interface.py @@ -42,7 +42,7 @@ class ResultsInterface(Interface): DataFrames. The representation of state in the simulation is complex, however, as it includes information both in the population state table and dynamically generated information available from the - :class:`value pipelines `. + :class:`value pipelines `. Additionally, good encapsulation of simulation logic typically has results production separated from the modeling code into specialized `Observer` components. This often highlights the need for transformations diff --git a/src/vivarium/framework/values/__init__.py b/src/vivarium/framework/values/__init__.py new file mode 100644 index 00000000..2c85a66e --- /dev/null +++ b/src/vivarium/framework/values/__init__.py @@ -0,0 +1,28 @@ +""" +========================= +The Value Pipeline System +========================= + +The value pipeline system is a vital part of the :mod:`vivarium` +infrastructure. It allows for values that determine the behavior of individual +:term:`simulants ` to be constructed across multiple +:ref:`components `. + +For more information about when and how you should use pipelines in your +simulations, see the value system :ref:`concept note `. + +""" +from vivarium.framework.values.combiners import ValueCombiner, list_combiner, replace_combiner +from vivarium.framework.values.exceptions import DynamicValueError +from vivarium.framework.values.manager import ValuesInterface, ValuesManager +from vivarium.framework.values.pipeline import ( + MissingValueSource, + Pipeline, + ValueModifier, + ValueSource, +) +from vivarium.framework.values.post_processors import ( + PostProcessor, + rescale_post_processor, + union_post_processor, +) diff --git a/src/vivarium/framework/values/combiners.py b/src/vivarium/framework/values/combiners.py new file mode 100644 index 00000000..ba2ad529 --- /dev/null +++ b/src/vivarium/framework/values/combiners.py @@ -0,0 +1,64 @@ +from collections.abc import Callable +from typing import Any, Protocol + + +class ValueCombiner(Protocol): + def __call__( + self, value: Any, mutator: Callable[..., Any], *args: Any, **kwargs: Any + ) -> Any: + ... + + +def replace_combiner( + value: Any, mutator: Callable[..., Any], *args: Any, **kwargs: Any +) -> Any: + """Replace the previous pipeline output with the output of the mutator. + + This is the default combiner. + + Parameters + ---------- + value + The value from the previous step in the pipeline. + mutator + A callable that takes in all arguments that the pipeline source takes + in plus an additional last positional argument for the value from + the previous stage in the pipeline. + args, kwargs + The same args and kwargs provided during the invocation of the + pipeline. + + Returns + ------- + A modified version of the input value. + """ + expanded_args = list(args) + [value] + return mutator(*expanded_args, **kwargs) + + +def list_combiner( + value: list[Any], mutator: Callable[..., Any], *args: Any, **kwargs: Any +) -> list[Any]: + """Aggregates source and mutator output into a list. + + This combiner is meant to be used with a post-processor that does some + kind of reduce operation like summing all values in the list. + + Parameters + ---------- + value + A list of all values provided by the source and prior mutators in the + pipeline. + mutator + A callable that returns some portion of this pipeline's final value. + args, kwargs + The same args and kwargs provided during the invocation of the + pipeline. + + Returns + ------- + The input list with new mutator portion of the pipeline value + appended to it. + """ + value.append(mutator(*args, **kwargs)) + return value diff --git a/src/vivarium/framework/values/exceptions.py b/src/vivarium/framework/values/exceptions.py new file mode 100644 index 00000000..ae37b6be --- /dev/null +++ b/src/vivarium/framework/values/exceptions.py @@ -0,0 +1,7 @@ +from vivarium.exceptions import VivariumError + + +class DynamicValueError(VivariumError): + """Indicates an improperly configured value was invoked.""" + + pass diff --git a/src/vivarium/framework/values.py b/src/vivarium/framework/values/manager.py similarity index 61% rename from src/vivarium/framework/values.py rename to src/vivarium/framework/values/manager.py index 13436dd1..052214be 100644 --- a/src/vivarium/framework/values.py +++ b/src/vivarium/framework/values/manager.py @@ -1,33 +1,21 @@ -""" -========================= -The Value Pipeline System -========================= - -The value pipeline system is a vital part of the :mod:`vivarium` -infrastructure. It allows for values that determine the behavior of individual -:term:`simulants ` to be constructed across multiple -:ref:`components `. - -For more information about when and how you should use pipelines in your -simulations, see the value system :ref:`concept note `. - -""" from __future__ import annotations import warnings from collections.abc import Callable, Iterable, Sequence -from datetime import timedelta -from typing import TYPE_CHECKING, Any, Protocol, TypeVar +from typing import TYPE_CHECKING, Any, TypeVar -import pandas as pd - -from vivarium.exceptions import VivariumError from vivarium.framework.event import Event from vivarium.framework.randomness import RandomnessStream from vivarium.framework.resource import Resource -from vivarium.framework.utilities import from_yearly +from vivarium.framework.values.combiners import ValueCombiner, replace_combiner +from vivarium.framework.values.pipeline import ( + MissingValueSource, + Pipeline, + ValueModifier, + ValueSource, +) +from vivarium.framework.values.post_processors import PostProcessor, rescale_post_processor from vivarium.manager import Interface, Manager -from vivarium.types import NumberLike if TYPE_CHECKING: from vivarium.framework.engine import Builder @@ -35,317 +23,6 @@ T = TypeVar("T") -class ValueCombiner(Protocol): - def __call__( - self, value: Any, mutator: Callable[..., Any], *args: Any, **kwargs: Any - ) -> Any: - ... - - -class PostProcessor(Protocol): - def __call__(self, value: Any, manager: ValuesManager) -> Any: - ... - - -class DynamicValueError(VivariumError): - """Indicates an improperly configured value was invoked.""" - - pass - - -def replace_combiner( - value: Any, mutator: Callable[..., Any], *args: Any, **kwargs: Any -) -> Any: - """Replace the previous pipeline output with the output of the mutator. - - This is the default combiner. - - Parameters - ---------- - value - The value from the previous step in the pipeline. - mutator - A callable that takes in all arguments that the pipeline source takes - in plus an additional last positional argument for the value from - the previous stage in the pipeline. - args, kwargs - The same args and kwargs provided during the invocation of the - pipeline. - - Returns - ------- - A modified version of the input value. - """ - expanded_args = list(args) + [value] - return mutator(*expanded_args, **kwargs) - - -def list_combiner( - value: list[Any], mutator: Callable[..., Any], *args: Any, **kwargs: Any -) -> list[Any]: - """Aggregates source and mutator output into a list. - - This combiner is meant to be used with a post-processor that does some - kind of reduce operation like summing all values in the list. - - Parameters - ---------- - value - A list of all values provided by the source and prior mutators in the - pipeline. - mutator - A callable that returns some portion of this pipeline's final value. - args, kwargs - The same args and kwargs provided during the invocation of the - pipeline. - - Returns - ------- - The input list with new mutator portion of the pipeline value - appended to it. - """ - value.append(mutator(*args, **kwargs)) - return value - - -def rescale_post_processor(value: NumberLike, manager: ValuesManager) -> NumberLike: - """Rescales annual rates to time-step appropriate rates. - - This should only be used with a simulation using a - :class:`~vivarium.framework.time.DateTimeClock` or another implementation - of a clock that traffics in pandas date-time objects. - - Parameters - ---------- - value - Annual rates, either as a number or something we can broadcast - multiplication over like a :mod:`numpy` array or :mod:`pandas` - data frame. - manager - The ValuesManager for this simulation. - - Returns - ------- - The annual rates rescaled to the size of the current time step size. - """ - if isinstance(value, (pd.Series, pd.DataFrame)): - return value.mul( - manager.simulant_step_sizes(value.index) - .astype("timedelta64[ns]") - .dt.total_seconds() - / (60 * 60 * 24 * 365.0), - axis=0, - ) - else: - time_step = manager.step_size() - if not isinstance(time_step, (pd.Timedelta, timedelta)): - raise DynamicValueError( - "The rescale post processor requires a time step size that is a " - "datetime timedelta or pandas Timedelta object." - ) - return from_yearly(value, time_step) - - -def union_post_processor(values: list[NumberLike], _: Any) -> NumberLike: - """Computes a probability on the union of the sample spaces in the values. - - Given a list of values where each value is a probability of an independent - event, this post processor computes the probability of the union of the - events. - - .. list-table:: - :width: 100% - :widths: 1 3 - - * - :math:`p_x` - - Probability of event x - * - :math:`1 - p_x` - - Probability of not event x - * - :math:`\prod_x(1 - p_x)` - - Probability of not any events x - * - :math:`1 - \prod_x(1 - p_x)` - - Probability of any event x - - Parameters - ---------- - values - A list of independent proportions or probabilities, either - as numbers or as a something we can broadcast addition and - multiplication over. - - Returns - ------- - The probability over the union of the sample spaces represented - by the original probabilities. - """ - # if there is only one value, return the value - if len(values) == 1: - return values[0] - - # if there are multiple values, calculate the joint value - product: NumberLike = 1 - for v in values: - new_value = 1 - v - product = product * new_value - joint_value = 1 - product - return joint_value - - -class ValueSource(Resource): - """A resource representing the source of a value pipeline.""" - - def __init__(self, name: str) -> None: - super().__init__("value_source", name) - - -class MissingValueSource(Resource): - """A resource representing an undefined source of a value pipeline.""" - - def __init__(self, name: str) -> None: - super().__init__("missing_value_source", name) - - -class ValueModifier(Resource): - """A resource representing a modifier of a value pipeline.""" - - def __init__(self, name: str) -> None: - super().__init__("value_modifier", name) - - -class Pipeline(Resource): - """A tool for building up values across several components. - - Pipelines are lazily initialized so that we don't have to put constraints - on the order in which components are created and set up. The values manager - will configure a pipeline (set all of its attributes) when the pipeline - source is created. - - As long as a pipeline is not actually called in a simulation, it does not - need a source or to be configured. This might occur when writing - generic components that create a set of pipeline modifiers for - values that won't be used in the particular simulation. - """ - - def __init__(self, name: str) -> None: - super().__init__("value", name) - - self.source: Callable[..., Any] | None = None - """The callable source of the value represented by the pipeline.""" - self.mutators: list[Callable[..., Any]] = [] - """A list of callables that directly modify the pipeline source or - contribute portions of the value.""" - self._combiner: ValueCombiner | None = None - self.post_processor: PostProcessor | None = None - """An optional final transformation to perform on the combined output of - the source and mutators.""" - self._manager: ValuesManager | None = None - - def _get_attr_error(self, attribute: str) -> str: - return ( - f"The pipeline for {self.name} has no {attribute}. This likely means " - f"you are attempting to modify a value that hasn't been created." - ) - - def _set_attr_error(self, attribute: str, new_value: Any) -> str: - current_value = getattr(self, f"_{attribute}") - return ( - f"A second component is attempting to set the {attribute} for pipeline {self.name} " - f"with {new_value}, but it already has a {attribute}: {current_value}." - ) - - def _get_property(self, property: T | None, property_name: str) -> T: - if property is None: - raise DynamicValueError(self._get_attr_error(property_name)) - return property - - @property - def combiner(self) -> ValueCombiner: - """A strategy for combining the source and mutator values into the - final value represented by the pipeline.""" - return self._get_property(self._combiner, "combiner") - - @property - def manager(self) -> ValuesManager: - """A reference to the simulation values manager.""" - return self._get_property(self._manager, "manager") - - def __call__(self, *args: Any, skip_post_processor: bool = False, **kwargs: Any) -> Any: - """Generates the value represented by this pipeline. - - Arguments - --------- - skip_post_processor - Whether we should invoke the post-processor on the combined - source and mutator output or return without post-processing. - This is useful when the post-processor acts as some sort of final - unit conversion (e.g. the rescale post processor). - args, kwargs - Pipeline arguments. These should be the arguments to the - callable source of the pipeline. - - Returns - ------- - The value represented by the pipeline. - - Raises - ------ - DynamicValueError - If the pipeline is invoked without a source set. - """ - return self._call(*args, skip_post_processor=skip_post_processor, **kwargs) - - def _call(self, *args: Any, skip_post_processor: bool = False, **kwargs: Any) -> Any: - if not self.source: - raise DynamicValueError( - f"The dynamic value pipeline for {self.name} has no source. This likely means " - f"you are attempting to modify a value that hasn't been created." - ) - value = self.source(*args, **kwargs) - for mutator in self.mutators: - value = self.combiner(value, mutator, *args, **kwargs) - if self.post_processor and not skip_post_processor: - return self.post_processor(value, self.manager) - if isinstance(value, pd.Series): - value.name = self.name - - return value - - def __repr__(self) -> str: - return f"_Pipeline({self.name})" - - @classmethod - def setup_pipeline( - cls, - pipeline: Pipeline, - source: Callable[..., Any], - combiner: ValueCombiner, - post_processor: PostProcessor | None, - manager: ValuesManager, - ) -> None: - """ - Add a source, combiner, and post-processor to a pipeline. - - Parameters - ---------- - pipeline - The pipeline to configure. - source - The callable source of the value represented by the pipeline. - combiner - A strategy for combining the source and mutator values into the - final value represented by the pipeline. - post_processor - An optional final transformation to perform on the combined output - of the source and mutators. - manager - The simulation values manager. - """ - pipeline.source = source - pipeline._combiner = combiner - pipeline.post_processor = post_processor - pipeline._manager = manager - - class ValuesManager(Manager): """Manager for the dynamic value system.""" diff --git a/src/vivarium/framework/values/pipeline.py b/src/vivarium/framework/values/pipeline.py new file mode 100644 index 00000000..78e8c25a --- /dev/null +++ b/src/vivarium/framework/values/pipeline.py @@ -0,0 +1,171 @@ +from __future__ import annotations + +from collections.abc import Callable +from typing import TYPE_CHECKING, Any, TypeVar + +import pandas as pd + +from vivarium.framework.resource import Resource +from vivarium.framework.values.exceptions import DynamicValueError + +if TYPE_CHECKING: + from vivarium.framework.values.combiners import ValueCombiner + from vivarium.framework.values.manager import ValuesManager + from vivarium.framework.values.post_processors import PostProcessor + +T = TypeVar("T") + + +class ValueSource(Resource): + """A resource representing the source of a value pipeline.""" + + def __init__(self, name: str) -> None: + super().__init__("value_source", name) + + +class MissingValueSource(Resource): + """A resource representing an undefined source of a value pipeline.""" + + def __init__(self, name: str) -> None: + super().__init__("missing_value_source", name) + + +class ValueModifier(Resource): + """A resource representing a modifier of a value pipeline.""" + + def __init__(self, name: str) -> None: + super().__init__("value_modifier", name) + + +class Pipeline(Resource): + """A tool for building up values across several components. + + Pipelines are lazily initialized so that we don't have to put constraints + on the order in which components are created and set up. The values manager + will configure a pipeline (set all of its attributes) when the pipeline + source is created. + + As long as a pipeline is not actually called in a simulation, it does not + need a source or to be configured. This might occur when writing + generic components that create a set of pipeline modifiers for + values that won't be used in the particular simulation. + """ + + def __init__(self, name: str) -> None: + super().__init__("value", name) + + self.source: Callable[..., Any] | None = None + """The callable source of the value represented by the pipeline.""" + self.mutators: list[Callable[..., Any]] = [] + """A list of callables that directly modify the pipeline source or + contribute portions of the value.""" + self._combiner: ValueCombiner | None = None + self.post_processor: PostProcessor | None = None + """An optional final transformation to perform on the combined output of + the source and mutators.""" + self._manager: ValuesManager | None = None + + def _get_attr_error(self, attribute: str) -> str: + return ( + f"The pipeline for {self.name} has no {attribute}. This likely means " + f"you are attempting to modify a value that hasn't been created." + ) + + def _set_attr_error(self, attribute: str, new_value: Any) -> str: + current_value = getattr(self, f"_{attribute}") + return ( + f"A second component is attempting to set the {attribute} for pipeline {self.name} " + f"with {new_value}, but it already has a {attribute}: {current_value}." + ) + + def _get_property(self, property: T | None, property_name: str) -> T: + if property is None: + raise DynamicValueError(self._get_attr_error(property_name)) + return property + + @property + def combiner(self) -> ValueCombiner: + """A strategy for combining the source and mutator values into the + final value represented by the pipeline.""" + return self._get_property(self._combiner, "combiner") + + @property + def manager(self) -> ValuesManager: + """A reference to the simulation values manager.""" + return self._get_property(self._manager, "manager") + + def __call__(self, *args: Any, skip_post_processor: bool = False, **kwargs: Any) -> Any: + """Generates the value represented by this pipeline. + + Arguments + --------- + skip_post_processor + Whether we should invoke the post-processor on the combined + source and mutator output or return without post-processing. + This is useful when the post-processor acts as some sort of final + unit conversion (e.g. the rescale post processor). + args, kwargs + Pipeline arguments. These should be the arguments to the + callable source of the pipeline. + + Returns + ------- + The value represented by the pipeline. + + Raises + ------ + DynamicValueError + If the pipeline is invoked without a source set. + """ + return self._call(*args, skip_post_processor=skip_post_processor, **kwargs) + + def _call(self, *args: Any, skip_post_processor: bool = False, **kwargs: Any) -> Any: + if not self.source: + raise DynamicValueError( + f"The dynamic value pipeline for {self.name} has no source. This likely means " + f"you are attempting to modify a value that hasn't been created." + ) + value = self.source(*args, **kwargs) + for mutator in self.mutators: + value = self.combiner(value, mutator, *args, **kwargs) + if self.post_processor and not skip_post_processor: + return self.post_processor(value, self.manager) + if isinstance(value, pd.Series): + value.name = self.name + + return value + + def __repr__(self) -> str: + return f"_Pipeline({self.name})" + + @classmethod + def setup_pipeline( + cls, + pipeline: Pipeline, + source: Callable[..., Any], + combiner: ValueCombiner, + post_processor: PostProcessor | None, + manager: ValuesManager, + ) -> None: + """ + Add a source, combiner, and post-processor to a pipeline. + + Parameters + ---------- + pipeline + The pipeline to configure. + source + The callable source of the value represented by the pipeline. + combiner + A strategy for combining the source and mutator values into the + final value represented by the pipeline. + post_processor + An optional final transformation to perform on the combined output + of the source and mutators. + manager + The simulation values manager. + """ + pipeline.source = source + pipeline._combiner = combiner + pipeline.post_processor = post_processor + pipeline._manager = manager diff --git a/src/vivarium/framework/values/post_processors.py b/src/vivarium/framework/values/post_processors.py new file mode 100644 index 00000000..811372fe --- /dev/null +++ b/src/vivarium/framework/values/post_processors.py @@ -0,0 +1,101 @@ +from __future__ import annotations + +from datetime import timedelta +from typing import TYPE_CHECKING, Any, Protocol + +import pandas as pd + +from vivarium.framework.utilities import from_yearly +from vivarium.framework.values.exceptions import DynamicValueError +from vivarium.types import NumberLike + +if TYPE_CHECKING: + from vivarium.framework.values.manager import ValuesManager + + +class PostProcessor(Protocol): + def __call__(self, value: Any, manager: ValuesManager) -> Any: + ... + + +def rescale_post_processor(value: NumberLike, manager: ValuesManager) -> NumberLike: + """Rescales annual rates to time-step appropriate rates. + + This should only be used with a simulation using a + :class:`~vivarium.framework.time.DateTimeClock` or another implementation + of a clock that traffics in pandas date-time objects. + + Parameters + ---------- + value + Annual rates, either as a number or something we can broadcast + multiplication over like a :mod:`numpy` array or :mod:`pandas` + data frame. + manager + The ValuesManager for this simulation. + + Returns + ------- + The annual rates rescaled to the size of the current time step size. + """ + if isinstance(value, (pd.Series, pd.DataFrame)): + return value.mul( + manager.simulant_step_sizes(value.index) + .astype("timedelta64[ns]") + .dt.total_seconds() + / (60 * 60 * 24 * 365.0), + axis=0, + ) + else: + time_step = manager.step_size() + if not isinstance(time_step, (pd.Timedelta, timedelta)): + raise DynamicValueError( + "The rescale post processor requires a time step size that is a " + "datetime timedelta or pandas Timedelta object." + ) + return from_yearly(value, time_step) + + +def union_post_processor(values: list[NumberLike], _: Any) -> NumberLike: + """Computes a probability on the union of the sample spaces in the values. + + Given a list of values where each value is a probability of an independent + event, this post processor computes the probability of the union of the + events. + + .. list-table:: + :width: 100% + :widths: 1 3 + + * - :math:`p_x` + - Probability of event x + * - :math:`1 - p_x` + - Probability of not event x + * - :math:`\prod_x(1 - p_x)` + - Probability of not any events x + * - :math:`1 - \prod_x(1 - p_x)` + - Probability of any event x + + Parameters + ---------- + values + A list of independent proportions or probabilities, either + as numbers or as a something we can broadcast addition and + multiplication over. + + Returns + ------- + The probability over the union of the sample spaces represented + by the original probabilities. + """ + # if there is only one value, return the value + if len(values) == 1: + return values[0] + + # if there are multiple values, calculate the joint value + product: NumberLike = 1 + for v in values: + new_value = 1 - v + product = product * new_value + joint_value = 1 - product + return joint_value From bb247df85ea7c0d0bbffdd04a80128f30f47cb24 Mon Sep 17 00:00:00 2001 From: Rajan Mudambi <11376379+rmudambi@users.noreply.github.com> Date: Tue, 29 Oct 2024 15:01:48 -0700 Subject: [PATCH 15/22] attach-pipeline-resources-to-pipeline (#515) --- src/vivarium/framework/values/__init__.py | 7 +-- src/vivarium/framework/values/manager.py | 58 +++++-------------- src/vivarium/framework/values/pipeline.py | 57 ++++++++++++++---- tests/framework/resource/test_manager.py | 28 +++++---- .../framework/resource/test_resource_group.py | 9 ++- tests/framework/test_values.py | 2 +- 6 files changed, 83 insertions(+), 78 deletions(-) diff --git a/src/vivarium/framework/values/__init__.py b/src/vivarium/framework/values/__init__.py index 2c85a66e..27dd71bf 100644 --- a/src/vivarium/framework/values/__init__.py +++ b/src/vivarium/framework/values/__init__.py @@ -15,12 +15,7 @@ from vivarium.framework.values.combiners import ValueCombiner, list_combiner, replace_combiner from vivarium.framework.values.exceptions import DynamicValueError from vivarium.framework.values.manager import ValuesInterface, ValuesManager -from vivarium.framework.values.pipeline import ( - MissingValueSource, - Pipeline, - ValueModifier, - ValueSource, -) +from vivarium.framework.values.pipeline import Pipeline, ValueModifier, ValueSource from vivarium.framework.values.post_processors import ( PostProcessor, rescale_post_processor, diff --git a/src/vivarium/framework/values/manager.py b/src/vivarium/framework/values/manager.py index 052214be..98d76411 100644 --- a/src/vivarium/framework/values/manager.py +++ b/src/vivarium/framework/values/manager.py @@ -8,12 +8,7 @@ from vivarium.framework.randomness import RandomnessStream from vivarium.framework.resource import Resource from vivarium.framework.values.combiners import ValueCombiner, replace_combiner -from vivarium.framework.values.pipeline import ( - MissingValueSource, - Pipeline, - ValueModifier, - ValueSource, -) +from vivarium.framework.values.pipeline import Pipeline, ValueModifier, ValueSource from vivarium.framework.values.post_processors import PostProcessor, rescale_post_processor from vivarium.manager import Interface, Manager @@ -50,24 +45,18 @@ def on_post_setup(self, _event: Event) -> None: """Finalizes dependency structure for the pipelines.""" # Unsourced pipelines might occur when generic components register # modifiers to values that aren't required in a simulation. - unsourced_pipelines = [p for p, v in self._pipelines.items() if v.source is None] + unsourced_pipelines = [p for p, v in self._pipelines.items() if not v.source] if unsourced_pipelines: self.logger.warning(f"Unsourced pipelines: {unsourced_pipelines}") # register_value_producer and register_value_modifier record the # dependency structure for the pipeline source and pipeline modifiers, - # respectively. We don't have enough information to record the + # respectively. We don't have enough information to record the # dependency structure for the pipeline itself until now, where # we say the pipeline value depends on its source and all its # modifiers. for name, pipe in self._pipelines.items(): - dependencies: list[Resource] = [ - ValueSource(name) if pipe.source else MissingValueSource(name) - ] - for i, m in enumerate(pipe.mutators): - mutator_name = self._get_modifier_name(m) - dependencies.append(ValueModifier(f"{name}.{i + 1}.{mutator_name}")) - self.resources.add_resources([pipe], dependencies) + self.resources.add_resources([pipe], [pipe.source] + list(pipe.mutators)) def register_value_producer( self, @@ -97,7 +86,7 @@ def register_value_producer( dependencies = self._convert_dependencies( source, requires_columns, requires_values, requires_streams, required_resources ) - self.resources.add_resources([ValueSource(value_name)], dependencies) + self.resources.add_resources([pipeline.source], dependencies) self.add_constraint( pipeline._call, restrict_during=["initialization", "setup", "post_setup"] ) @@ -108,15 +97,14 @@ def _register_value_producer( self, value_name: str, source: Callable[..., Any], - preferred_combiner: ValueCombiner, - preferred_post_processor: PostProcessor | None, + combiner: ValueCombiner, + post_processor: PostProcessor | None, ) -> Pipeline: """Configure the named value pipeline with a source, combiner, and post-processor.""" self.logger.debug(f"Registering value pipeline {value_name}") pipeline = self.get_value(value_name) - Pipeline.setup_pipeline( - pipeline, source, preferred_combiner, preferred_post_processor, self - ) + value_source = ValueSource(pipeline, source) + Pipeline.setup_pipeline(pipeline, value_source, combiner, post_processor, self) return pipeline def register_value_modifier( @@ -157,17 +145,16 @@ def register_value_modifier( pipeline modifier is called. This is a list of strings, pipeline names, or randomness streams. """ - modifier_name = self._get_modifier_name(modifier) pipeline = self.get_value(value_name) - pipeline.mutators.append(modifier) + value_modifier = ValueModifier(pipeline, modifier) + self.logger.debug(f"Registering {value_modifier.name} as modifier to {value_name}") + pipeline.mutators.append(value_modifier) - name = f"{value_name}.{len(pipeline.mutators)}.{modifier_name}" - self.logger.debug(f"Registering {name} as modifier to {value_name}") dependencies = self._convert_dependencies( modifier, requires_columns, requires_values, requires_streams, required_resources ) - self.resources.add_resources([ValueModifier(name)], dependencies) + self.resources.add_resources([value_modifier], dependencies) def get_value(self, name: str) -> Pipeline: """Retrieve the pipeline representing the named value. @@ -222,25 +209,6 @@ def _convert_dependencies( else: return required_resources - @staticmethod - def _get_modifier_name(modifier: Callable[..., Any]) -> str: - """Get reproducible modifier names based on the modifier type.""" - if hasattr(modifier, "name"): # This is Pipeline or lookup table or something similar - modifier_name: str = modifier.name - elif hasattr(modifier, "__self__") and hasattr( - modifier, "__name__" - ): # This is a bound method of a component or other object - owner = modifier.__self__ - owner_name = owner.name if hasattr(owner, "name") else owner.__class__.__name__ - modifier_name = f"{owner_name}.{modifier.__name__}" - elif hasattr(modifier, "__name__"): # Some unbound function - modifier_name = modifier.__name__ - elif hasattr(modifier, "__call__"): # Some anonymous callable - modifier_name = f"{modifier.__class__.__name__}.__call__" - else: # I don't know what this is. - raise ValueError(f"Unknown modifier type: {type(modifier)}") - return modifier_name - def keys(self) -> Iterable[str]: """Get an iterable of pipeline names.""" return self._pipelines.keys() diff --git a/src/vivarium/framework/values/pipeline.py b/src/vivarium/framework/values/pipeline.py index 78e8c25a..32b72cc8 100644 --- a/src/vivarium/framework/values/pipeline.py +++ b/src/vivarium/framework/values/pipeline.py @@ -19,23 +19,58 @@ class ValueSource(Resource): """A resource representing the source of a value pipeline.""" - def __init__(self, name: str) -> None: - super().__init__("value_source", name) - + def __init__(self, pipeline: Pipeline, source: Callable[..., Any] | None = None) -> None: + super().__init__("value_source" if source else "missing_value_source", pipeline.name) + self._pipeline = pipeline + self._source = source -class MissingValueSource(Resource): - """A resource representing an undefined source of a value pipeline.""" + def __bool__(self) -> bool: + return self._source is not None - def __init__(self, name: str) -> None: - super().__init__("missing_value_source", name) + def __call__(self, *args: Any, **kwargs: Any) -> Any: + if not self._source: + raise DynamicValueError( + f"The dynamic value pipeline for {self.name} has no source." + " This likely means you are attempting to modify a value that" + " hasn't been created." + ) + return self._source(*args, **kwargs) class ValueModifier(Resource): """A resource representing a modifier of a value pipeline.""" - def __init__(self, name: str) -> None: + def __init__(self, pipeline: Pipeline, modifier: Callable[..., Any]) -> None: + mutator_name = self._get_modifier_name(modifier) + mutator_index = len(pipeline.mutators) + 1 + name = f"{pipeline.name}.{mutator_index}.{mutator_name}" super().__init__("value_modifier", name) + self._pipeline = pipeline + self._source = modifier + + def __call__(self, *args: Any, **kwargs: Any) -> Any: + return self._source(*args, **kwargs) + + @staticmethod + def _get_modifier_name(modifier: Callable[..., Any]) -> str: + """Get reproducible modifier names based on the modifier type.""" + if hasattr(modifier, "name"): # This is Pipeline or lookup table or something similar + modifier_name: str = modifier.name + elif hasattr(modifier, "__self__") and hasattr( + modifier, "__name__" + ): # This is a bound method of a component or other object + owner = modifier.__self__ + owner_name = owner.name if hasattr(owner, "name") else owner.__class__.__name__ + modifier_name = f"{owner_name}.{modifier.__name__}" + elif hasattr(modifier, "__name__"): # Some unbound function + modifier_name = modifier.__name__ + elif hasattr(modifier, "__call__"): # Some anonymous callable + modifier_name = f"{modifier.__class__.__name__}.__call__" + else: # I don't know what this is. + raise ValueError(f"Unknown modifier type: {type(modifier)}") + return modifier_name + class Pipeline(Resource): """A tool for building up values across several components. @@ -54,9 +89,9 @@ class Pipeline(Resource): def __init__(self, name: str) -> None: super().__init__("value", name) - self.source: Callable[..., Any] | None = None + self.source: ValueSource = ValueSource(self) """The callable source of the value represented by the pipeline.""" - self.mutators: list[Callable[..., Any]] = [] + self.mutators: list[ValueModifier] = [] """A list of callables that directly modify the pipeline source or contribute portions of the value.""" self._combiner: ValueCombiner | None = None @@ -142,7 +177,7 @@ def __repr__(self) -> str: def setup_pipeline( cls, pipeline: Pipeline, - source: Callable[..., Any], + source: ValueSource, combiner: ValueCombiner, post_processor: PostProcessor | None, manager: ValuesManager, diff --git a/tests/framework/resource/test_manager.py b/tests/framework/resource/test_manager.py index e8a9a6e3..a63b37ca 100644 --- a/tests/framework/resource/test_manager.py +++ b/tests/framework/resource/test_manager.py @@ -2,6 +2,7 @@ from collections.abc import Callable, Mapping from datetime import datetime +from typing import Any import pytest import pytest_mock @@ -13,7 +14,7 @@ from vivarium.framework.resource import ResourceManager from vivarium.framework.resource.exceptions import ResourceError from vivarium.framework.resource.resource import Column, NullResource -from vivarium.framework.values import MissingValueSource, Pipeline, ValueModifier, ValueSource +from vivarium.framework.values import Pipeline, ValueModifier, ValueSource @pytest.fixture @@ -42,28 +43,31 @@ def initializer(self, _simulant_data: SimulantData) -> None: @pytest.mark.parametrize( - "resource_class, type_string, is_initializer", + "resource_class, init_args, type_string, is_initializer", [ - (Pipeline, "value", False), - (ValueSource, "value_source", False), - (MissingValueSource, "missing_value_source", False), - (ValueModifier, "value_modifier", False), - (Column, "column", True), - (NullResource, "null", True), + (Pipeline, ["foo"], "value", False), + (ValueSource, [Pipeline("foo"), lambda: 1], "value_source", False), + (ValueModifier, [Pipeline("foo"), lambda: 1], "value_modifier", False), + (Column, ["foo"], "column", True), + (NullResource, ["foo"], "null", True), ], - ids=lambda x: {x.__name__ if isinstance(x, type) else x}, + ids=lambda x: [x.__name__ if isinstance(x, type) else x], ) def test_resource_manager_get_resource_group( - resource_class: type, type_string: str, is_initializer: bool, manager: ResourceManager + resource_class: type, + init_args: list[Any], + type_string: str, + is_initializer: bool, + manager: ResourceManager, ) -> None: initializer = ResourceProducer("base").initializer group = manager._get_resource_group( - [resource_class("foo")], [], initializer if is_initializer else None + [resource_class(*init_args)], [], initializer if is_initializer else None ) assert group.type == type_string - assert group.names == [f"{type_string}.foo"] + assert group.names == [r.resource_id for r in group._resources.values()] assert not group.dependencies assert group.is_initializer == is_initializer if is_initializer: diff --git a/tests/framework/resource/test_resource_group.py b/tests/framework/resource/test_resource_group.py index d3c2bd84..5f3d6b6c 100644 --- a/tests/framework/resource/test_resource_group.py +++ b/tests/framework/resource/test_resource_group.py @@ -23,7 +23,7 @@ def test_resource_group() -> None: Column("an_interesting_column"), Pipeline("baz"), RandomnessStream("bar", lambda: datetime.now(), 1, IndexMap()), - ValueSource("foo"), + ValueSource(Pipeline("foo"), lambda: 1), ] rg = ResourceGroup(resources, r_dependencies, dummy_initializer) @@ -41,7 +41,7 @@ def test_resource_group() -> None: def test_resource_group_is_initializer() -> None: - resources = [ValueModifier("foo")] + resources = [ValueModifier(Pipeline("foo"), lambda: 1)] rg = ResourceGroup(resources, [Column("bar")]) with pytest.raises(ResourceError, match="does not have an initializer"): @@ -54,7 +54,10 @@ def test_resource_group_with_no_resources() -> None: def test_resource_group_with_multiple_resource_types() -> None: - resources = [ValueModifier("foo"), ValueSource("bar")] + resources = [ + ValueModifier(Pipeline("foo"), lambda: 1), + ValueSource(Pipeline("bar"), lambda: 2), + ] with pytest.raises(ResourceError, match="resources must be of the same type"): _ = ResourceGroup(resources, []) diff --git a/tests/framework/test_values.py b/tests/framework/test_values.py index f4276b0a..da610668 100644 --- a/tests/framework/test_values.py +++ b/tests/framework/test_values.py @@ -124,7 +124,7 @@ def test_rescale_post_processor_variable(manager_with_step_size): def test_unsourced_pipeline(): pipeline = Pipeline("some_name") - assert pipeline.source is None + assert pipeline.source.resource_id == "missing_value_source.some_name" with pytest.raises( DynamicValueError, match=f"The dynamic value pipeline for {pipeline.name} has no source.", From a03cc2c280250d3ddd0f799b10da5df104eda402 Mon Sep 17 00:00:00 2001 From: Rajan Mudambi <11376379+rmudambi@users.noreply.github.com> Date: Mon, 4 Nov 2024 16:16:33 -0800 Subject: [PATCH 16/22] wire up managers to be able to create columns (#527) --- src/vivarium/framework/plugins.py | 4 ++ src/vivarium/framework/population/manager.py | 7 ++- src/vivarium/manager.py | 55 ++++++++++++++++---- 3 files changed, 54 insertions(+), 12 deletions(-) diff --git a/src/vivarium/framework/plugins.py b/src/vivarium/framework/plugins.py index 01bf814e..9edcd5f3 100644 --- a/src/vivarium/framework/plugins.py +++ b/src/vivarium/framework/plugins.py @@ -95,6 +95,10 @@ class PluginConfigurationError(VivariumError): class PluginManager(Manager): + @property + def name(self) -> str: + return "plugin_manager" + def __init__( self, plugin_configuration: ( diff --git a/src/vivarium/framework/population/manager.py b/src/vivarium/framework/population/manager.py index a1c0da9a..0318d336 100644 --- a/src/vivarium/framework/population/manager.py +++ b/src/vivarium/framework/population/manager.py @@ -87,7 +87,6 @@ def add( f"You provided {initializer} which is of type {type(initializer)}." ) component = initializer.__self__ - # TODO: consider if we can initialize the tracked column with a component instead # TODO: raise error once all active Component implementations have been refactored # if not (isinstance(component, Component) or isinstance(component, PopulationManager)): # raise AttributeError( @@ -161,6 +160,10 @@ def name(self) -> str: """The name of this component.""" return "population_manager" + @property + def columns_created(self) -> list[str]: + return ["tracked"] + def setup(self, builder: Builder) -> None: """Registers the population manager with other vivarium systems.""" self.clock = builder.time.clock() @@ -184,7 +187,7 @@ def setup(self, builder: Builder) -> None: ) self.register_simulant_initializer( - self.on_initialize_simulants, creates_columns="tracked" + self.on_initialize_simulants, creates_columns=self.columns_created ) self._view = self.get_view("tracked") diff --git a/src/vivarium/manager.py b/src/vivarium/manager.py index e797c073..17b073aa 100644 --- a/src/vivarium/manager.py +++ b/src/vivarium/manager.py @@ -7,15 +7,18 @@ simulations. """ +from __future__ import annotations -from typing import TYPE_CHECKING, Any, Dict +from abc import ABC, abstractmethod +from typing import TYPE_CHECKING, Any if TYPE_CHECKING: from vivarium.framework.engine import Builder + from vivarium.framework.population import SimulantData -class Manager: - CONFIGURATION_DEFAULTS: Dict[str, Any] = {} +class Manager(ABC): + CONFIGURATION_DEFAULTS: dict[str, Any] = {} """A dictionary containing the defaults for any configurations managed by this manager. An empty dictionary indicates no managed configurations. @@ -26,25 +29,57 @@ class Manager: ############## @property - def configuration_defaults(self) -> Dict[str, Any]: + @abstractmethod + def name(self) -> str: + pass + + @property + def configuration_defaults(self) -> dict[str, Any]: """Provides a dictionary containing the defaults for any configurations managed by this manager. These default values will be stored at the `component_configs` layer of the simulation's LayeredConfigTree. - - Returns - ------- - A dictionary containing the defaults for any configurations managed by - this manager. """ return self.CONFIGURATION_DEFAULTS + @property + def columns_created(self) -> list[str]: + """Provides names of columns created by the manager.""" + return [] + ##################### # Lifecycle methods # ##################### - def setup(self, builder: "Builder") -> None: + def setup(self, builder: Builder) -> None: + """Defines custom actions this manager needs to run during the setup + lifecycle phase. + + This method is intended to be overridden by subclasses to perform any + necessary setup operations specific to the manager. By default, it + does nothing. + + Parameters + ---------- + builder + The builder object used to set up the manager. + """ + pass + + def on_initialize_simulants(self, pop_data: SimulantData) -> None: + """ + Method that vivarium will run during simulant initialization. + + This method is intended to be overridden by subclasses if there are + operations they need to perform specifically during the simulant + initialization phase. + + Parameters + ---------- + pop_data : SimulantData + The data associated with the simulants being initialized. + """ pass From ffa03062c2c168280de5b9f160769db715d9f3a9 Mon Sep 17 00:00:00 2001 From: Rajan Mudambi <11376379+rmudambi@users.noreply.github.com> Date: Mon, 4 Nov 2024 16:20:24 -0800 Subject: [PATCH 17/22] specify the component that creates resources and resource groups (#528) --- src/vivarium/component.py | 12 +- src/vivarium/framework/population/manager.py | 41 ++--- src/vivarium/framework/randomness/manager.py | 36 +++- src/vivarium/framework/randomness/stream.py | 9 +- src/vivarium/framework/resource/group.py | 56 +++--- src/vivarium/framework/resource/manager.py | 68 ++++--- src/vivarium/framework/resource/resource.py | 20 ++- src/vivarium/framework/time.py | 4 +- src/vivarium/framework/values/manager.py | 76 ++++---- src/vivarium/framework/values/pipeline.py | 67 +++++-- tests/framework/components/test_component.py | 2 +- tests/framework/randomness/test_manager.py | 7 +- tests/framework/resource/test_manager.py | 167 ++++++++++-------- tests/framework/resource/test_resource.py | 26 ++- .../framework/resource/test_resource_group.py | 57 +++--- tests/helpers.py | 1 + 16 files changed, 393 insertions(+), 256 deletions(-) diff --git a/src/vivarium/component.py b/src/vivarium/component.py index a773986d..5cbade2c 100644 --- a/src/vivarium/component.py +++ b/src/vivarium/component.py @@ -717,10 +717,10 @@ def get_data( data = data_source(builder) else: raise ConfigurationError( - f"Data source '{data_source}' is not a valid data source. It " - f"must be a LookupTableData instance, a string corresponding to " - f"an artifact key, a callable that returns a LookupTableData " - f"instance, or a string defining such a callable." + f"Data source is of type '{type(data_source)}'. It must be a " + "LookupTableData instance, a string corresponding to an " + "artifact key, a callable that returns a LookupTableData " + "instance, or a string defining such a callable." ) if not isinstance(data, valid_data_types): @@ -806,9 +806,7 @@ def _register_simulant_initializer(self, builder: Builder) -> None: if type(self).on_initialize_simulants != Component.on_initialize_simulants: builder.population.initializes_simulants( - self.on_initialize_simulants, - creates_columns=self.columns_created, - **initialization_requirements, + self, creates_columns=self.columns_created, **initialization_requirements ) def _register_time_step_prepare_listener(self, builder: "Builder") -> None: diff --git a/src/vivarium/framework/population/manager.py b/src/vivarium/framework/population/manager.py index 0318d336..690ee24a 100644 --- a/src/vivarium/framework/population/manager.py +++ b/src/vivarium/framework/population/manager.py @@ -18,14 +18,13 @@ from vivarium.framework.population.exceptions import PopulationError from vivarium.framework.population.population_view import PopulationView -from vivarium.framework.randomness import RandomnessStream from vivarium.framework.resource import Resource -from vivarium.framework.values import Pipeline from vivarium.manager import Interface, Manager -from vivarium.types import ClockStepSize, ClockTime if TYPE_CHECKING: + from vivarium import Component from vivarium.framework.engine import Builder + from vivarium.types import ClockStepSize, ClockTime @dataclass @@ -58,7 +57,7 @@ def __init__(self) -> None: self._columns_produced: dict[str, str] = {} def add( - self, initializer: Callable[[SimulantData], None], columns_produced: list[str] + self, initializer: Callable[[SimulantData], None], columns_produced: Sequence[str] ) -> None: """Adds an initializer and columns to the set, enforcing uniqueness. @@ -117,7 +116,7 @@ def add( f"for column {column}." ) self._columns_produced[column] = component_name - self._components[component_name] = columns_produced + self._components[component_name] = list(columns_produced) def __repr__(self) -> str: return repr(self._components) @@ -186,9 +185,7 @@ def setup(self, builder: Builder) -> None: self.register_simulant_initializer, allow_during=["setup"] ) - self.register_simulant_initializer( - self.on_initialize_simulants, creates_columns=self.columns_created - ) + self.register_simulant_initializer(self, creates_columns=self.columns_created) self._view = self.get_view("tracked") def on_initialize_simulants(self, pop_data: SimulantData) -> None: @@ -264,7 +261,7 @@ def _get_view(self, columns: str | Sequence[str], query: str) -> PopulationView: def register_simulant_initializer( self, - initializer: Callable[[SimulantData], None], + component: Component | Manager, creates_columns: str | Sequence[str] = (), requires_columns: str | Sequence[str] = (), requires_values: str | Sequence[str] = (), @@ -275,9 +272,9 @@ def register_simulant_initializer( Parameters ---------- - initializer - A callable that adds or updates initial state information about - new simulants. + component + The component or manager that will add or update initial state + information about new simulants. creates_columns The state table columns that the given initializer provides the initial state information for. @@ -311,8 +308,8 @@ def register_simulant_initializer( required_resources = ( list(requires_columns) - + [Resource("value", name) for name in requires_values] - + [Resource("stream", name) for name in requires_streams] + + [Resource("value", name, component) for name in requires_values] + + [Resource("stream", name, component) for name in requires_streams] ) if isinstance(creates_columns, str): @@ -325,8 +322,8 @@ def register_simulant_initializer( else: all_dependencies = list(required_resources) - self._initializer_components.add(initializer, list(creates_columns)) - self.resources.add_resources(creates_columns, all_dependencies, initializer) + self._initializer_components.add(component.on_initialize_simulants, creates_columns) + self.resources.add_resources(component, creates_columns, all_dependencies) def get_simulant_creator(self) -> Callable[[int, dict[str, Any] | None], pd.Index[int]]: """Gets a function that can generate new simulants. @@ -470,20 +467,20 @@ def get_simulant_creator(self) -> Callable[[int, dict[str, Any] | None], pd.Inde def initializes_simulants( self, - initializer: Callable[[SimulantData], None], + component: Component | Manager, creates_columns: str | Sequence[str] = (), requires_columns: str | Sequence[str] = (), requires_values: str | Sequence[str] = (), requires_streams: str | Sequence[str] = (), - required_resources: Sequence[str | Pipeline | RandomnessStream] = (), + required_resources: Sequence[str | Resource] = (), ) -> None: """Marks a source of initial state information for new simulants. Parameters ---------- - initializer - A callable that adds or updates initial state information about - new simulants. + component + The component or manager that will add or update initial state + information about new simulants. creates_columns The state table columns that the given initializer provides the initial state information for. @@ -503,7 +500,7 @@ def initializes_simulants( are interpreted as value pipelines and randomness streams, """ self._manager.register_simulant_initializer( - initializer, + component, creates_columns, requires_columns, requires_values, diff --git a/src/vivarium/framework/randomness/manager.py b/src/vivarium/framework/randomness/manager.py index eae1921f..c1539a62 100644 --- a/src/vivarium/framework/randomness/manager.py +++ b/src/vivarium/framework/randomness/manager.py @@ -20,6 +20,9 @@ if TYPE_CHECKING: from vivarium.framework.engine import Builder +if TYPE_CHECKING: + from vivarium import Component + class RandomnessManager(Manager): """Access point for common random number generation.""" @@ -86,7 +89,10 @@ def setup(self, builder: Builder) -> None: ) def get_randomness_stream( - self, decision_point: str, initializes_crn_attributes: bool = False + self, + decision_point: str, + component: Component | None, + initializes_crn_attributes: bool = False, ) -> RandomnessStream: """Provides a new source of random numbers for the given decision point. @@ -96,6 +102,8 @@ def get_randomness_stream( A unique identifier for a stream of random numbers. Typically represents a decision that needs to be made each time step like 'moves_left' or 'gets_disease'. + component + The component that is requesting the randomness stream. initializes_crn_attributes A flag indicating whether this stream is used to generate key initialization information that will be used to identify simulants @@ -115,10 +123,12 @@ def get_randomness_stream( If another location in the simulation has already created a randomness stream with the same identifier. """ - stream = self._get_randomness_stream(decision_point, initializes_crn_attributes) + stream = self._get_randomness_stream( + decision_point, component, initializes_crn_attributes + ) if not initializes_crn_attributes: # We need the key columns to be created before this stream can be called. - self.resources.add_resources([stream], self._key_columns) + self.resources.add_resources(component, [stream], self._key_columns) self._add_constraint( stream.get_draw, restrict_during=["initialization", "setup", "post_setup"] ) @@ -136,7 +146,10 @@ def get_randomness_stream( return stream def _get_randomness_stream( - self, decision_point: str, initializes_crn_attributes: bool = False + self, + decision_point: str, + component: Component | None, + initializes_crn_attributes: bool = False, ) -> RandomnessStream: if decision_point in self._decision_points: raise RandomnessError( @@ -148,6 +161,7 @@ def _get_randomness_stream( clock=self._clock, seed=self._seed, index_map=self._key_mapping, + component=component, initializes_crn_attributes=initializes_crn_attributes, ) self._decision_points[decision_point] = stream @@ -203,7 +217,11 @@ def __init__(self, manager: RandomnessManager): self._manager = manager def get_stream( - self, decision_point: str, initializes_crn_attributes: bool = False + self, + decision_point: str, + # TODO [MIC-5452]: all calls should have a component + component: Component | None = None, + initializes_crn_attributes: bool = False, ) -> RandomnessStream: """Provides a new source of random numbers for the given decision point. @@ -216,9 +234,11 @@ def get_stream( Parameters ---------- decision_point - A unique identifier for a stream of random numbers. Typically + A unique identifier for a stream of random numbers. Typically, this represents a decision that needs to be made each time step like 'moves_left' or 'gets_disease'. + component + The component that is requesting the randomness stream. initializes_crn_attributes A flag indicating whether this stream is used to generate key initialization information that will be used to identify simulants @@ -232,7 +252,9 @@ def get_stream( The stream provides vectorized access to random numbers and a few other utilities. """ - return self._manager.get_randomness_stream(decision_point, initializes_crn_attributes) + return self._manager.get_randomness_stream( + decision_point, component, initializes_crn_attributes + ) def get_seed(self, decision_point: str) -> int: """Get a randomly generated seed for use with external randomness tools. diff --git a/src/vivarium/framework/randomness/stream.py b/src/vivarium/framework/randomness/stream.py index 87236a9d..5199414d 100644 --- a/src/vivarium/framework/randomness/stream.py +++ b/src/vivarium/framework/randomness/stream.py @@ -28,7 +28,7 @@ import hashlib from collections.abc import Callable -from typing import Any, TypeVar +from typing import TYPE_CHECKING, Any, TypeVar import numpy as np import numpy.typing as npt @@ -41,6 +41,9 @@ from vivarium.framework.utilities import rate_to_probability from vivarium.types import ClockTime, NumericArray +if TYPE_CHECKING: + from vivarium import Component + RESIDUAL_CHOICE = object() # TODO: Parameterizing pandas objects fails below python 3.12 @@ -91,9 +94,11 @@ def __init__( clock: Callable[[], ClockTime], seed: Any, index_map: IndexMap, + # TODO [MIC-5452]: all resources should have a component + component: Component | None = None, initializes_crn_attributes: bool = False, ): - super().__init__("stream", key) + super().__init__("stream", key, component) self.key = key """The name of the randomness stream.""" self.clock = clock diff --git a/src/vivarium/framework/resource/group.py b/src/vivarium/framework/resource/group.py index 5981a636..8fa21fe5 100644 --- a/src/vivarium/framework/resource/group.py +++ b/src/vivarium/framework/resource/group.py @@ -1,6 +1,6 @@ from __future__ import annotations -from collections.abc import Callable, Iterable, Iterator +from collections.abc import Callable, Iterator, Sequence from typing import TYPE_CHECKING from vivarium.framework.resource.exceptions import ResourceError @@ -23,67 +23,59 @@ class ResourceGroup: """ def __init__( - self, - produced_resources: Iterable[Resource], - dependencies: Iterable[Resource], - initializer: Callable[[SimulantData], None] | None = None, + self, initialized_resources: Sequence[Resource], dependencies: Sequence[Resource] ): """Create a new resource group. Parameters ---------- - produced_resources - The resources produced by this resource group's producer. + initialized_resources + The resources initialized by this resource group's initializer. dependencies - The resources this resource group's producer depends on. - initializer - The method that initializes this group of resources. If this is - None, the resources don't need to be initialized. + The resources this resource group's initializer depends on. Raises ------ ResourceError If the resource group is not well-formed. """ - if not produced_resources: + if not initialized_resources: raise ResourceError("Resource groups must have at least one resource.") - if len(set(r.resource_type for r in produced_resources)) != 1: - raise ResourceError("All produced resources must be of the same type.") + if len(set(r.component for r in initialized_resources)) != 1: + raise ResourceError("All initialized resources must have the same component.") - if list(produced_resources)[0].is_initialized != (initializer is not None): - raise ResourceError( - "Resource groups with an initializer must have initialized resources." - ) + if len(set(r.resource_type for r in initialized_resources)) != 1: + raise ResourceError("All initialized resources must be of the same type.") - self.type = list(produced_resources)[0].resource_type - """The type of resource produced by this resource group's producer.""" - self._resources = {resource.resource_id: resource for resource in produced_resources} - self._initializer = initializer + self.component = initialized_resources[0].component + """The component or manager that produces the resources in this group.""" + self.type = initialized_resources[0].resource_type + """The type of resource in this group.""" + self.is_initialized = initialized_resources[0].is_initialized + """Whether this resource group contains initialized resources.""" self._dependencies = dependencies + self.resources = {r.resource_id: r for r in initialized_resources} + """A dictionary of resources produced by this group, keyed by resource_id.""" @property def names(self) -> list[str]: """The long names (including type) of all resources in this group.""" - return list(self._resources) + return list(self.resources) @property def initializer(self) -> Callable[[SimulantData], None]: """The method that initializes this group of resources.""" - if self._initializer is None: - raise ResourceError("This resource group does not have an initializer.") - return self._initializer + # TODO [MIC-5452]: all resource groups should have a component + if not self.component: + raise ResourceError(f"Resource group {self} does not have an initializer.") + return self.component.on_initialize_simulants @property def dependencies(self) -> list[str]: """The long names (including type) of dependencies for this group.""" return [dependency.resource_id for dependency in self._dependencies] - @property - def is_initializer(self) -> bool: - """Return True if this resource group's producer is an initializer.""" - return self._initializer is not None - def __iter__(self) -> Iterator[str]: return iter(self.names) @@ -97,4 +89,4 @@ def __str__(self) -> str: def get_resource(self, resource_id: str) -> Resource: """Get a resource by its resource_id.""" - return self._resources[resource_id] + return self.resources[resource_id] diff --git a/src/vivarium/framework/resource/manager.py b/src/vivarium/framework/resource/manager.py index 09165d8e..2589a57f 100644 --- a/src/vivarium/framework/resource/manager.py +++ b/src/vivarium/framework/resource/manager.py @@ -7,7 +7,7 @@ from __future__ import annotations -from collections.abc import Callable, Iterable +from collections.abc import Iterable from typing import TYPE_CHECKING, Any import networkx as nx @@ -18,8 +18,8 @@ from vivarium.manager import Interface, Manager if TYPE_CHECKING: + from vivarium import Component from vivarium.framework.engine import Builder - from vivarium.framework.population import SimulantData class ResourceManager(Manager): @@ -66,7 +66,7 @@ def sorted_nodes(self) -> list[ResourceGroup]: self._sorted_nodes = list(nx.algorithms.topological_sort(self.graph)) # type: ignore[func-returns-value] except nx.NetworkXUnfeasible: raise ResourceError( - f"The resource pool contains at least one cycle: " + "The resource pool contains at least one cycle: " f"{nx.find_cycle(self.graph)}." ) return self._sorted_nodes @@ -76,22 +76,22 @@ def setup(self, builder: Builder) -> None: def add_resources( self, + # TODO [MIC-5452]: all resource groups should have a component + component: Component | Manager | None, resources: Iterable[str | Resource], dependencies: Iterable[str | Resource], - initializer: Callable[[SimulantData], None] | None, ) -> None: """Adds managed resources to the resource pool. Parameters ---------- + component + The component or manager adding the resources. resources The resources being added. A string represents a column resource. dependencies A list of resources that the producer requires. A string represents a column resource. - initializer - A method that will be called to initialize the resources. This is - called during population initialization. Raises ------ @@ -99,27 +99,28 @@ def add_resources( If a component has multiple resource producers for the ``column`` resource type or there are multiple producers of the same resource. """ - resource_group = self._get_resource_group(resources, dependencies, initializer) + resource_group = self._get_resource_group(component, resources, dependencies) - for resource_id in resource_group: + for resource_id, resource in resource_group.resources.items(): if resource_id in self._resource_group_map: other_resource = self._resource_group_map[resource_id] - if resource_group.is_initializer: - raise ResourceError( - f"Both {initializer} and {other_resource.initializer}" - f" are registered as initializers for {resource_id}." - ) - resource = resource_group.get_resource(resource_id) + # TODO [MIC-5452]: all resource groups should have a component + resource_component = resource.component.name if resource.component else None + other_resource_component = ( + other_resource.component.name if other_resource.component else None + ) raise ResourceError( - f"Resource {resource} is not allowed to be registered more" " than once." + f"Component '{resource_component}' is attempting to register" + f" resource '{resource_id}' but it is already registered by" + f" '{other_resource_component}'." ) self._resource_group_map[resource_id] = resource_group def _get_resource_group( self, + component: Component | Manager | None, resources: Iterable[str | Resource], dependencies: Iterable[str | Resource], - initializer: Callable[[SimulantData], None] | None, ) -> ResourceGroup: """Packages resource information into a resource group. @@ -127,17 +128,26 @@ def _get_resource_group( -------- :class:`ResourceGroup` """ - resources_ = [Column(r) if isinstance(r, str) else r for r in resources] - dependencies_ = [Column(d) if isinstance(d, str) else d for d in dependencies] + resources_ = [Column(r, component) if isinstance(r, str) else r for r in resources] + dependencies_ = [Column(d, None) if isinstance(d, str) else d for d in dependencies] if not resources_: # We have a "producer" that doesn't produce anything, but # does have dependencies. This is necessary for components that # want to track private state information. - resources_ = [NullResource(self._null_producer_count)] + resources_ = [NullResource(self._null_producer_count, component)] self._null_producer_count += 1 - return ResourceGroup(resources_, dependencies_, initializer) + # TODO [MIC-5452]: all resource groups should have a component + if component and ( + have_other_component := [r for r in resources_ if r.component != component] + ): + raise ResourceError( + f"All initialized resources must have the component '{component.name}'." + f" The following resources have a different component: {have_other_component}" + ) + + return ResourceGroup(resources_, dependencies_) def _to_graph(self) -> nx.DiGraph: """Constructs the full resource graph from information in the groups. @@ -164,8 +174,8 @@ def _to_graph(self) -> nx.DiGraph: # Warn here because this sometimes happens naturally # if observer components are missing from a simulation. self.logger.warning( - f"Resource {dependency} is not provided by any component but is needed to " - f"compute {resource_group}." + f"Resource {dependency} is not produced by any" + f" component but is needed to compute {resource_group}." ) continue dependency_group = self._resource_group_map[dependency] @@ -180,7 +190,7 @@ def get_population_initializers(self) -> list[Any]: graph construction, but we only need the column producers at population creation time. """ - return [r.initializer for r in self.sorted_nodes if r.is_initializer] + return [r.initializer for r in self.sorted_nodes if r.is_initialized] def __repr__(self) -> str: out = {} @@ -213,22 +223,22 @@ def __init__(self, manager: ResourceManager): def add_resources( self, + # TODO [MIC-5452]: all resource groups should have a component + component: Component | Manager | None, resources: Iterable[str | Resource], dependencies: Iterable[str | Resource], - initializer: Callable[[SimulantData], None] | None = None, ) -> None: """Adds managed resources to the resource pool. Parameters ---------- + component + The component or manager adding the resources. resources The resources being added. A string represents a column resource. dependencies A list of resources that the producer requires. A string represents a column resource. - initializer - A method that will be called to initialize the resources. This is - called during population initialization. Raises ------ @@ -237,7 +247,7 @@ def add_resources( resource producers for the ``column`` resource type, or there are multiple producers of the same resource. """ - self._manager.add_resources(resources, dependencies, initializer) + self._manager.add_resources(component, resources, dependencies) def get_population_initializers(self) -> list[Any]: """Returns a dependency-sorted list of population initializers. diff --git a/src/vivarium/framework/resource/resource.py b/src/vivarium/framework/resource/resource.py index bfcda9e0..343368d4 100644 --- a/src/vivarium/framework/resource/resource.py +++ b/src/vivarium/framework/resource/resource.py @@ -1,6 +1,11 @@ from __future__ import annotations from dataclasses import dataclass +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from vivarium import Component + from vivarium.manager import Manager @dataclass @@ -11,14 +16,15 @@ class Resource: """The type of the resource.""" name: str """The name of the resource.""" + # TODO [MIC-5452]: all resources should have a component + component: Component | Manager | None + """The component that creates the resource.""" @property def resource_id(self) -> str: """The long name of the resource, including the type.""" return f"{self.resource_type}.{self.name}" - # TODO [MIC-5452]: make this an abstract method when support for old - # requirements specification is dropped @property def is_initialized(self) -> bool: """Return True if the resource needs to be initialized.""" @@ -28,8 +34,9 @@ def is_initialized(self) -> bool: class NullResource(Resource): """A node in the dependency graph that does not produce any resources.""" - def __init__(self, index: int): - super().__init__("null", f"{index}") + # TODO [MIC-5452]: all resources should have a component + def __init__(self, index: int, component: Component | Manager | None): + super().__init__("null", f"{index}", component) @property def is_initialized(self) -> bool: @@ -40,8 +47,9 @@ def is_initialized(self) -> bool: class Column(Resource): """A resource representing a column in the state table.""" - def __init__(self, name: str): - super().__init__("column", name) + # TODO [MIC-5452]: all resources should have a component + def __init__(self, name: str, component: Component | Manager | None): + super().__init__("column", name, component) @property def is_initialized(self) -> bool: diff --git a/src/vivarium/framework/time.py b/src/vivarium/framework/time.py index c92eda98..6e894f31 100644 --- a/src/vivarium/framework/time.py +++ b/src/vivarium/framework/time.py @@ -115,9 +115,7 @@ def setup(self, builder: "Builder"): self.register_step_modifier = partial( builder.value.register_value_modifier, self._pipeline_name ) - builder.population.initializes_simulants( - self.on_initialize_simulants, creates_columns=self.columns_created - ) + builder.population.initializes_simulants(self, creates_columns=self.columns_created) builder.event.register_listener("post_setup", self.on_post_setup) self._individual_clocks = builder.population.get_view( columns=self.columns_created + self.columns_required diff --git a/src/vivarium/framework/values/manager.py b/src/vivarium/framework/values/manager.py index 98d76411..cb4bf3f1 100644 --- a/src/vivarium/framework/values/manager.py +++ b/src/vivarium/framework/values/manager.py @@ -5,14 +5,14 @@ from typing import TYPE_CHECKING, Any, TypeVar from vivarium.framework.event import Event -from vivarium.framework.randomness import RandomnessStream from vivarium.framework.resource import Resource from vivarium.framework.values.combiners import ValueCombiner, replace_combiner -from vivarium.framework.values.pipeline import Pipeline, ValueModifier, ValueSource +from vivarium.framework.values.pipeline import Pipeline from vivarium.framework.values.post_processors import PostProcessor, rescale_post_processor from vivarium.manager import Interface, Manager if TYPE_CHECKING: + from vivarium import Component from vivarium.framework.engine import Builder T = TypeVar("T") @@ -56,16 +56,20 @@ def on_post_setup(self, _event: Event) -> None: # we say the pipeline value depends on its source and all its # modifiers. for name, pipe in self._pipelines.items(): - self.resources.add_resources([pipe], [pipe.source] + list(pipe.mutators)) + self.resources.add_resources( + pipe.component, [pipe], [pipe.source] + list(pipe.mutators) + ) def register_value_producer( self, value_name: str, source: Callable[..., Any], + # TODO [MIC-5452]: all calls should have a component + component: Component | None = None, requires_columns: Iterable[str] = (), requires_values: Iterable[str] = (), requires_streams: Iterable[str] = (), - required_resources: Sequence[str | Pipeline | RandomnessStream] = (), + required_resources: Sequence[str | Resource] = (), preferred_combiner: ValueCombiner = replace_combiner, preferred_post_processor: PostProcessor | None = None, ) -> Pipeline: @@ -75,8 +79,14 @@ def register_value_producer( -------- :meth:`ValuesInterface.register_value_producer` """ - pipeline = self._register_value_producer( - value_name, source, preferred_combiner, preferred_post_processor + self.logger.debug(f"Registering value pipeline {value_name}") + pipeline = self.get_value(value_name) + pipeline.set_attributes( + component, + source, + preferred_combiner, + preferred_post_processor, + self, ) # The resource we add here is just the pipeline source. @@ -86,35 +96,24 @@ def register_value_producer( dependencies = self._convert_dependencies( source, requires_columns, requires_values, requires_streams, required_resources ) - self.resources.add_resources([pipeline.source], dependencies) + self.resources.add_resources(pipeline.component, [pipeline.source], dependencies) + self.add_constraint( pipeline._call, restrict_during=["initialization", "setup", "post_setup"] ) return pipeline - def _register_value_producer( - self, - value_name: str, - source: Callable[..., Any], - combiner: ValueCombiner, - post_processor: PostProcessor | None, - ) -> Pipeline: - """Configure the named value pipeline with a source, combiner, and post-processor.""" - self.logger.debug(f"Registering value pipeline {value_name}") - pipeline = self.get_value(value_name) - value_source = ValueSource(pipeline, source) - Pipeline.setup_pipeline(pipeline, value_source, combiner, post_processor, self) - return pipeline - def register_value_modifier( self, value_name: str, modifier: Callable[..., Any], + # TODO [MIC-5452]: all calls should have a component + component: Component | None = None, requires_columns: Iterable[str] = (), requires_values: Iterable[str] = (), requires_streams: Iterable[str] = (), - required_resources: Sequence[str | Pipeline | RandomnessStream] = (), + required_resources: Sequence[str | Resource] = (), ) -> None: """Marks a ``Callable`` as the modifier of a named value. @@ -130,6 +129,8 @@ def register_value_modifier( previous stage in the pipeline. For the ``list_combiner`` strategy, the pipeline modifiers should have the same signature as the pipeline source. + component + The component that is registering the value modifier. requires_columns A list of the state table columns that already need to be present and populated in the state table before the pipeline modifier @@ -145,16 +146,14 @@ def register_value_modifier( pipeline modifier is called. This is a list of strings, pipeline names, or randomness streams. """ - pipeline = self.get_value(value_name) - value_modifier = ValueModifier(pipeline, modifier) + value_modifier = pipeline.get_value_modifier(modifier, component) self.logger.debug(f"Registering {value_modifier.name} as modifier to {value_name}") - pipeline.mutators.append(value_modifier) dependencies = self._convert_dependencies( modifier, requires_columns, requires_values, requires_streams, required_resources ) - self.resources.add_resources([value_modifier], dependencies) + self.resources.add_resources(component, [value_modifier], dependencies) def get_value(self, name: str) -> Pipeline: """Retrieve the pipeline representing the named value. @@ -203,8 +202,8 @@ def _convert_dependencies( return ( list(requires_columns) - + [Resource("value", name) for name in requires_values] - + [Resource("stream", name) for name in requires_streams] + + [Resource("value", name, None) for name in requires_values] + + [Resource("stream", name, None) for name in requires_streams] ) else: return required_resources @@ -247,10 +246,12 @@ def register_value_producer( self, value_name: str, source: Callable[..., Any], + # TODO [MIC-5452]: all calls should have a component + component: Component | None = None, requires_columns: Iterable[str] = (), requires_values: Iterable[str] = (), requires_streams: Iterable[str] = (), - required_resources: Sequence[str | Pipeline | RandomnessStream] = (), + required_resources: Sequence[str | Resource] = (), preferred_combiner: ValueCombiner = replace_combiner, preferred_post_processor: PostProcessor | None = None, ) -> Pipeline: @@ -262,6 +263,8 @@ def register_value_producer( The name of the new dynamic value pipeline. source A callable source for the dynamic value pipeline. + component + The component that is registering the value producer. requires_columns A list of the state table columns that already need to be present and populated in the state table before the pipeline source @@ -296,6 +299,7 @@ def register_value_producer( return self._manager.register_value_producer( value_name, source, + component, requires_columns, requires_values, requires_streams, @@ -308,10 +312,12 @@ def register_rate_producer( self, rate_name: str, source: Callable[..., Any], + # TODO [MIC-5452]: all calls should have a component + component: Component | None = None, requires_columns: Iterable[str] = (), requires_values: Iterable[str] = (), requires_streams: Iterable[str] = (), - required_resources: Sequence[str | Pipeline | RandomnessStream] = (), + required_resources: Sequence[str | Resource] = (), ) -> Pipeline: """Marks a ``Callable`` as the producer of a named rate. @@ -328,6 +334,8 @@ def register_rate_producer( The name of the new dynamic rate pipeline. source A callable source for the dynamic rate pipeline. + component + The component that is registering the rate producer. requires_columns A list of the state table columns that already need to be present and populated in the state table before the pipeline source @@ -350,6 +358,7 @@ def register_rate_producer( return self.register_value_producer( rate_name, source, + component, requires_columns, requires_values, requires_streams, @@ -361,10 +370,12 @@ def register_value_modifier( self, value_name: str, modifier: Callable[..., Any], + # TODO [MIC-5452]: all calls should have a component + component: Component | None = None, requires_columns: Iterable[str] = (), requires_values: Iterable[str] = (), requires_streams: Iterable[str] = (), - required_resources: Sequence[str | Pipeline | RandomnessStream] = (), + required_resources: Sequence[str | Resource] = (), ) -> None: """Marks a ``Callable`` as the modifier of a named value. @@ -380,6 +391,8 @@ def register_value_modifier( previous stage in the pipeline. For the ``list_combiner`` strategy, the pipeline modifiers should have the same signature as the pipeline source. + component + The component that is registering the value modifier. requires_columns A list of the state table columns that already need to be present and populated in the state table before the pipeline modifier @@ -398,6 +411,7 @@ def register_value_modifier( self._manager.register_value_modifier( value_name, modifier, + component, requires_columns, requires_values, requires_streams, diff --git a/src/vivarium/framework/values/pipeline.py b/src/vivarium/framework/values/pipeline.py index 32b72cc8..894a32bc 100644 --- a/src/vivarium/framework/values/pipeline.py +++ b/src/vivarium/framework/values/pipeline.py @@ -5,6 +5,7 @@ import pandas as pd +from vivarium import Component from vivarium.framework.resource import Resource from vivarium.framework.values.exceptions import DynamicValueError @@ -19,8 +20,15 @@ class ValueSource(Resource): """A resource representing the source of a value pipeline.""" - def __init__(self, pipeline: Pipeline, source: Callable[..., Any] | None = None) -> None: - super().__init__("value_source" if source else "missing_value_source", pipeline.name) + def __init__( + self, + pipeline: Pipeline, + source: Callable[..., Any] | None, + component: Component | None, + ) -> None: + super().__init__( + "value_source" if source else "missing_value_source", pipeline.name, component + ) self._pipeline = pipeline self._source = source @@ -40,11 +48,16 @@ def __call__(self, *args: Any, **kwargs: Any) -> Any: class ValueModifier(Resource): """A resource representing a modifier of a value pipeline.""" - def __init__(self, pipeline: Pipeline, modifier: Callable[..., Any]) -> None: + def __init__( + self, + pipeline: Pipeline, + modifier: Callable[..., Any], + component: Component | None, + ) -> None: mutator_name = self._get_modifier_name(modifier) mutator_index = len(pipeline.mutators) + 1 name = f"{pipeline.name}.{mutator_index}.{mutator_name}" - super().__init__("value_modifier", name) + super().__init__("value_modifier", name, component) self._pipeline = pipeline self._source = modifier @@ -86,10 +99,10 @@ class Pipeline(Resource): values that won't be used in the particular simulation. """ - def __init__(self, name: str) -> None: - super().__init__("value", name) + def __init__(self, name: str, component: Component | None = None) -> None: + super().__init__("value", name, component=component) - self.source: ValueSource = ValueSource(self) + self.source: ValueSource = ValueSource(self, source=None, component=None) """The callable source of the value represented by the pipeline.""" self.mutators: list[ValueModifier] = [] """A list of callables that directly modify the pipeline source or @@ -173,22 +186,37 @@ def _call(self, *args: Any, skip_post_processor: bool = False, **kwargs: Any) -> def __repr__(self) -> str: return f"_Pipeline({self.name})" - @classmethod - def setup_pipeline( - cls, - pipeline: Pipeline, - source: ValueSource, + def get_value_modifier( + self, modifier: Callable[..., Any], component: Component | None + ) -> ValueModifier: + """Add a value modifier to the pipeline and return it. + + Parameters + ---------- + modifier + The value modifier callable for the ValueModifier. + component + The component that creates the value modifier. + """ + value_modifier = ValueModifier(self, modifier, component) + self.mutators.append(value_modifier) + return value_modifier + + def set_attributes( + self, + component: Component | None, + source: Callable[..., Any], combiner: ValueCombiner, post_processor: PostProcessor | None, manager: ValuesManager, ) -> None: """ - Add a source, combiner, and post-processor to a pipeline. + Add a source, combiner, post-processor, and manager to a pipeline. Parameters ---------- - pipeline - The pipeline to configure. + component + The component that creates the pipeline. source The callable source of the value represented by the pipeline. combiner @@ -200,7 +228,8 @@ def setup_pipeline( manager The simulation values manager. """ - pipeline.source = source - pipeline._combiner = combiner - pipeline.post_processor = post_processor - pipeline._manager = manager + self.component = component + self.source = ValueSource(self, source, component) + self._combiner = combiner + self.post_processor = post_processor + self._manager = manager diff --git a/tests/framework/components/test_component.py b/tests/framework/components/test_component.py index 89a9dce4..6a988786 100644 --- a/tests/framework/components/test_component.py +++ b/tests/framework/components/test_component.py @@ -113,7 +113,7 @@ def test_component_with_initialization_requirements(): # get all resources in the dependency graph for r in simulation._resource.sorted_nodes # if the resource is an initializer - if r.is_initializer + if r.is_initialized # its initializer is an instance method and hasattr(r.initializer, "__self__") # and is a method of ColumnCreatorAndRequirer diff --git a/tests/framework/randomness/test_manager.py b/tests/framework/randomness/test_manager.py index 99667528..6b309991 100644 --- a/tests/framework/randomness/test_manager.py +++ b/tests/framework/randomness/test_manager.py @@ -1,6 +1,7 @@ import pandas as pd import pytest +from tests.helpers import ColumnCreator, ColumnRequirer from vivarium.framework.randomness.index_map import IndexMap from vivarium.framework.randomness.manager import RandomnessError, RandomnessManager from vivarium.framework.randomness.stream import get_hash @@ -12,6 +13,7 @@ def mock_clock(): def test_randomness_manager_get_randomness_stream(): seed = 123456 + component = ColumnCreator() rm = RandomnessManager() rm._add_constraint = lambda f, **kwargs: f @@ -19,15 +21,16 @@ def test_randomness_manager_get_randomness_stream(): rm._clock_ = mock_clock rm._key_columns = ["age", "sex"] rm._key_mapping_ = IndexMap(["age", "sex"]) - stream = rm._get_randomness_stream("test") + stream = rm._get_randomness_stream("test", component) assert stream.key == "test" assert stream.seed == seed assert stream.clock is mock_clock assert set(rm._decision_points.keys()) == {"test"} + assert stream.component == component with pytest.raises(RandomnessError): - rm.get_randomness_stream("test") + rm._get_randomness_stream("test", ColumnRequirer()) def test_randomness_manager_register_simulants(): diff --git a/tests/framework/resource/test_manager.py b/tests/framework/resource/test_manager.py index a63b37ca..8a2a1ff5 100644 --- a/tests/framework/resource/test_manager.py +++ b/tests/framework/resource/test_manager.py @@ -1,12 +1,13 @@ from __future__ import annotations -from collections.abc import Callable, Mapping +from collections.abc import Callable from datetime import datetime from typing import Any import pytest import pytest_mock +from tests.helpers import ColumnCreator, ColumnCreatorAndRequirer, ColumnRequirer from vivarium import Component from vivarium.framework.population import SimulantData from vivarium.framework.randomness import RandomnessStream @@ -24,9 +25,33 @@ def manager(mocker: pytest_mock.MockFixture) -> ResourceManager: return manager +@pytest.fixture +def resource_producers() -> dict[int, ResourceProducer]: + return {i: ResourceProducer(f"test_{i}") for i in range(5)} + + +@pytest.fixture +def manager_with_resources( + manager: ResourceManager, resource_producers: dict[int, ResourceProducer] +) -> ResourceManager: + stream = RandomnessStream( + "B", lambda: datetime.now(), 1, IndexMap(), resource_producers[1] + ) + pipeline = Pipeline("C", resource_producers[2]) + + manager.add_resources(resource_producers[3], ["D"], [stream, pipeline]) + manager.add_resources(stream.component, [stream], ["A"]) + manager.add_resources(pipeline.component, [pipeline], ["A"]) + manager.add_resources(resource_producers[0], ["A"], []) + manager.add_resources(resource_producers[4], [], [stream]) + return manager + + @pytest.fixture def randomness_stream() -> RandomnessStream: - return RandomnessStream("stream.1", lambda: datetime.now(), 1, IndexMap()) + return RandomnessStream( + "stream.1", lambda: datetime.now(), 1, IndexMap(), component=ColumnCreator() + ) class ResourceProducer(Component): @@ -38,7 +63,7 @@ def __init__(self, name: str): super().__init__() self._name = name - def initializer(self, _simulant_data: SimulantData) -> None: + def on_initialize_simulants(self, _simulant_data: SimulantData) -> None: pass @@ -49,7 +74,7 @@ def initializer(self, _simulant_data: SimulantData) -> None: (ValueSource, [Pipeline("foo"), lambda: 1], "value_source", False), (ValueModifier, [Pipeline("foo"), lambda: 1], "value_modifier", False), (Column, ["foo"], "column", True), - (NullResource, ["foo"], "null", True), + (NullResource, [1], "null", True), ], ids=lambda x: [x.__name__ if isinstance(x, type) else x], ) @@ -60,104 +85,117 @@ def test_resource_manager_get_resource_group( is_initializer: bool, manager: ResourceManager, ) -> None: - initializer = ResourceProducer("base").initializer + component = ColumnCreator() group = manager._get_resource_group( - [resource_class(*init_args)], [], initializer if is_initializer else None + component, [resource_class(*init_args, component=component)], [] ) assert group.type == type_string - assert group.names == [r.resource_id for r in group._resources.values()] + assert group.names == [r.resource_id for r in group.resources.values()] assert not group.dependencies - assert group.is_initializer == is_initializer - if is_initializer: - assert group.initializer == initializer - else: - with pytest.raises(ResourceError, match="does not have an initializer"): - _ = group.initializer + assert group.is_initialized == is_initializer + assert group.initializer == component.on_initialize_simulants def test_resource_manager_get_resource_group_null(manager: ResourceManager) -> None: - initializer = ResourceProducer("base").initializer + component_1 = ColumnCreator() + component_2 = ColumnCreatorAndRequirer() - group_1 = manager._get_resource_group([], [], initializer) - group_2 = manager._get_resource_group([], [], initializer) + group_1 = manager._get_resource_group(component_1, [], []) + group_2 = manager._get_resource_group(component_2, [], []) assert group_1.type == "null" assert group_1.names == ["null.0"] - assert group_1.initializer == initializer + assert group_1.initializer == component_1.on_initialize_simulants assert not group_1.dependencies assert group_2.type == "null" assert group_2.names == ["null.1"] - assert group_2.initializer == initializer + assert group_2.initializer == component_2.on_initialize_simulants assert not group_2.dependencies -def test_resource_manager_add_same_column_twice(manager: ResourceManager) -> None: - r1 = [str(i) for i in range(5)] - r2 = [str(i) for i in range(5, 10)] + ["1"] - - manager.add_resources(r1, [], ResourceProducer("1").initializer) - with pytest.raises(ResourceError, match="initializers for column.1"): - manager.add_resources(r2, [], ResourceProducer("2").initializer) - +def test_add_resource_wrong_component(manager: ResourceManager) -> None: + resource = Pipeline("foo", ColumnCreatorAndRequirer()) + error_message = "All initialized resources must have the component 'column_creator'." + with pytest.raises(ResourceError, match=error_message): + manager.add_resources(ColumnCreator(), [resource], []) -def test_resource_manager_add_same_pipeline_twice(manager: ResourceManager) -> None: - r1 = [Pipeline(str(i)) for i in range(5)] - r2 = [Pipeline(str(i)) for i in range(5, 10)] + [Pipeline("1")] - manager.add_resources(r1, [], None) - with pytest.raises(ResourceError, match="registered more than once"): - manager.add_resources(r2, [], None) +@pytest.mark.parametrize( + "resource_type, resource_creator", + [ + ("column", lambda name, component: name), + ("value", lambda name, component: Pipeline(name, component)), + ], +) +def test_resource_manager_add_same_resource_twice( + resource_type: str, + resource_creator: Callable[[str, Component], Any], + manager: ResourceManager, +) -> None: + c1 = ColumnCreator() + c2 = ColumnCreatorAndRequirer() + r1 = [resource_creator(str(i), c1) for i in range(5)] + r2 = [resource_creator(str(i), c2) for i in range(5, 10)] + [resource_creator("1", c2)] + + manager.add_resources(c1, r1, []) + error_message = ( + f"Component '{c2.name}' is attempting to register resource" + f" '{resource_type}.1' but it is already registered by '{c1.name}'." + ) + with pytest.raises(ResourceError, match=error_message): + manager.add_resources(c2, r2, []) def test_resource_manager_sorted_nodes_two_node_cycle( manager: ResourceManager, randomness_stream: RandomnessStream ) -> None: - manager.add_resources(["c_1"], [randomness_stream], ResourceProducer("1").initializer) - manager.add_resources([randomness_stream], ["c_1"], None) + manager.add_resources(ColumnCreatorAndRequirer(), ["c_1"], [randomness_stream]) + manager.add_resources(randomness_stream.component, [randomness_stream], ["c_1"]) - with pytest.raises(ResourceError, match="cycle"): + with pytest.raises(ResourceError, match="The resource pool contains at least one cycle"): _ = manager.sorted_nodes def test_resource_manager_sorted_nodes_three_node_cycle( manager: ResourceManager, randomness_stream: RandomnessStream ) -> None: - pipeline = Pipeline("some_pipeline") + pipeline = Pipeline("some_pipeline", ColumnRequirer()) - manager.add_resources(["c_1"], [randomness_stream], ResourceProducer("1").initializer) - manager.add_resources([pipeline], ["c_1"], None) - manager.add_resources([randomness_stream], [pipeline], None) + manager.add_resources(ColumnCreatorAndRequirer(), ["c_1"], [randomness_stream]) + manager.add_resources(pipeline.component, [pipeline], ["c_1"]) + manager.add_resources(randomness_stream.component, [randomness_stream], [pipeline]) - with pytest.raises(ResourceError, match="cycle"): + with pytest.raises(ResourceError, match="The resource pool contains at least one cycle"): _ = manager.sorted_nodes def test_resource_manager_sorted_nodes_large_cycle(manager: ResourceManager) -> None: + component = ColumnCreator() for i in range(10): - manager.add_resources([f"c_{i}"], [f"c_{i % 10}"], ResourceProducer("1").initializer) + manager.add_resources(component, [f"c_{i}"], [f"c_{i % 10}"]) with pytest.raises(ResourceError, match="cycle"): _ = manager.sorted_nodes def test_large_dependency_chain(manager: ResourceManager) -> None: + component = ColumnCreator() for i in range(9, 0, -1): - manager.add_resources( - [f"c_{i}"], [f"c_{i - 1}"], ResourceProducer(f"p_{i}").initializer - ) - manager.add_resources(["c_0"], [], ResourceProducer("producer_0").initializer) + manager.add_resources(component, [f"c_{i}"], [f"c_{i - 1}"]) + manager.add_resources(component, ["c_0"], []) for i, resource in enumerate(manager.sorted_nodes): assert str(resource) == f"(column.c_{i})" -def test_resource_manager_sorted_nodes_acyclic(manager: ResourceManager) -> None: - _add_resources(manager) +def test_resource_manager_sorted_nodes_acyclic( + manager_with_resources: ResourceManager, +) -> None: - n = [str(node) for node in manager.sorted_nodes] + n = [str(node) for node in manager_with_resources.sorted_nodes] assert n.index("(column.A)") < n.index("(stream.B)") assert n.index("(column.A)") < n.index("(value.C)") @@ -169,31 +207,12 @@ def test_resource_manager_sorted_nodes_acyclic(manager: ResourceManager) -> None assert n.index("(stream.B)") < n.index(f"(null.0)") -def test_get_population_initializers(manager: ResourceManager) -> None: - producers = _add_resources(manager) - initializers = manager.get_population_initializers() +def test_get_population_initializers( + manager_with_resources: ResourceManager, resource_producers: dict[int, ResourceProducer] +) -> None: + initializers = manager_with_resources.get_population_initializers() assert len(initializers) == 3 - assert initializers[0] == producers[0] - assert producers[3] in initializers - assert producers[4] in initializers - - -#################### -# Helper functions # -#################### - - -def _add_resources(manager: ResourceManager) -> Mapping[int, Callable[[SimulantData], None]]: - producers = {i: ResourceProducer(f"test_{i}").initializer for i in range(5)} - - stream = RandomnessStream("B", lambda: datetime.now(), 1, IndexMap()) - pipeline = Pipeline("C") - - manager.add_resources(["D"], [stream, pipeline], producers[3]) - manager.add_resources([stream], ["A"], None) - manager.add_resources([pipeline], ["A"], None) - manager.add_resources(["A"], [], producers[0]) - manager.add_resources([], [stream], producers[4]) - - return producers + assert initializers[0] == resource_producers[0].on_initialize_simulants + assert resource_producers[3].on_initialize_simulants in initializers + assert resource_producers[4].on_initialize_simulants in initializers diff --git a/tests/framework/resource/test_resource.py b/tests/framework/resource/test_resource.py index ee23641e..9f06d37d 100644 --- a/tests/framework/resource/test_resource.py +++ b/tests/framework/resource/test_resource.py @@ -1,6 +1,30 @@ +from datetime import datetime + +import pytest + +from tests.helpers import ColumnCreator +from vivarium.framework.randomness import RandomnessStream +from vivarium.framework.randomness.index_map import IndexMap from vivarium.framework.resource import Resource +from vivarium.framework.resource.resource import Column, NullResource +from vivarium.framework.values import Pipeline, ValueModifier, ValueSource def test_resource_id() -> None: - resource = Resource("value_source", "test") + resource = Resource("value_source", "test", ColumnCreator()) assert resource.resource_id == "value_source.test" + + +@pytest.mark.parametrize( + "resource, is_initialized", + [ + (Pipeline("foo"), False), + (ValueSource(Pipeline("bar"), lambda: 1, ColumnCreator()), False), + (ValueModifier(Pipeline("baz"), lambda: 1, ColumnCreator()), False), + (Column("foo", ColumnCreator()), True), + (RandomnessStream("bar", lambda: datetime.now(), 1, IndexMap()), False), + (NullResource(0, ColumnCreator()), True), + ], +) +def test_resource_is_initialized(resource: Resource, is_initialized: bool) -> None: + assert resource.is_initialized == is_initialized diff --git a/tests/framework/resource/test_resource_group.py b/tests/framework/resource/test_resource_group.py index 5f3d6b6c..e7f86790 100644 --- a/tests/framework/resource/test_resource_group.py +++ b/tests/framework/resource/test_resource_group.py @@ -4,33 +4,31 @@ import pytest -from vivarium.framework.population import SimulantData +from tests.helpers import ColumnCreator, ColumnRequirer from vivarium.framework.randomness import RandomnessStream from vivarium.framework.randomness.index_map import IndexMap from vivarium.framework.resource.exceptions import ResourceError from vivarium.framework.resource.group import ResourceGroup -from vivarium.framework.resource.resource import Column +from vivarium.framework.resource.resource import Column, NullResource, Resource from vivarium.framework.values import Pipeline, ValueModifier, ValueSource -def dummy_initializer(_simulant_data: SimulantData) -> None: - pass - - def test_resource_group() -> None: - resources = [Column(str(i)) for i in range(5)] + component = ColumnCreator() + resources = [Column(str(i), component) for i in range(5)] r_dependencies = [ - Column("an_interesting_column"), + Column("an_interesting_column", None), Pipeline("baz"), RandomnessStream("bar", lambda: datetime.now(), 1, IndexMap()), - ValueSource(Pipeline("foo"), lambda: 1), + ValueSource(Pipeline("foo"), lambda: 1, None), ] - rg = ResourceGroup(resources, r_dependencies, dummy_initializer) + rg = ResourceGroup(resources, r_dependencies) + assert rg.component == component assert rg.type == "column" assert rg.names == [f"column.{i}" for i in range(5)] - assert rg.initializer == dummy_initializer + assert rg.initializer == component.on_initialize_simulants assert rg.dependencies == [ "column.an_interesting_column", "value.baz", @@ -40,23 +38,42 @@ def test_resource_group() -> None: assert list(rg) == rg.names -def test_resource_group_is_initializer() -> None: - resources = [ValueModifier(Pipeline("foo"), lambda: 1)] - rg = ResourceGroup(resources, [Column("bar")]) - - with pytest.raises(ResourceError, match="does not have an initializer"): - _ = rg.initializer +@pytest.mark.parametrize( + "resource, has_initializer", + [ + (Pipeline("foo"), False), + (ValueSource(Pipeline("bar"), lambda: 1, ColumnCreator()), False), + (ValueModifier(Pipeline("baz"), lambda: 1, ColumnCreator()), False), + (Column("foo", ColumnCreator()), True), + (RandomnessStream("bar", lambda: datetime.now(), 1, IndexMap()), False), + (NullResource(0, ColumnCreator()), True), + ], +) +def test_resource_group_is_initializer(resource: Resource, has_initializer: bool) -> None: + rg = ResourceGroup([resource], [Column("bar", None)]) + assert rg.is_initialized == has_initializer def test_resource_group_with_no_resources() -> None: with pytest.raises(ResourceError, match="must have at least one resource"): - _ = ResourceGroup([], [Column("foo")]) + _ = ResourceGroup([], [Column("foo", None)]) + + +def test_resource_group_with_multiple_components() -> None: + resources = [ + ValueModifier(Pipeline("foo"), lambda: 1, ColumnCreator()), + ValueSource(Pipeline("bar"), lambda: 2, ColumnRequirer()), + ] + + with pytest.raises(ResourceError, match="resources must have the same component"): + _ = ResourceGroup(resources, []) def test_resource_group_with_multiple_resource_types() -> None: + component = ColumnCreator() resources = [ - ValueModifier(Pipeline("foo"), lambda: 1), - ValueSource(Pipeline("bar"), lambda: 2), + ValueModifier(Pipeline("foo"), lambda: 1, component), + ValueSource(Pipeline("bar"), lambda: 2, component), ] with pytest.raises(ResourceError, match="resources must be of the same type"): diff --git a/tests/helpers.py b/tests/helpers.py index b79bf622..fd565ab6 100644 --- a/tests/helpers.py +++ b/tests/helpers.py @@ -1,3 +1,4 @@ +# mypy: ignore-errors from __future__ import annotations from typing import Any, Dict, List, Optional From 6c454ce7a4b2a84a764c0ecb3a61e9243c4cd4c0 Mon Sep 17 00:00:00 2001 From: Rajan Mudambi <11376379+rmudambi@users.noreply.github.com> Date: Tue, 12 Nov 2024 10:42:42 -0800 Subject: [PATCH 18/22] fix broken vph tests (#533) --- src/vivarium/component.py | 22 +++++++--------------- src/vivarium/framework/state_machine.py | 20 ++++++++++++-------- src/vivarium/framework/values/pipeline.py | 3 +++ 3 files changed, 22 insertions(+), 23 deletions(-) diff --git a/src/vivarium/component.py b/src/vivarium/component.py index 5cbade2c..7d54308d 100644 --- a/src/vivarium/component.py +++ b/src/vivarium/component.py @@ -603,6 +603,10 @@ def build_lookup_table( If the data source is invalid. """ data = self.get_data(builder, data_source) + # TODO update this to use vivarium.types.LookupTableData once we drop + # support for Python 3.9 + if not isinstance(data, (Number, timedelta, datetime, pd.DataFrame, list, tuple)): + raise ConfigurationError(f"Data '{data}' must be a LookupTableData instance.") if isinstance(data, list): return builder.lookup.build_table(data, value_columns=list(value_columns)) @@ -656,7 +660,7 @@ def get_data( self, builder: Builder, data_source: LookupTableData | str | Callable[[Builder], LookupTableData], - ) -> float | pd.DataFrame: + ) -> Any: """Retrieves data from a data source. If the data source is a float or a DataFrame, it is treated as the data @@ -683,12 +687,7 @@ def get_data( layered_config_tree.exceptions.ConfigurationError If the data source is invalid. """ - # TODO update this to use vivarium.types.LookupTableData once we drop - # support for Python 3.9 - valid_data_types = (Number, timedelta, datetime, pd.DataFrame, list, tuple) - if isinstance(data_source, valid_data_types): - data = data_source - elif isinstance(data_source, str): + if isinstance(data_source, str): if "::" in data_source: module, method = data_source.split("::") try: @@ -716,15 +715,8 @@ def get_data( elif isinstance(data_source, Callable): data = data_source(builder) else: - raise ConfigurationError( - f"Data source is of type '{type(data_source)}'. It must be a " - "LookupTableData instance, a string corresponding to an " - "artifact key, a callable that returns a LookupTableData " - "instance, or a string defining such a callable." - ) + data = data_source - if not isinstance(data, valid_data_types): - raise ConfigurationError(f"Data '{data}' must be a LookupTableData instance.") return data def _set_population_view(self, builder: "Builder") -> None: diff --git a/src/vivarium/framework/state_machine.py b/src/vivarium/framework/state_machine.py index e91b59f7..7f04be57 100644 --- a/src/vivarium/framework/state_machine.py +++ b/src/vivarium/framework/state_machine.py @@ -27,10 +27,6 @@ from vivarium.types import ClockTime, LookupTableData -def default_initializer(_builder: Builder) -> LookupTableData: - return 0.0 - - def _next_state( index: pd.Index, event_time: ClockTime, @@ -221,7 +217,7 @@ def __init__( self, state_id: str, allow_self_transition: bool = False, - initialization_weights: Callable[[Builder], LookupTableData] = default_initializer, + initialization_weights: Callable[[Builder], LookupTableData] | None = None, ) -> None: super().__init__() self.state_id = state_id @@ -294,7 +290,10 @@ def allow_self_transitions(self) -> None: ################## def get_initialization_weights(self, builder: Builder) -> LookupTableData: - return self.initialization_weights(builder) + if self.initialization_weights: + return self.initialization_weights(builder) + else: + return 0.0 def transition_side_effect(self, index: pd.Index, event_time: ClockTime) -> None: pass @@ -516,7 +515,7 @@ def __init__( self.add_states(states) states_with_initialization_weights = [ - s for s in self.states if s.initialization_weights != default_initializer + state for state in self.states if state.initialization_weights ] if initial_state is not None: @@ -533,7 +532,12 @@ def __init__( initial_state.initialization_weights = lambda _builder: 1.0 - elif not states_with_initialization_weights: + # TODO: [MIC-5403] remove this on_initialize_simulants check once + # VPH's DiseaseModel has a compatible initialization strategy + elif ( + type(self).on_initialize_simulants == Machine.on_initialize_simulants + and not states_with_initialization_weights + ): raise ValueError( "Must specify either an initial state or provide" " initialization weights to states." diff --git a/src/vivarium/framework/values/pipeline.py b/src/vivarium/framework/values/pipeline.py index 894a32bc..6cc81a65 100644 --- a/src/vivarium/framework/values/pipeline.py +++ b/src/vivarium/framework/values/pipeline.py @@ -186,6 +186,9 @@ def _call(self, *args: Any, skip_post_processor: bool = False, **kwargs: Any) -> def __repr__(self) -> str: return f"_Pipeline({self.name})" + def __hash__(self) -> int: + return hash(self.name) + def get_value_modifier( self, modifier: Callable[..., Any], component: Component | None ) -> ValueModifier: From 8491afb1c75e401b8e2e0528662e1874dd1f45a6 Mon Sep 17 00:00:00 2001 From: Rajan Mudambi <11376379+rmudambi@users.noreply.github.com> Date: Tue, 12 Nov 2024 13:04:32 -0800 Subject: [PATCH 19/22] fix imprecise resource type-hinting (#534) --- src/vivarium/component.py | 5 ++--- src/vivarium/examples/disease_model/risk.py | 5 ++--- src/vivarium/framework/state_machine.py | 5 ++--- tests/helpers.py | 5 ++--- 4 files changed, 8 insertions(+), 12 deletions(-) diff --git a/src/vivarium/component.py b/src/vivarium/component.py index 7d54308d..76f2caf1 100644 --- a/src/vivarium/component.py +++ b/src/vivarium/component.py @@ -32,8 +32,7 @@ from vivarium.framework.event import Event from vivarium.framework.lookup import LookupTable from vivarium.framework.population import PopulationView, SimulantData - from vivarium.framework.randomness import RandomnessStream - from vivarium.framework.values import Pipeline + from vivarium.framework.resource import Resource from vivarium.types import LookupTableData DEFAULT_EVENT_PRIORITY = 5 @@ -239,7 +238,7 @@ def columns_required(self) -> Optional[List[str]]: @property def initialization_requirements( self, - ) -> list[str | Pipeline | RandomnessStream]: + ) -> list[str | Resource]: """A list containing the columns, pipelines, and randomness streams required by this component's simulant initializer.""" return [] diff --git a/src/vivarium/examples/disease_model/risk.py b/src/vivarium/examples/disease_model/risk.py index 419efc6e..0969813f 100644 --- a/src/vivarium/examples/disease_model/risk.py +++ b/src/vivarium/examples/disease_model/risk.py @@ -9,8 +9,7 @@ if TYPE_CHECKING: from vivarium.framework.engine import Builder - from vivarium.framework.randomness import RandomnessStream - from vivarium.framework.values import Pipeline + from vivarium.framework.resource import Resource class Risk(Component): @@ -33,7 +32,7 @@ def columns_created(self) -> List[str]: return [self.propensity_column] @property - def initialization_requirements(self) -> list[str | Pipeline | RandomnessStream]: + def initialization_requirements(self) -> list[str | Resource]: return [self.randomness] ##################### diff --git a/src/vivarium/framework/state_machine.py b/src/vivarium/framework/state_machine.py index 7f04be57..7a2ab9a4 100644 --- a/src/vivarium/framework/state_machine.py +++ b/src/vivarium/framework/state_machine.py @@ -22,8 +22,7 @@ from vivarium.framework.engine import Builder from vivarium.framework.event import Event from vivarium.framework.population import PopulationView, SimulantData - from vivarium.framework.randomness import RandomnessStream - from vivarium.framework.values import Pipeline + from vivarium.framework.resource import Resource from vivarium.types import ClockTime, LookupTableData @@ -495,7 +494,7 @@ def columns_created(self) -> List[str]: @property def initialization_requirements( self, - ) -> list[str | Pipeline | RandomnessStream]: + ) -> list[str | Resource]: return [self.randomness] ##################### diff --git a/tests/helpers.py b/tests/helpers.py index fd565ab6..f9c0a0e2 100644 --- a/tests/helpers.py +++ b/tests/helpers.py @@ -9,8 +9,7 @@ from vivarium.framework.engine import Builder from vivarium.framework.event import Event from vivarium.framework.population import SimulantData -from vivarium.framework.randomness import RandomnessStream -from vivarium.framework.values import Pipeline +from vivarium.framework.resource import Resource class MockComponentA(Observer): @@ -247,7 +246,7 @@ def columns_created(self) -> List[str]: return ["test_column_4"] @property - def initialization_requirements(self) -> list[str | Pipeline | RandomnessStream]: + def initialization_requirements(self) -> list[str | Resource]: return ["test_column_2", self.pipeline, self.randomness] def setup(self, builder: Builder) -> None: From ed09d1920cd9e3318ba72651e2c9c1bd5659082c Mon Sep 17 00:00:00 2001 From: Rajan Mudambi <11376379+rmudambi@users.noreply.github.com> Date: Tue, 12 Nov 2024 16:18:09 -0800 Subject: [PATCH 20/22] update changelog (#535) --- CHANGELOG.rst | 20 ++++++++++++++++---- 1 file changed, 16 insertions(+), 4 deletions(-) diff --git a/CHANGELOG.rst b/CHANGELOG.rst index 4fd9f1b8..c25fbcd9 100644 --- a/CHANGELOG.rst +++ b/CHANGELOG.rst @@ -1,7 +1,19 @@ -**3.2.0 - TBD** - - - Enable Machine to be used directly to model a state machine - - Support passing callables directly when building lookup tables +**3.2.0 - 11/12/24** + + - Feature: Supports passing callables directly when building lookup tables + - Feature: Enables columns and pipelines to specify dependencies directly, instead of by name + - Feature: Enables identification of which component produced a Pipeline or RandomnessStream + - Bugfix: Enables Machine to be used directly to model a state machine + - Bugfix: Ensures that a Pipeline will always have a name + - Bugfix: Appropriately declares dependencies in example models + - Testing: Adds coverage for example DiseaseModel + - Refactor: Converts resource module into a package + - Refactor: Converts values module into a package + - Refactor: Simplifies code to allow Managers to create columns + - Refactor: Converts ResourceManager __iter__ to a well-named instance method + - Refactor: Creates ResourceTypes for each type of resource + - Refactor: Makes Pipeline and RandomnessStream inherit from Resource + - Refactor: Creates ValueSource and ValueModifer resources and attaches them to Pipelines **3.1.0 - 11/07/24** From 2a81a36d566bb195ddab40f40e981a4b60fa1806 Mon Sep 17 00:00:00 2001 From: Rajan Mudambi <11376379+rmudambi@users.noreply.github.com> Date: Tue, 12 Nov 2024 16:23:01 -0800 Subject: [PATCH 21/22] Machine calls cleanup on time-step cleanup (#536) --- src/vivarium/framework/state_machine.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/src/vivarium/framework/state_machine.py b/src/vivarium/framework/state_machine.py index 7a2ab9a4..c44d7b6f 100644 --- a/src/vivarium/framework/state_machine.py +++ b/src/vivarium/framework/state_machine.py @@ -563,6 +563,9 @@ def on_initialize_simulants(self, pop_data: SimulantData) -> None: def on_time_step(self, event: Event) -> None: self.transition(event.index, event.time) + def on_time_step_cleanup(self, event: Event) -> None: + self.cleanup(event.index, event.time) + ################## # Public methods # ################## From 881d6b8df66a3b1400b617b01ca26d2d520211ae Mon Sep 17 00:00:00 2001 From: Steve Bachmeier <23350991+stevebachmeier@users.noreply.github.com> Date: Wed, 13 Nov 2024 11:30:02 -0700 Subject: [PATCH 22/22] Sbachmei/mic 5549/mypy results context (#538) --- CHANGELOG.rst | 4 ++ docs/source/concepts/results.rst | 26 ++++----- pyproject.toml | 1 - src/vivarium/framework/results/context.py | 54 +++++++++---------- src/vivarium/framework/results/manager.py | 5 +- src/vivarium/framework/results/observation.py | 17 +++--- src/vivarium/framework/results/observer.py | 2 +- .../framework/results/stratification.py | 7 +-- src/vivarium/types.py | 4 ++ 9 files changed, 65 insertions(+), 55 deletions(-) diff --git a/CHANGELOG.rst b/CHANGELOG.rst index c25fbcd9..982bc029 100644 --- a/CHANGELOG.rst +++ b/CHANGELOG.rst @@ -1,3 +1,7 @@ +**3.2.1 - TBD/TBD/TBD** + + - Fix mypy errors in vivarium/framework/results/context.py + **3.2.0 - 11/12/24** - Feature: Supports passing callables directly when building lookup tables diff --git a/docs/source/concepts/results.rst b/docs/source/concepts/results.rst index 3581aed5..4463a7fb 100644 --- a/docs/source/concepts/results.rst +++ b/docs/source/concepts/results.rst @@ -303,7 +303,7 @@ A couple other more specific and commonly used observations are provided as well that gathers new results and concatenates them to any existing results. Ideally, all concrete classes should inherit from the -:class:`BaseObservation ` +:class:`Observation ` abstract base class, which contains the common attributes between observation types: .. list-table:: **Common Observation Attributes** @@ -312,40 +312,40 @@ abstract base class, which contains the common attributes between observation ty * - Attribute - Description - * - | :attr:`name ` + * - | :attr:`name ` - | Name of the observation. It will also be the name of the output results file | for this particular observation. - * - | :attr:`pop_filter ` + * - | :attr:`pop_filter ` - | A Pandas query filter string to filter the population down to the simulants | who should be considered for the observation. - * - | :attr:`when ` + * - | :attr:`when ` - | Name of the lifecycle phase the observation should happen. Valid values are: | "time_step__prepare", "time_step", "time_step__cleanup", or "collect_metrics". - * - | :attr:`results_initializer ` + * - | :attr:`results_initializer ` - | Method or function that initializes the raw observation results | prior to starting the simulation. This could return, for example, an empty | DataFrame or one with a complete set of stratifications as the index and | all values set to 0.0. - * - | :attr:`results_gatherer ` + * - | :attr:`results_gatherer ` - | Method or function that gathers the new observation results. - * - | :attr:`results_updater ` + * - | :attr:`results_updater ` - | Method or function that updates existing raw observation results with newly | gathered results. - * - | :attr:`results_formatter ` + * - | :attr:`results_formatter ` - | Method or function that formats the raw observation results. - * - | :attr:`stratifications ` + * - | :attr:`stratifications ` - | Optional tuple of column names for the observation to stratify by. - * - | :attr:`to_observe ` + * - | :attr:`to_observe ` - | Method or function that determines whether to perform an observation on this Event. -The **BaseObservation** also contains the -:meth:`observe ` +The **Observation** also contains the +:meth:`observe ` method which is called at each :ref:`event ` and :ref:`time step ` to determine whether or not the observation should be recorded, and if so, gathers the results and stores them in the results system. .. note:: - All four observation types discussed above inherit from the **BaseObservation** + All four observation types discussed above inherit from the **Observation** abstract base class. What differentiates them are the assigned attributes (e.g. defining the **results_updater** to be an adding method for the **AddingObservation**) or adding other attributes as necessary (e.g. diff --git a/pyproject.toml b/pyproject.toml index 3f89b0c9..ccc81506 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -51,7 +51,6 @@ exclude = [ 'src/vivarium/framework/lookup/manager.py', 'src/vivarium/framework/population/manager.py', 'src/vivarium/framework/population/population_view.py', - 'src/vivarium/framework/results/context.py', 'src/vivarium/framework/results/interface.py', 'src/vivarium/framework/results/manager.py', 'src/vivarium/framework/results/observer.py', diff --git a/src/vivarium/framework/results/context.py b/src/vivarium/framework/results/context.py index 353cecba..987c5777 100644 --- a/src/vivarium/framework/results/context.py +++ b/src/vivarium/framework/results/context.py @@ -1,4 +1,3 @@ -# mypy: ignore-errors """ =============== Results Context @@ -6,8 +5,11 @@ """ +from __future__ import annotations + from collections import defaultdict -from typing import Callable, Generator, List, Optional, Tuple, Type, Union +from collections.abc import Callable, Generator +from typing import Any import pandas as pd from pandas.core.groupby.generic import DataFrameGroupBy @@ -15,13 +17,13 @@ from vivarium.framework.engine import Builder from vivarium.framework.event import Event from vivarium.framework.results.exceptions import ResultsConfigurationError -from vivarium.framework.results.observation import BaseObservation +from vivarium.framework.results.observation import Observation from vivarium.framework.results.stratification import ( Stratification, get_mapped_col_name, get_original_col_name, ) -from vivarium.types import ScalarValue +from vivarium.types import ScalarMapper, VectorMapper class ResultsContext: @@ -52,10 +54,12 @@ class ResultsContext: """ def __init__(self) -> None: - self.default_stratifications: List[str] = [] - self.stratifications: List[Stratification] = [] + self.default_stratifications: list[str] = [] + self.stratifications: list[Stratification] = [] self.excluded_categories: dict[str, list[str]] = {} - self.observations: defaultdict = defaultdict(lambda: defaultdict(list)) + self.observations: defaultdict[ + str, defaultdict[tuple[str, tuple[str, ...] | None], list[Observation]] + ] = defaultdict(lambda: defaultdict(list)) @property def name(self) -> str: @@ -73,7 +77,7 @@ def setup(self, builder: Builder) -> None: ) # noinspection PyAttributeOutsideInit - def set_default_stratifications(self, default_grouping_columns: List[str]) -> None: + def set_default_stratifications(self, default_grouping_columns: list[str]) -> None: """Set the default stratifications to be used by stratified observations. Parameters @@ -96,15 +100,10 @@ def set_default_stratifications(self, default_grouping_columns: List[str]) -> No def add_stratification( self, name: str, - sources: List[str], - categories: List[str], - excluded_categories: Optional[List[str]], - mapper: Optional[ - Union[ - Callable[[Union[pd.Series, pd.DataFrame]], pd.Series], - Callable[[ScalarValue], str], - ] - ], + sources: list[str], + categories: list[str], + excluded_categories: list[str] | None, + mapper: VectorMapper | ScalarMapper | None, is_vectorized: bool, ) -> None: """Add a stratification to the results context. @@ -187,11 +186,11 @@ def add_stratification( def register_observation( self, - observation_type: Type[BaseObservation], + observation_type: type[Observation], name: str, pop_filter: str, when: str, - **kwargs, + **kwargs: Any, ) -> None: """Add an observation to the results context. @@ -242,10 +241,10 @@ def register_observation( def gather_results( self, population: pd.DataFrame, lifecycle_phase: str, event: Event ) -> Generator[ - Tuple[ - Optional[pd.DataFrame], - Optional[str], - Optional[Callable[[pd.DataFrame, pd.DataFrame], pd.DataFrame]], + tuple[ + pd.DataFrame | None, + str | None, + Callable[[pd.DataFrame, pd.DataFrame], pd.DataFrame] | None, ], None, None, @@ -302,6 +301,7 @@ def gather_results( if filtered_pop.empty: yield None, None, None else: + pop: pd.DataFrame | DataFrameGroupBy[tuple[str, ...] | str] if stratification_names is None: pop = filtered_pop else: @@ -317,7 +317,7 @@ def _filter_population( self, population: pd.DataFrame, pop_filter: str, - stratification_names: Optional[tuple[str, ...]], + stratification_names: tuple[str, ...] | None, ) -> pd.DataFrame: """Filter out simulants not to observe.""" pop = population.query(pop_filter) if pop_filter else population.copy() @@ -334,8 +334,8 @@ def _filter_population( @staticmethod def _get_groups( - stratifications: Tuple[str, ...], filtered_pop: pd.DataFrame - ) -> DataFrameGroupBy: + stratifications: tuple[str, ...], filtered_pop: pd.DataFrame + ) -> DataFrameGroupBy[tuple[str, ...] | str]: """Group the population by stratification. Notes @@ -356,7 +356,7 @@ def _get_groups( ) else: pop_groups = filtered_pop.groupby(lambda _: "all") - return pop_groups + return pop_groups # type: ignore[return-value] def _rename_stratification_columns(self, results: pd.DataFrame) -> None: """Convert the temporary stratified mapped index names back to their original names.""" diff --git a/src/vivarium/framework/results/manager.py b/src/vivarium/framework/results/manager.py index 6c18d279..ee686476 100644 --- a/src/vivarium/framework/results/manager.py +++ b/src/vivarium/framework/results/manager.py @@ -8,12 +8,13 @@ from collections import defaultdict from enum import Enum -from typing import TYPE_CHECKING, Callable, Dict, List, Optional, Tuple, Union +from typing import TYPE_CHECKING, Callable, Dict, List, Optional, Tuple, Type, Union import pandas as pd from vivarium.framework.event import Event from vivarium.framework.results.context import ResultsContext +from vivarium.framework.results.observation import Observation from vivarium.framework.values import Pipeline from vivarium.manager import Manager from vivarium.types import ScalarValue @@ -301,7 +302,7 @@ def _bin_data(data: Union[pd.Series, pd.DataFrame]) -> pd.Series: def register_observation( self, - observation_type, + observation_type: Type[Observation], is_stratified: bool, name: str, pop_filter: str, diff --git a/src/vivarium/framework/results/observation.py b/src/vivarium/framework/results/observation.py index f400900d..8037573f 100644 --- a/src/vivarium/framework/results/observation.py +++ b/src/vivarium/framework/results/observation.py @@ -6,7 +6,7 @@ An observation is a class object that records simulation results; they are responsible for initializing, gathering, updating, and formatting results. -The provided :class:`BaseObservation` class is an abstract base class that should +The provided :class:`Observation` class is an abstract base class that should be subclassed by concrete observations. While there are no required abstract methods to define when subclassing, the class does provide common attributes as well as an `observe` method that determines whether to observe results for a given event. @@ -24,7 +24,6 @@ from abc import ABC from collections.abc import Callable from dataclasses import dataclass -from typing import Any import pandas as pd from pandas.api.types import CategoricalDtype @@ -37,7 +36,7 @@ @dataclass -class BaseObservation(ABC): +class Observation(ABC): """An abstract base dataclass to be inherited by concrete observations. This class includes an :meth:`observe ` method that determines whether @@ -60,7 +59,8 @@ class BaseObservation(ABC): DataFrame or one with a complete set of stratifications as the index and all values set to 0.0.""" results_gatherer: Callable[ - [pd.DataFrame | DataFrameGroupBy[str], tuple[str, ...] | None], pd.DataFrame + [pd.DataFrame | DataFrameGroupBy[tuple[str, ...] | str], tuple[str, ...] | None], + pd.DataFrame, ] """Method or function that gathers the new observation results.""" results_updater: Callable[[pd.DataFrame, pd.DataFrame], pd.DataFrame] @@ -76,7 +76,7 @@ class BaseObservation(ABC): def observe( self, event: Event, - df: pd.DataFrame | DataFrameGroupBy[str], + df: pd.DataFrame | DataFrameGroupBy[tuple[str, ...] | str], stratifications: tuple[str, ...] | None, ) -> pd.DataFrame | None: """Determine whether to observe the given event, and if so, gather the results. @@ -100,7 +100,7 @@ def observe( return self.results_gatherer(df, stratifications) -class UnstratifiedObservation(BaseObservation): +class UnstratifiedObservation(Observation): """Concrete class for observing results that are not stratified. The parent class `stratifications` are set to None and the `results_initializer` @@ -139,7 +139,8 @@ def __init__( to_observe: Callable[[Event], bool] = lambda event: True, ): def _wrap_results_gatherer( - df: pd.DataFrame | DataFrameGroupBy[str], _: tuple[str, ...] | None + df: pd.DataFrame | DataFrameGroupBy[tuple[str, ...] | str], + _: tuple[str, ...] | None, ) -> pd.DataFrame: if isinstance(df, DataFrameGroupBy): raise TypeError( @@ -181,7 +182,7 @@ def create_empty_df( return pd.DataFrame() -class StratifiedObservation(BaseObservation): +class StratifiedObservation(Observation): """Concrete class for observing stratified results. The parent class `results_initializer` and `results_gatherer` methods are diff --git a/src/vivarium/framework/results/observer.py b/src/vivarium/framework/results/observer.py index 7a37a5db..a04c71b2 100644 --- a/src/vivarium/framework/results/observer.py +++ b/src/vivarium/framework/results/observer.py @@ -5,7 +5,7 @@ ========= An observer is a component that is responsible for registering -:class:`observations ` +:class:`observations ` to the simulation. The provided :class:`Observer` class is an abstract base class that should be subclassed diff --git a/src/vivarium/framework/results/stratification.py b/src/vivarium/framework/results/stratification.py index 7be52813..e0d4d1f7 100644 --- a/src/vivarium/framework/results/stratification.py +++ b/src/vivarium/framework/results/stratification.py @@ -4,19 +4,20 @@ =============== """ + from __future__ import annotations from dataclasses import dataclass -from typing import Any, Callable +from typing import Any import pandas as pd from pandas.api.types import CategoricalDtype +from vivarium.types import ScalarMapper, VectorMapper + STRATIFICATION_COLUMN_SUFFIX: str = "mapped_values" # TODO: Parameterizing pandas objects fails below python 3.12 -VectorMapper = Callable[[pd.DataFrame], pd.Series] # type: ignore [type-arg] -ScalarMapper = Callable[[pd.Series], str] # type: ignore [type-arg] @dataclass diff --git a/src/vivarium/types.py b/src/vivarium/types.py index 89da4389..0de4ea49 100644 --- a/src/vivarium/types.py +++ b/src/vivarium/types.py @@ -1,3 +1,4 @@ +from collections.abc import Callable from datetime import datetime, timedelta from numbers import Number from typing import Union @@ -25,3 +26,6 @@ float, int, ] + +VectorMapper = Callable[[pd.DataFrame], pd.Series] # type: ignore [type-arg] +ScalarMapper = Callable[[pd.Series], str] # type: ignore [type-arg]