Skip to content

Commit

Permalink
Merge pull request #345 from AntonioMirarchi/update_csv_logger
Browse files Browse the repository at this point in the history
Add backup for metrics.csv
  • Loading branch information
stefdoerr authored Jan 28, 2025
2 parents f013b80 + e1bc7ef commit 8374e96
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 2 deletions.
12 changes: 10 additions & 2 deletions torchmdnet/scripts/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,13 @@
from torchmdnet.models import output_modules
from torchmdnet.models.model import create_prior_models
from torchmdnet.models.utils import rbf_class_mapping, act_class_mapping, dtype_mapping
from torchmdnet.utils import LoadFromFile, LoadFromCheckpoint, save_argparse, number
from torchmdnet.utils import (
LoadFromFile,
LoadFromCheckpoint,
save_argparse,
number,
check_logs,
)
from lightning_utilities.core.rank_zero import rank_zero_warn


Expand Down Expand Up @@ -219,9 +225,11 @@ def main():
args.early_stopping_monitor, patience=args.early_stopping_patience
)
callbacks.append(early_stopping)


check_logs(args.log_dir)
csv_logger = CSVLogger(args.log_dir, name="", version="")
_logger = [csv_logger]

if args.wandb_use:
wandb_logger = WandbLogger(
project=args.wandb_project,
Expand Down
10 changes: 10 additions & 0 deletions torchmdnet/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -397,3 +397,13 @@ def wrapped_init(self, *args, **kwargs):

cls.__init__ = wrapped_init
return cls

def check_logs(log_dir):
import os
import time
metr_file_path = os.path.join(log_dir, 'metrics.csv')
if os.path.exists(metr_file_path):
# we make a backup of the metrics file (rename)
bckp_date = f'{time.strftime("%Y%m%d")}-{time.strftime("%H%M%S")}'
os.rename(metr_file_path, metr_file_path.replace(".csv", f"_{bckp_date}.csv"))
return

0 comments on commit 8374e96

Please sign in to comment.