Skip to content

Commit

Permalink
cleanup test skipping conditions
Browse files Browse the repository at this point in the history
  • Loading branch information
stefdoerr committed Jan 24, 2025
1 parent 531e434 commit ca0da03
Show file tree
Hide file tree
Showing 3 changed files with 7 additions and 7 deletions.
3 changes: 2 additions & 1 deletion .github/workflows/CI.yml
Original file line number Diff line number Diff line change
Expand Up @@ -83,8 +83,9 @@ jobs:
run: pytest -v -s
env:
CPU_ONLY: 1
WINDOWS_TESTS: ${{ runner.os == 'Windows' && 'true' || 'false' }}
SKIP_TORCH_COMPILE: ${{ runner.os == 'Windows' && 'true' || 'false' }}
OMP_PREFIX: ${{ runner.os == 'macOS' && '/Users/runner/miniconda3/envs/test' || '' }}
CPU_TRAIN: ${{ runner.os == 'macOS' && 'true' || 'false' }}

- name: Test torchmd-train utility
run: torchmd-train --help
5 changes: 2 additions & 3 deletions tests/test_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,11 +32,10 @@ def test_load_model():
@mark.parametrize("precision", [32, 64])
def test_train(model_name, use_atomref, precision, tmpdir):
import torch
import platform

accelerator = "auto"
if platform.system() == "Darwin" and (os.getenv("CI", None) is not None):
# MPS backend runs out of memory on Github
if os.getenv("CPU_TRAIN", "false") == "true":
# OSX MPS backend runs out of memory on Github Actions
torch.set_default_device("cpu")
accelerator = "cpu"

Expand Down
6 changes: 3 additions & 3 deletions tests/test_neighbors.py
Original file line number Diff line number Diff line change
Expand Up @@ -671,10 +671,10 @@ def test_per_batch_box(device, strategy, n_batches, use_forward):
def test_torch_compile(device, dtype, loop, include_transpose):
import sys

# Skip if WINDOWS_TESTS is set to true
if os.environ.get("WINDOWS_TESTS", "false") == "true":
# Skip if SKIP_TORCH_COMPILE is set to true
if os.environ.get("SKIP_TORCH_COMPILE", "false") == "true":
# torch.compile doesn't detect cl.exe on Windows Github Actions
pytest.skip("Skipping test on Windows")
pytest.skip("Skipping torch compile test")

if sys.version_info >= (3, 12):
pytest.skip("Not available in this version")
Expand Down

0 comments on commit ca0da03

Please sign in to comment.