Skip to content

Commit 9cccfa1

Browse files
committed
added some minor changes
1 parent 47dce92 commit 9cccfa1

File tree

1 file changed

+9
-0
lines changed

1 file changed

+9
-0
lines changed

src/deepforest/callbacks.py

+9
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
import glob
1010
import tempfile
1111
from deepforest.visualize import plot_results
12+
from deepforest.utilities import ColorPalette
1213

1314
from pytorch_lightning import Callback
1415
from deepforest import dataset
@@ -59,6 +60,14 @@ def log_images(self, pl_module):
5960
selected_images = df.image_path.unique()[:self.n]
6061
df = df[df.image_path.isin(selected_images)]
6162

63+
# Ensure color is correctly assigned
64+
if self.color is None:
65+
num_classes = len(df["label"].unique()) # Determine number of classes
66+
results_color = ColorPalette(
67+
num_classes) # Generate appropriate color palette
68+
else:
69+
results_color = self.color
70+
6271
plot_results(results=df,
6372
savedir=self.savedir,
6473
results_color=self.color,

0 commit comments

Comments
 (0)