Skip to content

Commit

Permalink
Add backup writing to vivarium simulation Context (#455)
Browse files Browse the repository at this point in the history
* add feature and test

* remove unused import

* add skipped test

* lint

* lint

* add test and adjust sim.run()

* lint

* change protocol

* add logging statement

* Update test_engine.py

add ticket

* add self.current_time

* Revert "add self.current_time"

This reverts commit 322fd19.

* use current time but do it right instead of wrong
  • Loading branch information
patricktnast authored Aug 12, 2024
1 parent f17cc78 commit 0e60bf7
Show file tree
Hide file tree
Showing 3 changed files with 71 additions and 5 deletions.
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
"networkx",
"loguru",
"pyarrow",
"dill",
# Type stubs
"pandas-stubs",
]
Expand Down
34 changes: 29 additions & 5 deletions src/vivarium/framework/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,10 @@

from pathlib import Path
from pprint import pformat
from time import time
from typing import Any, Dict, List, Optional, Set, Union

import dill
import numpy as np
import pandas as pd
import yaml
Expand All @@ -42,7 +44,7 @@
from vivarium.framework.randomness import RandomnessInterface
from vivarium.framework.resource import ResourceInterface
from vivarium.framework.results import ResultsInterface
from vivarium.framework.time import TimeInterface
from vivarium.framework.time import Time, TimeInterface
from vivarium.framework.values import ValuesInterface


Expand Down Expand Up @@ -206,6 +208,11 @@ def __init__(
def name(self) -> str:
return self._name

@property
def current_time(self) -> Time:
"""Returns the current simulation time."""
return self._clock.time

def get_results(self) -> Dict[str, pd.DataFrame]:
"""Return the formatted results."""
return self._results.get_results()
Expand Down Expand Up @@ -246,7 +253,7 @@ def initialize_simulants(self) -> None:
self._clock.step_forward(self.get_population().index)

def step(self) -> None:
self._logger.debug(self._clock.time)
self._logger.debug(self.current_time)
for event in self.time_step_events:
self._logger.debug(f"Event: {event}")
self._lifecycle.set_state(event)
Expand All @@ -258,9 +265,22 @@ def step(self) -> None:
self.time_step_emitters[event](pop_to_update)
self._clock.step_forward(self.get_population().index)

def run(self) -> None:
while self._clock.time < self._clock.stop_time:
self.step()
def run(
self,
backup_path: Optional[Path] = None,
backup_freq: Optional[Union[int, float]] = None,
) -> None:
if backup_freq:
time_to_save = time() + backup_freq
while self.current_time < self._clock.stop_time:
self.step()
if time() >= time_to_save:
self._logger.debug(f"Writing Simulation Backup to {backup_path}")
self.write_backup(backup_path)
time_to_save = time() + backup_freq
else:
while self.current_time < self._clock.stop_time:
self.step()

def finalize(self) -> None:
self._lifecycle.set_state("simulation_end")
Expand Down Expand Up @@ -296,6 +316,10 @@ def _write_results(self, results: dict[str, pd.DataFrame]) -> None:
except ConfigurationKeyError:
self._logger.info("No results directory set; results are not written to disk.")

def write_backup(self, backup_path: Path) -> None:
with open(backup_path, "wb") as f:
dill.dump(self, f, protocol=dill.HIGHEST_PROTOCOL)

def get_performance_metrics(self) -> pd.DataFrame:
timing_dict = self._lifecycle.timings
total_time = np.sum([np.sum(v) for v in timing_dict.values()])
Expand Down
41 changes: 41 additions & 0 deletions tests/framework/test_engine.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
import math
from itertools import product
from pathlib import Path
from time import time
from typing import Dict, List

import dill
import pandas as pd
import pytest

Expand Down Expand Up @@ -347,6 +349,45 @@ def test_SimulationContext_report_write(SimulationContext, base_config, componen
assert results.equals(written_results)


def test_SimulationContext_write_backup(mocker, SimulationContext, tmpdir):
# TODO MIC-5216: Remove mocks when we can use dill in pytest.
mocker.patch("vivarium.framework.engine.dill.dump")
mocker.patch("vivarium.framework.engine.dill.load", return_value=SimulationContext())
sim = SimulationContext()
backup_path = tmpdir / "backup.pkl"
sim.write_backup(backup_path)
assert backup_path.exists()
with open(backup_path, "rb") as f:
sim_backup = dill.load(f)
assert isinstance(sim_backup, SimulationContext)


def test_SimulationContext_run_with_backup(mocker, SimulationContext, base_config, tmpdir):
mocker.patch("vivarium.framework.engine.SimulationContext.write_backup")
original_time = time()

def time_generator():
current_time = original_time
while True:
yield current_time
current_time += 5

mocker.patch("vivarium.framework.engine.time", side_effect=time_generator())
components = [
Hogwarts(),
HousePointsObserver(),
NoStratificationsQuidditchWinsObserver(),
QuidditchWinsObserver(),
HogwartsResultsStratifier(),
]
sim = SimulationContext(base_config, components, configuration=HARRY_POTTER_CONFIG)
backup_path = tmpdir / "backup.pkl"
sim.setup()
sim.initialize_simulants()
sim.run(backup_path=backup_path, backup_freq=5)
assert sim.write_backup.call_count == _get_num_steps(sim)


def test_get_results_formatting(SimulationContext, base_config):
"""Test formatted results are as expected"""
components = [
Expand Down

0 comments on commit 0e60bf7

Please sign in to comment.