Skip to content

Commit

Permalink
add ConvertBoxToPoints
Browse files Browse the repository at this point in the history
Signed-off-by: YunLiu <[email protected]>
  • Loading branch information
KumoLiu committed Aug 29, 2024
1 parent 29ce1a7 commit 7c8cfe4
Show file tree
Hide file tree
Showing 3 changed files with 103 additions and 1 deletion.
24 changes: 24 additions & 0 deletions monai/transforms/spatial/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@

from monai.config import USE_COMPILED, DtypeLike
from monai.config.type_definitions import NdarrayOrTensor
from monai.data.box_utils import BoxMode, StandardMode
from monai.data.meta_obj import get_track_meta, set_track_meta
from monai.data.meta_tensor import MetaTensor
from monai.data.utils import AFFINE_TOL, affine_to_spacing, compute_shape_offset, iter_patch, to_affine_nd, zoom_affine
Expand All @@ -34,6 +35,7 @@
from monai.transforms.inverse import InvertibleTransform
from monai.transforms.spatial.functional import (
affine_func,
convert_box_to_points,
flip,
orientation,
resize,
Expand Down Expand Up @@ -3544,3 +3546,25 @@ def __call__(self, img: torch.Tensor, randomize: bool = True) -> torch.Tensor:

else:
return img


class ConvertBoxToPoints(Transform):
"""
Convert boxes to points. It can automatically convert the boxes to the points based on the box mode.
The return points will be in the shape of (N, 4, 2) or (N, 8, 3) based on the box mode.
"""

backend = [TransformBackends.TORCH, TransformBackends.NUMPY]

def __init__(self, mode: str | BoxMode | type[BoxMode] | None = StandardMode) -> None:
"""
Args:
mode: the mode of the box, can be a string, a BoxMode instance or a BoxMode class. Defaults to StandardMode.
"""
super().__init__()
self.mode = mode

def __call__(self, data: Any):
data = convert_to_tensor(data, track_meta=get_track_meta())
points = convert_box_to_points(data, mode=self.mode)
return convert_to_dst_type(points, data)[0]
27 changes: 27 additions & 0 deletions monai/transforms/spatial/dictionary.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,13 +26,15 @@

from monai.config import DtypeLike, KeysCollection, SequenceStr
from monai.config.type_definitions import NdarrayOrTensor
from monai.data.box_utils import BoxMode, StandardMode
from monai.data.meta_obj import get_track_meta
from monai.data.meta_tensor import MetaTensor
from monai.networks.layers.simplelayers import GaussianFilter
from monai.transforms.croppad.array import CenterSpatialCrop
from monai.transforms.inverse import InvertibleTransform
from monai.transforms.spatial.array import (
Affine,
ConvertBoxToPoints,
Flip,
GridDistortion,
GridPatch,
Expand Down Expand Up @@ -2611,6 +2613,31 @@ def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> dict[Hashable, N
return d


class ConvertBoxToPointsd(MapTransform):
"""
Dictionary-based wrapper of :py:class:`monai.transforms.ConvertBoxToPoints`.
"""

backend = ConvertBoxToPoints.backend

def __init__(
self,
keys: KeysCollection,
point_key="points",
mode: str | BoxMode | type[BoxMode] | None = StandardMode,
allow_missing_keys: bool = False,
):
super().__init__(keys, allow_missing_keys)
self.point_key = point_key
self.converter = ConvertBoxToPoints(mode=mode)

def __call__(self, data):
d = dict(data)
for key in self.key_iterator(d):
data[self.point_key] = self.converter(d[key])
return data


SpatialResampleD = SpatialResampleDict = SpatialResampled
ResampleToMatchD = ResampleToMatchDict = ResampleToMatchd
SpacingD = SpacingDict = Spacingd
Expand Down
53 changes: 52 additions & 1 deletion monai/transforms/spatial/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
import monai
from monai.config import USE_COMPILED
from monai.config.type_definitions import NdarrayOrTensor
from monai.data.box_utils import get_boxmode
from monai.data.meta_obj import get_track_meta
from monai.data.meta_tensor import MetaTensor
from monai.data.utils import AFFINE_TOL, compute_shape_offset, to_affine_nd
Expand All @@ -32,7 +33,7 @@
from monai.transforms.intensity.array import GaussianSmooth
from monai.transforms.inverse import TraceableTransform
from monai.transforms.utils import create_rotate, create_translate, resolves_modes, scale_affine
from monai.transforms.utils_pytorch_numpy_unification import allclose
from monai.transforms.utils_pytorch_numpy_unification import allclose, concatenate, stack
from monai.utils import (
LazyAttr,
TraceKeys,
Expand Down Expand Up @@ -610,3 +611,53 @@ def affine_func(
out = _maybe_new_metatensor(img, dtype=torch.float32, device=resampler.device)
out = out.copy_meta_from(meta_info) if isinstance(out, MetaTensor) else out
return out if image_only else (out, affine)


def convert_box_to_points(bbox, mode):
"""
Convert bounding box to points.
Args:
mode: The mode specifying how to interpret the bounding box.
bbox: Bounding box in the form of [x1, y1, x2, y2] for 2D or [x1, y1, z1, x2, y2, z2] for 3D.
Return shape will be (N, 4) for 2D or (N, 6) for 3D.
Returns:
sequence of points representing the corners of the bounding box.
"""

mode = get_boxmode(mode)

points_list = []
for _num in range(bbox.shape[0]):
corners = mode.boxes_to_corners(bbox[_num : _num + 1])
if len(corners) == 4:
points_list.append(
concatenate(
[
concatenate([corners[0], corners[1]], axis=1),
concatenate([corners[2], corners[1]], axis=1),
concatenate([corners[2], corners[3]], axis=1),
concatenate([corners[0], corners[3]], axis=1),
],
axis=0,
)
)
else:
points_list.append(
concatenate(
[
concatenate([corners[0], corners[1], corners[2]], axis=1),
concatenate([corners[3], corners[1], corners[2]], axis=1),
concatenate([corners[3], corners[4], corners[2]], axis=1),
concatenate([corners[0], corners[4], corners[2]], axis=1),
concatenate([corners[0], corners[1], corners[5]], axis=1),
concatenate([corners[3], corners[1], corners[5]], axis=1),
concatenate([corners[3], corners[4], corners[5]], axis=1),
concatenate([corners[0], corners[4], corners[5]], axis=1),
],
axis=0,
)
)

return stack(points_list, dim=0)

0 comments on commit 7c8cfe4

Please sign in to comment.