diff --git a/README.md b/README.md index e5607ccb02..69cd1c657f 100644 --- a/README.md +++ b/README.md @@ -18,12 +18,13 @@ MONAI is a [PyTorch](https://pytorch.org/)-based, [open-source](https://github.com/Project-MONAI/MONAI/blob/dev/LICENSE) framework for deep learning in healthcare imaging, part of the [PyTorch Ecosystem](https://pytorch.org/ecosystem/). Its ambitions are as follows: + - Developing a community of academic, industrial and clinical researchers collaborating on a common foundation; - Creating state-of-the-art, end-to-end training workflows for healthcare imaging; - Providing researchers with the optimized and standardized way to create and evaluate deep learning models. - ## Features + > _Please see [the technical highlights](https://docs.monai.io/en/latest/highlights.html) and [What's New](https://docs.monai.io/en/latest/whatsnew.html) of the milestone releases._ - flexible pre-processing for multi-dimensional medical imaging data; diff --git a/docs/source/handlers.rst b/docs/source/handlers.rst index b48869d01e..729c86c34f 100644 --- a/docs/source/handlers.rst +++ b/docs/source/handlers.rst @@ -53,6 +53,12 @@ ROC AUC metrics handler :members: +Average Precision metric handler +-------------------------------- +.. autoclass:: AveragePrecision + :members: + + Confusion matrix metrics handler -------------------------------- .. autoclass:: ConfusionMatrix diff --git a/docs/source/metrics.rst b/docs/source/metrics.rst index 751c624405..b2bc2f114d 100644 --- a/docs/source/metrics.rst +++ b/docs/source/metrics.rst @@ -80,6 +80,13 @@ Metrics .. autoclass:: ROCAUCMetric :members: +`Average Precision` +------------------- +.. autofunction:: compute_average_precision + +.. autoclass:: AveragePrecisionMetric + :members: + `Confusion matrix` ------------------ .. autofunction:: get_confusion_matrix diff --git a/monai/handlers/__init__.py b/monai/handlers/__init__.py index fed8504722..39565c0903 100644 --- a/monai/handlers/__init__.py +++ b/monai/handlers/__init__.py @@ -11,6 +11,7 @@ from __future__ import annotations +from .average_precision import AveragePrecision from .checkpoint_loader import CheckpointLoader from .checkpoint_saver import CheckpointSaver from .classification_saver import ClassificationSaver diff --git a/monai/handlers/average_precision.py b/monai/handlers/average_precision.py new file mode 100644 index 0000000000..608d7eea72 --- /dev/null +++ b/monai/handlers/average_precision.py @@ -0,0 +1,53 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from collections.abc import Callable + +from monai.handlers.ignite_metric import IgniteMetricHandler +from monai.metrics import AveragePrecisionMetric +from monai.utils import Average + + +class AveragePrecision(IgniteMetricHandler): + """ + Computes Average Precision (AP). + accumulating predictions and the ground-truth during an epoch and applying `compute_average_precision`. + + Args: + average: {``"macro"``, ``"weighted"``, ``"micro"``, ``"none"``} + Type of averaging performed if not binary classification. Defaults to ``"macro"``. + + - ``"macro"``: calculate metrics for each label, and find their unweighted mean. + This does not take label imbalance into account. + - ``"weighted"``: calculate metrics for each label, and find their average, + weighted by support (the number of true instances for each label). + - ``"micro"``: calculate metrics globally by considering each element of the label + indicator matrix as a label. + - ``"none"``: the scores for each class are returned. + + output_transform: callable to extract `y_pred` and `y` from `ignite.engine.state.output` then + construct `(y_pred, y)` pair, where `y_pred` and `y` can be `batch-first` Tensors or + lists of `channel-first` Tensors. the form of `(y_pred, y)` is required by the `update()`. + `engine.state` and `output_transform` inherit from the ignite concept: + https://pytorch.org/ignite/concepts.html#state, explanation and usage example are in the tutorial: + https://github.com/Project-MONAI/tutorials/blob/master/modules/batch_output_transform.ipynb. + + Note: + Average Precision expects y to be comprised of 0's and 1's. + y_pred must either be probability estimates or confidence values. + + """ + + def __init__(self, average: Average | str = Average.MACRO, output_transform: Callable = lambda x: x) -> None: + metric_fn = AveragePrecisionMetric(average=Average(average)) + super().__init__(metric_fn=metric_fn, output_transform=output_transform, save_details=False) diff --git a/monai/inferers/inferer.py b/monai/inferers/inferer.py index 769b6cc0e7..7083373859 100644 --- a/monai/inferers/inferer.py +++ b/monai/inferers/inferer.py @@ -1202,15 +1202,16 @@ def sample( # type: ignore[override] if self.autoencoder_latent_shape is not None: latent = torch.stack([self.autoencoder_resizer(i) for i in decollate_batch(latent)], 0) - latent_intermediates = [ - torch.stack([self.autoencoder_resizer(i) for i in decollate_batch(l)], 0) for l in latent_intermediates - ] + if save_intermediates: + latent_intermediates = [ + torch.stack([self.autoencoder_resizer(i) for i in decollate_batch(l)], 0) + for l in latent_intermediates + ] decode = autoencoder_model.decode_stage_2_outputs if isinstance(autoencoder_model, SPADEAutoencoderKL): decode = partial(autoencoder_model.decode_stage_2_outputs, seg=seg) image = decode(latent / self.scale_factor) - if save_intermediates: intermediates = [] for latent_intermediate in latent_intermediates: @@ -1727,9 +1728,11 @@ def sample( # type: ignore[override] if self.autoencoder_latent_shape is not None: latent = torch.stack([self.autoencoder_resizer(i) for i in decollate_batch(latent)], 0) - latent_intermediates = [ - torch.stack([self.autoencoder_resizer(i) for i in decollate_batch(l)], 0) for l in latent_intermediates - ] + if save_intermediates: + latent_intermediates = [ + torch.stack([self.autoencoder_resizer(i) for i in decollate_batch(l)], 0) + for l in latent_intermediates + ] decode = autoencoder_model.decode_stage_2_outputs if isinstance(autoencoder_model, SPADEAutoencoderKL): diff --git a/monai/metrics/__init__.py b/monai/metrics/__init__.py index db0de24eb0..6368467c76 100644 --- a/monai/metrics/__init__.py +++ b/monai/metrics/__init__.py @@ -12,6 +12,7 @@ from __future__ import annotations from .active_learning_metrics import LabelQualityScore, VarianceMetric, compute_variance, label_quality_score +from .average_precision import AveragePrecisionMetric, compute_average_precision from .confusion_matrix import ConfusionMatrixMetric, compute_confusion_matrix_metric, get_confusion_matrix from .cumulative_average import CumulativeAverage from .f_beta_score import FBetaScore diff --git a/monai/metrics/average_precision.py b/monai/metrics/average_precision.py new file mode 100644 index 0000000000..53c41aeca5 --- /dev/null +++ b/monai/metrics/average_precision.py @@ -0,0 +1,187 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import warnings +from typing import TYPE_CHECKING, cast + +import numpy as np + +if TYPE_CHECKING: + import numpy.typing as npt + +import torch + +from monai.utils import Average, look_up_option + +from .metric import CumulativeIterationMetric + + +class AveragePrecisionMetric(CumulativeIterationMetric): + """ + Computes Average Precision (AP). AP is a useful metric to evaluate a classifier when the classes are + imbalanced. It can take values between 0.0 and 1.0, 1.0 being the best possible score. + It summarizes a Precision-Recall curve as the weighted mean of precisions achieved at each + threshold, with the increase in recall from the previous threshold used as the weight: + + .. math:: + \\text{AP} = \\sum_n (R_n - R_{n-1}) P_n + :label: ap + + where :math:`P_n` and :math:`R_n` are the precision and recall at the :math:`n^{th}` threshold. + + Referring to: `sklearn.metrics.average_precision_score + `_. + + The input `y_pred` and `y` can be a list of `channel-first` Tensor or a `batch-first` Tensor. + + Example of the typical execution steps of this metric class follows :py:class:`monai.metrics.metric.Cumulative`. + + Args: + average: {``"macro"``, ``"weighted"``, ``"micro"``, ``"none"``} + Type of averaging performed if not binary classification. + Defaults to ``"macro"``. + + - ``"macro"``: calculate metrics for each label, and find their unweighted mean. + This does not take label imbalance into account. + - ``"weighted"``: calculate metrics for each label, and find their average, + weighted by support (the number of true instances for each label). + - ``"micro"``: calculate metrics globally by considering each element of the label + indicator matrix as a label. + - ``"none"``: the scores for each class are returned. + + """ + + def __init__(self, average: Average | str = Average.MACRO) -> None: + super().__init__() + self.average = average + + def _compute_tensor(self, y_pred: torch.Tensor, y: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: # type: ignore[override] + return y_pred, y + + def aggregate(self, average: Average | str | None = None) -> np.ndarray | float | npt.ArrayLike: + """ + Typically `y_pred` and `y` are stored in the cumulative buffers at each iteration, + This function reads the buffers and computes the Average Precision. + + Args: + average: {``"macro"``, ``"weighted"``, ``"micro"``, ``"none"``} + Type of averaging performed if not binary classification. Defaults to `self.average`. + + """ + y_pred, y = self.get_buffer() + # compute final value and do metric reduction + if not isinstance(y_pred, torch.Tensor) or not isinstance(y, torch.Tensor): + raise ValueError("y_pred and y must be PyTorch Tensor.") + + return compute_average_precision(y_pred=y_pred, y=y, average=average or self.average) + + +def _calculate(y_pred: torch.Tensor, y: torch.Tensor) -> float: + if not (y.ndimension() == y_pred.ndimension() == 1 and len(y) == len(y_pred)): + raise AssertionError("y and y_pred must be 1 dimension data with same length.") + y_unique = y.unique() + if len(y_unique) == 1: + warnings.warn(f"y values can not be all {y_unique.item()}, skip AP computation and return `Nan`.") + return float("nan") + if not y_unique.equal(torch.tensor([0, 1], dtype=y.dtype, device=y.device)): + warnings.warn(f"y values must be 0 or 1, but in {y_unique.tolist()}, skip AP computation and return `Nan`.") + return float("nan") + + n = len(y) + indices = y_pred.argsort(descending=True) + y = y[indices].cpu().numpy() # type: ignore[assignment] + y_pred = y_pred[indices].cpu().numpy() # type: ignore[assignment] + npos = ap = tmp_pos = 0.0 + + for i in range(n): + y_i = cast(float, y[i]) + if i + 1 < n and y_pred[i] == y_pred[i + 1]: + tmp_pos += y_i + else: + tmp_pos += y_i + npos += tmp_pos + ap += tmp_pos * npos / (i + 1) + tmp_pos = 0 + + return ap / npos + + +def compute_average_precision( + y_pred: torch.Tensor, y: torch.Tensor, average: Average | str = Average.MACRO +) -> np.ndarray | float | npt.ArrayLike: + """Computes Average Precision (AP). AP is a useful metric to evaluate a classifier when the classes are + imbalanced. It summarizes a Precision-Recall according to equation :eq:`ap`. + Referring to: `sklearn.metrics.average_precision_score + `_. + + Args: + y_pred: input data to compute, typical classification model output. + the first dim must be batch, if multi-classes, it must be in One-Hot format. + for example: shape `[16]` or `[16, 1]` for a binary data, shape `[16, 2]` for 2 classes data. + y: ground truth to compute AP metric, the first dim must be batch. + if multi-classes, it must be in One-Hot format. + for example: shape `[16]` or `[16, 1]` for a binary data, shape `[16, 2]` for 2 classes data. + average: {``"macro"``, ``"weighted"``, ``"micro"``, ``"none"``} + Type of averaging performed if not binary classification. + Defaults to ``"macro"``. + + - ``"macro"``: calculate metrics for each label, and find their unweighted mean. + This does not take label imbalance into account. + - ``"weighted"``: calculate metrics for each label, and find their average, + weighted by support (the number of true instances for each label). + - ``"micro"``: calculate metrics globally by considering each element of the label + indicator matrix as a label. + - ``"none"``: the scores for each class are returned. + + Raises: + ValueError: When ``y_pred`` dimension is not one of [1, 2]. + ValueError: When ``y`` dimension is not one of [1, 2]. + ValueError: When ``average`` is not one of ["macro", "weighted", "micro", "none"]. + + Note: + Average Precision expects y to be comprised of 0's and 1's. `y_pred` must be either prob. estimates or confidence values. + + """ + y_pred_ndim = y_pred.ndimension() + y_ndim = y.ndimension() + if y_pred_ndim not in (1, 2): + raise ValueError( + f"Predictions should be of shape (batch_size, num_classes) or (batch_size, ), got {y_pred.shape}." + ) + if y_ndim not in (1, 2): + raise ValueError(f"Targets should be of shape (batch_size, num_classes) or (batch_size, ), got {y.shape}.") + if y_pred_ndim == 2 and y_pred.shape[1] == 1: + y_pred = y_pred.squeeze(dim=-1) + y_pred_ndim = 1 + if y_ndim == 2 and y.shape[1] == 1: + y = y.squeeze(dim=-1) + + if y_pred_ndim == 1: + return _calculate(y_pred, y) + + if y.shape != y_pred.shape: + raise ValueError(f"data shapes of y_pred and y do not match, got {y_pred.shape} and {y.shape}.") + + average = look_up_option(average, Average) + if average == Average.MICRO: + return _calculate(y_pred.flatten(), y.flatten()) + y, y_pred = y.transpose(0, 1), y_pred.transpose(0, 1) + ap_values = [_calculate(y_pred_, y_) for y_pred_, y_ in zip(y_pred, y)] + if average == Average.NONE: + return ap_values + if average == Average.MACRO: + return np.mean(ap_values) + if average == Average.WEIGHTED: + weights = [sum(y_) for y_ in y] + return np.average(ap_values, weights=weights) # type: ignore[no-any-return] + raise ValueError(f'Unsupported average: {average}, available options are ["macro", "weighted", "micro", "none"].') diff --git a/monai/transforms/utility/array.py b/monai/transforms/utility/array.py index 2963c8a2f8..8491e4739c 100644 --- a/monai/transforms/utility/array.py +++ b/monai/transforms/utility/array.py @@ -66,7 +66,6 @@ optional_import, ) from monai.utils.enums import TransformBackends -from monai.utils.misc import is_module_ver_at_least from monai.utils.type_conversion import convert_to_dst_type, get_dtype_string, get_equivalent_dtype PILImageImage, has_pil = optional_import("PIL.Image", name="Image") @@ -939,19 +938,10 @@ def __call__( data = img[[*select_labels]] else: where: Callable = np.where if isinstance(img, np.ndarray) else torch.where # type: ignore - if isinstance(img, np.ndarray) or is_module_ver_at_least(torch, (1, 8, 0)): - data = where(in1d(img, select_labels), True, False).reshape(img.shape) - # pre pytorch 1.8.0, need to use 1/0 instead of True/False - else: - data = where( - in1d(img, select_labels), torch.tensor(1, device=img.device), torch.tensor(0, device=img.device) - ).reshape(img.shape) + data = where(in1d(img, select_labels), True, False).reshape(img.shape) if merge_channels or self.merge_channels: - if isinstance(img, np.ndarray) or is_module_ver_at_least(torch, (1, 8, 0)): - return data.any(0)[None] - # pre pytorch 1.8.0 compatibility - return data.to(torch.uint8).any(0)[None].to(bool) # type: ignore + return data.any(0)[None] return data diff --git a/monai/transforms/utils_pytorch_numpy_unification.py b/monai/transforms/utils_pytorch_numpy_unification.py index 365bd1eab5..8f22d00674 100644 --- a/monai/transforms/utils_pytorch_numpy_unification.py +++ b/monai/transforms/utils_pytorch_numpy_unification.py @@ -18,7 +18,6 @@ import torch from monai.config.type_definitions import NdarrayOrTensor, NdarrayTensor -from monai.utils.misc import is_module_ver_at_least from monai.utils.type_conversion import convert_data_type, convert_to_dst_type __all__ = [ @@ -215,10 +214,9 @@ def floor_divide(a: NdarrayOrTensor, b) -> NdarrayOrTensor: Element-wise floor division between two arrays/tensors. """ if isinstance(a, torch.Tensor): - if is_module_ver_at_least(torch, (1, 8, 0)): - return torch.div(a, b, rounding_mode="floor") return torch.floor_divide(a, b) - return np.floor_divide(a, b) + else: + return np.floor_divide(a, b) def unravel_index(idx, shape) -> NdarrayOrTensor: diff --git a/monai/utils/enums.py b/monai/utils/enums.py index df8722d03a..793f32a16f 100644 --- a/monai/utils/enums.py +++ b/monai/utils/enums.py @@ -214,7 +214,8 @@ class GridSamplePadMode(StrEnum): class Average(StrEnum): """ - See also: :py:class:`monai.metrics.rocauc.compute_roc_auc` + See also: :py:class:`monai.metrics.rocauc.compute_roc_auc` or + :py:class:`monai.metrics.average_precision.compute_average_precision` """ MACRO = "macro" @@ -346,7 +347,7 @@ class CommonKeys(StrEnum): `LABEL` is the training or evaluation label of segmentation or classification task. `PRED` is the prediction data of model output. `LOSS` is the loss value of current iteration. - `INFO` is some useful information during training or evaluation, like loss value, etc. + `METADATA` is some useful information during training or evaluation, like loss value, etc. """ diff --git a/monai/utils/module.py b/monai/utils/module.py index d3f2ff09f2..7bbbb4ab1e 100644 --- a/monai/utils/module.py +++ b/monai/utils/module.py @@ -540,11 +540,11 @@ def version_leq(lhs: str, rhs: str) -> bool: """ lhs, rhs = str(lhs), str(rhs) - pkging, has_ver = optional_import("packaging.Version") + pkging, has_ver = optional_import("packaging.version") if has_ver: try: - return cast(bool, pkging.version.Version(lhs) <= pkging.version.Version(rhs)) - except pkging.version.InvalidVersion: + return cast(bool, pkging.Version(lhs) <= pkging.Version(rhs)) + except pkging.InvalidVersion: return True lhs_, rhs_ = parse_version_strs(lhs, rhs) @@ -567,12 +567,12 @@ def version_geq(lhs: str, rhs: str) -> bool: """ lhs, rhs = str(lhs), str(rhs) - pkging, has_ver = optional_import("packaging.Version") + pkging, has_ver = optional_import("packaging.version") if has_ver: try: - return cast(bool, pkging.version.Version(lhs) >= pkging.version.Version(rhs)) - except pkging.version.InvalidVersion: + return cast(bool, pkging.Version(lhs) >= pkging.Version(rhs)) + except pkging.InvalidVersion: return True lhs_, rhs_ = parse_version_strs(lhs, rhs) diff --git a/tests/bundle/test_bundle_trt_export.py b/tests/bundle/test_bundle_trt_export.py index a7c570438d..5168fcfdb5 100644 --- a/tests/bundle/test_bundle_trt_export.py +++ b/tests/bundle/test_bundle_trt_export.py @@ -70,7 +70,7 @@ def tearDown(self): @parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3, TEST_CASE_4]) @unittest.skipUnless(has_torchtrt and has_tensorrt, "Torch-TensorRT is required for conversion!") def test_trt_export(self, convert_precision, input_shape, dynamic_batch): - tests_dir = Path(__file__).resolve().parent + tests_dir = Path(__file__).resolve().parents[1] meta_file = os.path.join(tests_dir, "testing_data", "metadata.json") config_file = os.path.join(tests_dir, "testing_data", "inference.json") with tempfile.TemporaryDirectory() as tempdir: diff --git a/tests/handlers/test_handler_average_precision.py b/tests/handlers/test_handler_average_precision.py new file mode 100644 index 0000000000..7f52a5ee9c --- /dev/null +++ b/tests/handlers/test_handler_average_precision.py @@ -0,0 +1,79 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import unittest + +import numpy as np +import torch +import torch.distributed as dist + +from monai.handlers import AveragePrecision +from monai.transforms import Activations, AsDiscrete +from tests.test_utils import DistCall, DistTestCase + + +class TestHandlerAveragePrecision(unittest.TestCase): + + def test_compute(self): + ap_metric = AveragePrecision() + act = Activations(softmax=True) + to_onehot = AsDiscrete(to_onehot=2) + + y_pred = [torch.Tensor([0.1, 0.9]), torch.Tensor([0.3, 1.4])] + y = [torch.Tensor([0]), torch.Tensor([1])] + y_pred = [act(p) for p in y_pred] + y = [to_onehot(y_) for y_ in y] + ap_metric.update([y_pred, y]) + + y_pred = [torch.Tensor([0.2, 0.1]), torch.Tensor([0.1, 0.5])] + y = [torch.Tensor([0]), torch.Tensor([1])] + y_pred = [act(p) for p in y_pred] + y = [to_onehot(y_) for y_ in y] + + ap_metric.update([y_pred, y]) + + ap = ap_metric.compute() + np.testing.assert_allclose(0.8333333, ap) + + +class DistributedAveragePrecision(DistTestCase): + + @DistCall(nnodes=1, nproc_per_node=2, node_rank=0) + def test_compute(self): + ap_metric = AveragePrecision() + act = Activations(softmax=True) + to_onehot = AsDiscrete(to_onehot=2) + + device = f"cuda:{dist.get_rank()}" if torch.cuda.is_available() else "cpu" + if dist.get_rank() == 0: + y_pred = [torch.tensor([0.1, 0.9], device=device), torch.tensor([0.3, 1.4], device=device)] + y = [torch.tensor([0], device=device), torch.tensor([1], device=device)] + + if dist.get_rank() == 1: + y_pred = [ + torch.tensor([0.2, 0.1], device=device), + torch.tensor([0.1, 0.5], device=device), + torch.tensor([0.3, 0.4], device=device), + ] + y = [torch.tensor([0], device=device), torch.tensor([1], device=device), torch.tensor([1], device=device)] + + y_pred = [act(p) for p in y_pred] + y = [to_onehot(y_) for y_ in y] + ap_metric.update([y_pred, y]) + + result = ap_metric.compute() + np.testing.assert_allclose(0.7778, result, rtol=1e-4) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/inferers/test_controlnet_inferers.py b/tests/inferers/test_controlnet_inferers.py index e3b0aeb5a2..2ab5cec335 100644 --- a/tests/inferers/test_controlnet_inferers.py +++ b/tests/inferers/test_controlnet_inferers.py @@ -722,7 +722,7 @@ def test_prediction_shape( @parameterized.expand(LATENT_CNDM_TEST_CASES) @skipUnless(has_einops, "Requires einops") - def test_sample_shape( + def test_pred_shape( self, ae_model_type, autoencoder_params, @@ -1165,7 +1165,7 @@ def test_sample_shape_conditioned_concat( @parameterized.expand(LATENT_CNDM_TEST_CASES_DIFF_SHAPES) @skipUnless(has_einops, "Requires einops") - def test_sample_shape_different_latents( + def test_shape_different_latents( self, ae_model_type, autoencoder_params, @@ -1242,6 +1242,84 @@ def test_sample_shape_different_latents( ) self.assertEqual(prediction.shape, latent_shape) + @parameterized.expand(LATENT_CNDM_TEST_CASES_DIFF_SHAPES) + @skipUnless(has_einops, "Requires einops") + def test_sample_shape_different_latents( + self, + ae_model_type, + autoencoder_params, + dm_model_type, + stage_2_params, + controlnet_params, + input_shape, + latent_shape, + ): + stage_1 = None + + if ae_model_type == "AutoencoderKL": + stage_1 = AutoencoderKL(**autoencoder_params) + if ae_model_type == "VQVAE": + stage_1 = VQVAE(**autoencoder_params) + if ae_model_type == "SPADEAutoencoderKL": + stage_1 = SPADEAutoencoderKL(**autoencoder_params) + if dm_model_type == "SPADEDiffusionModelUNet": + stage_2 = SPADEDiffusionModelUNet(**stage_2_params) + else: + stage_2 = DiffusionModelUNet(**stage_2_params) + controlnet = ControlNet(**controlnet_params) + + device = "cuda:0" if torch.cuda.is_available() else "cpu" + stage_1.to(device) + stage_2.to(device) + controlnet.to(device) + stage_1.eval() + stage_2.eval() + controlnet.eval() + + noise = torch.randn(latent_shape).to(device) + mask = torch.randn(input_shape).to(device) + scheduler = DDPMScheduler(num_train_timesteps=10) + # We infer the VAE shape + if ae_model_type == "VQVAE": + autoencoder_latent_shape = [i // (2 ** (len(autoencoder_params["channels"]))) for i in input_shape[2:]] + else: + autoencoder_latent_shape = [i // (2 ** (len(autoencoder_params["channels"]) - 1)) for i in input_shape[2:]] + + inferer = ControlNetLatentDiffusionInferer( + scheduler=scheduler, + scale_factor=1.0, + ldm_latent_shape=list(latent_shape[2:]), + autoencoder_latent_shape=autoencoder_latent_shape, + ) + scheduler.set_timesteps(num_inference_steps=10) + + if dm_model_type == "SPADEDiffusionModelUNet" or ae_model_type == "SPADEAutoencoderKL": + input_shape_seg = list(input_shape) + if "label_nc" in stage_2_params.keys(): + input_shape_seg[1] = stage_2_params["label_nc"] + else: + input_shape_seg[1] = autoencoder_params["label_nc"] + input_seg = torch.randn(input_shape_seg).to(device) + prediction, _ = inferer.sample( + autoencoder_model=stage_1, + diffusion_model=stage_2, + controlnet=controlnet, + cn_cond=mask, + input_noise=noise, + seg=input_seg, + save_intermediates=True, + ) + else: + prediction = inferer.sample( + autoencoder_model=stage_1, + diffusion_model=stage_2, + input_noise=noise, + controlnet=controlnet, + cn_cond=mask, + save_intermediates=False, + ) + self.assertEqual(prediction.shape, input_shape) + @skipUnless(has_einops, "Requires einops") def test_incompatible_spade_setup(self): stage_1 = SPADEAutoencoderKL( diff --git a/tests/inferers/test_latent_diffusion_inferer.py b/tests/inferers/test_latent_diffusion_inferer.py index 2e04ad6c5c..4f81b96ca1 100644 --- a/tests/inferers/test_latent_diffusion_inferer.py +++ b/tests/inferers/test_latent_diffusion_inferer.py @@ -714,7 +714,7 @@ def test_sample_shape_conditioned_concat( @parameterized.expand(TEST_CASES_DIFF_SHAPES) @skipUnless(has_einops, "Requires einops") - def test_sample_shape_different_latents( + def test_shape_different_latents( self, ae_model_type, autoencoder_params, dm_model_type, stage_2_params, input_shape, latent_shape ): stage_1 = None @@ -772,6 +772,66 @@ def test_sample_shape_different_latents( ) self.assertEqual(prediction.shape, latent_shape) + @parameterized.expand(TEST_CASES_DIFF_SHAPES) + @skipUnless(has_einops, "Requires einops") + def test_sample_shape_different_latents( + self, ae_model_type, autoencoder_params, dm_model_type, stage_2_params, input_shape, latent_shape + ): + stage_1 = None + + if ae_model_type == "AutoencoderKL": + stage_1 = AutoencoderKL(**autoencoder_params) + if ae_model_type == "VQVAE": + stage_1 = VQVAE(**autoencoder_params) + if ae_model_type == "SPADEAutoencoderKL": + stage_1 = SPADEAutoencoderKL(**autoencoder_params) + if dm_model_type == "SPADEDiffusionModelUNet": + stage_2 = SPADEDiffusionModelUNet(**stage_2_params) + else: + stage_2 = DiffusionModelUNet(**stage_2_params) + + device = "cuda:0" if torch.cuda.is_available() else "cpu" + stage_1.to(device) + stage_2.to(device) + stage_1.eval() + stage_2.eval() + + noise = torch.randn(latent_shape).to(device) + scheduler = DDPMScheduler(num_train_timesteps=10) + # We infer the VAE shape + if ae_model_type == "VQVAE": + autoencoder_latent_shape = [i // (2 ** (len(autoencoder_params["channels"]))) for i in input_shape[2:]] + else: + autoencoder_latent_shape = [i // (2 ** (len(autoencoder_params["channels"]) - 1)) for i in input_shape[2:]] + + inferer = LatentDiffusionInferer( + scheduler=scheduler, + scale_factor=1.0, + ldm_latent_shape=list(latent_shape[2:]), + autoencoder_latent_shape=autoencoder_latent_shape, + ) + scheduler.set_timesteps(num_inference_steps=10) + + if dm_model_type == "SPADEDiffusionModelUNet" or ae_model_type == "SPADEAutoencoderKL": + input_shape_seg = list(input_shape) + if "label_nc" in stage_2_params.keys(): + input_shape_seg[1] = stage_2_params["label_nc"] + else: + input_shape_seg[1] = autoencoder_params["label_nc"] + input_seg = torch.randn(input_shape_seg).to(device) + prediction, _ = inferer.sample( + autoencoder_model=stage_1, + diffusion_model=stage_2, + input_noise=noise, + save_intermediates=True, + seg=input_seg, + ) + else: + prediction = inferer.sample( + autoencoder_model=stage_1, diffusion_model=stage_2, input_noise=noise, save_intermediates=False + ) + self.assertEqual(prediction.shape, input_shape) + @skipUnless(has_einops, "Requires einops") def test_incompatible_spade_setup(self): stage_1 = SPADEAutoencoderKL( diff --git a/tests/metrics/test_compute_average_precision.py b/tests/metrics/test_compute_average_precision.py new file mode 100644 index 0000000000..819bb61a42 --- /dev/null +++ b/tests/metrics/test_compute_average_precision.py @@ -0,0 +1,162 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import unittest + +import numpy as np +import torch +from parameterized import parameterized + +from monai.data import decollate_batch +from monai.metrics import AveragePrecisionMetric, compute_average_precision +from monai.transforms import Activations, AsDiscrete, Compose, ToTensor + +_device = "cuda:0" if torch.cuda.is_available() else "cpu" +TEST_CASE_1 = [ + torch.tensor([[0.1, 0.9], [0.3, 1.4], [0.2, 0.1], [0.1, 0.5]], device=_device), + torch.tensor([[0], [0], [1], [1]], device=_device), + True, + 2, + "macro", + 0.41667, +] + +TEST_CASE_2 = [ + torch.tensor([[0.1, 0.9], [0.3, 1.4], [0.2, 0.1], [0.1, 0.5]], device=_device), + torch.tensor([[1], [1], [0], [0]], device=_device), + True, + 2, + "micro", + 0.85417, +] + +TEST_CASE_3 = [ + torch.tensor([[0.1, 0.9], [0.3, 1.4], [0.2, 0.1], [0.1, 0.5]], device=_device), + torch.tensor([[0], [1], [0], [1]], device=_device), + True, + 2, + "macro", + 0.83333, +] + +TEST_CASE_4 = [ + torch.tensor([[0.5], [0.5], [0.2], [8.3]]), + torch.tensor([[0], [1], [0], [1]]), + False, + None, + "macro", + 0.83333, +] + +TEST_CASE_5 = [torch.tensor([[0.5], [0.5], [0.2], [8.3]]), torch.tensor([0, 1, 0, 1]), False, None, "macro", 0.83333] + +TEST_CASE_6 = [torch.tensor([0.5, 0.5, 0.2, 8.3]), torch.tensor([0, 1, 0, 1]), False, None, "macro", 0.83333] + +TEST_CASE_7 = [ + torch.tensor([[0.1, 0.9], [0.3, 1.4], [0.2, 0.1], [0.1, 0.5]]), + torch.tensor([[0], [1], [0], [1]]), + True, + 2, + "none", + [0.83333, 0.83333], +] + +TEST_CASE_8 = [ + torch.tensor([[0.1, 0.9], [0.3, 1.4], [0.2, 0.1], [0.1, 0.5], [0.1, 0.5]]), + torch.tensor([[1, 0], [0, 1], [0, 0], [1, 1], [0, 1]]), + True, + None, + "weighted", + 0.66667, +] + +TEST_CASE_9 = [ + torch.tensor([[0.1, 0.9], [0.3, 1.4], [0.2, 0.1], [0.1, 0.5], [0.1, 0.5]]), + torch.tensor([[1, 0], [0, 1], [0, 0], [1, 1], [0, 1]]), + True, + None, + "micro", + 0.71111, +] + +TEST_CASE_10 = [ + torch.tensor([[0.1, 0.9], [0.3, 1.4], [0.2, 0.1], [0.1, 0.5]]), + torch.tensor([[0], [0], [0], [0]]), + True, + 2, + "macro", + float("nan"), +] + +TEST_CASE_11 = [ + torch.tensor([[0.1, 0.9], [0.3, 1.4], [0.2, 0.1], [0.1, 0.5]]), + torch.tensor([[1], [1], [1], [1]]), + True, + 2, + "macro", + float("nan"), +] + +TEST_CASE_12 = [ + torch.tensor([[0.1, 0.9], [0.3, 1.4], [0.2, 0.1], [0.1, 0.5]]), + torch.tensor([[0, 0], [1, 1], [2, 2], [3, 3]]), + True, + None, + "macro", + float("nan"), +] + +ALL_TESTS = [ + TEST_CASE_1, + TEST_CASE_2, + TEST_CASE_3, + TEST_CASE_4, + TEST_CASE_5, + TEST_CASE_6, + TEST_CASE_7, + TEST_CASE_8, + TEST_CASE_9, + TEST_CASE_10, + TEST_CASE_11, + TEST_CASE_12, +] + + +class TestComputeAveragePrecision(unittest.TestCase): + + @parameterized.expand(ALL_TESTS) + def test_value(self, y_pred, y, softmax, to_onehot, average, expected_value): + y_pred_trans = Compose([ToTensor(), Activations(softmax=softmax)]) + y_trans = Compose([ToTensor(), AsDiscrete(to_onehot=to_onehot)]) + y_pred = torch.stack([y_pred_trans(i) for i in decollate_batch(y_pred)], dim=0) + y = torch.stack([y_trans(i) for i in decollate_batch(y)], dim=0) + result = compute_average_precision(y_pred=y_pred, y=y, average=average) + np.testing.assert_allclose(expected_value, result, rtol=1e-5) + + @parameterized.expand(ALL_TESTS) + def test_class_value(self, y_pred, y, softmax, to_onehot, average, expected_value): + y_pred_trans = Compose([ToTensor(), Activations(softmax=softmax)]) + y_trans = Compose([ToTensor(), AsDiscrete(to_onehot=to_onehot)]) + y_pred = [y_pred_trans(i) for i in decollate_batch(y_pred)] + y = [y_trans(i) for i in decollate_batch(y)] + metric = AveragePrecisionMetric(average=average) + metric(y_pred=y_pred, y=y) + result = metric.aggregate() + np.testing.assert_allclose(expected_value, result, rtol=1e-5) + result = metric.aggregate(average=average) # test optional argument + metric.reset() + np.testing.assert_allclose(expected_value, result, rtol=1e-5) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/min_tests.py b/tests/min_tests.py index 049c82d4c2..6e70bb77c0 100644 --- a/tests/min_tests.py +++ b/tests/min_tests.py @@ -76,6 +76,7 @@ def run_testsuit(): "test_grid_patch", "test_gmm", "test_handler_metrics_reloaded", + "test_handler_average_precision", "test_handler_checkpoint_loader", "test_handler_checkpoint_saver", "test_handler_classification_saver", diff --git a/tests/networks/test_convert_to_onnx.py b/tests/networks/test_convert_to_onnx.py index cfc356d5a4..106f15dc9d 100644 --- a/tests/networks/test_convert_to_onnx.py +++ b/tests/networks/test_convert_to_onnx.py @@ -64,7 +64,7 @@ def test_unet(self, device, use_trace, use_ort): rtol=rtol, atol=atol, ) - self.assertTrue(isinstance(onnx_model, onnx.ModelProto)) + self.assertTrue(isinstance(onnx_model, onnx.ModelProto)) @parameterized.expand(TESTS_ORT) @SkipIfBeforePyTorchVersion((1, 12)) diff --git a/tests/test_utils.py b/tests/test_utils.py index c494bb547c..97a3181c44 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -30,9 +30,10 @@ import warnings from contextlib import contextmanager from functools import partial, reduce +from itertools import product from pathlib import Path from subprocess import PIPE, Popen -from typing import Callable +from typing import Callable, Literal from urllib.error import ContentTooShortError, HTTPError import numpy as np @@ -862,6 +863,21 @@ def equal_state_dict(st_1, st_2): if torch.cuda.is_available(): TEST_DEVICES.append([torch.device("cuda")]) + +def dict_product(trailing=False, format: Literal["list", "dict"] = "dict", **items): + keys = items.keys() + values = items.values() + for pvalues in product(*values): + dict_comb = dict(zip(keys, pvalues)) + if format == "dict": + if trailing: + yield [dict_comb] + list(pvalues) + else: + yield dict_comb + else: + yield pvalues + + if __name__ == "__main__": parser = argparse.ArgumentParser(prog="util") parser.add_argument("-c", "--count", default=2, help="max number of gpus") diff --git a/tests/transforms/test_gibbs_noise.py b/tests/transforms/test_gibbs_noise.py index 2aa2a44d10..1f96595a26 100644 --- a/tests/transforms/test_gibbs_noise.py +++ b/tests/transforms/test_gibbs_noise.py @@ -21,14 +21,12 @@ from monai.transforms import GibbsNoise from monai.utils.misc import set_determinism from monai.utils.module import optional_import -from tests.test_utils import TEST_NDARRAYS, assert_allclose +from tests.test_utils import TEST_NDARRAYS, assert_allclose, dict_product _, has_torch_fft = optional_import("torch.fft", name="fftshift") -TEST_CASES = [] -for shape in ((128, 64), (64, 48, 80)): - for input_type in TEST_NDARRAYS if has_torch_fft else [np.array]: - TEST_CASES.append((shape, input_type)) +params = {"shape": ((128, 64), (64, 48, 80)), "input_type": TEST_NDARRAYS if has_torch_fft else [np.array]} +TEST_CASES = list(dict_product(format="list", **params)) class TestGibbsNoise(unittest.TestCase):