Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor and Simplify Tile transform #4012

Closed
bhashemian opened this issue Mar 28, 2022 · 0 comments · Fixed by #4321
Closed

Refactor and Simplify Tile transform #4012

bhashemian opened this issue Mar 28, 2022 · 0 comments · Fixed by #4321
Assignees
Labels
refactor Non-breaking feature enhancements

Comments

@bhashemian
Copy link
Member

bhashemian commented Mar 28, 2022

Refactor, simplify and possibly generalize TileOnGrid transform in pathology, for blending it into core MONAI as laid out in #4005.

class TileOnGrid(Randomizable, Transform):
"""
Tile the 2D image into patches on a grid and maintain a subset of it.
This transform works only with np.ndarray inputs for 2D images.
Args:
tile_count: number of tiles to extract, if None extracts all non-background tiles
Defaults to ``None``.
tile_size: size of the square tile
Defaults to ``256``.
step: step size
Defaults to ``None`` (same as tile_size)
random_offset: Randomize position of the grid, instead of starting from the top-left corner
Defaults to ``False``.
pad_full: pad image to the size evenly divisible by tile_size
Defaults to ``False``.
background_val: the background constant (e.g. 255 for white background)
Defaults to ``255``.
filter_mode: mode must be in ["min", "max", "random"]. If total number of tiles is more than tile_size,
then sort by intensity sum, and take the smallest (for min), largest (for max) or random (for random) subset
Defaults to ``min`` (which assumes background is high value)
"""
backend = [TransformBackends.NUMPY]
def __init__(
self,
tile_count: Optional[int] = None,
tile_size: int = 256,
step: Optional[int] = None,
random_offset: bool = False,
pad_full: bool = False,
background_val: int = 255,
filter_mode: str = "min",
):
self.tile_count = tile_count
self.tile_size = tile_size
self.random_offset = random_offset
self.pad_full = pad_full
self.background_val = background_val
self.filter_mode = filter_mode
if step is None:
# non-overlapping grid
self.step = self.tile_size
else:
self.step = step
self.offset = (0, 0)
self.random_idxs = np.array((0,))
if self.filter_mode not in ["min", "max", "random"]:
raise ValueError("Unsupported filter_mode, must be [min, max or random]: " + str(self.filter_mode))
def randomize(self, img_size: Sequence[int]) -> None:
c, h, w = img_size
self.offset = (0, 0)
if self.random_offset:
pad_h = h % self.tile_size
pad_w = w % self.tile_size
self.offset = (self.R.randint(pad_h) if pad_h > 0 else 0, self.R.randint(pad_w) if pad_w > 0 else 0)
h = h - self.offset[0]
w = w - self.offset[1]
if self.pad_full:
pad_h = (self.tile_size - h % self.tile_size) % self.tile_size
pad_w = (self.tile_size - w % self.tile_size) % self.tile_size
h = h + pad_h
w = w + pad_w
h_n = (h - self.tile_size + self.step) // self.step
w_n = (w - self.tile_size + self.step) // self.step
tile_total = h_n * w_n
if self.tile_count is not None and tile_total > self.tile_count:
self.random_idxs = self.R.choice(range(tile_total), self.tile_count, replace=False)
else:
self.random_idxs = np.array((0,))
def __call__(self, image: NdarrayOrTensor) -> NdarrayOrTensor:
img_np, *_ = convert_data_type(image, np.ndarray)
# add random offset
self.randomize(img_size=img_np.shape)
if self.random_offset and (self.offset[0] > 0 or self.offset[1] > 0):
img_np = img_np[:, self.offset[0] :, self.offset[1] :]
# pad to full size, divisible by tile_size
if self.pad_full:
c, h, w = img_np.shape
pad_h = (self.tile_size - h % self.tile_size) % self.tile_size
pad_w = (self.tile_size - w % self.tile_size) % self.tile_size
img_np = np.pad( # type: ignore
img_np,
[[0, 0], [pad_h // 2, pad_h - pad_h // 2], [pad_w // 2, pad_w - pad_w // 2]],
constant_values=self.background_val,
)
# extact tiles
x_step, y_step = self.step, self.step
h_tile, w_tile = self.tile_size, self.tile_size
c_image, h_image, w_image = img_np.shape
c_stride, x_stride, y_stride = img_np.strides
llw = as_strided(
img_np,
shape=((h_image - h_tile) // x_step + 1, (w_image - w_tile) // y_step + 1, c_image, h_tile, w_tile),
strides=(x_stride * x_step, y_stride * y_step, c_stride, x_stride, y_stride),
writeable=False,
)
img_np = llw.reshape(-1, c_image, h_tile, w_tile) # type: ignore
# if keeping all patches
if self.tile_count is None:
# retain only patches with significant foreground content to speed up inference
# FYI, this returns a variable number of tiles, so the batch_size must be 1 (per gpu), e.g during inference
thresh = 0.999 * 3 * self.background_val * self.tile_size * self.tile_size
if self.filter_mode == "min":
# default, keep non-background tiles (small values)
idxs = np.argwhere(img_np.sum(axis=(1, 2, 3)) < thresh)
img_np = img_np[idxs.reshape(-1)]
elif self.filter_mode == "max":
idxs = np.argwhere(img_np.sum(axis=(1, 2, 3)) >= thresh)
img_np = img_np[idxs.reshape(-1)]
else:
if len(img_np) > self.tile_count:
if self.filter_mode == "min":
# default, keep non-background tiles (smallest values)
idxs = np.argsort(img_np.sum(axis=(1, 2, 3)))[: self.tile_count]
img_np = img_np[idxs]
elif self.filter_mode == "max":
idxs = np.argsort(img_np.sum(axis=(1, 2, 3)))[-self.tile_count :]
img_np = img_np[idxs]
else:
# random subset (more appropriate for WSIs without distinct background)
if self.random_idxs is not None:
img_np = img_np[self.random_idxs]
elif len(img_np) < self.tile_count:
img_np = np.pad( # type: ignore
img_np,
[[0, self.tile_count - len(img_np)], [0, 0], [0, 0], [0, 0]],
constant_values=self.background_val,
)
image, *_ = convert_to_dst_type(src=img_np, dst=image, dtype=image.dtype)
return image

@bhashemian bhashemian added refactor Non-breaking feature enhancements WG: Pathology labels Mar 28, 2022
@bhashemian bhashemian changed the title Generalized Tile transform Refactor and Simplify Tile transform Mar 28, 2022
@bhashemian bhashemian self-assigned this May 16, 2022
@bhashemian bhashemian moved this from Todo to In Progress in AI in Pathology🔬 May 19, 2022
@bhashemian bhashemian moved this from In Progress to Under Review in AI in Pathology🔬 May 27, 2022
Repository owner moved this from Under Review to Done in AI in Pathology🔬 May 31, 2022
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
refactor Non-breaking feature enhancements
Projects
Status: 💯 Complete
Development

Successfully merging a pull request may close this issue.

1 participant