Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

GridPatch to Extract Tiles/Patches #4321

Merged
merged 72 commits into from
May 31, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
72 commits
Select commit Hold shift + click to select a range
c96dd21
Implement GridPatch and RandGridPatch
bhashemian May 22, 2022
3df8dbb
Add unittests for GridPatch
bhashemian May 22, 2022
ec2e6cd
Update pad mode
bhashemian May 23, 2022
8cae266
Add several test cases
bhashemian May 23, 2022
232fd6e
Implement GridPatchd
bhashemian May 23, 2022
d66377a
Update module imports
bhashemian May 23, 2022
cf6c72a
Add docs
bhashemian May 23, 2022
a01faa7
Deprecate SplitOnGrid and TileOnGrid
bhashemian May 23, 2022
f5bf4bf
Merge branch 'dev' into grid-patch
bhashemian May 23, 2022
c2afba0
Fix formatting
bhashemian May 23, 2022
a8e68cc
Update WSIReader value error message
bhashemian May 23, 2022
13fbd8d
Implement RandGridPatchd
bhashemian May 23, 2022
4a1b8da
Update init
bhashemian May 23, 2022
0b6fe58
Add unittests for GridPatchd
bhashemian May 23, 2022
4105006
Add unittests for RandGridPatch
bhashemian May 23, 2022
b7e2259
Add unittests for RandGridPatchd
bhashemian May 23, 2022
e7d8d25
Merge branch 'dev' into grid-patch
bhashemian May 23, 2022
173f616
Gen to list
bhashemian May 23, 2022
3600bf0
Fix List
bhashemian May 23, 2022
55cf543
Convert lambda to method
bhashemian May 23, 2022
65a3a15
change array to patch
bhashemian May 23, 2022
f948afd
Remove first patch_size dim
bhashemian May 23, 2022
60deeea
Remove first patch size dim and return gen
bhashemian May 23, 2022
69c6e79
Separate overlap for each dimension
bhashemian May 24, 2022
2b506a5
Remove trailing comma
bhashemian May 24, 2022
35c747e
Add num_patches
bhashemian May 24, 2022
770fa11
Remove trailing comma
bhashemian May 24, 2022
fe04994
Add seed and update docstring
bhashemian May 26, 2022
b1e942c
Merge branch 'dev' into grid-patch
bhashemian May 26, 2022
cfc7b1e
Merge branch 'dev' into grid-patch
bhashemian May 26, 2022
148243f
Add required files form #4239
bhashemian May 26, 2022
77f4157
Merge branch 'grid-patch' of github.com:behxyz/MONAI into grid-patch
bhashemian May 26, 2022
77930a8
Merge branch 'dev' into grid-patch
bhashemian May 26, 2022
d984eb3
Update docs
bhashemian May 26, 2022
4d789cd
Merge branch 'dev' into grid-patch
bhashemian May 26, 2022
a5b0e0b
Merge branch 'grid-patch' of github.com:behxyz/MONAI into grid-patch
bhashemian May 26, 2022
62f542c
Update init
bhashemian May 26, 2022
7ea478c
Convert to tensor
bhashemian May 26, 2022
f789599
Update docs
bhashemian May 26, 2022
83809e3
Fix number of patches
bhashemian May 27, 2022
9b52c7b
Merge branch 'dev' of github.com:Project-MONAI/MONAI into grid-patch
bhashemian May 27, 2022
5eca9a6
Make RandGridPatchd to inherit GridPatchd
bhashemian May 27, 2022
4af3bbe
Change the location type
bhashemian May 27, 2022
c31c546
Update fix num patches and separate randgridpatchd
bhashemian May 27, 2022
76d509c
Update branch
bhashemian May 27, 2022
fd71c90
Minor fixes
bhashemian May 27, 2022
fa965a2
Update docs
bhashemian May 27, 2022
37ae211
Merge branch 'dev' into grid-patch
Nic-Ma May 27, 2022
a809445
Address review comments
bhashemian May 27, 2022
037ac0a
Few minor fixes
bhashemian May 28, 2022
c9a40f1
Merge branch 'dev' into grid-patch
bhashemian May 28, 2022
6888665
Add PytorchPadMode and change sort_key to sort_fn
bhashemian May 28, 2022
6dc4b46
Merge branch 'grid-patch' of github.com:behxyz/MONAI into grid-patch
bhashemian May 28, 2022
02a1104
Update pad_mode
bhashemian May 28, 2022
76338b4
Remove seed
bhashemian May 28, 2022
9cf0b0c
Make it thread safe
bhashemian May 28, 2022
6879290
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] May 28, 2022
3d0d961
Merge branch 'dev' into grid-patch
bhashemian May 28, 2022
c85138b
Merge branch 'dev' into grid-patch
bhashemian May 29, 2022
dfb15b3
Address reviews
bhashemian May 30, 2022
337ca76
Merge branch 'grid-patch' of github.com:behxyz/MONAI into grid-patch
bhashemian May 30, 2022
1ba9022
Merge branch 'dev' into grid-patch
bhashemian May 30, 2022
3d00f87
Update additional patches
bhashemian May 30, 2022
a4a531f
Merge branch 'grid-patch' of github.com:behxyz/MONAI into grid-patch
bhashemian May 30, 2022
46b2069
Add GridPatchSort
bhashemian May 30, 2022
cb0cc93
Merge branch 'dev' into grid-patch
bhashemian May 30, 2022
2b58ba8
Merge branch 'dev' into grid-patch
bhashemian May 30, 2022
3c70c65
Fix an issue
bhashemian May 31, 2022
4c1c543
Merge branch 'dev' into grid-patch
bhashemian May 31, 2022
5e8632e
Update randomize
bhashemian May 31, 2022
6969968
Merge branch 'grid-patch' of github.com:behxyz/MONAI into grid-patch
bhashemian May 31, 2022
b3cfe1c
Merge branch 'dev' into grid-patch
bhashemian May 31, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 24 additions & 0 deletions docs/source/transforms.rst
Original file line number Diff line number Diff line change
Expand Up @@ -737,6 +737,18 @@ Spatial
:members:
:special-members: __call__

`GridPatch`
"""""""""""
.. autoclass:: GridPatch
:members:
:special-members: __call__

`RandGridPatch`
"""""""""""""""
.. autoclass:: RandGridPatch
:members:
:special-members: __call__

`GridSplit`
"""""""""""
.. autoclass:: GridSplit
Expand Down Expand Up @@ -1513,6 +1525,18 @@ Spatial (Dict)
:members:
:special-members: __call__

`GridPatchd`
""""""""""""
.. autoclass:: GridPatchd
:members:
:special-members: __call__

`RandGridPatchd`
""""""""""""""""
.. autoclass:: RandGridPatchd
:members:
:special-members: __call__

`GridSplitd`
""""""""""""
.. autoclass:: GridSplitd
Expand Down
4 changes: 3 additions & 1 deletion monai/apps/pathology/transforms/spatial/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,13 @@

from monai.config.type_definitions import NdarrayOrTensor
from monai.transforms.transform import Randomizable, Transform
from monai.utils import convert_data_type, convert_to_dst_type
from monai.utils import convert_data_type, convert_to_dst_type, deprecated
from monai.utils.enums import TransformBackends

__all__ = ["SplitOnGrid", "TileOnGrid"]


@deprecated(since="0.8", msg_suffix="use `monai.transforms.GridSplit` instead.")
class SplitOnGrid(Transform):
"""
Split the image into patches based on the provided grid shape.
Expand Down Expand Up @@ -107,6 +108,7 @@ def get_params(self, image_size):
return patch_size, steps


@deprecated(since="0.8", msg_suffix="use `monai.transforms.GridPatch` or `monai.transforms.RandGridPatch` instead.")
class TileOnGrid(Randomizable, Transform):
"""
Tile the 2D image into patches on a grid and maintain a subset of it.
Expand Down
3 changes: 3 additions & 0 deletions monai/apps/pathology/transforms/spatial/dictionary.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,14 @@
from monai.config import KeysCollection
from monai.config.type_definitions import NdarrayOrTensor
from monai.transforms.transform import MapTransform, Randomizable
from monai.utils import deprecated

from .array import SplitOnGrid, TileOnGrid

__all__ = ["SplitOnGridd", "SplitOnGridD", "SplitOnGridDict", "TileOnGridd", "TileOnGridD", "TileOnGridDict"]


@deprecated(since="0.8", msg_suffix="use `monai.transforms.GridSplitd` instead.")
class SplitOnGridd(MapTransform):
"""
Split the image into patches based on the provided grid shape.
Expand Down Expand Up @@ -55,6 +57,7 @@ def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, N
return d


@deprecated(since="0.8", msg_suffix="use `monai.transforms.GridPatchd` or `monai.transforms.RandGridPatchd` instead.")
class TileOnGridd(Randomizable, MapTransform):
"""
Tile the 2D image into patches on a grid and maintain a subset of it.
Expand Down
10 changes: 5 additions & 5 deletions monai/data/wsi_datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,13 +206,13 @@ def __init__(
elif isinstance(offset_limits[0], tuple):
self.offset_limits = offset_limits
else:
ValueError(
raise ValueError(
"The offset limits should be either a tuple of integers or tuple of tuple of integers."
)
else:
ValueError("The offset limits should be a tuple.")
raise ValueError("The offset limits should be a tuple.")
else:
ValueError(
raise ValueError(
f'Invalid string for offset "{offset}". It should be either "random" as a string,'
"an integer, or a tuple of integers defining the offset."
)
Expand All @@ -238,15 +238,15 @@ def _evaluate_patch_coordinates(self, sample):
"""Define the location for each patch based on sliding-window approach"""
patch_size = self._get_size(sample)
level = self._get_level(sample)
start_pos = self._get_offset(sample)
offset = self._get_offset(sample)

wsi_obj = self._get_wsi_object(sample)
wsi_size = self.wsi_reader.get_size(wsi_obj, 0)
downsample = self.wsi_reader.get_downsample_ratio(wsi_obj, level)
patch_size_ = tuple(p * downsample for p in patch_size) # patch size at level 0
locations = list(
iter_patch_position(
image_size=wsi_size, patch_size=patch_size_, start_pos=start_pos, overlap=self.overlap, padded=False
image_size=wsi_size, patch_size=patch_size_, start_pos=offset, overlap=self.overlap, padded=False
)
)
sample["size"] = patch_size
Expand Down
2 changes: 1 addition & 1 deletion monai/data/wsi_reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -278,7 +278,7 @@ def __init__(self, backend="cucim", level: int = 0, **kwargs):
elif self.backend == "openslide":
self.reader = OpenSlideWSIReader(level=level, **kwargs)
else:
raise ValueError("The supported backends are: cucim")
raise ValueError(f"The supported backends are cucim and openslide, '{self.backend}' was given.")
self.supported_suffixes = self.reader.supported_suffixes

def get_level_count(self, wsi) -> int:
Expand Down
8 changes: 8 additions & 0 deletions monai/transforms/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -311,6 +311,7 @@
AffineGrid,
Flip,
GridDistortion,
GridPatch,
GridSplit,
Orientation,
Rand2DElastic,
Expand All @@ -321,6 +322,7 @@
RandDeformGrid,
RandFlip,
RandGridDistortion,
RandGridPatch,
RandRotate,
RandRotate90,
RandZoom,
Expand All @@ -343,6 +345,9 @@
GridDistortiond,
GridDistortionD,
GridDistortionDict,
GridPatchd,
GridPatchD,
GridPatchDict,
GridSplitd,
GridSplitD,
GridSplitDict,
Expand All @@ -367,6 +372,9 @@
RandGridDistortiond,
RandGridDistortionD,
RandGridDistortionDict,
RandGridPatchd,
RandGridPatchD,
RandGridPatchDict,
RandRotate90d,
RandRotate90D,
RandRotate90Dict,
Expand Down
176 changes: 172 additions & 4 deletions monai/transforms/spatial/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,13 +22,21 @@

from monai.config import USE_COMPILED, DtypeLike
from monai.config.type_definitions import NdarrayOrTensor
from monai.data.utils import AFFINE_TOL, compute_shape_offset, reorient_spatial_axes, to_affine_nd, zoom_affine
from monai.data.utils import (
AFFINE_TOL,
compute_shape_offset,
iter_patch,
reorient_spatial_axes,
to_affine_nd,
zoom_affine,
)
from monai.networks.layers import AffineTransform, GaussianFilter, grid_pull
from monai.networks.utils import meshgrid_ij, normalize_transform
from monai.transforms.croppad.array import CenterSpatialCrop, Pad
from monai.transforms.intensity.array import GaussianSmooth
from monai.transforms.transform import Randomizable, RandomizableTransform, ThreadUnsafe, Transform
from monai.transforms.utils import (
convert_pad_mode,
create_control_grid,
create_grid,
create_rotate,
Expand All @@ -44,6 +52,7 @@
InterpolateMode,
NumpyPadMode,
PytorchPadMode,
convert_to_dst_type,
ensure_tuple,
ensure_tuple_rep,
ensure_tuple_size,
Expand All @@ -53,10 +62,10 @@
pytorch_after,
)
from monai.utils.deprecate_utils import deprecated_arg
from monai.utils.enums import TransformBackends
from monai.utils.enums import GridPatchSort, TransformBackends
from monai.utils.misc import ImageMetaKey as Key
from monai.utils.module import look_up_option
from monai.utils.type_conversion import convert_data_type, convert_to_dst_type
from monai.utils.type_conversion import convert_data_type

nib, has_nib = optional_import("nibabel")

Expand All @@ -68,6 +77,8 @@
"Flip",
"GridDistortion",
"GridSplit",
"GridPatch",
"RandGridPatch",
"Resize",
"Rotate",
"Zoom",
Expand Down Expand Up @@ -2577,7 +2588,6 @@ def __call__(
image,
shape=(*self.grid, n_channels, split_size[0], split_size[1]),
strides=(x_stride * x_step, y_stride * y_step, c_stride, x_stride, y_stride),
writeable=False,
)
# Flatten the first two dimensions
strided_image = strided_image.reshape(-1, *strided_image.shape[2:])
Expand Down Expand Up @@ -2609,3 +2619,161 @@ def _get_params(
)

return size, steps


class GridPatch(Transform):
"""
Extract all the patches sweeping the entire image in a row-major sliding-window manner with possible overlaps.
It can sort the patches and return all or a subset of them.

Args:
patch_size: size of patches to generate slices for, 0 or None selects whole dimension
offset: offset of starting position in the array, default is 0 for each dimension.
num_patches: number of patches to return. Defaults to None, which returns all the available patches.
overlap: the amount of overlap of neighboring patches in each dimension (a value between 0.0 and 1.0).
If only one float number is given, it will be applied to all dimensions. Defaults to 0.0.
sort_fn: a callable or string that defines the order of the patches to be returned. If it is a callable, it
will be passed directly to the `key` argument of `sorted` function. The string can be "min" or "max",
which are, respectively, the minimum and maximum of the sum of intensities of a patch across all dimensions
and channels. Also "random" creates a random order of patches.
By default no sorting is being done and patches are returned in a row-major order.
pad_mode: refer to NumpyPadMode and PytorchPadMode. Defaults to ``"constant"``.
pad_kwargs: other arguments for the `np.pad` or `torch.pad` function.

"""

backend = [TransformBackends.TORCH, TransformBackends.NUMPY]

def __init__(
self,
patch_size: Sequence[int],
offset: Sequence[int] = (),
num_patches: Optional[int] = None,
overlap: Union[Sequence[float], float] = 0.0,
sort_fn: Optional[Union[Callable, str]] = None,
pad_mode: Union[NumpyPadMode, PytorchPadMode, str] = NumpyPadMode.CONSTANT,
**pad_kwargs,
):
self.patch_size = ensure_tuple(patch_size)
self.offset = ensure_tuple(offset)
self.pad_mode: NumpyPadMode = convert_pad_mode(dst=np.zeros(1), mode=pad_mode)
self.pad_kwargs = pad_kwargs
self.overlap = overlap
self.num_patches = num_patches
self.sort_fn: Optional[Callable]
if isinstance(sort_fn, str):
if sort_fn == GridPatchSort.RANDOM.value:
self.sort_fn = np.random.random
elif sort_fn == GridPatchSort.MIN.value:
self.sort_fn = self.get_patch_sum
elif sort_fn == GridPatchSort.MAX.value:
self.sort_fn = self.get_negative_patch_sum
else:
raise ValueError(
f'sort_fn should be one of the following values, "{sort_fn}" was given:',
[enum.value for enum in GridPatchSort],
)
else:
self.sort_fn = sort_fn

@staticmethod
def get_patch_sum(x):
return x[0].sum()

@staticmethod
def get_negative_patch_sum(x):
return -x[0].sum()

def __call__(self, array: NdarrayOrTensor):
# create the patch iterator which sweeps the image row-by-row
array_np, *_ = convert_data_type(array, np.ndarray)
patch_iterator = iter_patch(
array_np,
patch_size=(None,) + self.patch_size, # expand to have the channel dim
start_pos=(0,) + self.offset, # expand to have the channel dim
overlap=self.overlap,
copy_back=False,
mode=self.pad_mode,
**self.pad_kwargs,
)
if self.sort_fn is not None:
output = sorted(patch_iterator, key=self.sort_fn)
else:
output = list(patch_iterator)
if self.num_patches:
output = output[: self.num_patches]
if len(output) < self.num_patches:
patch = np.full((array.shape[0], *self.patch_size), self.pad_kwargs.get("constant_values", 0))
slices = np.zeros((3, len(self.patch_size)))
output += [(patch, slices)] * (self.num_patches - len(output))

output = [convert_to_dst_type(src=patch, dst=array)[0] for patch in output]

return output


class RandGridPatch(GridPatch, RandomizableTransform):
"""
Extract all the patches sweeping the entire image in a row-major sliding-window manner with possible overlaps,
and with random offset for the minimal corner of the image, (0,0) for 2D and (0,0,0) for 3D.
It can sort the patches and return all or a subset of them.

Args:
patch_size: size of patches to generate slices for, 0 or None selects whole dimension
min_offset: the minimum range of offset to be selected randomly. Defaults to 0.
max_offset: the maximum range of offset to be selected randomly.
Defaults to image size modulo patch size.
num_patches: number of patches to return. Defaults to None, which returns all the available patches.
overlap: the amount of overlap of neighboring patches in each dimension (a value between 0.0 and 1.0).
If only one float number is given, it will be applied to all dimensions. Defaults to 0.0.
sort_fn: a callable or string that defines the order of the patches to be returned. If it is a callable, it
will be passed directly to the `key` argument of `sorted` function. The string can be "min" or "max",
which are, respectively, the minimum and maximum of the sum of intensities of a patch across all dimensions
and channels. Also "random" creates a random order of patches.
By default no sorting is being done and patches are returned in a row-major order.
pad_mode: refer to NumpyPadMode and PytorchPadMode. Defaults to ``"constant"``.
pad_kwargs: other arguments for the `np.pad` or `torch.pad` function.

"""

backend = [TransformBackends.TORCH, TransformBackends.NUMPY]

def __init__(
self,
patch_size: Sequence[int],
min_offset: Optional[Union[Sequence[int], int]] = None,
max_offset: Optional[Union[Sequence[int], int]] = None,
num_patches: Optional[int] = None,
overlap: Union[Sequence[float], float] = 0.0,
sort_fn: Optional[Union[Callable, str]] = None,
pad_mode: Union[NumpyPadMode, PytorchPadMode, str] = NumpyPadMode.CONSTANT,
**pad_kwargs,
):
super().__init__(
patch_size=patch_size,
offset=(),
num_patches=num_patches,
overlap=overlap,
sort_fn=sort_fn,
pad_mode=pad_mode,
**pad_kwargs,
)
self.min_offset = min_offset
self.max_offset = max_offset

def randomize(self, array):
if self.min_offset is None:
min_offset = (0,) * len(self.patch_size)
else:
min_offset = ensure_tuple_rep(self.min_offset, len(self.patch_size))
if self.max_offset is None:
max_offset = tuple(s % p for s, p in zip(array.shape[1:], self.patch_size))
else:
max_offset = ensure_tuple_rep(self.max_offset, len(self.patch_size))

self.offset = tuple(self.R.randint(low=low, high=high + 1) for low, high in zip(min_offset, max_offset))

def __call__(self, array: NdarrayOrTensor, randomize: bool = True):
if randomize:
self.randomize(array)
return super().__call__(array)
Loading