Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Extend automatic segmentation CLI to allow continuing annotation #858

Merged
merged 4 commits into from
Feb 11, 2025
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion examples/annotator_3d.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,6 @@ def main():

# The corresponding CLI call for em_3d_annotator:
# (replace with cache directory on your machine)
# $ micro_sam.annotator_3d -i /home/pape/.cache/micro_sam/sample_data/lucchi_pp.zip.unzip/Lucchi++/Test_In -k *.png -e /home/pape/.cache/micro_sam/embeddings/embeddings-lucchi.zarr
# $ micro_sam.annotator_3d -i /home/pape/.cache/micro_sam/sample_data/lucchi_pp.zip.unzip/Lucchi++/Test_In -k *.png -e /home/pape/.cache/micro_sam/embeddings/embeddings-lucchi.zarr # noqa
if __name__ == "__main__":
main()
43 changes: 43 additions & 0 deletions micro_sam/automatic_segmentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,8 @@ def automatic_instance_segmentation(
halo: Optional[Tuple[int, int]] = None,
verbose: bool = True,
return_embeddings: bool = False,
annotator: bool = False,
view: bool = False,
anwai98 marked this conversation as resolved.
Show resolved Hide resolved
**generate_kwargs
) -> np.ndarray:
"""Run automatic segmentation for the input image.
Expand All @@ -94,6 +96,8 @@ def automatic_instance_segmentation(
halo: Overlap of the tiles for tiled prediction.
verbose: Verbosity flag.
return_embeddings: Whether to return the precomputed image embeddings.
annotator: Whether to activate the annotator for continue annotation process.
view: Whether to visualize the segmentations corresponding to the original input image.
generate_kwargs: optional keyword arguments for the generate function of the AMG or AIS class.

Returns:
Expand Down Expand Up @@ -140,6 +144,17 @@ def automatic_instance_segmentation(
# if (raw) predictions provided, store them as it is w/o further post-processing.
instances = masks

if annotator: # Allow opening the automatic segmentation in the annotator for further annotation, if desired.
from micro_sam.sam_annotator import annotator_2d
annotator_2d(
image=image_data,
model_type=predictor.model_name,
embedding_path=embedding_path,
segmentation_result=instances, # Initializes the automatic segmentation to the annotator.
tile_shape=tile_shape,
halo=halo,
)
anwai98 marked this conversation as resolved.
Show resolved Hide resolved

else:
if (image_data.ndim != 3) and (image_data.ndim != 4 and image_data.shape[-1] != 3):
raise ValueError(f"The inputs does not match the shape expectation of 3d inputs: {image_data.shape}")
Expand All @@ -161,11 +176,30 @@ def automatic_instance_segmentation(
else:
instances = outputs

if annotator: # Allow opening the automatic segmentation in the annotator for further annotation, if desired.
from micro_sam.sam_annotator import annotator_3d
annotator_3d(
image=image_data,
model_type=predictor.model_name,
embedding_path=embedding_path,
segmentation_result=instances, # Initializes the automatic segmentation to the annotator.
tile_shape=tile_shape,
halp=halo,
)
anwai98 marked this conversation as resolved.
Show resolved Hide resolved

# Save the instance segmentation, if 'output_path' provided.
if output_path is not None:
output_path = Path(output_path).with_suffix(".tif")
imageio.imwrite(output_path, instances, compression="zlib")

# Whether to visualize the input image and corresponding segmentation.
if view:
import napari
v = napari.Viewer()
v.add_image(image_data, name="Image")
v.add_labels(instances.astype(int), name="Instance Segmentation")
napari.run()

if return_embeddings:
return instances, image_embeddings
else:
Expand Down Expand Up @@ -221,6 +255,13 @@ def main():
"--mode", type=str, default=None,
help="The choice of automatic segmentation with the Segment Anything models. Either 'amg' or 'ais'."
)
parser.add_argument(
"--view", action="store_true", help="Whether to view the segmentations in napari."
)
parser.add_argument(
"--annotator", action="store_true",
help="Whether to continue annotation after the automatic segmentation is generated."
)
parser.add_argument(
"-d", "--device", default=None,
help="The device to use for the predictor. Can be one of 'cuda', 'cpu' or 'mps' (only MAC)."
Expand Down Expand Up @@ -278,6 +319,8 @@ def _convert_argval(value):
ndim=args.ndim,
tile_shape=args.tile_shape,
halo=args.halo,
annotator=args.annotator,
view=args.view,
**generate_kwargs,
)

Expand Down