1
1
"""Core classifier class and helper functions."""
2
+ import logging
2
3
import pickle
3
4
import random
4
5
import string
11
12
from sklearn .metrics import f1_score
12
13
from sklearn .ensemble import RandomForestClassifier
13
14
14
- from napari_feature_classifier .utils import napari_info
15
15
16
-
17
- # TODO: define an interface for compatible classifiers (m.b. a subset of sklearn Estimators?)
16
+ # TODO: define an interface for compatible classifiers (m.b. a subset of
17
+ # sklearn Estimators?)
18
18
class Classifier :
19
19
"""Classifier class for napari-feature-classifier.
20
20
@@ -23,7 +23,7 @@ class Classifier:
23
23
feature_names: Sequence[str]
24
24
The names of the features that are used for classification
25
25
class_names: Sequence[str]
26
- The names of the classes. It's an ordered list that is matched to
26
+ The names of the classes. It's an ordered list that is matched to
27
27
annotations [1, 2, 3, ...]
28
28
classifier: sklearn classifier
29
29
The classifier that is used for classification. Default is a
@@ -42,7 +42,7 @@ class Classifier:
42
42
The percentage of the data that is used for training. The rest is used
43
43
for testing.
44
44
_index_columns: list[str]
45
- The columns that are used for indexing the data.
45
+ The columns that are used for indexing the data.
46
46
Hard-coded to roi_id and label
47
47
_input_schema: pandera.SchemaModel
48
48
The schema for the input data. It's used for validation.
@@ -51,10 +51,13 @@ class Classifier:
51
51
_predict_schema: pandera.SchemaModel
52
52
The schema for the prediction data.
53
53
_data: pd.DataFrame
54
- The internal data storage of the classifier. Contains both annotations
54
+ The internal data storage of the classifier. Contains both annotations
55
55
as well as feature measurements for all rows (annotated objects)
56
56
"""
57
+
57
58
def __init__ (self , feature_names , class_names , classifier = RandomForestClassifier ()):
59
+ self .logger = logging .getLogger ("classifier" )
60
+ self .logger .setLevel (logging .INFO )
58
61
self ._feature_names : list [str ] = list (feature_names )
59
62
self ._class_names : list [str ] = list (class_names )
60
63
self ._classifier = classifier
@@ -79,13 +82,13 @@ def train(self):
79
82
"""
80
83
Train the classifier on the data it already has in self._data.
81
84
"""
82
- napari_info ("Training classifier..." )
85
+ self . logger . info ("Training classifier..." )
83
86
train_data = self ._data [self ._data .hash < self ._training_data_perc ]
84
87
test_data = self ._data [self ._data .hash >= self ._training_data_perc ]
85
88
86
- # pylint: disable=C0103
89
+ # pylint: disable=C0103
87
90
X_train = train_data .drop (["hash" , "annotations" ], axis = 1 )
88
- # pylint: disable=C0103
91
+ # pylint: disable=C0103
89
92
X_test = test_data .drop (["hash" , "annotations" ], axis = 1 )
90
93
91
94
y_train = train_data ["annotations" ]
@@ -94,8 +97,7 @@ def train(self):
94
97
self ._classifier .fit (X_train , y_train )
95
98
96
99
f1 = f1_score (y_test , self ._classifier .predict (X_test ), average = "macro" )
97
- # napari_info("F1 score on test set: {}".format(f1))
98
- napari_info (
100
+ self .logger .info (
99
101
f"F1 score on test set: { f1 } \n "
100
102
f"Annotations split into { len (X_train )} training and { len (X_test )} "
101
103
"test samples. \n "
@@ -130,7 +132,6 @@ def predict_on_dict(self, dict_of_dfs):
130
132
# Make a prediction on each of the dataframes provided
131
133
predicted_dicts = {}
132
134
for roi in dict_of_dfs :
133
- # napari_info(f"Making a prediction for {roi=}...")
134
135
predicted_dicts [roi ] = self .predict (dict_of_dfs [roi ])
135
136
return predicted_dicts
136
137
@@ -149,12 +150,12 @@ def add_features(self, df_raw: pd.DataFrame):
149
150
150
151
def _validate_predict_features (self , df : pd .DataFrame ) -> pd .Series :
151
152
"""
152
- Validate the features that are received for prediction using
153
+ Validate the features that are received for prediction using
153
154
self._predict_schema.
154
155
"""
155
156
df_no_nans = df .dropna (subset = self ._feature_names )
156
157
if len (df ) != len (df_no_nans ):
157
- napari_info (
158
+ self . logger . info (
158
159
f"Could not do predictions for { len (df )- len (df_no_nans )} /{ len (df )} "
159
160
"objects because of features that contained `NA`s."
160
161
)
@@ -174,7 +175,7 @@ def _validate_input_features(self, df: pd.DataFrame) -> pd.DataFrame:
174
175
# Drop rows that have features with `NA`s, notify the user.
175
176
df_no_nans = df_annotated .dropna (subset = self ._feature_names )
176
177
if len (df_no_nans ) != len (df_annotated ):
177
- napari_info (
178
+ self . logger . info (
178
179
f"Dropped { len (df_annotated )- len (df_no_nans )} /{ len (df_annotated )} "
179
180
"objects because of features that contained `NA`s."
180
181
)
@@ -193,14 +194,14 @@ def add_dict_of_features(self, dict_of_features):
193
194
Parameters
194
195
----------
195
196
dict_of_features : dict
196
- Dictionary with roi as key and dataframe with feature measurements
197
+ Dictionary with roi as key and dataframe with feature measurements
197
198
and annotations as value
198
199
"""
199
200
for roi in dict_of_features :
200
201
if "roi_id" not in dict_of_features [roi ]:
201
202
dict_of_features [roi ]["roi_id" ] = roi
202
203
df = dict_of_features [roi ]
203
- napari_info (f"Adding features for { roi = } ..." )
204
+ self . logger . info (f"Adding features for { roi = } ..." )
204
205
self .add_features (df )
205
206
206
207
def get_class_names (self ):
@@ -210,7 +211,7 @@ def get_feature_names(self):
210
211
return self ._feature_names
211
212
212
213
def save (self , output_path ):
213
- napari_info (f"Saving classifier at { output_path } ..." )
214
+ self . logger . info (f"Saving classifier at { output_path } ..." )
214
215
with open (output_path , "wb" ) as f :
215
216
f .write (pickle .dumps (self ))
216
217
0 commit comments