Skip to content

Commit

Permalink
Merge branch 'dev' into 8085-r2-score
Browse files Browse the repository at this point in the history
  • Loading branch information
thibaultdvx authored Feb 20, 2025
2 parents 10d2423 + d98f348 commit 4445bad
Show file tree
Hide file tree
Showing 21 changed files with 685 additions and 43 deletions.
3 changes: 2 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,13 @@

MONAI is a [PyTorch](https://pytorch.org/)-based, [open-source](https://github.com/Project-MONAI/MONAI/blob/dev/LICENSE) framework for deep learning in healthcare imaging, part of the [PyTorch Ecosystem](https://pytorch.org/ecosystem/).
Its ambitions are as follows:

- Developing a community of academic, industrial and clinical researchers collaborating on a common foundation;
- Creating state-of-the-art, end-to-end training workflows for healthcare imaging;
- Providing researchers with the optimized and standardized way to create and evaluate deep learning models.


## Features

> _Please see [the technical highlights](https://docs.monai.io/en/latest/highlights.html) and [What's New](https://docs.monai.io/en/latest/whatsnew.html) of the milestone releases._
- flexible pre-processing for multi-dimensional medical imaging data;
Expand Down
6 changes: 6 additions & 0 deletions docs/source/handlers.rst
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,12 @@ ROC AUC metrics handler
:members:


Average Precision metric handler
--------------------------------
.. autoclass:: AveragePrecision
:members:


Confusion matrix metrics handler
--------------------------------
.. autoclass:: ConfusionMatrix
Expand Down
7 changes: 7 additions & 0 deletions docs/source/metrics.rst
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,13 @@ Metrics
.. autoclass:: ROCAUCMetric
:members:

`Average Precision`
-------------------
.. autofunction:: compute_average_precision

.. autoclass:: AveragePrecisionMetric
:members:

`Confusion matrix`
------------------
.. autofunction:: get_confusion_matrix
Expand Down
1 change: 1 addition & 0 deletions monai/handlers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@

from __future__ import annotations

from .average_precision import AveragePrecision
from .checkpoint_loader import CheckpointLoader
from .checkpoint_saver import CheckpointSaver
from .classification_saver import ClassificationSaver
Expand Down
53 changes: 53 additions & 0 deletions monai/handlers/average_precision.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
# Copyright (c) MONAI Consortium
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import annotations

from collections.abc import Callable

from monai.handlers.ignite_metric import IgniteMetricHandler
from monai.metrics import AveragePrecisionMetric
from monai.utils import Average


class AveragePrecision(IgniteMetricHandler):
"""
Computes Average Precision (AP).
accumulating predictions and the ground-truth during an epoch and applying `compute_average_precision`.
Args:
average: {``"macro"``, ``"weighted"``, ``"micro"``, ``"none"``}
Type of averaging performed if not binary classification. Defaults to ``"macro"``.
- ``"macro"``: calculate metrics for each label, and find their unweighted mean.
This does not take label imbalance into account.
- ``"weighted"``: calculate metrics for each label, and find their average,
weighted by support (the number of true instances for each label).
- ``"micro"``: calculate metrics globally by considering each element of the label
indicator matrix as a label.
- ``"none"``: the scores for each class are returned.
output_transform: callable to extract `y_pred` and `y` from `ignite.engine.state.output` then
construct `(y_pred, y)` pair, where `y_pred` and `y` can be `batch-first` Tensors or
lists of `channel-first` Tensors. the form of `(y_pred, y)` is required by the `update()`.
`engine.state` and `output_transform` inherit from the ignite concept:
https://pytorch.org/ignite/concepts.html#state, explanation and usage example are in the tutorial:
https://github.com/Project-MONAI/tutorials/blob/master/modules/batch_output_transform.ipynb.
Note:
Average Precision expects y to be comprised of 0's and 1's.
y_pred must either be probability estimates or confidence values.
"""

def __init__(self, average: Average | str = Average.MACRO, output_transform: Callable = lambda x: x) -> None:
metric_fn = AveragePrecisionMetric(average=Average(average))
super().__init__(metric_fn=metric_fn, output_transform=output_transform, save_details=False)
17 changes: 10 additions & 7 deletions monai/inferers/inferer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1202,15 +1202,16 @@ def sample( # type: ignore[override]

if self.autoencoder_latent_shape is not None:
latent = torch.stack([self.autoencoder_resizer(i) for i in decollate_batch(latent)], 0)
latent_intermediates = [
torch.stack([self.autoencoder_resizer(i) for i in decollate_batch(l)], 0) for l in latent_intermediates
]
if save_intermediates:
latent_intermediates = [
torch.stack([self.autoencoder_resizer(i) for i in decollate_batch(l)], 0)
for l in latent_intermediates
]

decode = autoencoder_model.decode_stage_2_outputs
if isinstance(autoencoder_model, SPADEAutoencoderKL):
decode = partial(autoencoder_model.decode_stage_2_outputs, seg=seg)
image = decode(latent / self.scale_factor)

if save_intermediates:
intermediates = []
for latent_intermediate in latent_intermediates:
Expand Down Expand Up @@ -1727,9 +1728,11 @@ def sample( # type: ignore[override]

if self.autoencoder_latent_shape is not None:
latent = torch.stack([self.autoencoder_resizer(i) for i in decollate_batch(latent)], 0)
latent_intermediates = [
torch.stack([self.autoencoder_resizer(i) for i in decollate_batch(l)], 0) for l in latent_intermediates
]
if save_intermediates:
latent_intermediates = [
torch.stack([self.autoencoder_resizer(i) for i in decollate_batch(l)], 0)
for l in latent_intermediates
]

decode = autoencoder_model.decode_stage_2_outputs
if isinstance(autoencoder_model, SPADEAutoencoderKL):
Expand Down
1 change: 1 addition & 0 deletions monai/metrics/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from __future__ import annotations

from .active_learning_metrics import LabelQualityScore, VarianceMetric, compute_variance, label_quality_score
from .average_precision import AveragePrecisionMetric, compute_average_precision
from .confusion_matrix import ConfusionMatrixMetric, compute_confusion_matrix_metric, get_confusion_matrix
from .cumulative_average import CumulativeAverage
from .f_beta_score import FBetaScore
Expand Down
187 changes: 187 additions & 0 deletions monai/metrics/average_precision.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,187 @@
# Copyright (c) MONAI Consortium
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import annotations

import warnings
from typing import TYPE_CHECKING, cast

import numpy as np

if TYPE_CHECKING:
import numpy.typing as npt

import torch

from monai.utils import Average, look_up_option

from .metric import CumulativeIterationMetric


class AveragePrecisionMetric(CumulativeIterationMetric):
"""
Computes Average Precision (AP). AP is a useful metric to evaluate a classifier when the classes are
imbalanced. It can take values between 0.0 and 1.0, 1.0 being the best possible score.
It summarizes a Precision-Recall curve as the weighted mean of precisions achieved at each
threshold, with the increase in recall from the previous threshold used as the weight:
.. math::
\\text{AP} = \\sum_n (R_n - R_{n-1}) P_n
:label: ap
where :math:`P_n` and :math:`R_n` are the precision and recall at the :math:`n^{th}` threshold.
Referring to: `sklearn.metrics.average_precision_score
<https://scikit-learn.org/stable/modules/generated/sklearn.metrics.average_precision_score>`_.
The input `y_pred` and `y` can be a list of `channel-first` Tensor or a `batch-first` Tensor.
Example of the typical execution steps of this metric class follows :py:class:`monai.metrics.metric.Cumulative`.
Args:
average: {``"macro"``, ``"weighted"``, ``"micro"``, ``"none"``}
Type of averaging performed if not binary classification.
Defaults to ``"macro"``.
- ``"macro"``: calculate metrics for each label, and find their unweighted mean.
This does not take label imbalance into account.
- ``"weighted"``: calculate metrics for each label, and find their average,
weighted by support (the number of true instances for each label).
- ``"micro"``: calculate metrics globally by considering each element of the label
indicator matrix as a label.
- ``"none"``: the scores for each class are returned.
"""

def __init__(self, average: Average | str = Average.MACRO) -> None:
super().__init__()
self.average = average

def _compute_tensor(self, y_pred: torch.Tensor, y: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: # type: ignore[override]
return y_pred, y

def aggregate(self, average: Average | str | None = None) -> np.ndarray | float | npt.ArrayLike:
"""
Typically `y_pred` and `y` are stored in the cumulative buffers at each iteration,
This function reads the buffers and computes the Average Precision.
Args:
average: {``"macro"``, ``"weighted"``, ``"micro"``, ``"none"``}
Type of averaging performed if not binary classification. Defaults to `self.average`.
"""
y_pred, y = self.get_buffer()
# compute final value and do metric reduction
if not isinstance(y_pred, torch.Tensor) or not isinstance(y, torch.Tensor):
raise ValueError("y_pred and y must be PyTorch Tensor.")

return compute_average_precision(y_pred=y_pred, y=y, average=average or self.average)


def _calculate(y_pred: torch.Tensor, y: torch.Tensor) -> float:
if not (y.ndimension() == y_pred.ndimension() == 1 and len(y) == len(y_pred)):
raise AssertionError("y and y_pred must be 1 dimension data with same length.")
y_unique = y.unique()
if len(y_unique) == 1:
warnings.warn(f"y values can not be all {y_unique.item()}, skip AP computation and return `Nan`.")
return float("nan")
if not y_unique.equal(torch.tensor([0, 1], dtype=y.dtype, device=y.device)):
warnings.warn(f"y values must be 0 or 1, but in {y_unique.tolist()}, skip AP computation and return `Nan`.")
return float("nan")

n = len(y)
indices = y_pred.argsort(descending=True)
y = y[indices].cpu().numpy() # type: ignore[assignment]
y_pred = y_pred[indices].cpu().numpy() # type: ignore[assignment]
npos = ap = tmp_pos = 0.0

for i in range(n):
y_i = cast(float, y[i])
if i + 1 < n and y_pred[i] == y_pred[i + 1]:
tmp_pos += y_i
else:
tmp_pos += y_i
npos += tmp_pos
ap += tmp_pos * npos / (i + 1)
tmp_pos = 0

return ap / npos


def compute_average_precision(
y_pred: torch.Tensor, y: torch.Tensor, average: Average | str = Average.MACRO
) -> np.ndarray | float | npt.ArrayLike:
"""Computes Average Precision (AP). AP is a useful metric to evaluate a classifier when the classes are
imbalanced. It summarizes a Precision-Recall according to equation :eq:`ap`.
Referring to: `sklearn.metrics.average_precision_score
<https://scikit-learn.org/stable/modules/generated/sklearn.metrics.average_precision_score>`_.
Args:
y_pred: input data to compute, typical classification model output.
the first dim must be batch, if multi-classes, it must be in One-Hot format.
for example: shape `[16]` or `[16, 1]` for a binary data, shape `[16, 2]` for 2 classes data.
y: ground truth to compute AP metric, the first dim must be batch.
if multi-classes, it must be in One-Hot format.
for example: shape `[16]` or `[16, 1]` for a binary data, shape `[16, 2]` for 2 classes data.
average: {``"macro"``, ``"weighted"``, ``"micro"``, ``"none"``}
Type of averaging performed if not binary classification.
Defaults to ``"macro"``.
- ``"macro"``: calculate metrics for each label, and find their unweighted mean.
This does not take label imbalance into account.
- ``"weighted"``: calculate metrics for each label, and find their average,
weighted by support (the number of true instances for each label).
- ``"micro"``: calculate metrics globally by considering each element of the label
indicator matrix as a label.
- ``"none"``: the scores for each class are returned.
Raises:
ValueError: When ``y_pred`` dimension is not one of [1, 2].
ValueError: When ``y`` dimension is not one of [1, 2].
ValueError: When ``average`` is not one of ["macro", "weighted", "micro", "none"].
Note:
Average Precision expects y to be comprised of 0's and 1's. `y_pred` must be either prob. estimates or confidence values.
"""
y_pred_ndim = y_pred.ndimension()
y_ndim = y.ndimension()
if y_pred_ndim not in (1, 2):
raise ValueError(
f"Predictions should be of shape (batch_size, num_classes) or (batch_size, ), got {y_pred.shape}."
)
if y_ndim not in (1, 2):
raise ValueError(f"Targets should be of shape (batch_size, num_classes) or (batch_size, ), got {y.shape}.")
if y_pred_ndim == 2 and y_pred.shape[1] == 1:
y_pred = y_pred.squeeze(dim=-1)
y_pred_ndim = 1
if y_ndim == 2 and y.shape[1] == 1:
y = y.squeeze(dim=-1)

if y_pred_ndim == 1:
return _calculate(y_pred, y)

if y.shape != y_pred.shape:
raise ValueError(f"data shapes of y_pred and y do not match, got {y_pred.shape} and {y.shape}.")

average = look_up_option(average, Average)
if average == Average.MICRO:
return _calculate(y_pred.flatten(), y.flatten())
y, y_pred = y.transpose(0, 1), y_pred.transpose(0, 1)
ap_values = [_calculate(y_pred_, y_) for y_pred_, y_ in zip(y_pred, y)]
if average == Average.NONE:
return ap_values
if average == Average.MACRO:
return np.mean(ap_values)
if average == Average.WEIGHTED:
weights = [sum(y_) for y_ in y]
return np.average(ap_values, weights=weights) # type: ignore[no-any-return]
raise ValueError(f'Unsupported average: {average}, available options are ["macro", "weighted", "micro", "none"].')
14 changes: 2 additions & 12 deletions monai/transforms/utility/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,6 @@
optional_import,
)
from monai.utils.enums import TransformBackends
from monai.utils.misc import is_module_ver_at_least
from monai.utils.type_conversion import convert_to_dst_type, get_dtype_string, get_equivalent_dtype

PILImageImage, has_pil = optional_import("PIL.Image", name="Image")
Expand Down Expand Up @@ -939,19 +938,10 @@ def __call__(
data = img[[*select_labels]]
else:
where: Callable = np.where if isinstance(img, np.ndarray) else torch.where # type: ignore
if isinstance(img, np.ndarray) or is_module_ver_at_least(torch, (1, 8, 0)):
data = where(in1d(img, select_labels), True, False).reshape(img.shape)
# pre pytorch 1.8.0, need to use 1/0 instead of True/False
else:
data = where(
in1d(img, select_labels), torch.tensor(1, device=img.device), torch.tensor(0, device=img.device)
).reshape(img.shape)
data = where(in1d(img, select_labels), True, False).reshape(img.shape)

if merge_channels or self.merge_channels:
if isinstance(img, np.ndarray) or is_module_ver_at_least(torch, (1, 8, 0)):
return data.any(0)[None]
# pre pytorch 1.8.0 compatibility
return data.to(torch.uint8).any(0)[None].to(bool) # type: ignore
return data.any(0)[None]

return data

Expand Down
6 changes: 2 additions & 4 deletions monai/transforms/utils_pytorch_numpy_unification.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
import torch

from monai.config.type_definitions import NdarrayOrTensor, NdarrayTensor
from monai.utils.misc import is_module_ver_at_least
from monai.utils.type_conversion import convert_data_type, convert_to_dst_type

__all__ = [
Expand Down Expand Up @@ -215,10 +214,9 @@ def floor_divide(a: NdarrayOrTensor, b) -> NdarrayOrTensor:
Element-wise floor division between two arrays/tensors.
"""
if isinstance(a, torch.Tensor):
if is_module_ver_at_least(torch, (1, 8, 0)):
return torch.div(a, b, rounding_mode="floor")
return torch.floor_divide(a, b)
return np.floor_divide(a, b)
else:
return np.floor_divide(a, b)


def unravel_index(idx, shape) -> NdarrayOrTensor:
Expand Down
Loading

0 comments on commit 4445bad

Please sign in to comment.