diff --git a/src/vivarium/framework/time.py b/src/vivarium/framework/time.py index 2aacab5e..4a239af5 100644 --- a/src/vivarium/framework/time.py +++ b/src/vivarium/framework/time.py @@ -120,7 +120,9 @@ def setup(self, builder: "Builder") -> None: preferred_post_processor=self.step_size_post_processor, ) self.register_step_modifier = partial( - builder.value.register_value_modifier, self._pipeline_name + builder.value.register_value_modifier, + self._pipeline_name, + component=self, ) builder.population.initializes_simulants(self, creates_columns=self.columns_created) builder.event.register_listener("post_setup", self.on_post_setup) @@ -352,5 +354,8 @@ def register_step_size_modifier( A list of the randomness streams that need to be properly sourced before the modifier is called.""" return self._manager.register_step_modifier( - modifier, requires_columns, requires_values, requires_streams + modifier=modifier, + requires_columns=requires_columns, + requires_values=requires_values, + requires_streams=requires_streams, ) diff --git a/src/vivarium/framework/values/manager.py b/src/vivarium/framework/values/manager.py index cb4bf3f1..9ad65a3a 100644 --- a/src/vivarium/framework/values/manager.py +++ b/src/vivarium/framework/values/manager.py @@ -109,7 +109,7 @@ def register_value_modifier( value_name: str, modifier: Callable[..., Any], # TODO [MIC-5452]: all calls should have a component - component: Component | None = None, + component: Component | Manager | None = None, requires_columns: Iterable[str] = (), requires_values: Iterable[str] = (), requires_streams: Iterable[str] = (), @@ -371,7 +371,7 @@ def register_value_modifier( value_name: str, modifier: Callable[..., Any], # TODO [MIC-5452]: all calls should have a component - component: Component | None = None, + component: Component | Manager | None = None, requires_columns: Iterable[str] = (), requires_values: Iterable[str] = (), requires_streams: Iterable[str] = (), diff --git a/src/vivarium/framework/values/pipeline.py b/src/vivarium/framework/values/pipeline.py index 6cc81a65..b7afe8eb 100644 --- a/src/vivarium/framework/values/pipeline.py +++ b/src/vivarium/framework/values/pipeline.py @@ -8,6 +8,7 @@ from vivarium import Component from vivarium.framework.resource import Resource from vivarium.framework.values.exceptions import DynamicValueError +from vivarium.manager import Manager if TYPE_CHECKING: from vivarium.framework.values.combiners import ValueCombiner @@ -52,7 +53,7 @@ def __init__( self, pipeline: Pipeline, modifier: Callable[..., Any], - component: Component | None, + component: Component | Manager | None, ) -> None: mutator_name = self._get_modifier_name(modifier) mutator_index = len(pipeline.mutators) + 1 @@ -190,7 +191,7 @@ def __hash__(self) -> int: return hash(self.name) def get_value_modifier( - self, modifier: Callable[..., Any], component: Component | None + self, modifier: Callable[..., Any], component: Component | Manager | None ) -> ValueModifier: """Add a value modifier to the pipeline and return it.