Skip to content

Commit

Permalink
Removed outdated torch version checks from transform functions (#8359)
Browse files Browse the repository at this point in the history
Fixes #8348

### Description

Support for `torch` versions prior to `1.13` has been dropped, so those
`1.8` version checks are not required anymore. Furthermore, as reported
in the issue description, those checks led to unstable behaviour when
using certain transforms in data pipelines.

### Types of changes
<!--- Put an `x` in all the boxes that apply, and remove the not
applicable items -->
- [x] Non-breaking change (fix or new feature that would not break
existing functionality).
- [ ] Breaking change (fix or new feature that would cause existing
functionality to change).
- [ ] New tests added to cover the changes.
- [ ] Integration tests passed locally by running `./runtests.sh -f -u
--net --coverage`.
- [ ] Quick tests passed locally by running `./runtests.sh --quick
--unittests --disttests`.
- [ ] In-line docstrings updated.
- [ ] Documentation updated, tested `make html` command in the `docs/`
folder.

---------

Signed-off-by: Nicolas Kaenzig <[email protected]>
  • Loading branch information
nkaenzig authored Feb 19, 2025
1 parent 960c59b commit fb8c5bf
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 16 deletions.
14 changes: 2 additions & 12 deletions monai/transforms/utility/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,6 @@
optional_import,
)
from monai.utils.enums import TransformBackends
from monai.utils.misc import is_module_ver_at_least
from monai.utils.type_conversion import convert_to_dst_type, get_dtype_string, get_equivalent_dtype

PILImageImage, has_pil = optional_import("PIL.Image", name="Image")
Expand Down Expand Up @@ -939,19 +938,10 @@ def __call__(
data = img[[*select_labels]]
else:
where: Callable = np.where if isinstance(img, np.ndarray) else torch.where # type: ignore
if isinstance(img, np.ndarray) or is_module_ver_at_least(torch, (1, 8, 0)):
data = where(in1d(img, select_labels), True, False).reshape(img.shape)
# pre pytorch 1.8.0, need to use 1/0 instead of True/False
else:
data = where(
in1d(img, select_labels), torch.tensor(1, device=img.device), torch.tensor(0, device=img.device)
).reshape(img.shape)
data = where(in1d(img, select_labels), True, False).reshape(img.shape)

if merge_channels or self.merge_channels:
if isinstance(img, np.ndarray) or is_module_ver_at_least(torch, (1, 8, 0)):
return data.any(0)[None]
# pre pytorch 1.8.0 compatibility
return data.to(torch.uint8).any(0)[None].to(bool) # type: ignore
return data.any(0)[None]

return data

Expand Down
6 changes: 2 additions & 4 deletions monai/transforms/utils_pytorch_numpy_unification.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
import torch

from monai.config.type_definitions import NdarrayOrTensor, NdarrayTensor
from monai.utils.misc import is_module_ver_at_least
from monai.utils.type_conversion import convert_data_type, convert_to_dst_type

__all__ = [
Expand Down Expand Up @@ -215,10 +214,9 @@ def floor_divide(a: NdarrayOrTensor, b) -> NdarrayOrTensor:
Element-wise floor division between two arrays/tensors.
"""
if isinstance(a, torch.Tensor):
if is_module_ver_at_least(torch, (1, 8, 0)):
return torch.div(a, b, rounding_mode="floor")
return torch.floor_divide(a, b)
return np.floor_divide(a, b)
else:
return np.floor_divide(a, b)


def unravel_index(idx, shape) -> NdarrayOrTensor:
Expand Down

0 comments on commit fb8c5bf

Please sign in to comment.