Skip to content

Commit 5495e66

Browse files
committed
Improve logging: Classifier module now no longer imports napari-related & qt-dependent things
1 parent b43b5e0 commit 5495e66

File tree

3 files changed

+47
-20
lines changed

3 files changed

+47
-20
lines changed

src/napari_feature_classifier/classifier.py

+19-18
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
"""Core classifier class and helper functions."""
2+
import logging
23
import pickle
34
import random
45
import string
@@ -11,10 +12,9 @@
1112
from sklearn.metrics import f1_score
1213
from sklearn.ensemble import RandomForestClassifier
1314

14-
from napari_feature_classifier.utils import napari_info
1515

16-
17-
# TODO: define an interface for compatible classifiers (m.b. a subset of sklearn Estimators?)
16+
# TODO: define an interface for compatible classifiers (m.b. a subset of
17+
# sklearn Estimators?)
1818
class Classifier:
1919
"""Classifier class for napari-feature-classifier.
2020
@@ -23,7 +23,7 @@ class Classifier:
2323
feature_names: Sequence[str]
2424
The names of the features that are used for classification
2525
class_names: Sequence[str]
26-
The names of the classes. It's an ordered list that is matched to
26+
The names of the classes. It's an ordered list that is matched to
2727
annotations [1, 2, 3, ...]
2828
classifier: sklearn classifier
2929
The classifier that is used for classification. Default is a
@@ -42,7 +42,7 @@ class Classifier:
4242
The percentage of the data that is used for training. The rest is used
4343
for testing.
4444
_index_columns: list[str]
45-
The columns that are used for indexing the data.
45+
The columns that are used for indexing the data.
4646
Hard-coded to roi_id and label
4747
_input_schema: pandera.SchemaModel
4848
The schema for the input data. It's used for validation.
@@ -51,10 +51,13 @@ class Classifier:
5151
_predict_schema: pandera.SchemaModel
5252
The schema for the prediction data.
5353
_data: pd.DataFrame
54-
The internal data storage of the classifier. Contains both annotations
54+
The internal data storage of the classifier. Contains both annotations
5555
as well as feature measurements for all rows (annotated objects)
5656
"""
57+
5758
def __init__(self, feature_names, class_names, classifier=RandomForestClassifier()):
59+
self.logger = logging.getLogger("classifier")
60+
self.logger.setLevel(logging.INFO)
5861
self._feature_names: list[str] = list(feature_names)
5962
self._class_names: list[str] = list(class_names)
6063
self._classifier = classifier
@@ -79,13 +82,13 @@ def train(self):
7982
"""
8083
Train the classifier on the data it already has in self._data.
8184
"""
82-
napari_info("Training classifier...")
85+
self.logger.info("Training classifier...")
8386
train_data = self._data[self._data.hash < self._training_data_perc]
8487
test_data = self._data[self._data.hash >= self._training_data_perc]
8588

86-
# pylint: disable=C0103
89+
# pylint: disable=C0103
8790
X_train = train_data.drop(["hash", "annotations"], axis=1)
88-
# pylint: disable=C0103
91+
# pylint: disable=C0103
8992
X_test = test_data.drop(["hash", "annotations"], axis=1)
9093

9194
y_train = train_data["annotations"]
@@ -94,8 +97,7 @@ def train(self):
9497
self._classifier.fit(X_train, y_train)
9598

9699
f1 = f1_score(y_test, self._classifier.predict(X_test), average="macro")
97-
# napari_info("F1 score on test set: {}".format(f1))
98-
napari_info(
100+
self.logger.info(
99101
f"F1 score on test set: {f1} \n"
100102
f"Annotations split into {len(X_train)} training and {len(X_test)} "
101103
"test samples. \n"
@@ -130,7 +132,6 @@ def predict_on_dict(self, dict_of_dfs):
130132
# Make a prediction on each of the dataframes provided
131133
predicted_dicts = {}
132134
for roi in dict_of_dfs:
133-
# napari_info(f"Making a prediction for {roi=}...")
134135
predicted_dicts[roi] = self.predict(dict_of_dfs[roi])
135136
return predicted_dicts
136137

@@ -149,12 +150,12 @@ def add_features(self, df_raw: pd.DataFrame):
149150

150151
def _validate_predict_features(self, df: pd.DataFrame) -> pd.Series:
151152
"""
152-
Validate the features that are received for prediction using
153+
Validate the features that are received for prediction using
153154
self._predict_schema.
154155
"""
155156
df_no_nans = df.dropna(subset=self._feature_names)
156157
if len(df) != len(df_no_nans):
157-
napari_info(
158+
self.logger.info(
158159
f"Could not do predictions for {len(df)-len(df_no_nans)}/{len(df)} "
159160
"objects because of features that contained `NA`s."
160161
)
@@ -174,7 +175,7 @@ def _validate_input_features(self, df: pd.DataFrame) -> pd.DataFrame:
174175
# Drop rows that have features with `NA`s, notify the user.
175176
df_no_nans = df_annotated.dropna(subset=self._feature_names)
176177
if len(df_no_nans) != len(df_annotated):
177-
napari_info(
178+
self.logger.info(
178179
f"Dropped {len(df_annotated)-len(df_no_nans)}/{len(df_annotated)} "
179180
"objects because of features that contained `NA`s."
180181
)
@@ -193,14 +194,14 @@ def add_dict_of_features(self, dict_of_features):
193194
Parameters
194195
----------
195196
dict_of_features : dict
196-
Dictionary with roi as key and dataframe with feature measurements
197+
Dictionary with roi as key and dataframe with feature measurements
197198
and annotations as value
198199
"""
199200
for roi in dict_of_features:
200201
if "roi_id" not in dict_of_features[roi]:
201202
dict_of_features[roi]["roi_id"] = roi
202203
df = dict_of_features[roi]
203-
napari_info(f"Adding features for {roi=}...")
204+
self.logger.info(f"Adding features for {roi=}...")
204205
self.add_features(df)
205206

206207
def get_class_names(self):
@@ -210,7 +211,7 @@ def get_feature_names(self):
210211
return self._feature_names
211212

212213
def save(self, output_path):
213-
napari_info(f"Saving classifier at {output_path}...")
214+
self.logger.info(f"Saving classifier at {output_path}...")
214215
with open(output_path, "wb") as f:
215216
f.write(pickle.dumps(self))
216217

src/napari_feature_classifier/classifier_widget.py

+18
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
"""Classifier container widget for napari"""
2+
import logging
23
import pickle
34

45
from pathlib import Path
@@ -32,6 +33,7 @@
3233
napari_info,
3334
overwrite_check_passed,
3435
add_annotation_names,
36+
NapariHandler,
3537
)
3638

3739

@@ -636,11 +638,27 @@ def __init__(self, viewer: napari.viewer.Viewer):
636638
self._init_container = None
637639
self._run_container = None
638640
self._init_container = None
641+
self.setup_logging()
639642

640643
super().__init__(widgets=[])
641644

642645
self.initialize_init_widget()
643646

647+
def setup_logging(self):
648+
# Create a custom handler for napari
649+
napari_handler = NapariHandler()
650+
napari_handler.setLevel(logging.INFO)
651+
652+
# Optionally, set a formatter for the handler
653+
# formatter = logging.Formatter(
654+
# '%(asctime)s - %(name)s - %(levelname)s - %(message)s'
655+
# )
656+
# napari_handler.setFormatter(formatter)
657+
658+
# Get the classifier's logger and add the napari handler to it
659+
classifier_logger = logging.getLogger("classifier")
660+
classifier_logger.addHandler(napari_handler)
661+
644662
def initialize_init_widget(self):
645663
self._init_container = ClassifierInitContainer(self._viewer)
646664
self.append(self._init_container)

src/napari_feature_classifier/utils.py

+10-2
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
"""Utils function for the classifier"""
22
from functools import lru_cache
3+
import logging
34
import math
45
from pathlib import Path
56

@@ -118,14 +119,21 @@ def napari_info(message):
118119
"""
119120
try:
120121
show_info(message)
121-
except: # pylint: disable=bare-except
122+
except: # pylint: disable=bare-except # noqa #E722
122123
print(message)
123124
# TODO: Would be better to check if it's running in napari and print in all
124125
# other cases (e.g. if someone runs the classifier form a script).
125126
# But can't make that work at the moment
126127
if in_notebook():
127128
print(message)
128129

130+
131+
class NapariHandler(logging.Handler):
132+
def emit(self, record):
133+
log_entry = self.format(record)
134+
napari_info(log_entry)
135+
136+
129137
def get_valid_label_layers(viewer) -> list[str]:
130138
"""
131139
Get a list of label layers that are not `Annotations` or `Predictions`.
@@ -183,7 +191,7 @@ def add_annotation_names(df, ClassSelection):
183191
Dataframe with annotations column.
184192
ClassSelection : Enum
185193
Enum with the class names.
186-
194+
187195
Returns
188196
-------
189197
pd.DataFrame

0 commit comments

Comments
 (0)