Skip to content

Commit

Permalink
Update typing to include manager
Browse files Browse the repository at this point in the history
  • Loading branch information
albrja committed Nov 13, 2024
1 parent 8acdaed commit e20c542
Show file tree
Hide file tree
Showing 3 changed files with 12 additions and 6 deletions.
9 changes: 7 additions & 2 deletions src/vivarium/framework/time.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,9 @@ def setup(self, builder: "Builder") -> None:
preferred_post_processor=self.step_size_post_processor,
)
self.register_step_modifier = partial(
builder.value.register_value_modifier, self._pipeline_name
builder.value.register_value_modifier,
self._pipeline_name,
component=self,
)
builder.population.initializes_simulants(self, creates_columns=self.columns_created)
builder.event.register_listener("post_setup", self.on_post_setup)
Expand Down Expand Up @@ -352,5 +354,8 @@ def register_step_size_modifier(
A list of the randomness streams that need to be properly sourced
before the modifier is called."""
return self._manager.register_step_modifier(
modifier, requires_columns, requires_values, requires_streams
modifier=modifier,
requires_columns=requires_columns,
requires_values=requires_values,
requires_streams=requires_streams,
)
4 changes: 2 additions & 2 deletions src/vivarium/framework/values/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ def register_value_modifier(
value_name: str,
modifier: Callable[..., Any],
# TODO [MIC-5452]: all calls should have a component
component: Component | None = None,
component: Component | Manager | None = None,
requires_columns: Iterable[str] = (),
requires_values: Iterable[str] = (),
requires_streams: Iterable[str] = (),
Expand Down Expand Up @@ -371,7 +371,7 @@ def register_value_modifier(
value_name: str,
modifier: Callable[..., Any],
# TODO [MIC-5452]: all calls should have a component
component: Component | None = None,
component: Component | Manager | None = None,
requires_columns: Iterable[str] = (),
requires_values: Iterable[str] = (),
requires_streams: Iterable[str] = (),
Expand Down
5 changes: 3 additions & 2 deletions src/vivarium/framework/values/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from vivarium import Component
from vivarium.framework.resource import Resource
from vivarium.framework.values.exceptions import DynamicValueError
from vivarium.manager import Manager

if TYPE_CHECKING:
from vivarium.framework.values.combiners import ValueCombiner
Expand Down Expand Up @@ -52,7 +53,7 @@ def __init__(
self,
pipeline: Pipeline,
modifier: Callable[..., Any],
component: Component | None,
component: Component | Manager | None,
) -> None:
mutator_name = self._get_modifier_name(modifier)
mutator_index = len(pipeline.mutators) + 1
Expand Down Expand Up @@ -190,7 +191,7 @@ def __hash__(self) -> int:
return hash(self.name)

def get_value_modifier(
self, modifier: Callable[..., Any], component: Component | None
self, modifier: Callable[..., Any], component: Component | Manager | None
) -> ValueModifier:
"""Add a value modifier to the pipeline and return it.
Expand Down

0 comments on commit e20c542

Please sign in to comment.