Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix prediction/annotation layer issues #55

Merged
merged 4 commits into from
Sep 16, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 16 additions & 10 deletions src/napari_feature_classifier/annotator_widget.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,16 +136,7 @@ def __init__(
for layer in self._viewer.layers:
if type(layer) == napari.layers.Labels and layer.name == "Annotations":
self._viewer.layers.remove(layer)
self._annotations_layer = self._viewer.add_labels(
self._last_selected_label_layer.data,
scale=self._last_selected_label_layer.scale,
name="Annotations",
translate=self._last_selected_label_layer.translate,
)
self._annotations_layer.editable = False

# Set the label selection to a valid label layer => Running into proxy bug
self._viewer.layers.selection.active = self._last_selected_label_layer
self.add_annotations_layer()

# Class selection
self.ClassSelection = ClassSelection # pylint: disable=C0103
Expand Down Expand Up @@ -212,11 +203,26 @@ def selection_changed(self, event):
self._save_destination.enabled = False
self._class_selector.enabled = False

def add_annotations_layer(self):
self._annotations_layer = self._viewer.add_labels(
self._last_selected_label_layer.data,
scale=self._last_selected_label_layer.scale,
name="Annotations",
translate=self._last_selected_label_layer.translate,
)
self._annotations_layer.editable = False
# Set the label selection to a valid label layer
self._viewer.layers.selection.active = self._last_selected_label_layer

def toggle_label(self, labels_layer, event):
"""
Callback for when a label is clicked. It then updates the color of that
label in the annotation layer.
"""
# If the annotations layer is missing, add it back
if "Annotations" not in [x.name for x in self._viewer.layers]:
self.add_annotations_layer()

# Need to translate & scale position that event.position returns by the
# label_layer scale.
# If scale is (1, 1, 1), nothing changes
Expand Down
116 changes: 85 additions & 31 deletions src/napari_feature_classifier/classifier_widget.py
Original file line number Diff line number Diff line change
Expand Up @@ -223,6 +223,7 @@ def __init__(
self._last_selected_label_layer = get_selected_or_valid_label_layer(
viewer=self._viewer
)

# Initialize the classifier
if classifier:
self._classifier = classifier
Expand All @@ -248,17 +249,10 @@ def __init__(
self._viewer, get_class_selection(class_names=self.class_names)
)

# Handle existing predictions layer
for layer in self._viewer.layers:
if type(layer) == napari.layers.Labels and layer.name == "Predictions":
self._viewer.layers.remove(layer)
self._prediction_layer = self._viewer.add_labels(
self._last_selected_label_layer.data,
scale=self._last_selected_label_layer.scale,
name="Predictions",
translate=self._last_selected_label_layer.translate,
)
self._prediction_layer.contour = 2
self.add_prediction_layer()

# Set the label selection to a valid label layer => Running into proxy bug
self._viewer.layers.selection.active = self._last_selected_label_layer
Expand Down Expand Up @@ -295,11 +289,6 @@ def __init__(
self._export_button.clicked.connect(self.export_results)
self._viewer.layers.selection.events.changed.connect(self.selection_changed)
self._init_prediction_layer(self._last_selected_label_layer)
# Whenever the label layer is clicked, hide the prediction layer
# (e.g. new annotations are made)
# self._last_selected_label_layer.mouse_drag_callbacks.append(
# self.hide_prediction_layer
# )

def run(self):
"""
Expand Down Expand Up @@ -349,6 +338,15 @@ def add_features_to_classifier(self):
dict_of_features[layer.name] = layer.features
self._classifier.add_dict_of_features(dict_of_features)

def add_prediction_layer(self):
self._prediction_layer = self._viewer.add_labels(
self._last_selected_label_layer.data,
scale=self._last_selected_label_layer.scale,
name="Predictions",
translate=self._last_selected_label_layer.translate,
)
self._prediction_layer.contour = 2

def make_predictions(self):
"""
Make predictions for all relevant label layers and add them to the
Expand Down Expand Up @@ -398,17 +396,72 @@ def selection_changed(self):
viewer=self._viewer
):
self._last_selected_label_layer = self._viewer.layers.selection.active
self._init_prediction_layer(self._viewer.layers.selection.active)
# self._last_selected_label_layer.mouse_drag_callbacks.append(
# self.hide_prediction_layer
# )
self._init_prediction_layer(
self._viewer.layers.selection.active, ensure_layer_presence=False
)
self._update_export_destination(self._last_selected_label_layer)

def _init_prediction_layer(self, label_layer: napari.layers.Labels):
def reorder_layers(self):
"""Reorders layers if needed to ensure Annotation & Prediction layers
are above the currently selected label layer.
"""
# Get the current order of layers
all_layers = list(self._viewer.layers)

# Determine the indices of the layers if they exist
indices_to_move = []

# Find the index of "Prediction" layer if it exists
if "Predictions" in self._viewer.layers:
indices_to_move.append(self._viewer.layers.index("Predictions"))

# Find the index of "Annotation" layer if it exists
if "Annotations" in self._viewer.layers:
indices_to_move.append(self._viewer.layers.index("Annotations"))

# Find the index of the reference_label_layer
if self._last_selected_label_layer.name in self._viewer.layers:
indices_to_move.append(
self._viewer.layers.index(self._last_selected_label_layer.name)
)

# Calculate the new order of layer indices
remaining_indices = [
i for i in range(len(all_layers)) if i not in indices_to_move
]
remaining_indices.reverse()
new_order = indices_to_move + remaining_indices
new_order.reverse()

# Reorder the layers using the move_multiple function
self._viewer.layers.move_multiple(new_order)

def _init_prediction_layer(
self, label_layer: napari.layers.Labels, ensure_layer_presence: bool = True
):
"""
Initialize the prediction layer and reset its data (to fit the input
label_layer) and its colormap
label_layer) and its colormap.
ensure_layer_presence creates the Predictions layer if it doesn't exist
yet and triggers layer reordering.
"""
# Ensure that prediction layer exists
if (
"Predictions" not in [x.name for x in self._viewer.layers]
and ensure_layer_presence
):
self.add_prediction_layer()
if ensure_layer_presence:
# Ensure correct layer order: This sometimes fails with weird
# EmitLoopError & IndexError that should be ignored
try:
self.reorder_layers()
except: # noqa
pass

# Ensure that prediction layer is above the current label layer
self._last_selected_label_layer

# Check if the predict column already exists in the layer.features
if "prediction" not in label_layer.features:
unique_labels = np.unique(label_layer.data)[1:]
Expand Down Expand Up @@ -448,12 +501,6 @@ def _init_prediction_layer(self, label_layer: napari.layers.Labels):
cmap=get_colormap(),
)

# def hide_prediction_layer(self, labels_layer, event):
# """
# Hide the prediction layer
# """
# self._prediction_layer.visible = False

def get_relevant_label_layers(self):
relevant_label_layers = []
required_columns = [self._label_column, self._roi_id_colum]
Expand Down Expand Up @@ -613,12 +660,19 @@ def load(self):
with open(clf_path, "rb") as f: # pylint: disable=C0103
clf = pickle.load(f)

self._run_container = ClassifierRunContainer(
self._viewer,
clf,
classifier_save_path=clf_path,
auto_save=True,
)
try:
self._run_container = ClassifierRunContainer(
self._viewer,
clf,
classifier_save_path=clf_path,
auto_save=True,
)
except NotImplementedError:
napari_info(
"Create a label layer with a feature dataframe before loading "
"the classifier"
)
return
self.clear()
self.append(self._run_container)

Expand Down
Loading