Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add backup writing to vivarium simulation Context #455

Merged
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
"networkx",
"loguru",
"pyarrow",
"dill",
# Type stubs
"pandas-stubs",
]
Expand Down
34 changes: 29 additions & 5 deletions src/vivarium/framework/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,10 @@

from pathlib import Path
from pprint import pformat
from time import time
from typing import Any, Dict, List, Optional, Set, Union

import dill
import numpy as np
import pandas as pd
import yaml
Expand All @@ -42,7 +44,7 @@
from vivarium.framework.randomness import RandomnessInterface
from vivarium.framework.resource import ResourceInterface
from vivarium.framework.results import ResultsInterface
from vivarium.framework.time import TimeInterface
from vivarium.framework.time import Time, TimeInterface
from vivarium.framework.values import ValuesInterface


Expand Down Expand Up @@ -206,6 +208,11 @@ def __init__(
def name(self) -> str:
return self._name

@property
def current_time(self) -> Time:
patricktnast marked this conversation as resolved.
Show resolved Hide resolved
"""Returns the current simulation time."""
return self._clock.time

def get_results(self) -> Dict[str, pd.DataFrame]:
"""Return the formatted results."""
return self._results.get_results()
Expand Down Expand Up @@ -246,7 +253,7 @@ def initialize_simulants(self) -> None:
self._clock.step_forward(self.get_population().index)

def step(self) -> None:
self._logger.debug(self._clock.time)
self._logger.debug(self.current_time)
for event in self.time_step_events:
self._logger.debug(f"Event: {event}")
self._lifecycle.set_state(event)
Expand All @@ -258,9 +265,22 @@ def step(self) -> None:
self.time_step_emitters[event](pop_to_update)
self._clock.step_forward(self.get_population().index)

def run(self) -> None:
while self._clock.time < self._clock.stop_time:
self.step()
def run(
self,
backup_path: Optional[Path] = None,
backup_freq: Optional[Union[int, float]] = None,
) -> None:
if backup_freq:
time_to_save = time() + backup_freq
while self.current_time < self._clock.stop_time:
self.step()
if time() >= time_to_save:
self._logger.debug(f"Writing Simulation Backup to {backup_path}")
self.write_backup(backup_path)
time_to_save = time() + backup_freq
else:
while self.current_time < self._clock.stop_time:
self.step()

def finalize(self) -> None:
self._lifecycle.set_state("simulation_end")
Expand Down Expand Up @@ -296,6 +316,10 @@ def _write_results(self, results: dict[str, pd.DataFrame]) -> None:
except ConfigurationKeyError:
self._logger.info("No results directory set; results are not written to disk.")

def write_backup(self, backup_path: Path) -> None:
with open(backup_path, "wb") as f:
patricktnast marked this conversation as resolved.
Show resolved Hide resolved
dill.dump(self, f, protocol=dill.HIGHEST_PROTOCOL)

def get_performance_metrics(self) -> pd.DataFrame:
timing_dict = self._lifecycle.timings
total_time = np.sum([np.sum(v) for v in timing_dict.values()])
Expand Down
41 changes: 41 additions & 0 deletions tests/framework/test_engine.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
import math
from itertools import product
from pathlib import Path
from time import time
from typing import Dict, List

import dill
import pandas as pd
import pytest

Expand Down Expand Up @@ -347,6 +349,45 @@ def test_SimulationContext_report_write(SimulationContext, base_config, componen
assert results.equals(written_results)


def test_SimulationContext_write_backup(mocker, SimulationContext, tmpdir):
# TODO MIC-5216: Remove mocks when we can use dill in pytest.
mocker.patch("vivarium.framework.engine.dill.dump")
mocker.patch("vivarium.framework.engine.dill.load", return_value=SimulationContext())
sim = SimulationContext()
backup_path = tmpdir / "backup.pkl"
sim.write_backup(backup_path)
assert backup_path.exists()
with open(backup_path, "rb") as f:
sim_backup = dill.load(f)
assert isinstance(sim_backup, SimulationContext)

patricktnast marked this conversation as resolved.
Show resolved Hide resolved

def test_SimulationContext_run_with_backup(mocker, SimulationContext, base_config, tmpdir):
mocker.patch("vivarium.framework.engine.SimulationContext.write_backup")
original_time = time()

def time_generator():
current_time = original_time
while True:
yield current_time
current_time += 5

mocker.patch("vivarium.framework.engine.time", side_effect=time_generator())
components = [
Hogwarts(),
HousePointsObserver(),
NoStratificationsQuidditchWinsObserver(),
QuidditchWinsObserver(),
HogwartsResultsStratifier(),
]
sim = SimulationContext(base_config, components, configuration=HARRY_POTTER_CONFIG)
backup_path = tmpdir / "backup.pkl"
sim.setup()
sim.initialize_simulants()
sim.run(backup_path=backup_path, backup_freq=5)
assert sim.write_backup.call_count == _get_num_steps(sim)


def test_get_results_formatting(SimulationContext, base_config):
"""Test formatted results are as expected"""
components = [
Expand Down
Loading