Skip to content

Commit

Permalink
Solves path problem in test_bundle_trt_export.py (#8357)
Browse files Browse the repository at this point in the history
Fixes #8354

### Description

Fixes path on test that is only run on special conditions.

### Types of changes
<!--- Put an `x` in all the boxes that apply, and remove the not
applicable items -->
- [x] Non-breaking change (fix or new feature that would not break
existing functionality).
- [ ] Breaking change (fix or new feature that would cause existing
functionality to change).
- [ ] New tests added to cover the changes.
- [ ] Integration tests passed locally by running `./runtests.sh -f -u
--net --coverage`.
- [x] Quick tests passed locally by running `./runtests.sh --quick
--unittests --disttests`.
- [ ] In-line docstrings updated.
- [ ] Documentation updated, tested `make html` command in the `docs/`
folder.

---------

Signed-off-by: R. Garcia-Dias <[email protected]>
Co-authored-by: YunLiu <[email protected]>
  • Loading branch information
garciadias and KumoLiu authored Feb 20, 2025
1 parent af54a17 commit d98f348
Show file tree
Hide file tree
Showing 5 changed files with 24 additions and 9 deletions.
3 changes: 2 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,13 @@

MONAI is a [PyTorch](https://pytorch.org/)-based, [open-source](https://github.com/Project-MONAI/MONAI/blob/dev/LICENSE) framework for deep learning in healthcare imaging, part of the [PyTorch Ecosystem](https://pytorch.org/ecosystem/).
Its ambitions are as follows:

- Developing a community of academic, industrial and clinical researchers collaborating on a common foundation;
- Creating state-of-the-art, end-to-end training workflows for healthcare imaging;
- Providing researchers with the optimized and standardized way to create and evaluate deep learning models.


## Features

> _Please see [the technical highlights](https://docs.monai.io/en/latest/highlights.html) and [What's New](https://docs.monai.io/en/latest/whatsnew.html) of the milestone releases._
- flexible pre-processing for multi-dimensional medical imaging data;
Expand Down
2 changes: 1 addition & 1 deletion tests/bundle/test_bundle_trt_export.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ def tearDown(self):
@parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3, TEST_CASE_4])
@unittest.skipUnless(has_torchtrt and has_tensorrt, "Torch-TensorRT is required for conversion!")
def test_trt_export(self, convert_precision, input_shape, dynamic_batch):
tests_dir = Path(__file__).resolve().parent
tests_dir = Path(__file__).resolve().parents[1]
meta_file = os.path.join(tests_dir, "testing_data", "metadata.json")
config_file = os.path.join(tests_dir, "testing_data", "inference.json")
with tempfile.TemporaryDirectory() as tempdir:
Expand Down
2 changes: 1 addition & 1 deletion tests/networks/test_convert_to_onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ def test_unet(self, device, use_trace, use_ort):
rtol=rtol,
atol=atol,
)
self.assertTrue(isinstance(onnx_model, onnx.ModelProto))
self.assertTrue(isinstance(onnx_model, onnx.ModelProto))

@parameterized.expand(TESTS_ORT)
@SkipIfBeforePyTorchVersion((1, 12))
Expand Down
18 changes: 17 additions & 1 deletion tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,9 +30,10 @@
import warnings
from contextlib import contextmanager
from functools import partial, reduce
from itertools import product
from pathlib import Path
from subprocess import PIPE, Popen
from typing import Callable
from typing import Callable, Literal
from urllib.error import ContentTooShortError, HTTPError

import numpy as np
Expand Down Expand Up @@ -862,6 +863,21 @@ def equal_state_dict(st_1, st_2):
if torch.cuda.is_available():
TEST_DEVICES.append([torch.device("cuda")])


def dict_product(trailing=False, format: Literal["list", "dict"] = "dict", **items):
keys = items.keys()
values = items.values()
for pvalues in product(*values):
dict_comb = dict(zip(keys, pvalues))
if format == "dict":
if trailing:
yield [dict_comb] + list(pvalues)
else:
yield dict_comb
else:
yield pvalues


if __name__ == "__main__":
parser = argparse.ArgumentParser(prog="util")
parser.add_argument("-c", "--count", default=2, help="max number of gpus")
Expand Down
8 changes: 3 additions & 5 deletions tests/transforms/test_gibbs_noise.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,14 +21,12 @@
from monai.transforms import GibbsNoise
from monai.utils.misc import set_determinism
from monai.utils.module import optional_import
from tests.test_utils import TEST_NDARRAYS, assert_allclose
from tests.test_utils import TEST_NDARRAYS, assert_allclose, dict_product

_, has_torch_fft = optional_import("torch.fft", name="fftshift")

TEST_CASES = []
for shape in ((128, 64), (64, 48, 80)):
for input_type in TEST_NDARRAYS if has_torch_fft else [np.array]:
TEST_CASES.append((shape, input_type))
params = {"shape": ((128, 64), (64, 48, 80)), "input_type": TEST_NDARRAYS if has_torch_fft else [np.array]}
TEST_CASES = list(dict_product(format="list", **params))


class TestGibbsNoise(unittest.TestCase):
Expand Down

0 comments on commit d98f348

Please sign in to comment.