Skip to content

Commit

Permalink
Update errors and ignore typing
Browse files Browse the repository at this point in the history
  • Loading branch information
albrja committed Nov 13, 2024
1 parent 4f8ed49 commit d286246
Showing 1 changed file with 54 additions and 44 deletions.
98 changes: 54 additions & 44 deletions src/vivarium/framework/time.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,12 @@
"""

from __future__ import annotations

import math
from collections.abc import Callable
from functools import partial
from typing import TYPE_CHECKING, Callable, List
from typing import TYPE_CHECKING, Any

import numpy as np
import pandas as pd
Expand All @@ -25,6 +28,7 @@
from vivarium.framework.population.population_view import PopulationView
from vivarium.framework.event import Event
from vivarium.framework.population import SimulantData
from vivarium.framework.values import PostProcessor, ValuesManager

from vivarium.framework.values import list_combiner
from vivarium.manager import Interface, Manager
Expand All @@ -34,15 +38,15 @@ class SimulationClock(Manager):
"""A base clock that includes global clock and a pandas series of clocks for each simulant"""

@property
def name(self):
def name(self) -> str:
return "simulation_clock"

@property
def columns_created(self) -> List[str]:
def columns_created(self) -> list[str]:
return ["next_event_time", "step_size"]

@property
def columns_required(self) -> List[str]:
def columns_required(self) -> list[str]:
return ["tracked"]

@property
Expand Down Expand Up @@ -87,24 +91,28 @@ def step_size(self) -> ClockStepSize:
@property
def event_time(self) -> ClockTime:
"Convenience method for event time, or clock + step"
return self.time + self.step_size
return self.time + self.step_size # type: ignore [operator]

@property
def time_steps_remaining(self) -> int:
return math.ceil((self.stop_time - self.time) / self.step_size)

def __init__(self):
self._clock_time: ClockTime = None
self._stop_time: ClockTime = None
self._minimum_step_size: ClockStepSize = None
self._standard_step_size: ClockStepSize = None
self._clock_step_size: ClockStepSize = None
self._individual_clocks: PopulationView = None
# return math.ceil((self.stop_time - self.time) / self.step_size) # type: ignore [operator]
number_steps_remaining = (self.stop_time - self.time) / self.step_size # type: ignore [operator]
if not isinstance(number_steps_remaining, (float, int)):
raise ValueError("Invalid type for number of steps remaining")
return math.ceil(number_steps_remaining)

def __init__(self) -> None:
self._clock_time: ClockTime | None = None
self._stop_time: ClockTime | None = None
self._minimum_step_size: ClockStepSize | None = None
self._standard_step_size: ClockStepSize | None = None
self._clock_step_size: ClockStepSize | None = None
self._individual_clocks: PopulationView | None = None
self._pipeline_name = "simulant_step_size"
# TODO: Delegate this functionality to "tracked" or similar when appropriate
self._simulants_to_snooze = pd.Index([])

def setup(self, builder: "Builder"):
def setup(self, builder: "Builder") -> None:
self._step_size_pipeline = builder.value.register_value_producer(
self._pipeline_name,
source=lambda idx: [pd.Series(np.nan, index=idx).astype("timedelta64[ns]")],
Expand Down Expand Up @@ -140,15 +148,15 @@ def on_initialize_simulants(self, pop_data: "SimulantData") -> None:
)
self._individual_clocks.update(clocks_to_initialize)

def simulant_next_event_times(self, index: pd.Index) -> pd.Series:
def simulant_next_event_times(self, index: pd.Index[int]) -> pd.Series[ClockTime]:
"""The next time each simulant will be updated."""
if not self._individual_clocks:
return pd.Series(self.event_time, index=index)
return self._individual_clocks.subview(["next_event_time", "tracked"]).get(index)[
"next_event_time"
]

def simulant_step_sizes(self, index: pd.Index) -> pd.Series:
def simulant_step_sizes(self, index: pd.Index[int]) -> pd.Series[ClockStepSize]:
"""The step size for each simulant."""
if not self._individual_clocks:
return pd.Series(self.step_size, index=index)
Expand All @@ -158,41 +166,43 @@ def simulant_step_sizes(self, index: pd.Index) -> pd.Series:

def step_backward(self) -> None:
"""Rewinds the clock by the current step size."""
self._clock_time -= self.step_size
if self._clock_time is None:
raise ValueError("No start time provided")
self._clock_time -= self.step_size # type: ignore [operator]

def step_forward(self, index: pd.Index) -> None:
def step_forward(self, index: pd.Index[int]) -> None:
"""Advances the clock by the current step size, and updates aligned simulant clocks."""
self._clock_time += self.step_size
if self._individual_clocks and index.any():
self._clock_time += self.step_size # type: ignore [assignment, operator]
if self._individual_clocks and not index.empty:
update_index = self.get_active_simulants(index, self.time)
clocks_to_update = self._individual_clocks.get(update_index)
if not clocks_to_update.empty:
clocks_to_update["step_size"] = self._step_size_pipeline(update_index)
# Simulants that were flagged to get moved to the end should have a next event time
# of stop time + 1 minimum timestep
clocks_to_update.loc[self._simulants_to_snooze, "step_size"] = (
self.stop_time + self.minimum_step_size - self.time
self.stop_time + self.minimum_step_size - self.time # type: ignore [operator]
)
# TODO: Delegate this functionality to "tracked" or similar when appropriate
self._simulants_to_snooze = pd.Index([])
clocks_to_update["next_event_time"] = (
self.time + clocks_to_update["step_size"]
)
self._individual_clocks.update(clocks_to_update)
self._clock_step_size = self.simulant_next_event_times(index).min() - self.time
self._clock_step_size = self.simulant_next_event_times(index).min() - self.time # type: ignore [operator]

def get_active_simulants(self, index: pd.Index, time: ClockTime) -> pd.Index:
def get_active_simulants(self, index: pd.Index[int], time: ClockTime) -> pd.Index[int]:
"""Gets population that is aligned with global clock"""
if index.empty or not self._individual_clocks:
return index
next_event_times = self.simulant_next_event_times(index)
return next_event_times[next_event_times <= time].index

def move_simulants_to_end(self, index: pd.Index) -> None:
if self._individual_clocks and index.any():
def move_simulants_to_end(self, index: pd.Index[int]) -> None:
if self._individual_clocks and not index.empty:
self._simulants_to_snooze = self._simulants_to_snooze.union(index)

def step_size_post_processor(self, values: List[NumberLike], _) -> pd.Series:
def step_size_post_processor(self, value: Any, manager: ValuesManager) -> Any:
"""Computes the largest feasible step size for each simulant.
This is the smallest component-modified step size (rounded down to increments
Expand All @@ -209,10 +219,10 @@ def step_size_post_processor(self, values: List[NumberLike], _) -> pd.Series:
The largest feasible step size for each simulant
"""

min_modified = pd.DataFrame(values).min(axis=0).fillna(self.standard_step_size)
min_modified = pd.DataFrame(value).min(axis=0).fillna(self.standard_step_size)
## Rescale pipeline values to global minimum step size
discretized_step_sizes = (
np.floor(min_modified / self.minimum_step_size).replace(0, 1)
np.floor(min_modified / self.minimum_step_size).replace(0, 1) # type: ignore [attr-defined, operator]
* self.minimum_step_size
)
## Make sure we don't get zero
Expand All @@ -232,10 +242,10 @@ class SimpleClock(SimulationClock):
}

@property
def name(self):
def name(self) -> str:
return "simple_clock"

def setup(self, builder):
def setup(self, builder: Builder) -> None:
super().setup(builder)
time = builder.configuration.time
self._clock_time = time.start
Expand All @@ -246,11 +256,11 @@ def setup(self, builder):
)
self._clock_step_size = self._standard_step_size

def __repr__(self):
def __repr__(self) -> str:
return "SimpleClock()"


def get_time_stamp(time):
def get_time_stamp(time: dict[str, int]) -> pd.Timestamp:
return pd.Timestamp(time["year"], time["month"], time["day"])


Expand All @@ -271,10 +281,10 @@ class DateTimeClock(SimulationClock):
}

@property
def name(self):
def name(self) -> str:
return "datetime_clock"

def setup(self, builder):
def setup(self, builder: Builder) -> None:
super().setup(builder)
time = builder.configuration.time
self._clock_time = get_time_stamp(time.start)
Expand All @@ -291,12 +301,12 @@ def setup(self, builder):
)
self._clock_step_size = self._minimum_step_size

def __repr__(self):
def __repr__(self) -> str:
return "DateTimeClock()"


class TimeInterface(Interface):
def __init__(self, manager: SimulationClock):
def __init__(self, manager: SimulationClock) -> None:
self._manager = manager

def clock(self) -> Callable[[], ClockTime]:
Expand All @@ -307,24 +317,24 @@ def step_size(self) -> Callable[[], ClockStepSize]:
"""Gets a callable that returns the current simulation step size."""
return lambda: self._manager.step_size

def simulant_next_event_times(self) -> Callable[[pd.Index], pd.Series]:
def simulant_next_event_times(self) -> Callable[[pd.Index[int]], pd.Series[ClockTime]]:
"""Gets a callable that returns the next event times for simulants."""
return self._manager.simulant_next_event_times

def simulant_step_sizes(self) -> Callable[[pd.Index], pd.Series]:
def simulant_step_sizes(self) -> Callable[[pd.Index[int]], pd.Series[ClockStepSize]]:
"""Gets a callable that returns the simulant step sizes."""
return self._manager.simulant_step_sizes

def move_simulants_to_end(self) -> Callable[[pd.Index], None]:
def move_simulants_to_end(self) -> Callable[[pd.Index[int]], None]:
"""Gets a callable that moves simulants to the end of the simulation"""
return self._manager.move_simulants_to_end

def register_step_size_modifier(
self,
modifier: Callable[[pd.Index], pd.Series],
requires_columns: List[str] = (),
requires_values: List[str] = (),
requires_streams: List[str] = (),
modifier: Callable[[pd.Index[int]], pd.Series[ClockStepSize]],
requires_columns: list[str] = [],
requires_values: list[str] = [],
requires_streams: list[str] = [],
) -> None:
"""Registers a step size modifier.
Expand Down

0 comments on commit d286246

Please sign in to comment.