Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix torch_dtype in Kolors text encoder with transformers v4.49 #10816

Open
wants to merge 3 commits into
base: main
Choose a base branch
from

Conversation

hlky
Copy link
Collaborator

@hlky hlky commented Feb 18, 2025

What does this PR do?

Tests for Kolors are failing. Tracked the issue to transformers version update. The test model's config contains torch_dtype as a string, in turn dtype is passed as str to torch.empty, seems torch_dtype was previously converted to a torch.dtype or ignored.

TypeError: empty() received an invalid combination of arguments - got (tuple, dtype=str, device=str), but expected one of:
 * (tuple of ints size, *, tuple of names names, torch.memory_format memory_format = None, torch.dtype dtype = None, torch.layout layout = None, torch.device device = None, bool pin_memory = False, bool requires_grad = False)

Generally torch_dtype would be passed to Kolors' pipelines from_pretrained or ChatGLMModel if creating it separately, so this should be ok for end users.

Edit: some tests still failing

Who can review?

Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.

@hlky hlky requested review from yiyixuxu and sayakpaul February 18, 2025 10:35
@hlky hlky marked this pull request as draft February 18, 2025 10:50
@sayakpaul
Copy link
Member

The test model's config contains torch_dtype as a string, in turn dtype is passed as str to torch.empty, seems torch_dtype was previously converted to a torch.dtype or ignored.

I think transformers has a way to map that to the correct dtype:
https://github.com/huggingface/transformers/blob/e6cc410d5b830e280cdc5097cb6ce6ea6a943e5e/src/transformers/modeling_utils.py#L4081

@hlky
Copy link
Collaborator Author

hlky commented Feb 18, 2025

Yeah, I haven't found the exact change but it was working recently on 4.48.3

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@sayakpaul
Copy link
Member

Let's maybe also update the transformers folks about this.

@hlky hlky marked this pull request as ready for review February 18, 2025 11:46
@hlky
Copy link
Collaborator Author

hlky commented Feb 18, 2025

Failing test is now just a temporary Hub issue.

To re-summarize the changes we have

  • Passed torch_dtype to ChatGLMModel in tests - ChatGLMModel's from_pretrained is from transformers
  • Set torch.float32 as the default for torch_dtype and added a warning if the passed torch_dtype is not a torch.dtype, this is for pipeline-level tests, we're not passing torch_dtype to some so ChatGLMModel doesn't get torch_dtype which causes the issue from latest transformers.

@hlky
Copy link
Collaborator Author

hlky commented Feb 18, 2025

Marked as draft again, waiting for huggingface/transformers#36262

@DN6
Copy link
Collaborator

DN6 commented Feb 21, 2025

@hlky I think this can be reopened and merged. Change is safe to make.

@hlky hlky marked this pull request as ready for review February 21, 2025 07:41
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants