Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor MaskedInferenceWSIDataset #4014

Closed
bhashemian opened this issue Mar 28, 2022 · 0 comments · Fixed by #4410
Closed

Refactor MaskedInferenceWSIDataset #4014

bhashemian opened this issue Mar 28, 2022 · 0 comments · Fixed by #4410
Assignees
Labels
enhancement New feature or request refactor Non-breaking feature enhancements

Comments

@bhashemian
Copy link
Member

Refactor and enhance MaskedInferenceWSIDataset in pathology, for blending it into core MONAI as laid out in #4005.

class MaskedInferenceWSIDataset(Dataset):
"""
This dataset load the provided foreground masks at an arbitrary resolution level,
and extract patches based on that mask from the associated whole slide image.
Args:
data: a list of sample including the path to the whole slide image and the path to the mask.
Like this: `[{"image": "path/to/image1.tiff", "mask": "path/to/mask1.npy}, ...]"`.
patch_size: the size of patches to be extracted from the whole slide image for inference.
transform: transforms to be executed on extracted patches.
image_reader_name: the name of library to be used for loading whole slide imaging, either CuCIM or OpenSlide.
Defaults to CuCIM.
Note:
The resulting output (probability maps) after performing inference using this dataset is
supposed to be the same size as the foreground mask and not the original wsi image size.
"""
def __init__(
self,
data: List[Dict["str", "str"]],
patch_size: Union[int, Tuple[int, int]],
transform: Optional[Callable] = None,
image_reader_name: str = "cuCIM",
) -> None:
super().__init__(data, transform)
self.patch_size = ensure_tuple_rep(patch_size, 2)
# set up whole slide image reader
self.image_reader_name = image_reader_name.lower()
self.image_reader = WSIReader(image_reader_name)
# process data and create a list of dictionaries containing all required data and metadata
self.data = self._prepare_data(data)
# calculate cumulative number of patches for all the samples
self.num_patches_per_sample = [len(d["image_locations"]) for d in self.data]
self.num_patches = sum(self.num_patches_per_sample)
self.cum_num_patches = np.cumsum([0] + self.num_patches_per_sample[:-1])
def _prepare_data(self, input_data: List[Dict["str", "str"]]) -> List[Dict]:
prepared_data = []
for sample in input_data:
prepared_sample = self._prepare_a_sample(sample)
prepared_data.append(prepared_sample)
return prepared_data
def _prepare_a_sample(self, sample: Dict["str", "str"]) -> Dict:
"""
Preprocess input data to load WSIReader object and the foreground mask,
and define the locations where patches need to be extracted.
Args:
sample: one sample, a dictionary containing path to the whole slide image and the foreground mask.
For example: `{"image": "path/to/image1.tiff", "mask": "path/to/mask1.npy}`
Return:
A dictionary containing:
"name": the base name of the whole slide image,
"image": the WSIReader image object,
"mask_shape": the size of the foreground mask,
"mask_locations": the list of non-zero pixel locations (x, y) on the foreground mask,
"image_locations": the list of pixel locations (x, y) on the whole slide image where patches are extracted, and
"level": the resolution level of the mask with respect to the whole slide image.
}
"""
image = self.image_reader.read(sample["image"])
mask = np.load(sample["mask"])
try:
level, ratio = self._calculate_mask_level(image, mask)
except ValueError as err:
err.args = (sample["mask"],) + err.args
raise
# get all indices for non-zero pixels of the foreground mask
mask_locations = np.vstack(mask.nonzero()).T
# convert mask locations to image locations to extract patches
image_locations = (mask_locations + 0.5) * ratio - np.array(self.patch_size) // 2
return {
"name": os.path.splitext(os.path.basename(sample["image"]))[0],
"image": image,
"mask_shape": mask.shape,
"mask_locations": mask_locations.astype(int).tolist(),
"image_locations": image_locations.astype(int).tolist(),
"level": level,
}
def _calculate_mask_level(self, image: np.ndarray, mask: np.ndarray) -> Tuple[int, float]:
"""
Calculate level of the mask and its ratio with respect to the whole slide image
Args:
image: the original whole slide image
mask: a mask, that can be down-sampled at an arbitrary level.
Note that down-sampling ratio should be 2^N and equal in all dimension.
Return:
tuple: (level, ratio) where ratio is 2^level
"""
image_shape = image.shape
mask_shape = mask.shape
ratios = [image_shape[i] / mask_shape[i] for i in range(2)]
level = np.log2(ratios[0])
if ratios[0] != ratios[1]:
raise ValueError(
"Image/Mask ratio across dimensions does not match!"
f"ratio 0: {ratios[0]} ({image_shape[0]} / {mask_shape[0]}),"
f"ratio 1: {ratios[1]} ({image_shape[1]} / {mask_shape[1]}),"
)
if not level.is_integer():
raise ValueError(f"Mask is not at a regular level (ratio not power of 2), image / mask ratio: {ratios[0]}")
return int(level), ratios[0]
def _load_a_patch(self, index):
"""
Load sample given the index
Since index is sequential and the patches are coming in an stream from different images,
this method, first, finds the whole slide image and the patch that should be extracted,
then it loads the patch and provide it with its image name and the corresponding mask location.
"""
sample_num = np.argmax(self.cum_num_patches > index) - 1
sample = self.data[sample_num]
patch_num = index - self.cum_num_patches[sample_num]
location_on_image = sample["image_locations"][patch_num]
location_on_mask = sample["mask_locations"][patch_num]
image, _ = self.image_reader.get_data(img=sample["image"], location=location_on_image, size=self.patch_size)
processed_sample = {"image": image, "name": sample["name"], "mask_location": location_on_mask}
return processed_sample
def __len__(self):
return self.num_patches
def __getitem__(self, index):
patch = [self._load_a_patch(index)]
if self.transform:
patch = self.transform(patch)
return patch

@bhashemian bhashemian added enhancement New feature or request refactor Non-breaking feature enhancements WG: Pathology labels Mar 28, 2022
@bhashemian bhashemian self-assigned this May 16, 2022
@bhashemian bhashemian moved this from Todo to In Progress in AI in Pathology🔬 May 26, 2022
This was referenced May 28, 2022
Repository owner moved this from In Progress to Done in AI in Pathology🔬 Jun 3, 2022
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request refactor Non-breaking feature enhancements
Projects
Status: 💯 Complete
Development

Successfully merging a pull request may close this issue.

1 participant