num_classes should not be a required argument for CropModel #960
Labels
API
This tag is used for small improvements to the readability and usability of the python API.
good first issue
Good for newcomers
Title: Refactor CropModel to infer num_classes from checkpoint during loading
DeepForest/src/deepforest/model.py
Line 87 in 171f55f
Current Behavior
Currently,
CropModel
requiresnum_classes
as a mandatory argument during initialization:This creates problems when loading from a checkpoint because users need to know the exact number of classes the model was trained with, even though this information is stored in the checkpoint.
Desired Behavior
num_classes
should be optional during initialization and automatically loaded from the checkpoint when available:General idea toward Solution
Working with copilot, I came up with:
Modify
CropModel
to use PyTorch Lightning'son_load_checkpoint
hook to handle thenum_classes
parameter. Here's the proposed implementation:Benefits
This needs to be tested and tests added.
The text was updated successfully, but these errors were encountered: