diff --git a/examples/annotator_3d.py b/examples/annotator_3d.py index 06ad93ee..4fecce9f 100644 --- a/examples/annotator_3d.py +++ b/examples/annotator_3d.py @@ -40,6 +40,6 @@ def main(): # The corresponding CLI call for em_3d_annotator: # (replace with cache directory on your machine) -# $ micro_sam.annotator_3d -i /home/pape/.cache/micro_sam/sample_data/lucchi_pp.zip.unzip/Lucchi++/Test_In -k *.png -e /home/pape/.cache/micro_sam/embeddings/embeddings-lucchi.zarr +# $ micro_sam.annotator_3d -i /home/pape/.cache/micro_sam/sample_data/lucchi_pp.zip.unzip/Lucchi++/Test_In -k *.png -e /home/pape/.cache/micro_sam/embeddings/embeddings-lucchi.zarr # noqa if __name__ == "__main__": main() diff --git a/examples/annotator_tracking.py b/examples/annotator_tracking.py index 003da5ce..3652b885 100644 --- a/examples/annotator_tracking.py +++ b/examples/annotator_tracking.py @@ -38,6 +38,6 @@ def main(): # The corresponding CLI call for track_ctc_data: # (replace with cache directory on your machine) -# $ micro_sam.annotator_tracking -i /home/pape/.cache/micro_sam/sample_data/DIC-C2DH-HeLa.zip.unzip/DIC-C2DH-HeLa/01 -k *.tif -e /home/pape/.cache/micro_sam/embeddings/embeddings-ctc.zarr +# $ micro_sam.annotator_tracking -i /home/pape/.cache/micro_sam/sample_data/DIC-C2DH-HeLa.zip.unzip/DIC-C2DH-HeLa/01 -k *.tif -e /home/pape/.cache/micro_sam/embeddings/embeddings-ctc.zarr # noqa if __name__ == "__main__": main() diff --git a/examples/image_series_annotator.py b/examples/image_series_annotator.py index abfe16d0..405e2625 100644 --- a/examples/image_series_annotator.py +++ b/examples/image_series_annotator.py @@ -38,6 +38,6 @@ def main(): # The corresponding CLI call for track_ctc_data: # (replace with cache directory on your machine) -# $ micro_sam.image_series_annotator -i /home/pape/.cache/micro_sam/sample_data/image-series.zip.unzip/series/ -e /home/pape/.cache/micro_sam/embeddings/series-embeddings/ -o segmentation_results +# $ micro_sam.image_series_annotator -i /home/pape/.cache/micro_sam/sample_data/image-series.zip.unzip/series/ -e /home/pape/.cache/micro_sam/embeddings/series-embeddings/ -o segmentation_results # noqa if __name__ == "__main__": main() diff --git a/micro_sam/automatic_segmentation.py b/micro_sam/automatic_segmentation.py index 72c1f4eb..9efcbed9 100644 --- a/micro_sam/automatic_segmentation.py +++ b/micro_sam/automatic_segmentation.py @@ -75,6 +75,7 @@ def automatic_instance_segmentation( halo: Optional[Tuple[int, int]] = None, verbose: bool = True, return_embeddings: bool = False, + annotate: bool = False, **generate_kwargs ) -> np.ndarray: """Run automatic segmentation for the input image. @@ -94,6 +95,7 @@ def automatic_instance_segmentation( halo: Overlap of the tiles for tiled prediction. verbose: Verbosity flag. return_embeddings: Whether to return the precomputed image embeddings. + annotate: Whether to activate the annotator for continue annotation process. generate_kwargs: optional keyword arguments for the generate function of the AMG or AIS class. Returns: @@ -161,6 +163,30 @@ def automatic_instance_segmentation( else: instances = outputs + # Allow opening the automatic segmentation in the annotator for further annotation, if desired. + if annotate: + from micro_sam.sam_annotator import annotator_2d, annotator_3d + annotator_function = annotator_2d if ndim == 2 else annotator_3d + + viewer = annotator_function( + image=image_data, + model_type=predictor.model_name, + embedding_path=embedding_path, + segmentation_result=instances, # Initializes the automatic segmentation to the annotator. + tile_shape=tile_shape, + halo=halo, + return_viewer=True, # Returns the viewer, which allows the user to store the updated segmentations. + ) + + # Start the GUI here + import napari + napari.run() + + # We extract the segmentation in "committed_objects" layer, where the user either: + # a) Performed interactive segmentation / corrections and committed them, OR + # b) Did not do anything and closed the annotator, i.e. keeps the segmentations as it is. + instances = viewer.layers["committed_objects"].data + # Save the instance segmentation, if 'output_path' provided. if output_path is not None: output_path = Path(output_path).with_suffix(".tif") @@ -221,6 +247,10 @@ def main(): "--mode", type=str, default=None, help="The choice of automatic segmentation with the Segment Anything models. Either 'amg' or 'ais'." ) + parser.add_argument( + "--annotate", action="store_true", + help="Whether to continue annotation after the automatic segmentation is generated." + ) parser.add_argument( "-d", "--device", default=None, help="The device to use for the predictor. Can be one of 'cuda', 'cpu' or 'mps' (only MAC)." @@ -278,6 +308,7 @@ def _convert_argval(value): ndim=args.ndim, tile_shape=args.tile_shape, halo=args.halo, + annotate=args.annotate, **generate_kwargs, )