Skip to content

Commit

Permalink
allow users to define initialization weights as LookupTableData or an…
Browse files Browse the repository at this point in the history
… artifact key
  • Loading branch information
rmudambi committed Nov 14, 2024
1 parent 05827ba commit b2f3bf3
Show file tree
Hide file tree
Showing 3 changed files with 63 additions and 18 deletions.
4 changes: 4 additions & 0 deletions CHANGELOG.rst
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
**3.2.3 - 11/15/24**

- Feature: Allow users to define initialization weights as LookupTableData or an artifact key

**3.2.2 - 11/14/24**

- Feature: Enable adding transition to a state by defining the output state and the transition probability
Expand Down
25 changes: 21 additions & 4 deletions src/vivarium/framework/state_machine.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,7 +221,10 @@ def __init__(
self,
state_id: str,
allow_self_transition: bool = False,
initialization_weights: Callable[[Builder], LookupTableData] | None = None,
initialization_weights: LookupTableData
| str
| Callable[[Builder], LookupTableData]
| None = None,
) -> None:
super().__init__()
self.state_id = state_id
Expand All @@ -236,6 +239,20 @@ def __init__(
# Public methods #
##################

def is_initial_state(self) -> bool:
"""Determines if simulants could be initialized into this state.
Note: this will incorrectly return True if initialization_weights is a
callable or an artifact key that will always return 0.
"""
if self.initialization_weights is None:
return False
if isinstance(self.initialization_weights, pd.DataFrame):
return not (self.initialization_weights == 0.0).all().all()
if isinstance(self.initialization_weights, pd.Series):
return not (self.initialization_weights == 0.0).all()
return bool(self.initialization_weights)

def set_model(self, model_name: str) -> None:
"""Defines the column name for the model this state belongs to"""
self._model = model_name
Expand Down Expand Up @@ -323,8 +340,8 @@ def allow_self_transitions(self) -> None:
##################

def get_initialization_weights(self, builder: Builder) -> LookupTableData:
if self.initialization_weights:
return self.initialization_weights(builder)
if self.is_initial_state():
return self.get_data(builder, self.initialization_weights)
else:
return 0.0

Expand Down Expand Up @@ -548,7 +565,7 @@ def __init__(
self.add_states(states)

states_with_initialization_weights = [
state for state in self.states if state.initialization_weights
state for state in self.states if state.is_initial_state()
]

if initial_state is not None:
Expand Down
52 changes: 38 additions & 14 deletions tests/framework/test_state_machine.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from __future__ import annotations

from collections.abc import Callable

import numpy as np
import pandas as pd
import pytest
Expand All @@ -8,10 +10,11 @@
from tests.helpers import ColumnCreator
from vivarium import InteractiveContext
from vivarium.framework.configuration import build_simulation_configuration
from vivarium.framework.engine import Builder
from vivarium.framework.population import SimulantData
from vivarium.framework.resource import Resource
from vivarium.framework.state_machine import Machine, State, Transition
from vivarium.types import ClockTime
from vivarium.types import ClockTime, LookupTableData


def test_initialize_allowing_self_transition() -> None:
Expand All @@ -32,27 +35,45 @@ def test_initialize_with_initial_state() -> None:
assert simulation.get_population()["state"].unique() == ["start"]


@pytest.mark.parametrize("weights_type", ["artifact", "callable", "scalar"])
def test_initialize_with_scalar_initialization_weights(
base_config: LayeredConfigTree,
base_config: LayeredConfigTree, weights_type: str
) -> None:
state_weights = {"state_a.weights": 0.2, "state_b.weights": 0.8}

def mock_load(key: str) -> float:
return state_weights.get(key)

base_config.update(
{"population": {"population_size": 10000}, "randomness": {"key_columns": []}}
)
state_a = State("a", initialization_weights=lambda _: 0.2)
state_b = State("b", initialization_weights=lambda _: 0.8)

def initialization_weights(
key: str,
) -> LookupTableData | str | Callable[[Builder], LookupTableData]:
return {
"artifact": key,
"callable": lambda _: state_weights[key],
"scalar": state_weights[key],
}[weights_type]

state_a = State("a", initialization_weights=initialization_weights("state_a.weights"))
state_b = State("b", initialization_weights=initialization_weights("state_b.weights"))
machine = Machine("state", states=[state_a, state_b])
simulation = InteractiveContext(components=[machine], configuration=base_config)
simulation = InteractiveContext(
components=[machine], configuration=base_config, setup=False
)
simulation._builder.data.load = mock_load
simulation.setup()

state = simulation.get_population()["state"]
assert np.all(simulation.get_population().state != "start")
assert round((state == "a").mean(), 1) == 0.2
assert round((state == "b").mean(), 1) == 0.8


@pytest.mark.parametrize(
"use_artifact", [True, False], ids=["with_artifact", "without_artifact"]
)
def test_initialize_with_array_initialization_weights(use_artifact) -> None:
@pytest.mark.parametrize("weights_type", ["artifact", "callable", "dataframe"])
def test_initialize_with_array_initialization_weights(weights_type: str) -> None:
state_weights = {
"state_a.weights": pd.DataFrame(
{"test_column_1": [0, 1, 2], "value": [0.2, 0.7, 0.4]}
Expand All @@ -78,11 +99,14 @@ def initialization_requirements(self) -> list[str | Resource]:
# specified by the states or the configuration.
return ["test_column_1"]

def initialization_weights(key: str):
if use_artifact:
return lambda builder: builder.data.load(key)
else:
return lambda _: state_weights[key]
def initialization_weights(
key: str,
) -> LookupTableData | str | Callable[[Builder], LookupTableData]:
return {
"artifact": key,
"callable": lambda _: state_weights[key],
"dataframe": state_weights[key],
}[weights_type]

state_a = State("a", initialization_weights=initialization_weights("state_a.weights"))
state_b = State("b", initialization_weights=initialization_weights("state_b.weights"))
Expand Down

0 comments on commit b2f3bf3

Please sign in to comment.