Skip to content

Commit

Permalink
Merge branch 'dev' into fix-commonkeys-docstring
Browse files Browse the repository at this point in the history
  • Loading branch information
KumoLiu authored Feb 20, 2025
2 parents ac54813 + fb8c5bf commit 2376600
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 16 deletions.
14 changes: 2 additions & 12 deletions monai/transforms/utility/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,6 @@
optional_import,
)
from monai.utils.enums import TransformBackends
from monai.utils.misc import is_module_ver_at_least
from monai.utils.type_conversion import convert_to_dst_type, get_dtype_string, get_equivalent_dtype

PILImageImage, has_pil = optional_import("PIL.Image", name="Image")
Expand Down Expand Up @@ -939,19 +938,10 @@ def __call__(
data = img[[*select_labels]]
else:
where: Callable = np.where if isinstance(img, np.ndarray) else torch.where # type: ignore
if isinstance(img, np.ndarray) or is_module_ver_at_least(torch, (1, 8, 0)):
data = where(in1d(img, select_labels), True, False).reshape(img.shape)
# pre pytorch 1.8.0, need to use 1/0 instead of True/False
else:
data = where(
in1d(img, select_labels), torch.tensor(1, device=img.device), torch.tensor(0, device=img.device)
).reshape(img.shape)
data = where(in1d(img, select_labels), True, False).reshape(img.shape)

if merge_channels or self.merge_channels:
if isinstance(img, np.ndarray) or is_module_ver_at_least(torch, (1, 8, 0)):
return data.any(0)[None]
# pre pytorch 1.8.0 compatibility
return data.to(torch.uint8).any(0)[None].to(bool) # type: ignore
return data.any(0)[None]

return data

Expand Down
6 changes: 2 additions & 4 deletions monai/transforms/utils_pytorch_numpy_unification.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
import torch

from monai.config.type_definitions import NdarrayOrTensor, NdarrayTensor
from monai.utils.misc import is_module_ver_at_least
from monai.utils.type_conversion import convert_data_type, convert_to_dst_type

__all__ = [
Expand Down Expand Up @@ -215,10 +214,9 @@ def floor_divide(a: NdarrayOrTensor, b) -> NdarrayOrTensor:
Element-wise floor division between two arrays/tensors.
"""
if isinstance(a, torch.Tensor):
if is_module_ver_at_least(torch, (1, 8, 0)):
return torch.div(a, b, rounding_mode="floor")
return torch.floor_divide(a, b)
return np.floor_divide(a, b)
else:
return np.floor_divide(a, b)


def unravel_index(idx, shape) -> NdarrayOrTensor:
Expand Down

0 comments on commit 2376600

Please sign in to comment.