From fb8c5bf75b30ab6b706e99a71b96ed0a88220728 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Nicolas=20K=C3=A4nzig?= <36882833+nkaenzig@users.noreply.github.com> Date: Wed, 19 Feb 2025 19:01:30 +0100 Subject: [PATCH 1/8] Removed outdated `torch` version checks from transform functions (#8359) Fixes #8348 ### Description Support for `torch` versions prior to `1.13` has been dropped, so those `1.8` version checks are not required anymore. Furthermore, as reported in the issue description, those checks led to unstable behaviour when using certain transforms in data pipelines. ### Types of changes - [x] Non-breaking change (fix or new feature that would not break existing functionality). - [ ] Breaking change (fix or new feature that would cause existing functionality to change). - [ ] New tests added to cover the changes. - [ ] Integration tests passed locally by running `./runtests.sh -f -u --net --coverage`. - [ ] Quick tests passed locally by running `./runtests.sh --quick --unittests --disttests`. - [ ] In-line docstrings updated. - [ ] Documentation updated, tested `make html` command in the `docs/` folder. --------- Signed-off-by: Nicolas Kaenzig --- monai/transforms/utility/array.py | 14 ++------------ .../transforms/utils_pytorch_numpy_unification.py | 6 ++---- 2 files changed, 4 insertions(+), 16 deletions(-) diff --git a/monai/transforms/utility/array.py b/monai/transforms/utility/array.py index 2963c8a2f8..8491e4739c 100644 --- a/monai/transforms/utility/array.py +++ b/monai/transforms/utility/array.py @@ -66,7 +66,6 @@ optional_import, ) from monai.utils.enums import TransformBackends -from monai.utils.misc import is_module_ver_at_least from monai.utils.type_conversion import convert_to_dst_type, get_dtype_string, get_equivalent_dtype PILImageImage, has_pil = optional_import("PIL.Image", name="Image") @@ -939,19 +938,10 @@ def __call__( data = img[[*select_labels]] else: where: Callable = np.where if isinstance(img, np.ndarray) else torch.where # type: ignore - if isinstance(img, np.ndarray) or is_module_ver_at_least(torch, (1, 8, 0)): - data = where(in1d(img, select_labels), True, False).reshape(img.shape) - # pre pytorch 1.8.0, need to use 1/0 instead of True/False - else: - data = where( - in1d(img, select_labels), torch.tensor(1, device=img.device), torch.tensor(0, device=img.device) - ).reshape(img.shape) + data = where(in1d(img, select_labels), True, False).reshape(img.shape) if merge_channels or self.merge_channels: - if isinstance(img, np.ndarray) or is_module_ver_at_least(torch, (1, 8, 0)): - return data.any(0)[None] - # pre pytorch 1.8.0 compatibility - return data.to(torch.uint8).any(0)[None].to(bool) # type: ignore + return data.any(0)[None] return data diff --git a/monai/transforms/utils_pytorch_numpy_unification.py b/monai/transforms/utils_pytorch_numpy_unification.py index 365bd1eab5..8f22d00674 100644 --- a/monai/transforms/utils_pytorch_numpy_unification.py +++ b/monai/transforms/utils_pytorch_numpy_unification.py @@ -18,7 +18,6 @@ import torch from monai.config.type_definitions import NdarrayOrTensor, NdarrayTensor -from monai.utils.misc import is_module_ver_at_least from monai.utils.type_conversion import convert_data_type, convert_to_dst_type __all__ = [ @@ -215,10 +214,9 @@ def floor_divide(a: NdarrayOrTensor, b) -> NdarrayOrTensor: Element-wise floor division between two arrays/tensors. """ if isinstance(a, torch.Tensor): - if is_module_ver_at_least(torch, (1, 8, 0)): - return torch.div(a, b, rounding_mode="floor") return torch.floor_divide(a, b) - return np.floor_divide(a, b) + else: + return np.floor_divide(a, b) def unravel_index(idx, shape) -> NdarrayOrTensor: From b0ed253440a563b54c5e929dc10d95d10ec2f628 Mon Sep 17 00:00:00 2001 From: Bartosz Grabowski <58475557+bartosz-grabowski@users.noreply.github.com> Date: Thu, 20 Feb 2025 09:44:10 +0100 Subject: [PATCH 2/8] Fix CommonKeys docstring (#8342) ### Description `CommonKeys()` docstring mentions `INFO` which doesn't exist. Instead there is a `METADATA` field, so the docstring was updated accordingly. ### Types of changes - [x] Non-breaking change (fix or new feature that would not break existing functionality). - [ ] Breaking change (fix or new feature that would cause existing functionality to change). - [ ] New tests added to cover the changes. - [ ] Integration tests passed locally by running `./runtests.sh -f -u --net --coverage`. - [ ] Quick tests passed locally by running `./runtests.sh --quick --unittests --disttests`. - [ ] In-line docstrings updated. - [ ] Documentation updated, tested `make html` command in the `docs/` folder. Signed-off-by: Bartosz Grabowski <58475557+bartosz-grabowski@users.noreply.github.com> Co-authored-by: YunLiu <55491388+KumoLiu@users.noreply.github.com> --- monai/utils/enums.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/monai/utils/enums.py b/monai/utils/enums.py index 1fbf3ffa05..ac14134acc 100644 --- a/monai/utils/enums.py +++ b/monai/utils/enums.py @@ -335,7 +335,7 @@ class CommonKeys(StrEnum): `LABEL` is the training or evaluation label of segmentation or classification task. `PRED` is the prediction data of model output. `LOSS` is the loss value of current iteration. - `INFO` is some useful information during training or evaluation, like loss value, etc. + `METADATA` is some useful information during training or evaluation, like loss value, etc. """ From af54a17b83a469113eb43a7c768ae3d1fdcc0f9b Mon Sep 17 00:00:00 2001 From: Thibault de Varax <154365476+thibaultdvx@users.noreply.github.com> Date: Thu, 20 Feb 2025 13:27:37 +0100 Subject: [PATCH 3/8] Add Average Precision to metrics (#8089) Fixes #8085. ### Description Average Precision is very similar to ROCAUC, so I was very much inspired by the ROCAUC implementation. More precisely, I created: - `AveragePrecisionMetric` and `compute_average_precision` in `monai.metrics`, - a handler called `AveragePrecision` in `monai.handlers`, - three unittest modules: `test_compute_average_precision.py`, `test_handler_average_precision.py` and `test_handler_average_precision_dist.py`. I also modified the docs to mention Average Precision. ### Types of changes - [x] Non-breaking change (fix or new feature that would not break existing functionality). - [ ] Breaking change (fix or new feature that would cause existing functionality to change). - [x] New tests added to cover the changes. - [ ] Integration tests passed locally by running `./runtests.sh -f -u --net --coverage`. - [ ] Quick tests passed locally by running `./runtests.sh --quick --unittests --disttests`. - [x] In-line docstrings updated. - [x] Documentation updated, tested `make html` command in the `docs/` folder. --------- Signed-off-by: thibaultdvx Signed-off-by: Thibault de Varax <154365476+thibaultdvx@users.noreply.github.com> Signed-off-by: YunLiu <55491388+KumoLiu@users.noreply.github.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: YunLiu <55491388+KumoLiu@users.noreply.github.com> Co-authored-by: Eric Kerfoot <17726042+ericspod@users.noreply.github.com> --- docs/source/handlers.rst | 6 + docs/source/metrics.rst | 7 + monai/handlers/__init__.py | 1 + monai/handlers/average_precision.py | 53 +++++ monai/metrics/__init__.py | 1 + monai/metrics/average_precision.py | 187 ++++++++++++++++++ monai/utils/enums.py | 3 +- .../test_handler_average_precision.py | 79 ++++++++ .../metrics/test_compute_average_precision.py | 162 +++++++++++++++ tests/min_tests.py | 1 + 10 files changed, 499 insertions(+), 1 deletion(-) create mode 100644 monai/handlers/average_precision.py create mode 100644 monai/metrics/average_precision.py create mode 100644 tests/handlers/test_handler_average_precision.py create mode 100644 tests/metrics/test_compute_average_precision.py diff --git a/docs/source/handlers.rst b/docs/source/handlers.rst index 270083f717..49c84dab28 100644 --- a/docs/source/handlers.rst +++ b/docs/source/handlers.rst @@ -53,6 +53,12 @@ ROC AUC metrics handler :members: +Average Precision metric handler +-------------------------------- +.. autoclass:: AveragePrecision + :members: + + Confusion matrix metrics handler -------------------------------- .. autoclass:: ConfusionMatrix diff --git a/docs/source/metrics.rst b/docs/source/metrics.rst index 616f0fe385..45e0827cf9 100644 --- a/docs/source/metrics.rst +++ b/docs/source/metrics.rst @@ -80,6 +80,13 @@ Metrics .. autoclass:: ROCAUCMetric :members: +`Average Precision` +------------------- +.. autofunction:: compute_average_precision + +.. autoclass:: AveragePrecisionMetric + :members: + `Confusion matrix` ------------------ .. autofunction:: get_confusion_matrix diff --git a/monai/handlers/__init__.py b/monai/handlers/__init__.py index c1fa448f25..ed5db8a7f3 100644 --- a/monai/handlers/__init__.py +++ b/monai/handlers/__init__.py @@ -11,6 +11,7 @@ from __future__ import annotations +from .average_precision import AveragePrecision from .checkpoint_loader import CheckpointLoader from .checkpoint_saver import CheckpointSaver from .classification_saver import ClassificationSaver diff --git a/monai/handlers/average_precision.py b/monai/handlers/average_precision.py new file mode 100644 index 0000000000..608d7eea72 --- /dev/null +++ b/monai/handlers/average_precision.py @@ -0,0 +1,53 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from collections.abc import Callable + +from monai.handlers.ignite_metric import IgniteMetricHandler +from monai.metrics import AveragePrecisionMetric +from monai.utils import Average + + +class AveragePrecision(IgniteMetricHandler): + """ + Computes Average Precision (AP). + accumulating predictions and the ground-truth during an epoch and applying `compute_average_precision`. + + Args: + average: {``"macro"``, ``"weighted"``, ``"micro"``, ``"none"``} + Type of averaging performed if not binary classification. Defaults to ``"macro"``. + + - ``"macro"``: calculate metrics for each label, and find their unweighted mean. + This does not take label imbalance into account. + - ``"weighted"``: calculate metrics for each label, and find their average, + weighted by support (the number of true instances for each label). + - ``"micro"``: calculate metrics globally by considering each element of the label + indicator matrix as a label. + - ``"none"``: the scores for each class are returned. + + output_transform: callable to extract `y_pred` and `y` from `ignite.engine.state.output` then + construct `(y_pred, y)` pair, where `y_pred` and `y` can be `batch-first` Tensors or + lists of `channel-first` Tensors. the form of `(y_pred, y)` is required by the `update()`. + `engine.state` and `output_transform` inherit from the ignite concept: + https://pytorch.org/ignite/concepts.html#state, explanation and usage example are in the tutorial: + https://github.com/Project-MONAI/tutorials/blob/master/modules/batch_output_transform.ipynb. + + Note: + Average Precision expects y to be comprised of 0's and 1's. + y_pred must either be probability estimates or confidence values. + + """ + + def __init__(self, average: Average | str = Average.MACRO, output_transform: Callable = lambda x: x) -> None: + metric_fn = AveragePrecisionMetric(average=Average(average)) + super().__init__(metric_fn=metric_fn, output_transform=output_transform, save_details=False) diff --git a/monai/metrics/__init__.py b/monai/metrics/__init__.py index 201acdfa50..7176f3311f 100644 --- a/monai/metrics/__init__.py +++ b/monai/metrics/__init__.py @@ -12,6 +12,7 @@ from __future__ import annotations from .active_learning_metrics import LabelQualityScore, VarianceMetric, compute_variance, label_quality_score +from .average_precision import AveragePrecisionMetric, compute_average_precision from .confusion_matrix import ConfusionMatrixMetric, compute_confusion_matrix_metric, get_confusion_matrix from .cumulative_average import CumulativeAverage from .f_beta_score import FBetaScore diff --git a/monai/metrics/average_precision.py b/monai/metrics/average_precision.py new file mode 100644 index 0000000000..53c41aeca5 --- /dev/null +++ b/monai/metrics/average_precision.py @@ -0,0 +1,187 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import warnings +from typing import TYPE_CHECKING, cast + +import numpy as np + +if TYPE_CHECKING: + import numpy.typing as npt + +import torch + +from monai.utils import Average, look_up_option + +from .metric import CumulativeIterationMetric + + +class AveragePrecisionMetric(CumulativeIterationMetric): + """ + Computes Average Precision (AP). AP is a useful metric to evaluate a classifier when the classes are + imbalanced. It can take values between 0.0 and 1.0, 1.0 being the best possible score. + It summarizes a Precision-Recall curve as the weighted mean of precisions achieved at each + threshold, with the increase in recall from the previous threshold used as the weight: + + .. math:: + \\text{AP} = \\sum_n (R_n - R_{n-1}) P_n + :label: ap + + where :math:`P_n` and :math:`R_n` are the precision and recall at the :math:`n^{th}` threshold. + + Referring to: `sklearn.metrics.average_precision_score + `_. + + The input `y_pred` and `y` can be a list of `channel-first` Tensor or a `batch-first` Tensor. + + Example of the typical execution steps of this metric class follows :py:class:`monai.metrics.metric.Cumulative`. + + Args: + average: {``"macro"``, ``"weighted"``, ``"micro"``, ``"none"``} + Type of averaging performed if not binary classification. + Defaults to ``"macro"``. + + - ``"macro"``: calculate metrics for each label, and find their unweighted mean. + This does not take label imbalance into account. + - ``"weighted"``: calculate metrics for each label, and find their average, + weighted by support (the number of true instances for each label). + - ``"micro"``: calculate metrics globally by considering each element of the label + indicator matrix as a label. + - ``"none"``: the scores for each class are returned. + + """ + + def __init__(self, average: Average | str = Average.MACRO) -> None: + super().__init__() + self.average = average + + def _compute_tensor(self, y_pred: torch.Tensor, y: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: # type: ignore[override] + return y_pred, y + + def aggregate(self, average: Average | str | None = None) -> np.ndarray | float | npt.ArrayLike: + """ + Typically `y_pred` and `y` are stored in the cumulative buffers at each iteration, + This function reads the buffers and computes the Average Precision. + + Args: + average: {``"macro"``, ``"weighted"``, ``"micro"``, ``"none"``} + Type of averaging performed if not binary classification. Defaults to `self.average`. + + """ + y_pred, y = self.get_buffer() + # compute final value and do metric reduction + if not isinstance(y_pred, torch.Tensor) or not isinstance(y, torch.Tensor): + raise ValueError("y_pred and y must be PyTorch Tensor.") + + return compute_average_precision(y_pred=y_pred, y=y, average=average or self.average) + + +def _calculate(y_pred: torch.Tensor, y: torch.Tensor) -> float: + if not (y.ndimension() == y_pred.ndimension() == 1 and len(y) == len(y_pred)): + raise AssertionError("y and y_pred must be 1 dimension data with same length.") + y_unique = y.unique() + if len(y_unique) == 1: + warnings.warn(f"y values can not be all {y_unique.item()}, skip AP computation and return `Nan`.") + return float("nan") + if not y_unique.equal(torch.tensor([0, 1], dtype=y.dtype, device=y.device)): + warnings.warn(f"y values must be 0 or 1, but in {y_unique.tolist()}, skip AP computation and return `Nan`.") + return float("nan") + + n = len(y) + indices = y_pred.argsort(descending=True) + y = y[indices].cpu().numpy() # type: ignore[assignment] + y_pred = y_pred[indices].cpu().numpy() # type: ignore[assignment] + npos = ap = tmp_pos = 0.0 + + for i in range(n): + y_i = cast(float, y[i]) + if i + 1 < n and y_pred[i] == y_pred[i + 1]: + tmp_pos += y_i + else: + tmp_pos += y_i + npos += tmp_pos + ap += tmp_pos * npos / (i + 1) + tmp_pos = 0 + + return ap / npos + + +def compute_average_precision( + y_pred: torch.Tensor, y: torch.Tensor, average: Average | str = Average.MACRO +) -> np.ndarray | float | npt.ArrayLike: + """Computes Average Precision (AP). AP is a useful metric to evaluate a classifier when the classes are + imbalanced. It summarizes a Precision-Recall according to equation :eq:`ap`. + Referring to: `sklearn.metrics.average_precision_score + `_. + + Args: + y_pred: input data to compute, typical classification model output. + the first dim must be batch, if multi-classes, it must be in One-Hot format. + for example: shape `[16]` or `[16, 1]` for a binary data, shape `[16, 2]` for 2 classes data. + y: ground truth to compute AP metric, the first dim must be batch. + if multi-classes, it must be in One-Hot format. + for example: shape `[16]` or `[16, 1]` for a binary data, shape `[16, 2]` for 2 classes data. + average: {``"macro"``, ``"weighted"``, ``"micro"``, ``"none"``} + Type of averaging performed if not binary classification. + Defaults to ``"macro"``. + + - ``"macro"``: calculate metrics for each label, and find their unweighted mean. + This does not take label imbalance into account. + - ``"weighted"``: calculate metrics for each label, and find their average, + weighted by support (the number of true instances for each label). + - ``"micro"``: calculate metrics globally by considering each element of the label + indicator matrix as a label. + - ``"none"``: the scores for each class are returned. + + Raises: + ValueError: When ``y_pred`` dimension is not one of [1, 2]. + ValueError: When ``y`` dimension is not one of [1, 2]. + ValueError: When ``average`` is not one of ["macro", "weighted", "micro", "none"]. + + Note: + Average Precision expects y to be comprised of 0's and 1's. `y_pred` must be either prob. estimates or confidence values. + + """ + y_pred_ndim = y_pred.ndimension() + y_ndim = y.ndimension() + if y_pred_ndim not in (1, 2): + raise ValueError( + f"Predictions should be of shape (batch_size, num_classes) or (batch_size, ), got {y_pred.shape}." + ) + if y_ndim not in (1, 2): + raise ValueError(f"Targets should be of shape (batch_size, num_classes) or (batch_size, ), got {y.shape}.") + if y_pred_ndim == 2 and y_pred.shape[1] == 1: + y_pred = y_pred.squeeze(dim=-1) + y_pred_ndim = 1 + if y_ndim == 2 and y.shape[1] == 1: + y = y.squeeze(dim=-1) + + if y_pred_ndim == 1: + return _calculate(y_pred, y) + + if y.shape != y_pred.shape: + raise ValueError(f"data shapes of y_pred and y do not match, got {y_pred.shape} and {y.shape}.") + + average = look_up_option(average, Average) + if average == Average.MICRO: + return _calculate(y_pred.flatten(), y.flatten()) + y, y_pred = y.transpose(0, 1), y_pred.transpose(0, 1) + ap_values = [_calculate(y_pred_, y_) for y_pred_, y_ in zip(y_pred, y)] + if average == Average.NONE: + return ap_values + if average == Average.MACRO: + return np.mean(ap_values) + if average == Average.WEIGHTED: + weights = [sum(y_) for y_ in y] + return np.average(ap_values, weights=weights) # type: ignore[no-any-return] + raise ValueError(f'Unsupported average: {average}, available options are ["macro", "weighted", "micro", "none"].') diff --git a/monai/utils/enums.py b/monai/utils/enums.py index ac14134acc..3463a92e4b 100644 --- a/monai/utils/enums.py +++ b/monai/utils/enums.py @@ -213,7 +213,8 @@ class GridSamplePadMode(StrEnum): class Average(StrEnum): """ - See also: :py:class:`monai.metrics.rocauc.compute_roc_auc` + See also: :py:class:`monai.metrics.rocauc.compute_roc_auc` or + :py:class:`monai.metrics.average_precision.compute_average_precision` """ MACRO = "macro" diff --git a/tests/handlers/test_handler_average_precision.py b/tests/handlers/test_handler_average_precision.py new file mode 100644 index 0000000000..7f52a5ee9c --- /dev/null +++ b/tests/handlers/test_handler_average_precision.py @@ -0,0 +1,79 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import unittest + +import numpy as np +import torch +import torch.distributed as dist + +from monai.handlers import AveragePrecision +from monai.transforms import Activations, AsDiscrete +from tests.test_utils import DistCall, DistTestCase + + +class TestHandlerAveragePrecision(unittest.TestCase): + + def test_compute(self): + ap_metric = AveragePrecision() + act = Activations(softmax=True) + to_onehot = AsDiscrete(to_onehot=2) + + y_pred = [torch.Tensor([0.1, 0.9]), torch.Tensor([0.3, 1.4])] + y = [torch.Tensor([0]), torch.Tensor([1])] + y_pred = [act(p) for p in y_pred] + y = [to_onehot(y_) for y_ in y] + ap_metric.update([y_pred, y]) + + y_pred = [torch.Tensor([0.2, 0.1]), torch.Tensor([0.1, 0.5])] + y = [torch.Tensor([0]), torch.Tensor([1])] + y_pred = [act(p) for p in y_pred] + y = [to_onehot(y_) for y_ in y] + + ap_metric.update([y_pred, y]) + + ap = ap_metric.compute() + np.testing.assert_allclose(0.8333333, ap) + + +class DistributedAveragePrecision(DistTestCase): + + @DistCall(nnodes=1, nproc_per_node=2, node_rank=0) + def test_compute(self): + ap_metric = AveragePrecision() + act = Activations(softmax=True) + to_onehot = AsDiscrete(to_onehot=2) + + device = f"cuda:{dist.get_rank()}" if torch.cuda.is_available() else "cpu" + if dist.get_rank() == 0: + y_pred = [torch.tensor([0.1, 0.9], device=device), torch.tensor([0.3, 1.4], device=device)] + y = [torch.tensor([0], device=device), torch.tensor([1], device=device)] + + if dist.get_rank() == 1: + y_pred = [ + torch.tensor([0.2, 0.1], device=device), + torch.tensor([0.1, 0.5], device=device), + torch.tensor([0.3, 0.4], device=device), + ] + y = [torch.tensor([0], device=device), torch.tensor([1], device=device), torch.tensor([1], device=device)] + + y_pred = [act(p) for p in y_pred] + y = [to_onehot(y_) for y_ in y] + ap_metric.update([y_pred, y]) + + result = ap_metric.compute() + np.testing.assert_allclose(0.7778, result, rtol=1e-4) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/metrics/test_compute_average_precision.py b/tests/metrics/test_compute_average_precision.py new file mode 100644 index 0000000000..819bb61a42 --- /dev/null +++ b/tests/metrics/test_compute_average_precision.py @@ -0,0 +1,162 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import unittest + +import numpy as np +import torch +from parameterized import parameterized + +from monai.data import decollate_batch +from monai.metrics import AveragePrecisionMetric, compute_average_precision +from monai.transforms import Activations, AsDiscrete, Compose, ToTensor + +_device = "cuda:0" if torch.cuda.is_available() else "cpu" +TEST_CASE_1 = [ + torch.tensor([[0.1, 0.9], [0.3, 1.4], [0.2, 0.1], [0.1, 0.5]], device=_device), + torch.tensor([[0], [0], [1], [1]], device=_device), + True, + 2, + "macro", + 0.41667, +] + +TEST_CASE_2 = [ + torch.tensor([[0.1, 0.9], [0.3, 1.4], [0.2, 0.1], [0.1, 0.5]], device=_device), + torch.tensor([[1], [1], [0], [0]], device=_device), + True, + 2, + "micro", + 0.85417, +] + +TEST_CASE_3 = [ + torch.tensor([[0.1, 0.9], [0.3, 1.4], [0.2, 0.1], [0.1, 0.5]], device=_device), + torch.tensor([[0], [1], [0], [1]], device=_device), + True, + 2, + "macro", + 0.83333, +] + +TEST_CASE_4 = [ + torch.tensor([[0.5], [0.5], [0.2], [8.3]]), + torch.tensor([[0], [1], [0], [1]]), + False, + None, + "macro", + 0.83333, +] + +TEST_CASE_5 = [torch.tensor([[0.5], [0.5], [0.2], [8.3]]), torch.tensor([0, 1, 0, 1]), False, None, "macro", 0.83333] + +TEST_CASE_6 = [torch.tensor([0.5, 0.5, 0.2, 8.3]), torch.tensor([0, 1, 0, 1]), False, None, "macro", 0.83333] + +TEST_CASE_7 = [ + torch.tensor([[0.1, 0.9], [0.3, 1.4], [0.2, 0.1], [0.1, 0.5]]), + torch.tensor([[0], [1], [0], [1]]), + True, + 2, + "none", + [0.83333, 0.83333], +] + +TEST_CASE_8 = [ + torch.tensor([[0.1, 0.9], [0.3, 1.4], [0.2, 0.1], [0.1, 0.5], [0.1, 0.5]]), + torch.tensor([[1, 0], [0, 1], [0, 0], [1, 1], [0, 1]]), + True, + None, + "weighted", + 0.66667, +] + +TEST_CASE_9 = [ + torch.tensor([[0.1, 0.9], [0.3, 1.4], [0.2, 0.1], [0.1, 0.5], [0.1, 0.5]]), + torch.tensor([[1, 0], [0, 1], [0, 0], [1, 1], [0, 1]]), + True, + None, + "micro", + 0.71111, +] + +TEST_CASE_10 = [ + torch.tensor([[0.1, 0.9], [0.3, 1.4], [0.2, 0.1], [0.1, 0.5]]), + torch.tensor([[0], [0], [0], [0]]), + True, + 2, + "macro", + float("nan"), +] + +TEST_CASE_11 = [ + torch.tensor([[0.1, 0.9], [0.3, 1.4], [0.2, 0.1], [0.1, 0.5]]), + torch.tensor([[1], [1], [1], [1]]), + True, + 2, + "macro", + float("nan"), +] + +TEST_CASE_12 = [ + torch.tensor([[0.1, 0.9], [0.3, 1.4], [0.2, 0.1], [0.1, 0.5]]), + torch.tensor([[0, 0], [1, 1], [2, 2], [3, 3]]), + True, + None, + "macro", + float("nan"), +] + +ALL_TESTS = [ + TEST_CASE_1, + TEST_CASE_2, + TEST_CASE_3, + TEST_CASE_4, + TEST_CASE_5, + TEST_CASE_6, + TEST_CASE_7, + TEST_CASE_8, + TEST_CASE_9, + TEST_CASE_10, + TEST_CASE_11, + TEST_CASE_12, +] + + +class TestComputeAveragePrecision(unittest.TestCase): + + @parameterized.expand(ALL_TESTS) + def test_value(self, y_pred, y, softmax, to_onehot, average, expected_value): + y_pred_trans = Compose([ToTensor(), Activations(softmax=softmax)]) + y_trans = Compose([ToTensor(), AsDiscrete(to_onehot=to_onehot)]) + y_pred = torch.stack([y_pred_trans(i) for i in decollate_batch(y_pred)], dim=0) + y = torch.stack([y_trans(i) for i in decollate_batch(y)], dim=0) + result = compute_average_precision(y_pred=y_pred, y=y, average=average) + np.testing.assert_allclose(expected_value, result, rtol=1e-5) + + @parameterized.expand(ALL_TESTS) + def test_class_value(self, y_pred, y, softmax, to_onehot, average, expected_value): + y_pred_trans = Compose([ToTensor(), Activations(softmax=softmax)]) + y_trans = Compose([ToTensor(), AsDiscrete(to_onehot=to_onehot)]) + y_pred = [y_pred_trans(i) for i in decollate_batch(y_pred)] + y = [y_trans(i) for i in decollate_batch(y)] + metric = AveragePrecisionMetric(average=average) + metric(y_pred=y_pred, y=y) + result = metric.aggregate() + np.testing.assert_allclose(expected_value, result, rtol=1e-5) + result = metric.aggregate(average=average) # test optional argument + metric.reset() + np.testing.assert_allclose(expected_value, result, rtol=1e-5) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/min_tests.py b/tests/min_tests.py index 1fc3da4a19..12f494be9c 100644 --- a/tests/min_tests.py +++ b/tests/min_tests.py @@ -76,6 +76,7 @@ def run_testsuit(): "test_grid_patch", "test_gmm", "test_handler_metrics_reloaded", + "test_handler_average_precision", "test_handler_checkpoint_loader", "test_handler_checkpoint_saver", "test_handler_classification_saver", From d98f348118b121e5f44fdf87c008851edbf68504 Mon Sep 17 00:00:00 2001 From: Rafael Garcia-Dias Date: Thu, 20 Feb 2025 14:11:20 +0000 Subject: [PATCH 4/8] Solves path problem in test_bundle_trt_export.py (#8357) Fixes #8354 ### Description Fixes path on test that is only run on special conditions. ### Types of changes - [x] Non-breaking change (fix or new feature that would not break existing functionality). - [ ] Breaking change (fix or new feature that would cause existing functionality to change). - [ ] New tests added to cover the changes. - [ ] Integration tests passed locally by running `./runtests.sh -f -u --net --coverage`. - [x] Quick tests passed locally by running `./runtests.sh --quick --unittests --disttests`. - [ ] In-line docstrings updated. - [ ] Documentation updated, tested `make html` command in the `docs/` folder. --------- Signed-off-by: R. Garcia-Dias Co-authored-by: YunLiu <55491388+KumoLiu@users.noreply.github.com> --- README.md | 3 ++- tests/bundle/test_bundle_trt_export.py | 2 +- tests/networks/test_convert_to_onnx.py | 2 +- tests/test_utils.py | 18 +++++++++++++++++- tests/transforms/test_gibbs_noise.py | 8 +++----- 5 files changed, 24 insertions(+), 9 deletions(-) diff --git a/README.md b/README.md index e5607ccb02..69cd1c657f 100644 --- a/README.md +++ b/README.md @@ -18,12 +18,13 @@ MONAI is a [PyTorch](https://pytorch.org/)-based, [open-source](https://github.com/Project-MONAI/MONAI/blob/dev/LICENSE) framework for deep learning in healthcare imaging, part of the [PyTorch Ecosystem](https://pytorch.org/ecosystem/). Its ambitions are as follows: + - Developing a community of academic, industrial and clinical researchers collaborating on a common foundation; - Creating state-of-the-art, end-to-end training workflows for healthcare imaging; - Providing researchers with the optimized and standardized way to create and evaluate deep learning models. - ## Features + > _Please see [the technical highlights](https://docs.monai.io/en/latest/highlights.html) and [What's New](https://docs.monai.io/en/latest/whatsnew.html) of the milestone releases._ - flexible pre-processing for multi-dimensional medical imaging data; diff --git a/tests/bundle/test_bundle_trt_export.py b/tests/bundle/test_bundle_trt_export.py index a7c570438d..5168fcfdb5 100644 --- a/tests/bundle/test_bundle_trt_export.py +++ b/tests/bundle/test_bundle_trt_export.py @@ -70,7 +70,7 @@ def tearDown(self): @parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3, TEST_CASE_4]) @unittest.skipUnless(has_torchtrt and has_tensorrt, "Torch-TensorRT is required for conversion!") def test_trt_export(self, convert_precision, input_shape, dynamic_batch): - tests_dir = Path(__file__).resolve().parent + tests_dir = Path(__file__).resolve().parents[1] meta_file = os.path.join(tests_dir, "testing_data", "metadata.json") config_file = os.path.join(tests_dir, "testing_data", "inference.json") with tempfile.TemporaryDirectory() as tempdir: diff --git a/tests/networks/test_convert_to_onnx.py b/tests/networks/test_convert_to_onnx.py index cfc356d5a4..106f15dc9d 100644 --- a/tests/networks/test_convert_to_onnx.py +++ b/tests/networks/test_convert_to_onnx.py @@ -64,7 +64,7 @@ def test_unet(self, device, use_trace, use_ort): rtol=rtol, atol=atol, ) - self.assertTrue(isinstance(onnx_model, onnx.ModelProto)) + self.assertTrue(isinstance(onnx_model, onnx.ModelProto)) @parameterized.expand(TESTS_ORT) @SkipIfBeforePyTorchVersion((1, 12)) diff --git a/tests/test_utils.py b/tests/test_utils.py index c494bb547c..97a3181c44 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -30,9 +30,10 @@ import warnings from contextlib import contextmanager from functools import partial, reduce +from itertools import product from pathlib import Path from subprocess import PIPE, Popen -from typing import Callable +from typing import Callable, Literal from urllib.error import ContentTooShortError, HTTPError import numpy as np @@ -862,6 +863,21 @@ def equal_state_dict(st_1, st_2): if torch.cuda.is_available(): TEST_DEVICES.append([torch.device("cuda")]) + +def dict_product(trailing=False, format: Literal["list", "dict"] = "dict", **items): + keys = items.keys() + values = items.values() + for pvalues in product(*values): + dict_comb = dict(zip(keys, pvalues)) + if format == "dict": + if trailing: + yield [dict_comb] + list(pvalues) + else: + yield dict_comb + else: + yield pvalues + + if __name__ == "__main__": parser = argparse.ArgumentParser(prog="util") parser.add_argument("-c", "--count", default=2, help="max number of gpus") diff --git a/tests/transforms/test_gibbs_noise.py b/tests/transforms/test_gibbs_noise.py index 2aa2a44d10..1f96595a26 100644 --- a/tests/transforms/test_gibbs_noise.py +++ b/tests/transforms/test_gibbs_noise.py @@ -21,14 +21,12 @@ from monai.transforms import GibbsNoise from monai.utils.misc import set_determinism from monai.utils.module import optional_import -from tests.test_utils import TEST_NDARRAYS, assert_allclose +from tests.test_utils import TEST_NDARRAYS, assert_allclose, dict_product _, has_torch_fft = optional_import("torch.fft", name="fftshift") -TEST_CASES = [] -for shape in ((128, 64), (64, 48, 80)): - for input_type in TEST_NDARRAYS if has_torch_fft else [np.array]: - TEST_CASES.append((shape, input_type)) +params = {"shape": ((128, 64), (64, 48, 80)), "input_type": TEST_NDARRAYS if has_torch_fft else [np.array]} +TEST_CASES = list(dict_product(format="list", **params)) class TestGibbsNoise(unittest.TestCase): From a7905909e785d1ef24103c32a2d3a5a36e1059a2 Mon Sep 17 00:00:00 2001 From: Rafael Garcia-Dias Date: Fri, 21 Feb 2025 16:02:52 +0000 Subject: [PATCH 5/8] 8354 fix path at test onnx trt export (#8361) Fixes #8354 ### Description A few sentences describing the changes proposed in this pull request. ### Types of changes - [x] Non-breaking change (fix or new feature that would not break existing functionality). - [ ] Breaking change (fix or new feature that would cause existing functionality to change). - [ ] New tests added to cover the changes. - [ ] Integration tests passed locally by running `./runtests.sh -f -u --net --coverage`. - [ ] Quick tests passed locally by running `./runtests.sh --quick --unittests --disttests`. - [ ] In-line docstrings updated. - [ ] Documentation updated, tested `make html` command in the `docs/` folder. --------- Signed-off-by: R. Garcia-Dias --- README.md | 31 ++++++++++++++------------ tests/bundle/test_bundle_trt_export.py | 2 +- 2 files changed, 18 insertions(+), 15 deletions(-) diff --git a/README.md b/README.md index 69cd1c657f..5e006f5d64 100644 --- a/README.md +++ b/README.md @@ -33,7 +33,6 @@ Its ambitions are as follows: - customizable design for varying user expertise; - multi-GPU multi-node data parallelism support. - ## Installation To install [the current release](https://pypi.org/project/monai/), you can simply run: @@ -54,30 +53,34 @@ Technical documentation is available at [docs.monai.io](https://docs.monai.io). ## Citation -If you have used MONAI in your research, please cite us! The citation can be exported from: https://arxiv.org/abs/2211.02701. +If you have used MONAI in your research, please cite us! The citation can be exported from: . ## Model Zoo + [The MONAI Model Zoo](https://github.com/Project-MONAI/model-zoo) is a place for researchers and data scientists to share the latest and great models from the community. Utilizing [the MONAI Bundle format](https://docs.monai.io/en/latest/bundle_intro.html) makes it easy to [get started](https://github.com/Project-MONAI/tutorials/tree/main/model_zoo) building workflows with MONAI. ## Contributing + For guidance on making a contribution to MONAI, see the [contributing guidelines](https://github.com/Project-MONAI/MONAI/blob/dev/CONTRIBUTING.md). ## Community + Join the conversation on Twitter/X [@ProjectMONAI](https://twitter.com/ProjectMONAI) or join our [Slack channel](https://forms.gle/QTxJq3hFictp31UM9). Ask and answer questions over on [MONAI's GitHub Discussions tab](https://github.com/Project-MONAI/MONAI/discussions). ## Links -- Website: https://monai.io/ -- API documentation (milestone): https://docs.monai.io/ -- API documentation (latest dev): https://docs.monai.io/en/latest/ -- Code: https://github.com/Project-MONAI/MONAI -- Project tracker: https://github.com/Project-MONAI/MONAI/projects -- Issue tracker: https://github.com/Project-MONAI/MONAI/issues -- Wiki: https://github.com/Project-MONAI/MONAI/wiki -- Test status: https://github.com/Project-MONAI/MONAI/actions -- PyPI package: https://pypi.org/project/monai/ -- conda-forge: https://anaconda.org/conda-forge/monai -- Weekly previews: https://pypi.org/project/monai-weekly/ -- Docker Hub: https://hub.docker.com/r/projectmonai/monai + +- Website: +- API documentation (milestone): +- API documentation (latest dev): +- Code: +- Project tracker: +- Issue tracker: +- Wiki: +- Test status: +- PyPI package: +- conda-forge: +- Weekly previews: +- Docker Hub: diff --git a/tests/bundle/test_bundle_trt_export.py b/tests/bundle/test_bundle_trt_export.py index 5168fcfdb5..730338ad4e 100644 --- a/tests/bundle/test_bundle_trt_export.py +++ b/tests/bundle/test_bundle_trt_export.py @@ -108,7 +108,7 @@ def test_trt_export(self, convert_precision, input_shape, dynamic_batch): has_onnx and has_torchtrt and has_tensorrt, "Onnx and TensorRT are required for onnx-trt conversion!" ) def test_onnx_trt_export(self, convert_precision, input_shape, dynamic_batch): - tests_dir = Path(__file__).resolve().parent + tests_dir = Path(__file__).resolve().parents[1] meta_file = os.path.join(tests_dir, "testing_data", "metadata.json") config_file = os.path.join(tests_dir, "testing_data", "inference.json") with tempfile.TemporaryDirectory() as tempdir: From ab0752343b65198dae12d7389e441cfafeb9890a Mon Sep 17 00:00:00 2001 From: Virginia Fernandez <61539159+virginiafdez@users.noreply.github.com> Date: Mon, 24 Feb 2025 10:11:58 +0000 Subject: [PATCH 6/8] =?UTF-8?q?Modify=20ControlNet=20inferer=20so=20that?= =?UTF-8?q?=20it=20takes=20in=20context=20when=20the=20diffus=E2=80=A6=20(?= =?UTF-8?q?#8360)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Fixes #8344 ### Description The ControlNet inferers (latent and not latent) work in such a way that, when conditioning is used, the ControlNet does not take in the conditioning. It should, in theory, exhibit the same behaviour as the diffusion model. I've changed this behaviour, which has included modifying ControlNetDiffusionInferer and ControlNetLatentDiffusionInferer; the methods call, sample and get_likelihood. I've also modified the tests to take this into account. ### Types of changes - [x] Non-breaking change (fix or new feature that would not break existing functionality). - [x] New tests added to cover the changes (modified, rather than new) - [x] Quick tests passed locally by running `./runtests.sh --quick --unittests --disttests`. Signed-off-by: Virginia Fernandez Co-authored-by: Virginia Fernandez Co-authored-by: Eric Kerfoot <17726042+ericspod@users.noreply.github.com> Co-authored-by: YunLiu <55491388+KumoLiu@users.noreply.github.com> --- monai/inferers/inferer.py | 40 ++++++++++++++++------ tests/inferers/test_controlnet_inferers.py | 9 +++++ 2 files changed, 38 insertions(+), 11 deletions(-) diff --git a/monai/inferers/inferer.py b/monai/inferers/inferer.py index 7083373859..156677d992 100644 --- a/monai/inferers/inferer.py +++ b/monai/inferers/inferer.py @@ -1334,13 +1334,15 @@ def __call__( # type: ignore[override] raise NotImplementedError(f"{mode} condition is not supported") noisy_image = self.scheduler.add_noise(original_samples=inputs, noise=noise, timesteps=timesteps) - down_block_res_samples, mid_block_res_sample = controlnet( - x=noisy_image, timesteps=timesteps, controlnet_cond=cn_cond - ) + if mode == "concat" and condition is not None: noisy_image = torch.cat([noisy_image, condition], dim=1) condition = None + down_block_res_samples, mid_block_res_sample = controlnet( + x=noisy_image, timesteps=timesteps, controlnet_cond=cn_cond, context=condition + ) + diffuse = diffusion_model if isinstance(diffusion_model, SPADEDiffusionModelUNet): diffuse = partial(diffusion_model, seg=seg) @@ -1396,17 +1398,21 @@ def sample( # type: ignore[override] progress_bar = iter(scheduler.timesteps) intermediates = [] for t in progress_bar: - # 1. ControlNet forward - down_block_res_samples, mid_block_res_sample = controlnet( - x=image, timesteps=torch.Tensor((t,)).to(input_noise.device), controlnet_cond=cn_cond - ) - # 2. predict noise model_output diffuse = diffusion_model if isinstance(diffusion_model, SPADEDiffusionModelUNet): diffuse = partial(diffusion_model, seg=seg) if mode == "concat" and conditioning is not None: + # 1. Conditioning model_input = torch.cat([image, conditioning], dim=1) + # 2. ControlNet forward + down_block_res_samples, mid_block_res_sample = controlnet( + x=model_input, + timesteps=torch.Tensor((t,)).to(input_noise.device), + controlnet_cond=cn_cond, + context=None, + ) + # 3. predict noise model_output model_output = diffuse( model_input, timesteps=torch.Tensor((t,)).to(input_noise.device), @@ -1415,6 +1421,12 @@ def sample( # type: ignore[override] mid_block_additional_residual=mid_block_res_sample, ) else: + down_block_res_samples, mid_block_res_sample = controlnet( + x=image, + timesteps=torch.Tensor((t,)).to(input_noise.device), + controlnet_cond=cn_cond, + context=conditioning, + ) model_output = diffuse( image, timesteps=torch.Tensor((t,)).to(input_noise.device), @@ -1485,9 +1497,6 @@ def get_likelihood( # type: ignore[override] for t in progress_bar: timesteps = torch.full(inputs.shape[:1], t, device=inputs.device).long() noisy_image = self.scheduler.add_noise(original_samples=inputs, noise=noise, timesteps=timesteps) - down_block_res_samples, mid_block_res_sample = controlnet( - x=noisy_image, timesteps=torch.Tensor((t,)).to(inputs.device), controlnet_cond=cn_cond - ) diffuse = diffusion_model if isinstance(diffusion_model, SPADEDiffusionModelUNet): @@ -1495,6 +1504,9 @@ def get_likelihood( # type: ignore[override] if mode == "concat" and conditioning is not None: noisy_image = torch.cat([noisy_image, conditioning], dim=1) + down_block_res_samples, mid_block_res_sample = controlnet( + x=noisy_image, timesteps=torch.Tensor((t,)).to(inputs.device), controlnet_cond=cn_cond, context=None + ) model_output = diffuse( noisy_image, timesteps=timesteps, @@ -1503,6 +1515,12 @@ def get_likelihood( # type: ignore[override] mid_block_additional_residual=mid_block_res_sample, ) else: + down_block_res_samples, mid_block_res_sample = controlnet( + x=noisy_image, + timesteps=torch.Tensor((t,)).to(inputs.device), + controlnet_cond=cn_cond, + context=conditioning, + ) model_output = diffuse( x=noisy_image, timesteps=timesteps, diff --git a/tests/inferers/test_controlnet_inferers.py b/tests/inferers/test_controlnet_inferers.py index 2ab5cec335..909f2cf398 100644 --- a/tests/inferers/test_controlnet_inferers.py +++ b/tests/inferers/test_controlnet_inferers.py @@ -550,6 +550,8 @@ def test_ddim_sampler(self, model_params, controlnet_params, input_shape): def test_sampler_conditioned(self, model_params, controlnet_params, input_shape): model_params["with_conditioning"] = True model_params["cross_attention_dim"] = 3 + controlnet_params["with_conditioning"] = True + controlnet_params["cross_attention_dim"] = 3 model = DiffusionModelUNet(**model_params) controlnet = ControlNet(**controlnet_params) device = "cuda:0" if torch.cuda.is_available() else "cpu" @@ -619,8 +621,11 @@ def test_sampler_conditioned_concat(self, model_params, controlnet_params, input model_params = model_params.copy() n_concat_channel = 2 model_params["in_channels"] = model_params["in_channels"] + n_concat_channel + controlnet_params["in_channels"] = controlnet_params["in_channels"] + n_concat_channel model_params["cross_attention_dim"] = None + controlnet_params["cross_attention_dim"] = None model_params["with_conditioning"] = False + controlnet_params["with_conditioning"] = False model = DiffusionModelUNet(**model_params) device = "cuda:0" if torch.cuda.is_available() else "cpu" model.to(device) @@ -1023,8 +1028,10 @@ def test_prediction_shape_conditioned_concat( if ae_model_type == "SPADEAutoencoderKL": stage_1 = SPADEAutoencoderKL(**autoencoder_params) stage_2_params = stage_2_params.copy() + controlnet_params = controlnet_params.copy() n_concat_channel = 3 stage_2_params["in_channels"] = stage_2_params["in_channels"] + n_concat_channel + controlnet_params["in_channels"] = controlnet_params["in_channels"] + n_concat_channel if dm_model_type == "SPADEDiffusionModelUNet": stage_2 = SPADEDiffusionModelUNet(**stage_2_params) else: @@ -1106,8 +1113,10 @@ def test_sample_shape_conditioned_concat( if ae_model_type == "SPADEAutoencoderKL": stage_1 = SPADEAutoencoderKL(**autoencoder_params) stage_2_params = stage_2_params.copy() + controlnet_params = controlnet_params.copy() n_concat_channel = 3 stage_2_params["in_channels"] = stage_2_params["in_channels"] + n_concat_channel + controlnet_params["in_channels"] = controlnet_params["in_channels"] + n_concat_channel if dm_model_type == "SPADEDiffusionModelUNet": stage_2 = SPADEDiffusionModelUNet(**stage_2_params) else: From a09c1f08461cec3d2131fde3939ef38c3c4ad5fc Mon Sep 17 00:00:00 2001 From: Yiheng Wang <68361391+yiheng-wang-nv@users.noreply.github.com> Date: Tue, 25 Feb 2025 22:59:12 +0800 Subject: [PATCH 7/8] Update monaihosting download method (#8364) Related to https://github.com/Project-MONAI/model-zoo/pull/723. ### Description Currently, bundle download on source "monaihosting" uses fixed download url according to the function `_get_monaihosting_bundle_url`. A possible enhancement if to support on bundles that are hosted in different places. ### Types of changes - [x] Non-breaking change (fix or new feature that would not break existing functionality). - [ ] Breaking change (fix or new feature that would cause existing functionality to change). - [ ] New tests added to cover the changes. - [ ] Integration tests passed locally by running `./runtests.sh -f -u --net --coverage`. - [ ] Quick tests passed locally by running `./runtests.sh --quick --unittests --disttests`. - [ ] In-line docstrings updated. - [ ] Documentation updated, tested `make html` command in the `docs/` folder. --------- Signed-off-by: Yiheng Wang Co-authored-by: YunLiu <55491388+KumoLiu@users.noreply.github.com> --- monai/bundle/scripts.py | 55 +++++++++++++++++++++++++++++------------ 1 file changed, 39 insertions(+), 16 deletions(-) diff --git a/monai/bundle/scripts.py b/monai/bundle/scripts.py index 5089f0c045..b43f7e0fa0 100644 --- a/monai/bundle/scripts.py +++ b/monai/bundle/scripts.py @@ -15,6 +15,7 @@ import json import os import re +import urllib import warnings import zipfile from collections.abc import Mapping, Sequence @@ -58,7 +59,7 @@ validate, _ = optional_import("jsonschema", name="validate") ValidationError, _ = optional_import("jsonschema.exceptions", name="ValidationError") Checkpoint, has_ignite = optional_import("ignite.handlers", IgniteInfo.OPT_IMPORT_VERSION, min_version, "Checkpoint") -requests_get, has_requests = optional_import("requests", name="get") +requests, has_requests = optional_import("requests") onnx, _ = optional_import("onnx") huggingface_hub, _ = optional_import("huggingface_hub") @@ -206,6 +207,16 @@ def _download_from_monaihosting(download_path: Path, filename: str, version: str extractall(filepath=filepath, output_dir=download_path, has_base=True) +def _download_from_bundle_info(download_path: Path, filename: str, version: str, progress: bool) -> None: + bundle_info = get_bundle_info(bundle_name=filename, version=version) + if not bundle_info: + raise ValueError(f"Bundle info not found for {filename} v{version}.") + url = bundle_info["browser_download_url"] + filepath = download_path / f"{filename}_v{version}.zip" + download_url(url=url, filepath=filepath, hash_val=None, progress=progress) + extractall(filepath=filepath, output_dir=download_path, has_base=True) + + def _add_ngc_prefix(name: str, prefix: str = "monai_") -> str: if name.startswith(prefix): return name @@ -222,7 +233,7 @@ def _get_all_download_files(request_url: str, headers: dict | None = None) -> li if not has_requests: raise ValueError("requests package is required, please install it.") headers = {} if headers is None else headers - response = requests_get(request_url, headers=headers) + response = requests.get(request_url, headers=headers) response.raise_for_status() model_info = json.loads(response.text) @@ -266,7 +277,7 @@ def _download_from_ngc_private( request_url = _get_ngc_private_bundle_url(model_name=filename, version=version, repo=repo) if has_requests: headers = {} if headers is None else headers - response = requests_get(request_url, headers=headers) + response = requests.get(request_url, headers=headers) response.raise_for_status() else: raise ValueError("NGC API requires requests package. Please install it.") @@ -289,7 +300,7 @@ def _get_ngc_token(api_key, retry=0): url = "https://authn.nvidia.com/token?service=ngc" headers = {"Accept": "application/json", "Authorization": "ApiKey " + api_key} if has_requests: - response = requests_get(url, headers=headers) + response = requests.get(url, headers=headers) if not response.ok: # retry 3 times, if failed, raise an error. if retry < 3: @@ -303,14 +314,17 @@ def _get_ngc_token(api_key, retry=0): def _get_latest_bundle_version_monaihosting(name): full_url = f"{MONAI_HOSTING_BASE_URL}/{name.lower()}" - requests_get, has_requests = optional_import("requests", name="get") if has_requests: - resp = requests_get(full_url) - resp.raise_for_status() - else: - raise ValueError("NGC API requires requests package. Please install it.") - model_info = json.loads(resp.text) - return model_info["model"]["latestVersionIdStr"] + resp = requests.get(full_url) + try: + resp.raise_for_status() + model_info = json.loads(resp.text) + return model_info["model"]["latestVersionIdStr"] + except requests.exceptions.HTTPError: + # for monaihosting bundles, if cannot find the version, get from model zoo model_info.json + return get_bundle_versions(name)["latest_version"] + + raise ValueError("NGC API requires requests package. Please install it.") def _examine_monai_version(monai_version: str) -> tuple[bool, str]: @@ -388,14 +402,14 @@ def _get_latest_bundle_version_ngc(name: str, repo: str | None = None, headers: version_header = {"Accept-Encoding": "gzip, deflate"} # Excluding 'zstd' to fit NGC requirements if headers: version_header.update(headers) - resp = requests_get(version_endpoint, headers=version_header) + resp = requests.get(version_endpoint, headers=version_header) resp.raise_for_status() model_info = json.loads(resp.text) latest_versions = _list_latest_versions(model_info) for version in latest_versions: file_endpoint = base_url + f"/{name.lower()}/versions/{version}/files/configs/metadata.json" - resp = requests_get(file_endpoint, headers=headers) + resp = requests.get(file_endpoint, headers=headers) metadata = json.loads(resp.text) resp.raise_for_status() # if the package version is not available or the model is compatible with the package version @@ -585,7 +599,16 @@ def download( name_ver = "_v".join([name_, version_]) if version_ is not None else name_ _download_from_github(repo=repo_, download_path=bundle_dir_, filename=name_ver, progress=progress_) elif source_ == "monaihosting": - _download_from_monaihosting(download_path=bundle_dir_, filename=name_, version=version_, progress=progress_) + try: + _download_from_monaihosting( + download_path=bundle_dir_, filename=name_, version=version_, progress=progress_ + ) + except urllib.error.HTTPError: + # for monaihosting bundles, if cannot download from default host, download according to bundle_info + _download_from_bundle_info( + download_path=bundle_dir_, filename=name_, version=version_, progress=progress_ + ) + elif source_ == "ngc": _download_from_ngc( download_path=bundle_dir_, @@ -792,9 +815,9 @@ def _get_all_bundles_info( if auth_token is not None: headers = {"Authorization": f"Bearer {auth_token}"} - resp = requests_get(request_url, headers=headers) + resp = requests.get(request_url, headers=headers) else: - resp = requests_get(request_url) + resp = requests.get(request_url) resp.raise_for_status() else: raise ValueError("requests package is required, please install it.") From 2e391c82d9e5fc565917ac46fdeb3a900c96acaf Mon Sep 17 00:00:00 2001 From: James Butler Date: Tue, 4 Mar 2025 10:33:31 -0500 Subject: [PATCH 8/8] Bump torch minimum to mitigate CVE-2024-31580 & CVE-2024-31583 and enable numpy 2 compatibility (#8368) This is a follow-up to the comments made in https://github.com/Project-MONAI/MONAI/pull/8296#issuecomment-2587338931. ### Description This bumps the minimum required `torch` version from 1.13.1 to 2.2.0 in the first commit. See https://github.com/advisories/GHSA-5pcm-hx3q-hm94 and https://github.com/advisories/GHSA-pg7h-5qx3-wjr3 for more details regarding the "High" severity scoring. - https://nvd.nist.gov/vuln/detail/CVE-2024-31580 - https://nvd.nist.gov/vuln/detail/CVE-2024-31583 Additionally, PyTorch added support for numpy 2 starting with PyTorch 2.3.0. The second commit in this PR allows for numpy 1 or numpy 2 to be used with torch>=2.3.0. I have included this commit in this PR as upgrading to torch 2.2 means you might as well update to 2.3 to get the numpy 2 compatibility. A special case is being handled on Windows as PyTorch Windows binaries had compatibilities issues with numpy 2 that were fixed in torch 2.4.1 (see https://github.com/pytorch/pytorch/issues/131668#issuecomment-2307447045). Maintainers will need to update the required status checks for the [`dev`](https://github.com/Project-MONAI/MONAI/tree/dev) branch to: - Remove min-dep-pytorch (2.0.1) ### Types of changes - [X] Breaking change (fix or new feature that would cause existing functionality to change). - [ ] Integration tests passed locally by running `./runtests.sh -f -u --net --coverage`. - [ ] Quick tests passed locally by running `./runtests.sh --quick --unittests --disttests`. --------- Signed-off-by: James Butler --- .github/workflows/cron.yml | 10 +++---- .github/workflows/pythonapp-gpu.yml | 26 ++++++++++--------- .github/workflows/pythonapp-min.yml | 2 +- .github/workflows/pythonapp.yml | 6 ++--- docs/requirements.txt | 4 +-- environment-dev.yml | 4 +-- monai/engines/evaluator.py | 11 +++----- monai/engines/trainer.py | 10 ++----- monai/networks/blocks/crossattention.py | 7 +---- monai/networks/blocks/selfattention.py | 7 +---- monai/networks/blocks/upsample.py | 14 +++------- pyproject.toml | 2 +- requirements.txt | 5 ++-- setup.cfg | 5 ++-- .../test_integration_bundle_run.py | 6 ++--- tests/metrics/test_surface_dice.py | 6 ++--- tests/nonconfig_workflow.py | 2 +- 17 files changed, 48 insertions(+), 79 deletions(-) diff --git a/.github/workflows/cron.yml b/.github/workflows/cron.yml index 2e7921ec94..77fe9ca3a2 100644 --- a/.github/workflows/cron.yml +++ b/.github/workflows/cron.yml @@ -13,17 +13,13 @@ jobs: strategy: matrix: environment: - - "PT113+CUDA118" - - "PT210+CUDA121" + - "PT230+CUDA121" - "PT240+CUDA126" - "PTLATEST+CUDA126" include: # https://docs.nvidia.com/deeplearning/frameworks/pytorch-release-notes - - environment: PT113+CUDA118 - pytorch: "torch==1.13.1 torchvision==0.14.1 --extra-index-url https://download.pytorch.org/whl/cu121" - base: "nvcr.io/nvidia/pytorch:22.10-py3" # CUDA 11.8 - - environment: PT210+CUDA121 - pytorch: "pytorch==2.1.0 torchvision==0.16.0 --extra-index-url https://download.pytorch.org/whl/cu121" + - environment: PT230+CUDA121 + pytorch: "pytorch==2.3.0 torchvision==0.18.0 --extra-index-url https://download.pytorch.org/whl/cu121" base: "nvcr.io/nvidia/pytorch:23.08-py3" # CUDA 12.1 - environment: PT240+CUDA126 pytorch: "pytorch==2.4.0 torchvision==0.19.0 --extra-index-url https://download.pytorch.org/whl/cu121" diff --git a/.github/workflows/pythonapp-gpu.yml b/.github/workflows/pythonapp-gpu.yml index cd916f2ebb..6b0a5084a2 100644 --- a/.github/workflows/pythonapp-gpu.yml +++ b/.github/workflows/pythonapp-gpu.yml @@ -22,19 +22,21 @@ jobs: strategy: matrix: environment: - - "PT113+CUDA116" - - "PT210+CUDA121DOCKER" + - "PT230+CUDA124DOCKER" + - "PT240+CUDA125DOCKER" + - "PT250+CUDA126DOCKER" include: # https://docs.nvidia.com/deeplearning/frameworks/pytorch-release-notes - - environment: PT113+CUDA116 - pytorch: "torch==1.13.1 torchvision==0.14.1" - base: "nvcr.io/nvidia/cuda:11.6.1-devel-ubuntu18.04" - - environment: PT210+CUDA121DOCKER - # 23.08: 2.1.0a0+29c30b1 + - environment: PT230+CUDA124DOCKER + # 24.04: 2.3.0a0+6ddf5cf85e pytorch: "-h" # we explicitly set pytorch to -h to avoid pip install error - base: "nvcr.io/nvidia/pytorch:23.08-py3" - - environment: PT210+CUDA121DOCKER - # 24.08: 2.3.0a0+40ec155e58.nv24.3 + base: "nvcr.io/nvidia/pytorch:24.04-py3" + - environment: PT240+CUDA125DOCKER + # 24.06: 2.4.0a0+f70bd71a48 + pytorch: "-h" # we explicitly set pytorch to -h to avoid pip install error + base: "nvcr.io/nvidia/pytorch:24.06-py3" + - environment: PT250+CUDA126DOCKER + # 24.08: 2.5.0a0+872d972e41 pytorch: "-h" # we explicitly set pytorch to -h to avoid pip install error base: "nvcr.io/nvidia/pytorch:24.08-py3" container: @@ -49,7 +51,7 @@ jobs: apt-get update apt-get install -y wget - if [ ${{ matrix.environment }} = "PT113+CUDA116" ] + if [ ${{ matrix.environment }} = "PT230+CUDA124" ] then PYVER=3.9 PYSFX=3 DISTUTILS=python3-distutils && \ apt-get update && apt-get install -y --no-install-recommends \ @@ -114,7 +116,7 @@ jobs: # build for the current self-hosted CI Tesla V100 BUILD_MONAI=1 TORCH_CUDA_ARCH_LIST="7.0" ./runtests.sh --build --disttests ./runtests.sh --quick --unittests - if [ ${{ matrix.environment }} = "PT113+CUDA116" ]; then + if [ ${{ matrix.environment }} = "PT230+CUDA124" ]; then # test the clang-format tool downloading once coverage run -m tests.clang_format_utils fi diff --git a/.github/workflows/pythonapp-min.yml b/.github/workflows/pythonapp-min.yml index 19e30f86bb..afc9f6f6d4 100644 --- a/.github/workflows/pythonapp-min.yml +++ b/.github/workflows/pythonapp-min.yml @@ -124,7 +124,7 @@ jobs: strategy: fail-fast: false matrix: - pytorch-version: ['1.13.1', '2.0.1', '2.2.2', '2.3.1', '2.4.1', 'latest'] + pytorch-version: ['2.3.1', '2.4.1', '2.5.1', 'latest'] timeout-minutes: 40 steps: - uses: actions/checkout@v4 diff --git a/.github/workflows/pythonapp.yml b/.github/workflows/pythonapp.yml index f175cc3f7c..5d6fd06afa 100644 --- a/.github/workflows/pythonapp.yml +++ b/.github/workflows/pythonapp.yml @@ -94,7 +94,7 @@ jobs: - if: runner.os == 'windows' name: Install torch cpu from pytorch.org (Windows only) run: | - python -m pip install torch==1.13.1+cpu torchvision==0.14.1+cpu -f https://download.pytorch.org/whl/torch_stable.html + python -m pip install torch==2.4.1 torchvision==0.19.1+cpu --index-url https://download.pytorch.org/whl/cpu - if: runner.os == 'Linux' name: Install itk pre-release (Linux only) run: | @@ -103,7 +103,7 @@ jobs: - name: Install the dependencies run: | python -m pip install --user --upgrade pip wheel - python -m pip install torch==1.13.1 torchvision==0.14.1 + python -m pip install torch==2.4.1 torchvision==0.19.1 cat "requirements-dev.txt" python -m pip install -r requirements-dev.txt python -m pip list @@ -155,7 +155,7 @@ jobs: # install the latest pytorch for testing # however, "pip install monai*.tar.gz" will build cpp/cuda with an isolated # fresh torch installation according to pyproject.toml - python -m pip install torch>=1.13.1 torchvision + python -m pip install torch>=2.3.0 torchvision - name: Check packages run: | pip uninstall monai diff --git a/docs/requirements.txt b/docs/requirements.txt index d657580743..b314e10640 100644 --- a/docs/requirements.txt +++ b/docs/requirements.txt @@ -1,5 +1,5 @@ --f https://download.pytorch.org/whl/cpu/torch-1.13.1%2Bcpu-cp39-cp39-linux_x86_64.whl -torch>=1.13.1 +-f https://download.pytorch.org/whl/cpu/torch-2.3.0%2Bcpu-cp39-cp39-linux_x86_64.whl +torch>=2.3.0 pytorch-ignite==0.4.11 numpy>=1.20 itk>=5.2 diff --git a/environment-dev.yml b/environment-dev.yml index 8617a3b9cb..9358cdc83b 100644 --- a/environment-dev.yml +++ b/environment-dev.yml @@ -5,8 +5,8 @@ channels: - nvidia - conda-forge dependencies: - - numpy>=1.24,<2.0 - - pytorch>=1.13.1 + - numpy>=1.24,<3.0 + - pytorch>=2.3.0 - torchio - torchvision - pytorch-cuda>=11.6 diff --git a/monai/engines/evaluator.py b/monai/engines/evaluator.py index d70a39726b..35d4928465 100644 --- a/monai/engines/evaluator.py +++ b/monai/engines/evaluator.py @@ -28,7 +28,7 @@ from monai.utils import ForwardMode, IgniteInfo, ensure_tuple, min_version, optional_import from monai.utils.enums import CommonKeys as Keys from monai.utils.enums import EngineStatsKeys as ESKeys -from monai.utils.module import look_up_option, pytorch_after +from monai.utils.module import look_up_option if TYPE_CHECKING: from ignite.engine import Engine, EventEnum @@ -269,13 +269,8 @@ def __init__( amp_kwargs=amp_kwargs, ) if compile: - if pytorch_after(2, 1): - compile_kwargs = {} if compile_kwargs is None else compile_kwargs - network = torch.compile(network, **compile_kwargs) # type: ignore[assignment] - else: - warnings.warn( - "Network compilation (compile=True) not supported for Pytorch versions before 2.1, no compilation done" - ) + compile_kwargs = {} if compile_kwargs is None else compile_kwargs + network = torch.compile(network, **compile_kwargs) # type: ignore[assignment] self.network = network self.compile = compile self.inferer = SimpleInferer() if inferer is None else inferer diff --git a/monai/engines/trainer.py b/monai/engines/trainer.py index a0be86bae5..fdb45fbab8 100644 --- a/monai/engines/trainer.py +++ b/monai/engines/trainer.py @@ -27,7 +27,6 @@ from monai.utils import AdversarialIterationEvents, AdversarialKeys, GanKeys, IgniteInfo, min_version, optional_import from monai.utils.enums import CommonKeys as Keys from monai.utils.enums import EngineStatsKeys as ESKeys -from monai.utils.module import pytorch_after if TYPE_CHECKING: from ignite.engine import Engine, EventEnum @@ -183,13 +182,8 @@ def __init__( amp_kwargs=amp_kwargs, ) if compile: - if pytorch_after(2, 1): - compile_kwargs = {} if compile_kwargs is None else compile_kwargs - network = torch.compile(network, **compile_kwargs) # type: ignore[assignment] - else: - warnings.warn( - "Network compilation (compile=True) not supported for Pytorch versions before 2.1, no compilation done" - ) + compile_kwargs = {} if compile_kwargs is None else compile_kwargs + network = torch.compile(network, **compile_kwargs) # type: ignore[assignment] self.network = network self.compile = compile self.optimizer = optimizer diff --git a/monai/networks/blocks/crossattention.py b/monai/networks/blocks/crossattention.py index bdecf63168..be31d2d8fb 100644 --- a/monai/networks/blocks/crossattention.py +++ b/monai/networks/blocks/crossattention.py @@ -17,7 +17,7 @@ import torch.nn as nn from monai.networks.layers.utils import get_rel_pos_embedding_layer -from monai.utils import optional_import, pytorch_after +from monai.utils import optional_import Rearrange, _ = optional_import("einops.layers.torch", name="Rearrange") @@ -84,11 +84,6 @@ def __init__( if causal and sequence_length is None: raise ValueError("sequence_length is necessary for causal attention.") - if use_flash_attention and not pytorch_after(minor=13, major=1, patch=0): - raise ValueError( - "use_flash_attention is only supported for PyTorch versions >= 2.0." - "Upgrade your PyTorch or set the flag to False." - ) if use_flash_attention and save_attn: raise ValueError( "save_attn has been set to True, but use_flash_attention is also set" diff --git a/monai/networks/blocks/selfattention.py b/monai/networks/blocks/selfattention.py index 86e1b1d3ae..360579f3df 100644 --- a/monai/networks/blocks/selfattention.py +++ b/monai/networks/blocks/selfattention.py @@ -18,7 +18,7 @@ import torch.nn.functional as F from monai.networks.layers.utils import get_rel_pos_embedding_layer -from monai.utils import optional_import, pytorch_after +from monai.utils import optional_import Rearrange, _ = optional_import("einops.layers.torch", name="Rearrange") @@ -90,11 +90,6 @@ def __init__( if causal and sequence_length is None: raise ValueError("sequence_length is necessary for causal attention.") - if use_flash_attention and not pytorch_after(minor=13, major=1, patch=0): - raise ValueError( - "use_flash_attention is only supported for PyTorch versions >= 2.0." - "Upgrade your PyTorch or set the flag to False." - ) if use_flash_attention and save_attn: raise ValueError( "save_attn has been set to True, but use_flash_attention is also set" diff --git a/monai/networks/blocks/upsample.py b/monai/networks/blocks/upsample.py index 50fd39a70b..62908e9825 100644 --- a/monai/networks/blocks/upsample.py +++ b/monai/networks/blocks/upsample.py @@ -17,8 +17,8 @@ import torch.nn as nn from monai.networks.layers.factories import Conv, Pad, Pool -from monai.networks.utils import CastTempType, icnr_init, pixelshuffle -from monai.utils import InterpolateMode, UpsampleMode, ensure_tuple_rep, look_up_option, pytorch_after +from monai.networks.utils import icnr_init, pixelshuffle +from monai.utils import InterpolateMode, UpsampleMode, ensure_tuple_rep, look_up_option __all__ = ["Upsample", "UpSample", "SubpixelUpsample", "Subpixelupsample", "SubpixelUpSample"] @@ -164,15 +164,7 @@ def __init__( align_corners=align_corners, ) - # Cast to float32 as 'upsample_nearest2d_out_frame' op does not support bfloat16 - # https://github.com/pytorch/pytorch/issues/86679. This issue is solved in PyTorch 2.1 - if pytorch_after(major=2, minor=1): - self.add_module("upsample_non_trainable", upsample) - else: - self.add_module( - "upsample_non_trainable", - CastTempType(initial_type=torch.bfloat16, temporary_type=torch.float32, submodule=upsample), - ) + self.add_module("upsample_non_trainable", upsample) if post_conv: self.add_module("postconv", post_conv) elif up_mode == UpsampleMode.PIXELSHUFFLE: diff --git a/pyproject.toml b/pyproject.toml index 8ad55b1c2c..588d6d22d8 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -2,7 +2,7 @@ requires = [ "wheel", "setuptools", - "torch>=1.13.1", + "torch>=2.3.0", "ninja", "packaging" ] diff --git a/requirements.txt b/requirements.txt index 5203b43128..452a62adda 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,2 +1,3 @@ -torch>=1.13.1,<2.6 -numpy>=1.24,<2.0 +torch>=2.3.0,<2.6; sys_platform != 'win32' +torch>=2.4.1,<2.6; sys_platform == 'win32' +numpy>=1.24,<3.0 diff --git a/setup.cfg b/setup.cfg index 66d9e19609..2b06df64de 100644 --- a/setup.cfg +++ b/setup.cfg @@ -42,8 +42,9 @@ setup_requires = ninja packaging install_requires = - torch>=1.13.1 - numpy>=1.24,<2.0 + torch>=2.3.0; sys_platform != 'win32' + torch>=2.4.1; sys_platform == 'win32' + numpy>=1.24,<3.0 [options.extras_require] all = diff --git a/tests/integration/test_integration_bundle_run.py b/tests/integration/test_integration_bundle_run.py index cfbbcfe154..7f366d4745 100644 --- a/tests/integration/test_integration_bundle_run.py +++ b/tests/integration/test_integration_bundle_run.py @@ -76,8 +76,7 @@ def test_tiny(self): ) with open(meta_file, "w") as f: json.dump( - {"version": "0.1.0", "monai_version": "1.1.0", "pytorch_version": "1.13.1", "numpy_version": "1.22.2"}, - f, + {"version": "0.1.0", "monai_version": "1.1.0", "pytorch_version": "2.3.0", "numpy_version": "1.22.2"}, f ) cmd = ["coverage", "run", "-m", "monai.bundle"] # test both CLI entry "run" and "run_workflow" @@ -114,8 +113,7 @@ def test_scripts_fold(self): ) with open(meta_file, "w") as f: json.dump( - {"version": "0.1.0", "monai_version": "1.1.0", "pytorch_version": "1.13.1", "numpy_version": "1.22.2"}, - f, + {"version": "0.1.0", "monai_version": "1.1.0", "pytorch_version": "2.3.0", "numpy_version": "1.22.2"}, f ) os.mkdir(scripts_dir) diff --git a/tests/metrics/test_surface_dice.py b/tests/metrics/test_surface_dice.py index 01f80bd01e..a3d03e9937 100644 --- a/tests/metrics/test_surface_dice.py +++ b/tests/metrics/test_surface_dice.py @@ -82,7 +82,7 @@ def test_tolerance_euclidean_distance_with_spacing(self): expected_res0[1, 1] = np.nan for b, c in np.ndindex(batch_size, n_class): np.testing.assert_allclose(expected_res0[b, c], res0[b, c].cpu()) - np.testing.assert_array_equal(agg0.cpu(), np.nanmean(np.nanmean(expected_res0, axis=1), axis=0)) + np.testing.assert_allclose(agg0.cpu(), np.nanmean(np.nanmean(expected_res0, axis=1), axis=0)) np.testing.assert_equal(not_nans.cpu(), torch.tensor(2)) def test_tolerance_euclidean_distance(self): @@ -126,7 +126,7 @@ def test_tolerance_euclidean_distance(self): expected_res0[1, 1] = np.nan for b, c in np.ndindex(batch_size, n_class): np.testing.assert_allclose(expected_res0[b, c], res0[b, c].cpu()) - np.testing.assert_array_equal(agg0.cpu(), np.nanmean(np.nanmean(expected_res0, axis=1), axis=0)) + np.testing.assert_allclose(agg0.cpu(), np.nanmean(np.nanmean(expected_res0, axis=1), axis=0)) np.testing.assert_equal(not_nans.cpu(), torch.tensor(2)) def test_tolerance_euclidean_distance_3d(self): @@ -173,7 +173,7 @@ def test_tolerance_euclidean_distance_3d(self): expected_res0[1, 1] = np.nan for b, c in np.ndindex(batch_size, n_class): np.testing.assert_allclose(expected_res0[b, c], res0[b, c].cpu()) - np.testing.assert_array_equal(agg0.cpu(), np.nanmean(np.nanmean(expected_res0, axis=1), axis=0)) + np.testing.assert_allclose(agg0.cpu(), np.nanmean(np.nanmean(expected_res0, axis=1), axis=0)) np.testing.assert_equal(not_nans.cpu(), torch.tensor(2)) def test_tolerance_all_distances(self): diff --git a/tests/nonconfig_workflow.py b/tests/nonconfig_workflow.py index fcfc5b2951..bcbdc67b71 100644 --- a/tests/nonconfig_workflow.py +++ b/tests/nonconfig_workflow.py @@ -65,7 +65,7 @@ def initialize(self): self._monai_version = "1.1.0" if self._pytorch_version is None: - self._pytorch_version = "1.13.1" + self._pytorch_version = "2.3.0" if self._numpy_version is None: self._numpy_version = "1.22.2"