Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

num_classes should not be a required argument for CropModel #960

Open
bw4sz opened this issue Mar 7, 2025 · 1 comment · May be fixed by #966
Open

num_classes should not be a required argument for CropModel #960

bw4sz opened this issue Mar 7, 2025 · 1 comment · May be fixed by #966
Labels
API This tag is used for small improvements to the readability and usability of the python API. good first issue Good for newcomers

Comments

@bw4sz
Copy link
Collaborator

bw4sz commented Mar 7, 2025

Title: Refactor CropModel to infer num_classes from checkpoint during loading

num_classes (int): Number of classes for classification

Current Behavior

Currently, CropModel requires num_classes as a mandatory argument during initialization:

model = CropModel(num_classes=2)  # num_classes must be known beforehand

This creates problems when loading from a checkpoint because users need to know the exact number of classes the model was trained with, even though this information is stored in the checkpoint.

Desired Behavior

num_classes should be optional during initialization and automatically loaded from the checkpoint when available:

# Loading from checkpoint without needing to specify num_classes
model = CropModel.load_from_checkpoint("path/to/checkpoint.ckpt")

General idea toward Solution

Working with copilot, I came up with:

Modify CropModel to use PyTorch Lightning's on_load_checkpoint hook to handle the num_classes parameter. Here's the proposed implementation:

class CropModel(LightningModule):
    def __init__(self, num_classes=None, batch_size=4, num_workers=0, lr=0.0001, model=None, label_dict=None):
        super().__init__()
        self.num_classes = num_classes
        
        # Only initialize model if num_classes is provided
        if num_classes is not None:
            self._init_model(num_classes)
            
    def _init_model(self, num_classes):
        """Initialize model and metrics with given num_classes"""
        if getattr(self, 'model', None) is None:
            self.model = simple_resnet_50(num_classes=num_classes)
            
        self.accuracy = torchmetrics.Accuracy(average='none',
                                            num_classes=num_classes,
                                            task="multiclass")
        self.total_accuracy = torchmetrics.Accuracy(num_classes=num_classes,
                                                  task="multiclass")
        self.precision_metric = torchmetrics.Precision(num_classes=num_classes,
                                                     task="multiclass")
        self.metrics = torchmetrics.MetricCollection({
            "Class Accuracy": self.accuracy,
            "Accuracy": self.total_accuracy,
            "Precision": self.precision_metric
        })
            
    def on_load_checkpoint(self, checkpoint):
        if self.num_classes is None:
            # Extract num_classes from checkpoint
            self.num_classes = checkpoint['hyper_parameters']['num_classes']
            self._init_model(self.num_classes)

Benefits

  1. More user-friendly API when loading from checkpoints
  2. Follows PyTorch Lightning's best practices for checkpoint handling
  3. Maintains backward compatibility (can still specify num_classes if needed)
  4. Reduces potential errors from mismatched class numbers
  5. Self-documenting code that makes the model's configuration clear from the checkpoint

This needs to be tested and tests added.

@bw4sz bw4sz added API This tag is used for small improvements to the readability and usability of the python API. good first issue Good for newcomers labels Mar 7, 2025
@bw4sz bw4sz changed the title num_classes should not be a required argument. num_classes should not be a required argument for CropModel Mar 7, 2025
@bw4sz bw4sz closed this as completed Mar 7, 2025
@bw4sz
Copy link
Collaborator Author

bw4sz commented Mar 7, 2025

Closed by accident, this is still valid.

CropModel.load_from_checkpoint(checkpoint, num_classes=6)
Traceback (most recent call last):
  File "<string>", line 1, in <module>
  File "/blue/ewhite/b.weinstein/miniconda3/envs/BOEM/lib/python3.10/site-packages/pytorch_lightning/utilities/model_helpers.py", line 125, in wrapper
    return self.method(cls, *args, **kwargs)
  File "/blue/ewhite/b.weinstein/miniconda3/envs/BOEM/lib/python3.10/site-packages/pytorch_lightning/core/module.py", line 1581, in load_from_checkpoint
    loaded = _load_from_checkpoint(
  File "/blue/ewhite/b.weinstein/miniconda3/envs/BOEM/lib/python3.10/site-packages/pytorch_lightning/core/saving.py", line 91, in _load_from_checkpoint
    model = _load_state(cls, checkpoint, strict=strict, **kwargs)
  File "/blue/ewhite/b.weinstein/miniconda3/envs/BOEM/lib/python3.10/site-packages/pytorch_lightning/core/saving.py", line 187, in _load_state
    keys = obj.load_state_dict(checkpoint["state_dict"], strict=strict)
  File "/blue/ewhite/b.weinstein/miniconda3/envs/BOEM/lib/python3.10/site-packages/torch/nn/modules/module.py", line 2584, in load_state_dict
    raise RuntimeError(
RuntimeError: Error(s) in loading state_dict for CropModel:
	size mismatch for model.fc.weight: copying a param with shape torch.Size([7, 2048]) from checkpoint, the shape in current model is torch.Size([6, 2048]).
	size mismatch for model.fc.bias: copying a param with shape torch.Size([7]) from checkpoint, the shape in current model is torch.Size([6]).

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
API This tag is used for small improvements to the readability and usability of the python API. good first issue Good for newcomers
Projects
None yet
Development

Successfully merging a pull request may close this issue.

1 participant