-
Notifications
You must be signed in to change notification settings - Fork 5.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Comprehensive type checking for from_pretrained
kwargs
#10758
base: main
Are you sure you want to change the base?
Conversation
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks @guiyrt, nice work. Could you take a look through pipeline test output searching for Expected types for
(GitHub's built in search works best)
There are some easy cases that we could fix like
Expected types for feature_extractor: (<class 'transformers.models.clip.image_processing_clip.CLIPImageProcessor'>,), got <class 'transformers.models.clip.feature_extraction_clip.CLIPFeatureExtractor'>.
For that we could do global find+replace on feature_extractor: CLIPImageProcessor
-> feature_extractor: CLIPFeatureExtractor
.
and some that need investigating
Expected types for unet: (<class 'inspect._empty'>,), got <class 'diffusers.models.unets.unet_2d.UNet2DModel'>.
Type correctness is not strictly enforced so some warnings are expected but we should make a best effort to reduce the number of new warnings that we're introducing. If we find a particular component to be a problem we can skip it like scheduler
.
Failing tests appear unrelated, will re-run later. |
@hlky Findings from looking through the test logs TL;DR 1. Using XYZFast
|
def test_download_legacy_variants_with_sharded_ckpts_raises_warning(self): | |
repo_id = "hf-internal-testing/tiny-stable-diffusion-pipe-variants-all-kinds" |
3. PipelineFastTests::test_optional_components
This test purposefully sets requires_safety_checker
as [True, True]
and safety_checker
as an unet and feature_extractor
as a function, so the warnings here are expected.
2+2+2 occurrences
Expected types for safety_checker: (<class 'diffusers.pipelines.stable_diffusion.safety_checker.StableDiffusionSafetyChecker'>,),
got <class 'diffusers.models.unets.unet_2d_condition.UNet2DConditionModel'>.
Expected types for requires_safety_checker: (<class 'bool'>,), got typing.List[bool].
Expected types for feature_extractor: (<class 'transformers.models.clip.image_processing_clip.CLIPImageProcessor'>,),
got <class 'function'>.
diffusers/tests/pipelines/test_pipelines.py
Lines 1695 to 1701 in 8ae8008
# Test that partially loading works | |
sd = StableDiffusionPipeline.from_pretrained( | |
tmpdirname, | |
feature_extractor=self.dummy_extractor, | |
safety_checker=unet, | |
requires_safety_checker=[True, True], | |
) |
4. Missing type hinting
Added the intended types.
6+6 occurrences from HunyuanDiTPipeline
s.
Expected types for text_encoder_2: (<class 'inspect._empty'>,),
got <class 'transformers.models.t5.modeling_t5.T5EncoderModel'>.
Expected types for tokenizer_2: (<class 'inspect._empty'>,),
got <class 'transformers.models.t5.tokenization_t5_fast.T5TokenizerFast'>.
diffusers/src/diffusers/pipelines/hunyuandit/pipeline_hunyuandit.py
Lines 210 to 211 in 8ae8008
text_encoder_2=T5EncoderModel, | |
tokenizer_2=MT5Tokenizer, |
diffusers/src/diffusers/pipelines/controlnet_hunyuandit/pipeline_hunyuandit_controlnet.py
Lines 235 to 236 in 57ac673
text_encoder_2=T5EncoderModel, | |
tokenizer_2=MT5Tokenizer, |
7+7 occurrences from CustomPipeline
tests. Only showed for unet
because scheduler
is not checked.
Expected types for unet: (<class 'inspect._empty'>,),
got <class 'diffusers.models.unets.unet_2d.UNet2DModel'>.
Expected types for unet: (<class 'inspect._empty'>,),
got <class 'diffusers.models.unets.unet_1d.UNet1DModel'>.
def __init__(self, unet, scheduler): |
def __init__(self, unet, scheduler): |
def __init__(self, unet, scheduler): |
def __init__(self, unet, scheduler): |
def __init__(self, unet, scheduler): |
5. CustomPipelineTests
Not sure what to make of this.
2+2+2 occurrences
Expected types for unet: (<class 'diffusers_modules.local.unet.my_unet_model.MyUNetModel'>,),
got <class 'diffusers_modules.local.my_unet_model.MyUNetModel'>.
Expected types for unet: (<class 'diffusers.models.unets.unet_2d_condition.UNet2DConditionModel'>,),
got <class 'diffusers_modules.local.my_unet_model.MyUNetModel'>.
Expected types for unet: (<class 'diffusers_modules.local.unet.my_unet_model.MyUNetModel'>,),
got <class 'diffusers_modules.local.my_unet_model.MyUNetModel'>.
6. Incomplete unet
type hints in AnimateDiffVideoToVideoPipeline
s
Changed to be unet: Union[UNet2DConditionModel, UNetMotionModel]
, as in AnimateDiffSDXLPipeline
.
8 occurrences
Expected types for unet: (<class 'diffusers.models.unets.unet_2d_condition.UNet2DConditionModel'>,),
got <class 'diffusers.models.unets.unet_motion_model.UNetMotionModel'>.
diffusers/src/diffusers/pipelines/animatediff/pipeline_animatediff_video2video_controlnet.py
Line 249 in 8ae8008
unet: UNet2DConditionModel, |
diffusers/src/diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py
Line 227 in 8ae8008
unet: UNet2DConditionModel, |
7. Subclasses not checked in is_valid_type
This should be correct, so changed is_valid_type
to use isinstance
to also allow subsclasses. This way, you can annotate with a parent class.
4 occurrences
Expected types for bert: (<class 'transformers.modeling_utils.PreTrainedModel'>,),
got <class 'transformers.models.clip.modeling_clip.CLIPTextModel'>.
8. Type hinting with AutoTokenizer
and AutoModel
on SanaPipelines
s
The proper base classes are PreTrainedModel
and PreTrainedTokenizerBase
.
7 occurrences
Expected types for tokenizer: (<class 'transformers.models.auto.tokenization_auto.AutoTokenizer'>,),
got <class 'transformers.models.gemma.tokenization_gemma.GemmaTokenizer'>.
3 occurences
Expected types for tokenizer: (<class 'transformers.models.auto.tokenization_auto.AutoTokenizer'>,),
got <class 'transformers.models.gemma.tokenization_gemma_fast.GemmaTokenizerFast'>.
3 occurrences
Expected types for text_encoder: (<class 'transformers.models.auto.modeling_auto.AutoModel'>,),
got <class 'transformers.models.gemma.modeling_gemma.GemmaForCausalLM'>.
3 occurrences
Expected types for text_encoder: (<class 'transformers.models.auto.modeling_auto.AutoModelForCausalLM'>,),
got <class 'transformers.models.gemma2.modeling_gemma2.Gemma2ForCausalLM'>.
4 occurrences
Expected types for text_encoder: (<class 'transformers.models.auto.modeling_auto.AutoModelForCausalLM'>,),
got <class 'transformers.models.gemma2.modeling_gemma2.Gemma2Model'>.
diffusers/src/diffusers/pipelines/sana/pipeline_sana.py
Lines 203 to 204 in 57ac673
tokenizer: AutoTokenizer, | |
text_encoder: AutoModelForCausalLM, |
diffusers/src/diffusers/pipelines/pag/pipeline_pag_sana.py
Lines 163 to 164 in 57ac673
tokenizer: AutoTokenizer, | |
text_encoder: AutoModelForCausalLM, |
9. Interchanged use of CLIPTextModel
and CLIPTextModelWithProjection
Just swapped with the intended type.
4 occurrences
Expected types for text_encoder: (<class 'transformers.models.clip.modeling_clip.CLIPTextModelWithProjection'>,),
got <class 'transformers.models.clip.modeling_clip.CLIPTextModel'>.
text_encoder: CLIPTextModelWithProjection, |
8 occurrences
Expected types for text_encoder: (<class 'transformers.models.clip.modeling_clip.CLIPTextModel'>,),
got <class 'transformers.models.clip.modeling_clip.CLIPTextModelWithProjection'>.
text_encoder: CLIPTextModel, |
diffusers/src/diffusers/pipelines/stable_cascade/pipeline_stable_cascade_combined.py
Line 79 in 8ae8008
text_encoder: CLIPTextModel, |
4 occurrences
Expected types for prior_text_encoder: (<class 'transformers.models.clip.modeling_clip.CLIPTextModel'>,),
got <class 'transformers.models.clip.modeling_clip.CLIPTextModelWithProjection'>.
diffusers/src/diffusers/pipelines/stable_cascade/pipeline_stable_cascade_combined.py
Line 84 in 8ae8008
prior_text_encoder: CLIPTextModel, |
Found another warning related to custom pipelines, this time on "hf-internal-testing/diffusers-dummy-pipeline". The fix is having the correct type hinting there. 4 occurrences
diffusers/tests/pipelines/test_pipelines.py Lines 1043 to 1053 in 067eab1
|
I opened PRs on the hf-internal-testing repos with warnings: The last one is regarding arguments with no annotations, which might be common on custom pipelines? To keep warnings relevant, it might be a good idea to skip type checking if there is no type annotated for a given argument. |
|
This to-do includes:
|
@hlky can we move the functions |
@yiyixuxu WDYT about |
I updated the docs. But if I got it right, Lumina v1 uses |
Tests failed due to network issues I think. I noticed yesterday very slow download speeds from the hub, anything you are aware? |
Gentle ping @yiyixuxu |
thanks for the PR @guiyrt @hlky |
What does this PR do?
Changes
List[ControlNetModel]
instead oflist
)To-do
is_valid_type
andget_detailed_type
be placed?According to new functions location, add simple tests for type checking.These changes are proposed based on testing for #10747.
Example warning when providing
controlnet
asList[ControlNetUnionModel]
forStableDiffusionXLControlNetPipeline
, whereList[ControlNetModel]
is expected:Code for warning replication
Before submitting
documentation guidelines, and
here are tips on formatting docstrings.
Who can review?
@hlky