Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Second attempt to fix issue #30: Classifier widget restarts #46

Merged
merged 8 commits into from
Jul 19, 2024
4 changes: 2 additions & 2 deletions setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -36,13 +36,13 @@ package_dir =

# add your package requirements here
install_requires =
numpy
numpy < 2.0
napari < 0.4.19
matplotlib
magicgui
pandas
scikit-learn >= 1.2.2
pandera
pandera < 0.20.0
xxhash
hypothesis

Expand Down
40 changes: 30 additions & 10 deletions src/napari_feature_classifier/annotator_widget.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,8 @@ class LabelAnnotator(Container):
Can also be controlled via the number keys.
"""

# TODO: Do we need to keep the annotation layer on top when new annotations are made?
# TODO: Do we need to keep the annotation layer on top when new
# annotations are made?
def __init__(
self,
viewer: napari.viewer.Viewer,
Expand All @@ -129,10 +130,15 @@ def __init__(
label="Last selected label layer:", value=self._last_selected_label_layer
)

# Handle existing predictions layer
for layer in self._viewer.layers:
if type(layer) == napari.layers.Labels and layer.name == "Annotations":
self._viewer.layers.remove(layer)
self._annotations_layer = self._viewer.add_labels(
self._last_selected_label_layer.data,
scale=self._last_selected_label_layer.scale,
name="Annotations",
translate=self._last_selected_label_layer.translate,
)
self._annotations_layer.editable = False

Expand Down Expand Up @@ -209,22 +215,34 @@ def toggle_label(self, labels_layer, event):
Callback for when a label is clicked. It then updates the color of that
label in the annotation layer.
"""
# Need to scale position that event.position returns by the
# Need to translate & scale position that event.position returns by the
# label_layer scale.
# If scale is (1, 1, 1), nothing changes
# If translate is (0, 0, 0)
# If scale is anything else, this makes the click still match the
# correct label
# translate before scale
scaled_position = tuple(
pos / scale for pos, scale in zip(event.position, labels_layer.scale)
(pos - trans) / scale
for pos, trans, scale in zip(
event.position, labels_layer.translate, labels_layer.scale
)
)
label = labels_layer.get_value(scaled_position)
if label == 0 or not label:
napari_info("No label clicked.")
napari_info(f"No label clicked on the {labels_layer} label layer.")
return

labels_layer.features.loc[
labels_layer.features[self._label_column] == label, "annotations"
] = self._class_selector.value.value
# Left click: add annotation
if event.button == 1:
labels_layer.features.loc[
labels_layer.features[self._label_column] == label, "annotations"
] = self._class_selector.value.value
# Right click: Remove annotation
elif event.button == 2:
labels_layer.features.loc[
labels_layer.features[self._label_column] == label, "annotations"
] = np.NaN

# Update only the single color value that changed
self.update_single_color(labels_layer, label)
Expand Down Expand Up @@ -253,9 +271,11 @@ def _init_annotation(self, label_layer: napari.layers.Labels):
[label_layer.features, annotation_df], axis=1
)

label_layer.opacity = 0.4
# label_layer.opacity = 0.4
self._annotations_layer.data = label_layer.data
self._annotations_layer.scale = label_layer.scale
self._annotations_layer.translate = label_layer.translate

reset_display_colormaps(
label_layer,
feature_col="annotations",
Expand All @@ -265,7 +285,7 @@ def _init_annotation(self, label_layer: napari.layers.Labels):
)
label_layer.mouse_drag_callbacks.append(self.toggle_label)

# # keybindings for the available classes (0 = deselect)
# keybindings for the available classes (0 = deselect)
for i in range(len(self.ClassSelection)):
set_class = partial(self.set_class_n, n=i)
set_class.__name__ = f"set_class_{i}"
Expand All @@ -275,7 +295,7 @@ def _update_save_destination(self, label_layer: napari.layers.Labels):
"""
Update the default save destination to the name of the label layer.
If a base_path was already set, keep it on that base path.

"""
base_path = Path(self._save_destination.value).parent
self._save_destination.value = base_path / f"{label_layer.name}_annotation.csv"
Expand Down
37 changes: 19 additions & 18 deletions src/napari_feature_classifier/classifier.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Core classifier class and helper functions."""
import logging
import pickle
import random
import string
Expand All @@ -11,10 +12,9 @@
from sklearn.metrics import f1_score
from sklearn.ensemble import RandomForestClassifier

from napari_feature_classifier.utils import napari_info


# TODO: define an interface for compatible classifiers (m.b. a subset of sklearn Estimators?)
# TODO: define an interface for compatible classifiers (m.b. a subset of
# sklearn Estimators?)
class Classifier:
"""Classifier class for napari-feature-classifier.

Expand All @@ -23,7 +23,7 @@ class Classifier:
feature_names: Sequence[str]
The names of the features that are used for classification
class_names: Sequence[str]
The names of the classes. It's an ordered list that is matched to
The names of the classes. It's an ordered list that is matched to
annotations [1, 2, 3, ...]
classifier: sklearn classifier
The classifier that is used for classification. Default is a
Expand All @@ -42,7 +42,7 @@ class Classifier:
The percentage of the data that is used for training. The rest is used
for testing.
_index_columns: list[str]
The columns that are used for indexing the data.
The columns that are used for indexing the data.
Hard-coded to roi_id and label
_input_schema: pandera.SchemaModel
The schema for the input data. It's used for validation.
Expand All @@ -51,10 +51,13 @@ class Classifier:
_predict_schema: pandera.SchemaModel
The schema for the prediction data.
_data: pd.DataFrame
The internal data storage of the classifier. Contains both annotations
The internal data storage of the classifier. Contains both annotations
as well as feature measurements for all rows (annotated objects)
"""

def __init__(self, feature_names, class_names, classifier=RandomForestClassifier()):
self.logger = logging.getLogger("classifier")
self.logger.setLevel(logging.INFO)
self._feature_names: list[str] = list(feature_names)
self._class_names: list[str] = list(class_names)
self._classifier = classifier
Expand All @@ -79,13 +82,13 @@ def train(self):
"""
Train the classifier on the data it already has in self._data.
"""
napari_info("Training classifier...")
self.logger.info("Training classifier...")
train_data = self._data[self._data.hash < self._training_data_perc]
test_data = self._data[self._data.hash >= self._training_data_perc]

# pylint: disable=C0103
# pylint: disable=C0103
X_train = train_data.drop(["hash", "annotations"], axis=1)
# pylint: disable=C0103
# pylint: disable=C0103
X_test = test_data.drop(["hash", "annotations"], axis=1)

y_train = train_data["annotations"]
Expand All @@ -94,8 +97,7 @@ def train(self):
self._classifier.fit(X_train, y_train)

f1 = f1_score(y_test, self._classifier.predict(X_test), average="macro")
# napari_info("F1 score on test set: {}".format(f1))
napari_info(
self.logger.info(
f"F1 score on test set: {f1} \n"
f"Annotations split into {len(X_train)} training and {len(X_test)} "
"test samples. \n"
Expand Down Expand Up @@ -130,7 +132,6 @@ def predict_on_dict(self, dict_of_dfs):
# Make a prediction on each of the dataframes provided
predicted_dicts = {}
for roi in dict_of_dfs:
# napari_info(f"Making a prediction for {roi=}...")
predicted_dicts[roi] = self.predict(dict_of_dfs[roi])
return predicted_dicts

Expand All @@ -149,12 +150,12 @@ def add_features(self, df_raw: pd.DataFrame):

def _validate_predict_features(self, df: pd.DataFrame) -> pd.Series:
"""
Validate the features that are received for prediction using
Validate the features that are received for prediction using
self._predict_schema.
"""
df_no_nans = df.dropna(subset=self._feature_names)
if len(df) != len(df_no_nans):
napari_info(
self.logger.info(
f"Could not do predictions for {len(df)-len(df_no_nans)}/{len(df)} "
"objects because of features that contained `NA`s."
)
Expand All @@ -174,7 +175,7 @@ def _validate_input_features(self, df: pd.DataFrame) -> pd.DataFrame:
# Drop rows that have features with `NA`s, notify the user.
df_no_nans = df_annotated.dropna(subset=self._feature_names)
if len(df_no_nans) != len(df_annotated):
napari_info(
self.logger.info(
f"Dropped {len(df_annotated)-len(df_no_nans)}/{len(df_annotated)} "
"objects because of features that contained `NA`s."
)
Expand All @@ -193,14 +194,14 @@ def add_dict_of_features(self, dict_of_features):
Parameters
----------
dict_of_features : dict
Dictionary with roi as key and dataframe with feature measurements
Dictionary with roi as key and dataframe with feature measurements
and annotations as value
"""
for roi in dict_of_features:
if "roi_id" not in dict_of_features[roi]:
dict_of_features[roi]["roi_id"] = roi
df = dict_of_features[roi]
napari_info(f"Adding features for {roi=}...")
self.logger.info(f"Adding features for {roi=}...")
self.add_features(df)

def get_class_names(self):
Expand All @@ -210,7 +211,7 @@ def get_feature_names(self):
return self._feature_names

def save(self, output_path):
napari_info(f"Saving classifier at {output_path}...")
self.logger.info(f"Saving classifier at {output_path}...")
with open(output_path, "wb") as f:
f.write(pickle.dumps(self))

Expand Down
24 changes: 24 additions & 0 deletions src/napari_feature_classifier/classifier_widget.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Classifier container widget for napari"""
import logging
import pickle

from pathlib import Path
Expand Down Expand Up @@ -32,6 +33,7 @@
napari_info,
overwrite_check_passed,
add_annotation_names,
NapariHandler,
)


Expand Down Expand Up @@ -244,10 +246,15 @@ def __init__(
self._viewer, get_class_selection(class_names=self.class_names)
)

# Handle existing predictions layer
for layer in self._viewer.layers:
if type(layer) == napari.layers.Labels and layer.name == "Predictions":
self._viewer.layers.remove(layer)
self._prediction_layer = self._viewer.add_labels(
self._last_selected_label_layer.data,
scale=self._last_selected_label_layer.scale,
name="Predictions",
translate=self._last_selected_label_layer.translate,
)

# Set the label selection to a valid label layer => Running into proxy bug
Expand Down Expand Up @@ -417,6 +424,7 @@ def _init_prediction_layer(self, label_layer: napari.layers.Labels):
# Update the label data in the prediction layer
self._prediction_layer.data = label_layer.data
self._prediction_layer.scale = label_layer.scale
self._prediction_layer.translate = label_layer.translate

# Update the colormap of the prediction layer
reset_display_colormaps(
Expand Down Expand Up @@ -632,11 +640,27 @@ def __init__(self, viewer: napari.viewer.Viewer):
self._init_container = None
self._run_container = None
self._init_container = None
self.setup_logging()

super().__init__(widgets=[])

self.initialize_init_widget()

def setup_logging(self):
# Create a custom handler for napari
napari_handler = NapariHandler()
napari_handler.setLevel(logging.INFO)

# Optionally, set a formatter for the handler
# formatter = logging.Formatter(
# '%(asctime)s - %(name)s - %(levelname)s - %(message)s'
# )
# napari_handler.setFormatter(formatter)

# Get the classifier's logger and add the napari handler to it
classifier_logger = logging.getLogger("classifier")
classifier_logger.addHandler(napari_handler)

def initialize_init_widget(self):
self._init_container = ClassifierInitContainer(self._viewer)
self.append(self._init_container)
Expand Down
12 changes: 10 additions & 2 deletions src/napari_feature_classifier/utils.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
"""Utils function for the classifier"""
from functools import lru_cache
import logging
import math
from pathlib import Path

Expand Down Expand Up @@ -118,14 +119,21 @@ def napari_info(message):
"""
try:
show_info(message)
except: # pylint: disable=bare-except
except: # pylint: disable=bare-except # noqa #E722
print(message)
# TODO: Would be better to check if it's running in napari and print in all
# other cases (e.g. if someone runs the classifier form a script).
# But can't make that work at the moment
if in_notebook():
print(message)


class NapariHandler(logging.Handler):
def emit(self, record):
log_entry = self.format(record)
napari_info(log_entry)


def get_valid_label_layers(viewer) -> list[str]:
"""
Get a list of label layers that are not `Annotations` or `Predictions`.
Expand Down Expand Up @@ -183,7 +191,7 @@ def add_annotation_names(df, ClassSelection):
Dataframe with annotations column.
ClassSelection : Enum
Enum with the class names.

Returns
-------
pd.DataFrame
Expand Down
Loading