Skip to content

Commit 69a1524

Browse files
authored
Merge pull request #46 from fractal-napari-plugins-collection/30_closing_issue
Second attempt to fix issue #30: Classifier widget restarts
2 parents 738060b + ab869bb commit 69a1524

File tree

5 files changed

+85
-32
lines changed

5 files changed

+85
-32
lines changed

setup.cfg

+2-2
Original file line numberDiff line numberDiff line change
@@ -36,13 +36,13 @@ package_dir =
3636

3737
# add your package requirements here
3838
install_requires =
39-
numpy
39+
numpy < 2.0
4040
napari < 0.4.19
4141
matplotlib
4242
magicgui
4343
pandas
4444
scikit-learn >= 1.2.2
45-
pandera
45+
pandera < 0.20.0
4646
xxhash
4747
hypothesis
4848

src/napari_feature_classifier/annotator_widget.py

+30-10
Original file line numberDiff line numberDiff line change
@@ -112,7 +112,8 @@ class LabelAnnotator(Container):
112112
Can also be controlled via the number keys.
113113
"""
114114

115-
# TODO: Do we need to keep the annotation layer on top when new annotations are made?
115+
# TODO: Do we need to keep the annotation layer on top when new
116+
# annotations are made?
116117
def __init__(
117118
self,
118119
viewer: napari.viewer.Viewer,
@@ -129,10 +130,15 @@ def __init__(
129130
label="Last selected label layer:", value=self._last_selected_label_layer
130131
)
131132

133+
# Handle existing predictions layer
134+
for layer in self._viewer.layers:
135+
if type(layer) == napari.layers.Labels and layer.name == "Annotations":
136+
self._viewer.layers.remove(layer)
132137
self._annotations_layer = self._viewer.add_labels(
133138
self._last_selected_label_layer.data,
134139
scale=self._last_selected_label_layer.scale,
135140
name="Annotations",
141+
translate=self._last_selected_label_layer.translate,
136142
)
137143
self._annotations_layer.editable = False
138144

@@ -209,22 +215,34 @@ def toggle_label(self, labels_layer, event):
209215
Callback for when a label is clicked. It then updates the color of that
210216
label in the annotation layer.
211217
"""
212-
# Need to scale position that event.position returns by the
218+
# Need to translate & scale position that event.position returns by the
213219
# label_layer scale.
214220
# If scale is (1, 1, 1), nothing changes
221+
# If translate is (0, 0, 0)
215222
# If scale is anything else, this makes the click still match the
216223
# correct label
224+
# translate before scale
217225
scaled_position = tuple(
218-
pos / scale for pos, scale in zip(event.position, labels_layer.scale)
226+
(pos - trans) / scale
227+
for pos, trans, scale in zip(
228+
event.position, labels_layer.translate, labels_layer.scale
229+
)
219230
)
220231
label = labels_layer.get_value(scaled_position)
221232
if label == 0 or not label:
222-
napari_info("No label clicked.")
233+
napari_info(f"No label clicked on the {labels_layer} label layer.")
223234
return
224235

225-
labels_layer.features.loc[
226-
labels_layer.features[self._label_column] == label, "annotations"
227-
] = self._class_selector.value.value
236+
# Left click: add annotation
237+
if event.button == 1:
238+
labels_layer.features.loc[
239+
labels_layer.features[self._label_column] == label, "annotations"
240+
] = self._class_selector.value.value
241+
# Right click: Remove annotation
242+
elif event.button == 2:
243+
labels_layer.features.loc[
244+
labels_layer.features[self._label_column] == label, "annotations"
245+
] = np.NaN
228246

229247
# Update only the single color value that changed
230248
self.update_single_color(labels_layer, label)
@@ -253,9 +271,11 @@ def _init_annotation(self, label_layer: napari.layers.Labels):
253271
[label_layer.features, annotation_df], axis=1
254272
)
255273

256-
label_layer.opacity = 0.4
274+
# label_layer.opacity = 0.4
257275
self._annotations_layer.data = label_layer.data
258276
self._annotations_layer.scale = label_layer.scale
277+
self._annotations_layer.translate = label_layer.translate
278+
259279
reset_display_colormaps(
260280
label_layer,
261281
feature_col="annotations",
@@ -265,7 +285,7 @@ def _init_annotation(self, label_layer: napari.layers.Labels):
265285
)
266286
label_layer.mouse_drag_callbacks.append(self.toggle_label)
267287

268-
# # keybindings for the available classes (0 = deselect)
288+
# keybindings for the available classes (0 = deselect)
269289
for i in range(len(self.ClassSelection)):
270290
set_class = partial(self.set_class_n, n=i)
271291
set_class.__name__ = f"set_class_{i}"
@@ -275,7 +295,7 @@ def _update_save_destination(self, label_layer: napari.layers.Labels):
275295
"""
276296
Update the default save destination to the name of the label layer.
277297
If a base_path was already set, keep it on that base path.
278-
298+
279299
"""
280300
base_path = Path(self._save_destination.value).parent
281301
self._save_destination.value = base_path / f"{label_layer.name}_annotation.csv"

src/napari_feature_classifier/classifier.py

+19-18
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
"""Core classifier class and helper functions."""
2+
import logging
23
import pickle
34
import random
45
import string
@@ -11,10 +12,9 @@
1112
from sklearn.metrics import f1_score
1213
from sklearn.ensemble import RandomForestClassifier
1314

14-
from napari_feature_classifier.utils import napari_info
1515

16-
17-
# TODO: define an interface for compatible classifiers (m.b. a subset of sklearn Estimators?)
16+
# TODO: define an interface for compatible classifiers (m.b. a subset of
17+
# sklearn Estimators?)
1818
class Classifier:
1919
"""Classifier class for napari-feature-classifier.
2020
@@ -23,7 +23,7 @@ class Classifier:
2323
feature_names: Sequence[str]
2424
The names of the features that are used for classification
2525
class_names: Sequence[str]
26-
The names of the classes. It's an ordered list that is matched to
26+
The names of the classes. It's an ordered list that is matched to
2727
annotations [1, 2, 3, ...]
2828
classifier: sklearn classifier
2929
The classifier that is used for classification. Default is a
@@ -42,7 +42,7 @@ class Classifier:
4242
The percentage of the data that is used for training. The rest is used
4343
for testing.
4444
_index_columns: list[str]
45-
The columns that are used for indexing the data.
45+
The columns that are used for indexing the data.
4646
Hard-coded to roi_id and label
4747
_input_schema: pandera.SchemaModel
4848
The schema for the input data. It's used for validation.
@@ -51,10 +51,13 @@ class Classifier:
5151
_predict_schema: pandera.SchemaModel
5252
The schema for the prediction data.
5353
_data: pd.DataFrame
54-
The internal data storage of the classifier. Contains both annotations
54+
The internal data storage of the classifier. Contains both annotations
5555
as well as feature measurements for all rows (annotated objects)
5656
"""
57+
5758
def __init__(self, feature_names, class_names, classifier=RandomForestClassifier()):
59+
self.logger = logging.getLogger("classifier")
60+
self.logger.setLevel(logging.INFO)
5861
self._feature_names: list[str] = list(feature_names)
5962
self._class_names: list[str] = list(class_names)
6063
self._classifier = classifier
@@ -79,13 +82,13 @@ def train(self):
7982
"""
8083
Train the classifier on the data it already has in self._data.
8184
"""
82-
napari_info("Training classifier...")
85+
self.logger.info("Training classifier...")
8386
train_data = self._data[self._data.hash < self._training_data_perc]
8487
test_data = self._data[self._data.hash >= self._training_data_perc]
8588

86-
# pylint: disable=C0103
89+
# pylint: disable=C0103
8790
X_train = train_data.drop(["hash", "annotations"], axis=1)
88-
# pylint: disable=C0103
91+
# pylint: disable=C0103
8992
X_test = test_data.drop(["hash", "annotations"], axis=1)
9093

9194
y_train = train_data["annotations"]
@@ -94,8 +97,7 @@ def train(self):
9497
self._classifier.fit(X_train, y_train)
9598

9699
f1 = f1_score(y_test, self._classifier.predict(X_test), average="macro")
97-
# napari_info("F1 score on test set: {}".format(f1))
98-
napari_info(
100+
self.logger.info(
99101
f"F1 score on test set: {f1} \n"
100102
f"Annotations split into {len(X_train)} training and {len(X_test)} "
101103
"test samples. \n"
@@ -130,7 +132,6 @@ def predict_on_dict(self, dict_of_dfs):
130132
# Make a prediction on each of the dataframes provided
131133
predicted_dicts = {}
132134
for roi in dict_of_dfs:
133-
# napari_info(f"Making a prediction for {roi=}...")
134135
predicted_dicts[roi] = self.predict(dict_of_dfs[roi])
135136
return predicted_dicts
136137

@@ -149,12 +150,12 @@ def add_features(self, df_raw: pd.DataFrame):
149150

150151
def _validate_predict_features(self, df: pd.DataFrame) -> pd.Series:
151152
"""
152-
Validate the features that are received for prediction using
153+
Validate the features that are received for prediction using
153154
self._predict_schema.
154155
"""
155156
df_no_nans = df.dropna(subset=self._feature_names)
156157
if len(df) != len(df_no_nans):
157-
napari_info(
158+
self.logger.info(
158159
f"Could not do predictions for {len(df)-len(df_no_nans)}/{len(df)} "
159160
"objects because of features that contained `NA`s."
160161
)
@@ -174,7 +175,7 @@ def _validate_input_features(self, df: pd.DataFrame) -> pd.DataFrame:
174175
# Drop rows that have features with `NA`s, notify the user.
175176
df_no_nans = df_annotated.dropna(subset=self._feature_names)
176177
if len(df_no_nans) != len(df_annotated):
177-
napari_info(
178+
self.logger.info(
178179
f"Dropped {len(df_annotated)-len(df_no_nans)}/{len(df_annotated)} "
179180
"objects because of features that contained `NA`s."
180181
)
@@ -193,14 +194,14 @@ def add_dict_of_features(self, dict_of_features):
193194
Parameters
194195
----------
195196
dict_of_features : dict
196-
Dictionary with roi as key and dataframe with feature measurements
197+
Dictionary with roi as key and dataframe with feature measurements
197198
and annotations as value
198199
"""
199200
for roi in dict_of_features:
200201
if "roi_id" not in dict_of_features[roi]:
201202
dict_of_features[roi]["roi_id"] = roi
202203
df = dict_of_features[roi]
203-
napari_info(f"Adding features for {roi=}...")
204+
self.logger.info(f"Adding features for {roi=}...")
204205
self.add_features(df)
205206

206207
def get_class_names(self):
@@ -210,7 +211,7 @@ def get_feature_names(self):
210211
return self._feature_names
211212

212213
def save(self, output_path):
213-
napari_info(f"Saving classifier at {output_path}...")
214+
self.logger.info(f"Saving classifier at {output_path}...")
214215
with open(output_path, "wb") as f:
215216
f.write(pickle.dumps(self))
216217

src/napari_feature_classifier/classifier_widget.py

+24
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
"""Classifier container widget for napari"""
2+
import logging
23
import pickle
34

45
from pathlib import Path
@@ -32,6 +33,7 @@
3233
napari_info,
3334
overwrite_check_passed,
3435
add_annotation_names,
36+
NapariHandler,
3537
)
3638

3739

@@ -244,10 +246,15 @@ def __init__(
244246
self._viewer, get_class_selection(class_names=self.class_names)
245247
)
246248

249+
# Handle existing predictions layer
250+
for layer in self._viewer.layers:
251+
if type(layer) == napari.layers.Labels and layer.name == "Predictions":
252+
self._viewer.layers.remove(layer)
247253
self._prediction_layer = self._viewer.add_labels(
248254
self._last_selected_label_layer.data,
249255
scale=self._last_selected_label_layer.scale,
250256
name="Predictions",
257+
translate=self._last_selected_label_layer.translate,
251258
)
252259

253260
# Set the label selection to a valid label layer => Running into proxy bug
@@ -417,6 +424,7 @@ def _init_prediction_layer(self, label_layer: napari.layers.Labels):
417424
# Update the label data in the prediction layer
418425
self._prediction_layer.data = label_layer.data
419426
self._prediction_layer.scale = label_layer.scale
427+
self._prediction_layer.translate = label_layer.translate
420428

421429
# Update the colormap of the prediction layer
422430
reset_display_colormaps(
@@ -632,11 +640,27 @@ def __init__(self, viewer: napari.viewer.Viewer):
632640
self._init_container = None
633641
self._run_container = None
634642
self._init_container = None
643+
self.setup_logging()
635644

636645
super().__init__(widgets=[])
637646

638647
self.initialize_init_widget()
639648

649+
def setup_logging(self):
650+
# Create a custom handler for napari
651+
napari_handler = NapariHandler()
652+
napari_handler.setLevel(logging.INFO)
653+
654+
# Optionally, set a formatter for the handler
655+
# formatter = logging.Formatter(
656+
# '%(asctime)s - %(name)s - %(levelname)s - %(message)s'
657+
# )
658+
# napari_handler.setFormatter(formatter)
659+
660+
# Get the classifier's logger and add the napari handler to it
661+
classifier_logger = logging.getLogger("classifier")
662+
classifier_logger.addHandler(napari_handler)
663+
640664
def initialize_init_widget(self):
641665
self._init_container = ClassifierInitContainer(self._viewer)
642666
self.append(self._init_container)

src/napari_feature_classifier/utils.py

+10-2
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
"""Utils function for the classifier"""
22
from functools import lru_cache
3+
import logging
34
import math
45
from pathlib import Path
56

@@ -118,14 +119,21 @@ def napari_info(message):
118119
"""
119120
try:
120121
show_info(message)
121-
except: # pylint: disable=bare-except
122+
except: # pylint: disable=bare-except # noqa #E722
122123
print(message)
123124
# TODO: Would be better to check if it's running in napari and print in all
124125
# other cases (e.g. if someone runs the classifier form a script).
125126
# But can't make that work at the moment
126127
if in_notebook():
127128
print(message)
128129

130+
131+
class NapariHandler(logging.Handler):
132+
def emit(self, record):
133+
log_entry = self.format(record)
134+
napari_info(log_entry)
135+
136+
129137
def get_valid_label_layers(viewer) -> list[str]:
130138
"""
131139
Get a list of label layers that are not `Annotations` or `Predictions`.
@@ -183,7 +191,7 @@ def add_annotation_names(df, ClassSelection):
183191
Dataframe with annotations column.
184192
ClassSelection : Enum
185193
Enum with the class names.
186-
194+
187195
Returns
188196
-------
189197
pd.DataFrame

0 commit comments

Comments
 (0)