Skip to content

Commit

Permalink
Add secondary custom callback for logging step wise information (#23)
Browse files Browse the repository at this point in the history
  • Loading branch information
danielwolff1 authored Jan 24, 2025
2 parents 2b56da7 + d75b94b commit 7dd63b9
Show file tree
Hide file tree
Showing 14 changed files with 185 additions and 17 deletions.
1 change: 1 addition & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ repos:
- id: detect-private-key
- id: end-of-file-fixer
- id: mixed-line-ending
- id: trailing-whitespace
- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.9.2
hooks:
Expand Down
8 changes: 5 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ sphinx is installed ignore the first two lines.
The following command line calls create a conda environment with all necessary
dependencies for building the documentation.
``` console
(base) $ conda create -n sphinx python=3.9
(base) $ conda create -n sphinx python=3.11
(base) $ conda activate sphinx
(sphinx) $ pip install ".[docs]"
```
Expand Down Expand Up @@ -72,7 +72,7 @@ The packages can be installed via pip or conda with the following commands:
**conda**

``` console
(base) $ conda create -n releso python=3.9 "pydantic<2" tensorboard
(base) $ conda create -n releso python=3.11 "pydantic<2" tensorboard
(base) $ conda activate releso
(releso) $ pip install stable-baselines3 hjson
```
Expand Down Expand Up @@ -107,5 +107,7 @@ running the command below in the main repository folder.
**Development**

``` console
(releso) $ pip install -e "dev."
(releso) $ pip install -e ".[dev]"
(releso) $ pip install pre-commit
(releso) $ pre-commit install
```
4 changes: 2 additions & 2 deletions docs/source/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from pydantic import BaseModel
from sphinx.ext.napoleon import _skip_member

from releso.__version__ import version
from releso.__version__ import __version__

# sys.path.insert(0, str(releso_dir / "util"))
# -- Project information -----------------------------------------------------
Expand All @@ -25,7 +25,7 @@
author = "Clemens Fricke"

# The full version, including alpha/beta/rc tags
release = version
release = __version__


# -- General configuration ---------------------------------------------------
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ requires = ["setuptools"]
build-backend = "setuptools.build_meta"

[tool.setuptools.dynamic]
version = {attr = "releso.__version__.version"}
version = {attr = "releso.__version__.__version__"}

[tool.black]
line-length = 79
Expand Down
1 change: 1 addition & 0 deletions releso/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,3 +3,4 @@
Please view the Documentation under :doc:`/index` for the information
about this library and framework.
"""
from releso.__version__ import __version__ as __version__
8 changes: 4 additions & 4 deletions releso/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
import stable_baselines3
import torch

from releso.__version__ import version
from releso.__version__ import __version__
from releso.base_parser import BaseParser

try:
Expand Down Expand Up @@ -59,7 +59,7 @@ def main(args) -> None:

shutil.copy(file_path, optimization_object.save_location / file_path.name)
versions = [
f"releso version: {version}; ",
f"releso version: {__version__}; ",
f"stable-baselines3 version: {stable_baselines3.__version__}; ",
f"torch version: {torch.__version__}; ",
f"gymnasium version: {gymnasium.__version__} ",
Expand Down Expand Up @@ -101,7 +101,7 @@ def entry():
"Toolbox. This python program loads a problem "
"definition and trains the resulting problem. Further the "
"model can be evaluated"
f"The package version is: {version}."
f"The package version is: {__version__}."
)
parser.add_argument(
"-i",
Expand Down Expand Up @@ -133,7 +133,7 @@ def entry():
)
args = parser.parse_args()
if args.version:
print(f"releso: {version}")
print(f"releso: {__version__}")
return
if args.input_file is None:
print(
Expand Down
2 changes: 1 addition & 1 deletion releso/__version__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,4 @@
Current version.
"""
version = "0.1.2"
__version__ = "0.1.3"
20 changes: 19 additions & 1 deletion releso/base_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@

from releso.agent import AgentTypeDefinition
from releso.base_model import BaseModel
from releso.callback import EpisodeLogCallback
from releso.callback import EpisodeLogCallback, StepLogCallback
from releso.exceptions import ValidationNotSet
from releso.parser_environment import Environment
from releso.validation import Validation
Expand Down Expand Up @@ -59,6 +59,14 @@ class BaseParser(BaseModel):
#: updated at the end of the training in any case. But making this number
#: higher will lower the computational overhead. Defaults to 100.
episode_log_update: conint(ge=1) = 100
#: Flag indicating whether the step information (like actions,
#: observations, ...) should be logged to file. Defaults to False.
export_step_log: bool = False
#: Number of steps after which the step_log is updated. It will be
#: updated at the end of the training in any case. But making this number
#: higher will lower the computational overhead. Defaults to 0 which
#: triggers the output after every episode.
step_log_update: conint(ge=0) = 0

# internal objects
#: Holds the trainable agent for the RL use case. The
Expand Down Expand Up @@ -115,6 +123,16 @@ def learn(self) -> None:
update_n_episodes=self.episode_log_update,
),
]

if self.export_step_log:
callbacks.append(
StepLogCallback(
step_log_location=self.save_location / "step_log.csv",
verbose=1,
update_n_steps=self.step_log_update,
),
)

if self.number_of_episodes is not None:
num = self.number_of_episodes
if self.normalize_training_values:
Expand Down
105 changes: 105 additions & 0 deletions releso/callback.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,3 +145,108 @@ def _on_step(self) -> bool:
def _on_training_end(self) -> None:
"""Function is called when training is terminated."""
self._export()


class StepLogCallback(BaseCallback):
"""Step Callback class.
This class tracks all step-wise information that might come in handy
during evaluation. Originally created to easily track the actions
undertaken in each episode for better evaluation of the learned policy.
"""

def __init__(
self,
step_log_location: Path,
verbose: int = 0,
update_n_steps: int = 0,
):
"""Constructor for the Callback using SB3 interface.
Args:
step_log_location (Path): Path to the step log file.
verbose (int, optional): Verbosity of the callback. Defaults to 0.
update_every (int, optional): Update the step log file every n
steps. Defaults to 0 which triggers the update after every episode.
"""
super().__init__(verbose)
self.step_log_location: Path = step_log_location
self.step_log_location.parent.mkdir(parents=True, exist_ok=True)
self.current_episode: int = 0
self.update_n_episodes: int = update_n_steps
self.first_export: bool = True

self._reset_internal_storage()

def _reset_internal_storage(self) -> None:
""" Reset the internally used lists which store the step-wise
information since last updating the logfile.
"""
self.episodes = [] # Store episode numbers
self.timesteps = [] # Store step numbers
self.actions = [] # Store actions
self.observations = [] # Store observations
self.rewards = [] # Optionally store rewards

def _export(self) -> None:
"""Convert the step-wise information to a dataframe and export to csv."""
# Combine all relevant information into a pandas DataFrame
export_data_frame = pd.DataFrame(
{
"episodes": self.episodes,
"actions": self.actions,
"observations": self.observations,
"rewards": self.rewards,
}
)
export_data_frame.index = self.timesteps
export_data_frame.index.name = "timesteps"
# Write the data to file
export_data_frame.to_csv(
self.step_log_location,
mode="a" if not self.first_export else "w",
header=True if self.first_export else False
)
# data frame has been exported already at least once, so reset the flag
self.first_export = False
# reset the internal storage
self._reset_internal_storage()

def _on_step(self) -> bool:
"""Function that is called after a step was performed.
Returns:
bool: If the callback returns False, training is aborted early.
"""
# Retrieve the step-wise information that we want to keep track of
actions = self.locals["actions"] # Agent's actions
observations = self.locals["new_obs"] # Resulting observations
rewards = self.locals["rewards"] # Rewards (optional)

# Store actions, observations, and rewards
self.episodes.append(self.current_episode)
self.timesteps.append(self.num_timesteps)
self.actions.append(actions)
self.observations.append(observations)
self.rewards.append(rewards)

dones = self.locals["dones"]

# Check if the environment has completed an episode
if any(dones):
# If the update is supposed to be performed after an episode has
# been completed ...
if self.update_n_episodes == 0:
# ... export the information
self._export()
# Always increase the episode counter
self.current_episode += 1

# If no episode has been completed yet, only export with the given
# frequency
if self.update_n_episodes != 0 and any(
timestep % self.update_n_episodes == 0 for timestep in self.timesteps
):
self._export()

return True
2 changes: 2 additions & 0 deletions releso/spor.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,6 +218,8 @@ def validate_additional_observations(
# We might want to deprecate the hidden option to instantiate
# additional observations via an int. Building up a default
# observation definition.
# TODO: change name to something more meaningful for example
# using the function name values["name"]
v = {
"name": f"unnamed_{str(uuid4())}",
"value_min": -1,
Expand Down
8 changes: 8 additions & 0 deletions tests/test_base_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,7 @@ def test_base_parser_validation(
"n_environments",
"normalize_training_values",
"multi_env_sequential",
"export_step_log",
],
[
(
Expand All @@ -168,6 +169,7 @@ def test_base_parser_validation(
5,
None,
False,
False,
),
(
None,
Expand All @@ -182,6 +184,7 @@ def test_base_parser_validation(
5,
True,
None,
False,
),
(
None,
Expand All @@ -196,6 +199,7 @@ def test_base_parser_validation(
5,
None,
True,
True,
),
(
None,
Expand All @@ -210,6 +214,7 @@ def test_base_parser_validation(
5,
True,
False,
False,
),
(
None,
Expand All @@ -229,6 +234,7 @@ def test_base_parser_validation(
None,
None,
None,
False,
),
],
indirect=["basic_agent_definition"],
Expand All @@ -244,6 +250,7 @@ def test_base_parser_learn(
n_environments,
normalize_training_values,
multi_env_sequential,
export_step_log,
dir_save_location,
basic_environment_definition,
basic_verbosity_definition,
Expand All @@ -255,6 +262,7 @@ def test_base_parser_learn(
"agent": basic_agent_definition,
"environment": basic_environment_definition,
"number_of_timesteps": number_of_timesteps,
"export_step_log": export_step_log,
}
if agent_additions is not None:
calling_dict["agent"].update(agent_additions)
Expand Down
33 changes: 32 additions & 1 deletion tests/test_callback.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
import pandas as pd
from stable_baselines3 import PPO

from releso.callback import EpisodeLogCallback
import pytest

from releso.callback import EpisodeLogCallback, StepLogCallback


def test_callback_episode_log_callback(
Expand All @@ -21,3 +23,32 @@ def test_callback_episode_log_callback(
episode_log = pd.read_csv(dir_save_location / "test.csv")
assert episode_log.shape[0] == 10
clean_up_provider(dir_save_location)


@pytest.mark.parametrize(
["update_n_steps",],
[
(0,),
(20,),
],
)
def test_callback_step_information_log_callback(
dir_save_location, clean_up_provider, provide_dummy_environment,
update_n_steps
):
# this test is not very good, but it is a start
# TODO: improve this test
call_back = StepLogCallback(
step_log_location=dir_save_location / "test.csv",
update_n_steps=update_n_steps
)
assert call_back.step_log_location == dir_save_location / "test.csv"
assert call_back.current_episode == 0
assert call_back.update_n_episodes == update_n_steps
assert call_back.first_export
env = provide_dummy_environment
agent = PPO("MlpPolicy", env, verbose=0, n_steps=100)
agent.learn(100, callback=call_back)
episode_log = pd.read_csv(dir_save_location / "test.csv")
assert episode_log.shape[0] == 100
clean_up_provider(dir_save_location)
6 changes: 3 additions & 3 deletions tests/test_mesh.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,10 +188,10 @@ def test_mixd_mesh_class(
# setup mixd
if not load_sample_file.with_suffix(".mien").exists():
from gustaf.io import load
from gustaf.io.mixd import export
from gustaf.io.mixd import export as export_io

mesh = load(load_sample_file)
export(mesh, load_sample_file.with_suffix(".xns"))
mesh = load(load_sample_file)[-1]
export_io(load_sample_file.with_suffix(".xns"), mesh)
shutil.copy(
load_sample_file.with_suffix(".mxyz"),
load_sample_file.parent / "mxyz",
Expand Down
2 changes: 1 addition & 1 deletion tests/test_util_funcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
({"test": np.ones(12)}, np.ones(12).tolist()),
({"test": np.int64(123)}, 123),
({"test": bytes("test", "utf-8")}, "test"),
({"test": np.float128(13.1)}, False),
({"test": np.longdouble(13.1)}, False),
],
)
def test_json_encode(dictionary, wanted_value, capsys):
Expand Down

0 comments on commit 7dd63b9

Please sign in to comment.