Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Generic classifier builder #2356

Merged
merged 22 commits into from
Mar 27, 2025

Conversation

SalmanMohammadi
Copy link
Collaborator

@SalmanMohammadi SalmanMohammadi commented Feb 7, 2025

This PR adds support for a generic classifier builder. Currently we require explicit classifier builders to be defined for every model which we wish to use as a classifier. This is costly to maintain.

We can instead define a generic classifier builder which simply adapts the final layer of any decoder-based model (how strict should this contract be?) to the desired number of classes.

This PR does not include support for PEFT-based classifier model builders. This is because there are no use-cases in the library for such model builders at this time.

Copy link

pytorch-bot bot commented Feb 7, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/torchtune/2356

Note: Links to docs will display an error until the docs builds have been completed.

⏳ No Failures, 6 Pending

As of commit c96a871 with merge base 32d195c (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Feb 7, 2025
Returns:
nn.Module: The classifier model.
"""
model = _get_component_from_path(base_model)(**base_model_kwargs)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if this is bound by the models in our library, then we would need to test that we can successfully create a classifier model for every model in our library, and any future models that are added. Can users pass in a custom model path? If that's not the intention, then we'll need to add some more gating. For example, if we check for if it's a TransformerDecoder instance that would easily keep this constrained to models that we know with confidence can be made into classifiers.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great points, @RdoubleA.

Can users pass in a custom model path?

I think we only ever claim to support models in our library, right? So this would follow the same contract.

we would need to test that we can successfully create a classifier model for every model in our library, and any future models that are added

I'm curious why it wouldn't be sufficient to write an equally generic test against TransformerDecoder? I feel like the idea behind this builder is that we shouldn't have to define (or test) classifier builders for every model in our library - rather, I would think testing it against the generic cases we think could come up might be sufficient: model.output: TiedLinear, model.output: Linear, and maybe (model.output: TransformerDecoder.output). What do you think?

For example, if we check for if it's a TransformerDecoder instance that would easily keep this constrained to models that we know with confidence can be made into classifiers.

I think we agree here, if we can rely on any TransformerDecoder satisfying the requirements for this builder i.e. that it uses some linear output projection (which is a pretty flexible condition), then we should be able to use it as a classifier model. I think this would include vision models which use a TransformerDecoder as a decoder.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yep agreed, if we test against TransformerDecoder that should satisfy all the conditions I mentioned

@codecov-commenter
Copy link

Codecov Report

All modified and coverable lines are covered by tests ✅

Project coverage is 63.21%. Comparing base (67a8706) to head (ecf2e72).
Report is 1 commits behind head on main.

Additional details and impacted files
@@            Coverage Diff             @@
##             main    #2356      +/-   ##
==========================================
- Coverage   65.47%   63.21%   -2.27%     
==========================================
  Files         374      376       +2     
  Lines       22280    22316      +36     
==========================================
- Hits        14587    14106     -481     
- Misses       7693     8210     +517     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

@SalmanMohammadi SalmanMohammadi marked this pull request as ready for review March 15, 2025 19:07
Comment on lines 38 to 43
if isinstance(model.output, nn.Linear):
del model.output.weight
if hasattr(model.output, "bias"):
del model.output.bias
model.output = nn.Linear(model.head_dim * model.num_heads, num_classes, bias=False)
return model
Copy link
Contributor

@felipemello1 felipemello1 Mar 17, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this wont work for TiedEmbedding layer.

Should we use

if hasattr(model, 'output')

Instead of:

if isinstance(model.output, nn.Linear)

also, for multimodal, it would have to be model.decoder.output, instead of model.output

Copy link
Collaborator Author

@SalmanMohammadi SalmanMohammadi Mar 17, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we use
if hasattr(model, 'output')

Hmm so the logic would be something like -

    if hasattr(model, 'output'):
        del model.output.weight
        if hasattr(model.output, "bias"):
            del model.output.bias
            

In the TiedEmbedding case, we will have model.output: TiedLinear, but model.output will have no weight/bias parameters to delete, so we won't need to enter this block right? Let me know if I'm missing something or misunderstanding here : )

also, for multimodal, it would have to be model.decoder.output, instead of model.output

I agree, but similar to PEFT, we don't support any use-cases for MM classifier models at the moment, so I thought I'd keep the design super simple and directed. What do you think?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sorry, still not clear to me why not just replace the output head if there is one, instead of deleting the weights first.

Just do:

if hasattr(model, 'output'):
     model.output = new_head()

I agree, but similar to PEFT, we don't support any use-cases for MM classifier models at the moment, so I thought I'd keep the design super simple and directed. What do you think?

I dont think it hurts to add and extra "else" to cover this case, if someone tries to experiments with it. If you are strongly against it, maybe we could raise an warning?

if hasattr(model, 'output'):
     do_something:
else:
     raise "Could not find model.output"

Comment on lines 49 to 50
_component_: torchtune.models.common.classifier_model
base_model_path: torchtune.models.mistral.mistral_7b
Copy link
Contributor

@felipemello1 felipemello1 Mar 17, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this works, but if we have nested instantiation, maybe a better design would be:

reward_and_value_model:
  _component_: torchtune.models.common.classifier_model
  model:
          _component_: torchtune.models.mistral.mistral_7b
          args

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you mean changing base_model_path to model here? This sounds good to me : )

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

edited so that it renders a bit better

@pbontrager
Copy link
Contributor

Instead of adding a common.py function here, could this be a classifier.py file in torchtune/modules?

@@ -46,17 +46,18 @@ policy_model:
# we need to manually build the mistral classifier model
# because our reward model checkpoint has a larger vocabulary size (due to an added padding token)
reward_and_value_model:
_component_: torchtune.models.mistral._component_builders.mistral_classifier
_component_: torchtune.modules.classifier.classifier_model
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

NEAT

Comment on lines +43 to +55
if hasattr(model, "output"):
del model.output
model.output = nn.Linear(
model.head_dim * model.num_heads, num_classes, bias=False
)
elif hasattr(model, "decoder") and hasattr(model.decoder, "output"):
del model.decoder.output
model.decoder.output = nn.Linear(
model.decoder.head_dim * model.decoder.num_heads, num_classes, bias=False
)
else:
raise ValueError("Could not find a valid output layer to adapt.")
return model
Copy link
Contributor

@felipemello1 felipemello1 Mar 26, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

looks good! Not for this PR, but this wont work with PEFT. But i dont think we want to go there. Maybe we could add a warning: "found LoRA module, this is not supported, replacing it with a new head", but this might be an overkill. Any thoughts?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I can update the docstring but I think adding checks to cover PEFT cases might be overkill : )

Copy link
Contributor

@felipemello1 felipemello1 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

will approve after CLI is green. Thanks for this PR!!

Copy link
Contributor

@joecummings joecummings left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks good! Can we file a TODO to add a small tutorial or something?

) -> Union[TransformerDecoder, nn.Module]:
"""
Create a classifier model from a base model by adapting the output layer.
This builder does not support models which apply PEFT to the output layer.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: make this a Note:

@SalmanMohammadi SalmanMohammadi merged commit 0afea1a into pytorch:main Mar 27, 2025
17 checks passed
@SalmanMohammadi SalmanMohammadi deleted the classifier_builder branch March 27, 2025 16:07
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

7 participants