Skip to content

Commit

Permalink
create tiny train test for conda-forge to pass
Browse files Browse the repository at this point in the history
  • Loading branch information
stefdoerr committed Jan 24, 2025
1 parent ca0da03 commit 6cf27a5
Show file tree
Hide file tree
Showing 2 changed files with 55 additions and 0 deletions.
1 change: 1 addition & 0 deletions .github/workflows/CI.yml
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,7 @@ jobs:
SKIP_TORCH_COMPILE: ${{ runner.os == 'Windows' && 'true' || 'false' }}
OMP_PREFIX: ${{ runner.os == 'macOS' && '/Users/runner/miniconda3/envs/test' || '' }}
CPU_TRAIN: ${{ runner.os == 'macOS' && 'true' || 'false' }}
LONG_TRAIN: "true"

- name: Test torchmd-train utility
run: torchmd-train --help
54 changes: 54 additions & 0 deletions tests/test_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,9 @@ def test_load_model():
@mark.parametrize("model_name", models.__all_models__)
@mark.parametrize("use_atomref", [True, False])
@mark.parametrize("precision", [32, 64])
@mark.skipif(
os.getenv("LONG_TRAIN", "false") == "false", reason="Skipping long train test"
)
def test_train(model_name, use_atomref, precision, tmpdir):
import torch

Expand Down Expand Up @@ -71,3 +74,54 @@ def test_train(model_name, use_atomref, precision, tmpdir):
)
trainer.fit(module, datamodule)
trainer.test(module, datamodule)


@mark.parametrize("model_name", models.__all_models__)
@mark.parametrize("use_atomref", [True, False])
@mark.parametrize("precision", [32, 64])
def test_dummy_train(model_name, use_atomref, precision, tmpdir):
import torch

accelerator = "auto"
if os.getenv("CPU_TRAIN", "false") == "true":
# OSX MPS backend runs out of memory on Github Actions
torch.set_default_device("cpu")
accelerator = "cpu"

extra_args = {}
if model_name != "tensornet":
extra_args["num_heads"] = 2

args = load_example_args(
model_name,
remove_prior=not use_atomref,
train_size=0.05,
val_size=0.01,
test_size=0.01,
log_dir=tmpdir,
derivative=True,
embedding_dimension=2,
num_layers=1,
num_rbf=4,
batch_size=2,
precision=precision,
**extra_args,
)
datamodule = DataModule(args, DummyDataset(has_atomref=use_atomref))

prior = None
if use_atomref:
prior = getattr(priors, args["prior_model"])(dataset=datamodule.dataset)
args["prior_args"] = prior.get_init_args()

module = LNNP(args, prior_model=prior)

trainer = pl.Trainer(
max_steps=10,
default_root_dir=tmpdir,
precision=args["precision"],
inference_mode=False,
accelerator=accelerator,
)
trainer.fit(module, datamodule)
trainer.test(module, datamodule)

0 comments on commit 6cf27a5

Please sign in to comment.