Skip to content

Commit 37583cf

Browse files
authored
Merge pull request #55 from fractal-napari-plugins-collection/51_prediction_layer_fixes
Fix prediction/annotation layer issues
2 parents 1734442 + 6981f2b commit 37583cf

File tree

2 files changed

+101
-41
lines changed

2 files changed

+101
-41
lines changed

src/napari_feature_classifier/annotator_widget.py

+16-10
Original file line numberDiff line numberDiff line change
@@ -136,16 +136,7 @@ def __init__(
136136
for layer in self._viewer.layers:
137137
if type(layer) == napari.layers.Labels and layer.name == "Annotations":
138138
self._viewer.layers.remove(layer)
139-
self._annotations_layer = self._viewer.add_labels(
140-
self._last_selected_label_layer.data,
141-
scale=self._last_selected_label_layer.scale,
142-
name="Annotations",
143-
translate=self._last_selected_label_layer.translate,
144-
)
145-
self._annotations_layer.editable = False
146-
147-
# Set the label selection to a valid label layer => Running into proxy bug
148-
self._viewer.layers.selection.active = self._last_selected_label_layer
139+
self.add_annotations_layer()
149140

150141
# Class selection
151142
self.ClassSelection = ClassSelection # pylint: disable=C0103
@@ -212,11 +203,26 @@ def selection_changed(self, event):
212203
self._save_destination.enabled = False
213204
self._class_selector.enabled = False
214205

206+
def add_annotations_layer(self):
207+
self._annotations_layer = self._viewer.add_labels(
208+
self._last_selected_label_layer.data,
209+
scale=self._last_selected_label_layer.scale,
210+
name="Annotations",
211+
translate=self._last_selected_label_layer.translate,
212+
)
213+
self._annotations_layer.editable = False
214+
# Set the label selection to a valid label layer
215+
self._viewer.layers.selection.active = self._last_selected_label_layer
216+
215217
def toggle_label(self, labels_layer, event):
216218
"""
217219
Callback for when a label is clicked. It then updates the color of that
218220
label in the annotation layer.
219221
"""
222+
# If the annotations layer is missing, add it back
223+
if "Annotations" not in [x.name for x in self._viewer.layers]:
224+
self.add_annotations_layer()
225+
220226
# Need to translate & scale position that event.position returns by the
221227
# label_layer scale.
222228
# If scale is (1, 1, 1), nothing changes

src/napari_feature_classifier/classifier_widget.py

+85-31
Original file line numberDiff line numberDiff line change
@@ -223,6 +223,7 @@ def __init__(
223223
self._last_selected_label_layer = get_selected_or_valid_label_layer(
224224
viewer=self._viewer
225225
)
226+
226227
# Initialize the classifier
227228
if classifier:
228229
self._classifier = classifier
@@ -248,17 +249,10 @@ def __init__(
248249
self._viewer, get_class_selection(class_names=self.class_names)
249250
)
250251

251-
# Handle existing predictions layer
252252
for layer in self._viewer.layers:
253253
if type(layer) == napari.layers.Labels and layer.name == "Predictions":
254254
self._viewer.layers.remove(layer)
255-
self._prediction_layer = self._viewer.add_labels(
256-
self._last_selected_label_layer.data,
257-
scale=self._last_selected_label_layer.scale,
258-
name="Predictions",
259-
translate=self._last_selected_label_layer.translate,
260-
)
261-
self._prediction_layer.contour = 2
255+
self.add_prediction_layer()
262256

263257
# Set the label selection to a valid label layer => Running into proxy bug
264258
self._viewer.layers.selection.active = self._last_selected_label_layer
@@ -295,11 +289,6 @@ def __init__(
295289
self._export_button.clicked.connect(self.export_results)
296290
self._viewer.layers.selection.events.changed.connect(self.selection_changed)
297291
self._init_prediction_layer(self._last_selected_label_layer)
298-
# Whenever the label layer is clicked, hide the prediction layer
299-
# (e.g. new annotations are made)
300-
# self._last_selected_label_layer.mouse_drag_callbacks.append(
301-
# self.hide_prediction_layer
302-
# )
303292

304293
def run(self):
305294
"""
@@ -349,6 +338,15 @@ def add_features_to_classifier(self):
349338
dict_of_features[layer.name] = layer.features
350339
self._classifier.add_dict_of_features(dict_of_features)
351340

341+
def add_prediction_layer(self):
342+
self._prediction_layer = self._viewer.add_labels(
343+
self._last_selected_label_layer.data,
344+
scale=self._last_selected_label_layer.scale,
345+
name="Predictions",
346+
translate=self._last_selected_label_layer.translate,
347+
)
348+
self._prediction_layer.contour = 2
349+
352350
def make_predictions(self):
353351
"""
354352
Make predictions for all relevant label layers and add them to the
@@ -398,17 +396,72 @@ def selection_changed(self):
398396
viewer=self._viewer
399397
):
400398
self._last_selected_label_layer = self._viewer.layers.selection.active
401-
self._init_prediction_layer(self._viewer.layers.selection.active)
402-
# self._last_selected_label_layer.mouse_drag_callbacks.append(
403-
# self.hide_prediction_layer
404-
# )
399+
self._init_prediction_layer(
400+
self._viewer.layers.selection.active, ensure_layer_presence=False
401+
)
405402
self._update_export_destination(self._last_selected_label_layer)
406403

407-
def _init_prediction_layer(self, label_layer: napari.layers.Labels):
404+
def reorder_layers(self):
405+
"""Reorders layers if needed to ensure Annotation & Prediction layers
406+
are above the currently selected label layer.
407+
"""
408+
# Get the current order of layers
409+
all_layers = list(self._viewer.layers)
410+
411+
# Determine the indices of the layers if they exist
412+
indices_to_move = []
413+
414+
# Find the index of "Prediction" layer if it exists
415+
if "Predictions" in self._viewer.layers:
416+
indices_to_move.append(self._viewer.layers.index("Predictions"))
417+
418+
# Find the index of "Annotation" layer if it exists
419+
if "Annotations" in self._viewer.layers:
420+
indices_to_move.append(self._viewer.layers.index("Annotations"))
421+
422+
# Find the index of the reference_label_layer
423+
if self._last_selected_label_layer.name in self._viewer.layers:
424+
indices_to_move.append(
425+
self._viewer.layers.index(self._last_selected_label_layer.name)
426+
)
427+
428+
# Calculate the new order of layer indices
429+
remaining_indices = [
430+
i for i in range(len(all_layers)) if i not in indices_to_move
431+
]
432+
remaining_indices.reverse()
433+
new_order = indices_to_move + remaining_indices
434+
new_order.reverse()
435+
436+
# Reorder the layers using the move_multiple function
437+
self._viewer.layers.move_multiple(new_order)
438+
439+
def _init_prediction_layer(
440+
self, label_layer: napari.layers.Labels, ensure_layer_presence: bool = True
441+
):
408442
"""
409443
Initialize the prediction layer and reset its data (to fit the input
410-
label_layer) and its colormap
444+
label_layer) and its colormap.
445+
ensure_layer_presence creates the Predictions layer if it doesn't exist
446+
yet and triggers layer reordering.
411447
"""
448+
# Ensure that prediction layer exists
449+
if (
450+
"Predictions" not in [x.name for x in self._viewer.layers]
451+
and ensure_layer_presence
452+
):
453+
self.add_prediction_layer()
454+
if ensure_layer_presence:
455+
# Ensure correct layer order: This sometimes fails with weird
456+
# EmitLoopError & IndexError that should be ignored
457+
try:
458+
self.reorder_layers()
459+
except: # noqa
460+
pass
461+
462+
# Ensure that prediction layer is above the current label layer
463+
self._last_selected_label_layer
464+
412465
# Check if the predict column already exists in the layer.features
413466
if "prediction" not in label_layer.features:
414467
unique_labels = np.unique(label_layer.data)[1:]
@@ -448,12 +501,6 @@ def _init_prediction_layer(self, label_layer: napari.layers.Labels):
448501
cmap=get_colormap(),
449502
)
450503

451-
# def hide_prediction_layer(self, labels_layer, event):
452-
# """
453-
# Hide the prediction layer
454-
# """
455-
# self._prediction_layer.visible = False
456-
457504
def get_relevant_label_layers(self):
458505
relevant_label_layers = []
459506
required_columns = [self._label_column, self._roi_id_colum]
@@ -613,12 +660,19 @@ def load(self):
613660
with open(clf_path, "rb") as f: # pylint: disable=C0103
614661
clf = pickle.load(f)
615662

616-
self._run_container = ClassifierRunContainer(
617-
self._viewer,
618-
clf,
619-
classifier_save_path=clf_path,
620-
auto_save=True,
621-
)
663+
try:
664+
self._run_container = ClassifierRunContainer(
665+
self._viewer,
666+
clf,
667+
classifier_save_path=clf_path,
668+
auto_save=True,
669+
)
670+
except NotImplementedError:
671+
napari_info(
672+
"Create a label layer with a feature dataframe before loading "
673+
"the classifier"
674+
)
675+
return
622676
self.clear()
623677
self.append(self._run_container)
624678

0 commit comments

Comments
 (0)