Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Comprehensive type checking for from_pretrained kwargs #10758

Open
wants to merge 13 commits into
base: main
Choose a base branch
from

Conversation

guiyrt
Copy link
Contributor

@guiyrt guiyrt commented Feb 10, 2025

What does this PR do?

Changes

  • Moved type checking to just before pipeline instantiation, so all kwargs are checked
  • Full type-check for collections (list, dicts, ...), every element checked
  • More detailed warning for unexpected arguments type (List[ControlNetModel] instead of list)

To-do

  • Where should the new functions is_valid_type and get_detailed_type be placed?
  • According to new functions location, add simple tests for type checking.

These changes are proposed based on testing for #10747.

Example warning when providing controlnetas List[ControlNetUnionModel] for StableDiffusionXLControlNetPipeline, where List[ControlNetModel] is expected:

Expected types for controlnet: (<class 'diffusers.models.controlnets.controlnet.ControlNetModel'>,
typing.List[diffusers.models.controlnets.controlnet.ControlNetModel], 
typing.Tuple[diffusers.models.controlnets.controlnet.ControlNetModel],
<class 'diffusers.models.controlnets.multicontrolnet.MultiControlNetModel'>),
got typing.List[diffusers.models.controlnets.controlnet_union.ControlNetUnionModel].
Code for warning replication
import torch

from diffusers import StableDiffusionXLControlNetPipeline
from diffusers.models import ControlNetUnionModel, AutoencoderKL
from diffusers.utils import load_image


controlnet = ControlNetUnionModel.from_pretrained(
    "brad-twinkl/controlnet-union-sdxl-1.0-promax", torch_dtype=torch.float16
)

vae = AutoencoderKL.from_pretrained(
    "madebyollin/sdxl-vae-fp16-fix", torch_dtype=torch.float16
)

pipe = StableDiffusionXLControlNetPipeline.from_pretrained(
    "stabilityai/stable-diffusion-xl-base-1.0",
    controlnet=[controlnet, controlnet],
    vae=vae,
    torch_dtype=torch.float16,
    variant="fp16",
)

room_seg_img = load_image("https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/sd_controlnet/room_seg.png")
pose_img = load_image("https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/sd_controlnet/pose.png")


pipe.enable_model_cpu_offload()

image = pipe(
    prompt="an astronaut in a space station",
    width=1024,
    height=1024,
    negative_prompt="lowres, low quality, worst quality",
    generator=torch.manual_seed(42),
    guidance_scale=5,
    num_inference_steps=50,
    image=[pose_img, room_seg_img],
).images[0]

image.save("result.jpg")

Before submitting

Who can review?

@hlky

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

Copy link
Collaborator

@hlky hlky left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @guiyrt, nice work. Could you take a look through pipeline test output searching for Expected types for (GitHub's built in search works best)

There are some easy cases that we could fix like

Expected types for feature_extractor: (<class 'transformers.models.clip.image_processing_clip.CLIPImageProcessor'>,), got <class 'transformers.models.clip.feature_extraction_clip.CLIPFeatureExtractor'>.

For that we could do global find+replace on feature_extractor: CLIPImageProcessor -> feature_extractor: CLIPFeatureExtractor.

and some that need investigating

Expected types for unet: (<class 'inspect._empty'>,), got <class 'diffusers.models.unets.unet_2d.UNet2DModel'>.

Type correctness is not strictly enforced so some warnings are expected but we should make a best effort to reduce the number of new warnings that we're introducing. If we find a particular component to be a problem we can skip it like scheduler.

@hlky
Copy link
Collaborator

hlky commented Feb 10, 2025

Failing tests appear unrelated, will re-run later.

@guiyrt
Copy link
Contributor Author

guiyrt commented Feb 12, 2025

@hlky Findings from looking through the test logs

TL;DR
tokenizer is the one with most warnings, for example, when T5Tokenizer is annotated but T5TokenizerFast is used. Most of the warnings are smaller things and most are corrected/addressed in 5ca27aa. Doing find+replace for Union[BaseTokenizer, FastTokenizer] deals with this problem, but will change many files, is this ok?

1. Using XYZFast tokenizer when only XYZ is annotated (and vice-versa)

We can make a quick search and replace and update all tokenizer annotations to be Union[XYZBase, XYZFast], but as this is a big change, let me know if you agree.

18 occurrences

Expected types for tokenizer: (<class 'transformers.models.xlm_roberta.tokenization_xlm_roberta.XLMRobertaTokenizer'>,), 
got <class 'transformers.models.xlm_roberta.tokenization_xlm_roberta_fast.XLMRobertaTokenizerFast'>.

68 occurrences

Expected types for tokenizer: (<class 'transformers.models.t5.tokenization_t5.T5Tokenizer'>,),
got <class 'transformers.models.t5.tokenization_t5_fast.T5TokenizerFast'>.

9 occurrences

Expected types for tokenizer: (<class 'transformers.models.bert.tokenization_bert.BertTokenizer'>,),
got <class 'transformers.models.bert.tokenization_bert_fast.BertTokenizerFast'>.

4 occurences

Expected types for tokenizer: (<class 'transformers.models.llama.tokenization_llama_fast.LlamaTokenizerFast'>,), got <class 'transformers.models.llama.tokenization_llama.LlamaTokenizer'>.

2. CLIPFeatureExtractor as feature_extractor

This comes from tests that use a hf-internal-testing repo with legacy CLIPFeatureExtractor instead of CLIPImageProcessor. A warning from transformers is also thrown FutureWarning: The class CLIPFeatureExtractor is deprecated and will be removed in version 5 of Transformers. Please use CLIPImageProcessor instead.

15 occurrences

Expected types for feature_extractor: (<class 'transformers.models.clip.image_processing_clip.CLIPImageProcessor'>,),
got <class 'transformers.models.clip.feature_extraction_clip.CLIPFeatureExtractor'>.

def test_download_legacy_variants_with_sharded_ckpts_raises_warning(self):
repo_id = "hf-internal-testing/tiny-stable-diffusion-pipe-variants-all-kinds"

3. PipelineFastTests::test_optional_components

This test purposefully sets requires_safety_checker as [True, True] and safety_checker as an unet and feature_extractor as a function, so the warnings here are expected.

2+2+2 occurrences

Expected types for safety_checker: (<class 'diffusers.pipelines.stable_diffusion.safety_checker.StableDiffusionSafetyChecker'>,),
got <class 'diffusers.models.unets.unet_2d_condition.UNet2DConditionModel'>.
Expected types for requires_safety_checker: (<class 'bool'>,), got typing.List[bool].
Expected types for feature_extractor: (<class 'transformers.models.clip.image_processing_clip.CLIPImageProcessor'>,),
got <class 'function'>.

# Test that partially loading works
sd = StableDiffusionPipeline.from_pretrained(
tmpdirname,
feature_extractor=self.dummy_extractor,
safety_checker=unet,
requires_safety_checker=[True, True],
)

4. Missing type hinting

Added the intended types.

6+6 occurrences from HunyuanDiTPipelines.

Expected types for text_encoder_2: (<class 'inspect._empty'>,),
got <class 'transformers.models.t5.modeling_t5.T5EncoderModel'>.
Expected types for tokenizer_2: (<class 'inspect._empty'>,),
got <class 'transformers.models.t5.tokenization_t5_fast.T5TokenizerFast'>.

text_encoder_2=T5EncoderModel,
tokenizer_2=MT5Tokenizer,

text_encoder_2=T5EncoderModel,
tokenizer_2=MT5Tokenizer,

7+7 occurrences from CustomPipeline tests. Only showed for unet because scheduler is not checked.

Expected types for unet: (<class 'inspect._empty'>,),
got <class 'diffusers.models.unets.unet_2d.UNet2DModel'>.
Expected types for unet: (<class 'inspect._empty'>,),
got <class 'diffusers.models.unets.unet_1d.UNet1DModel'>.

def __init__(self, unet, scheduler):

def __init__(self, unet, scheduler):

def __init__(self, unet, scheduler):

5. CustomPipelineTests

Not sure what to make of this.

2+2+2 occurrences

Expected types for unet: (<class 'diffusers_modules.local.unet.my_unet_model.MyUNetModel'>,),
got <class 'diffusers_modules.local.my_unet_model.MyUNetModel'>.
Expected types for unet: (<class 'diffusers.models.unets.unet_2d_condition.UNet2DConditionModel'>,),
got <class 'diffusers_modules.local.my_unet_model.MyUNetModel'>.
Expected types for unet: (<class 'diffusers_modules.local.unet.my_unet_model.MyUNetModel'>,),
got <class 'diffusers_modules.local.my_unet_model.MyUNetModel'>.

6. Incomplete unet type hints in AnimateDiffVideoToVideoPipelines

Changed to be unet: Union[UNet2DConditionModel, UNetMotionModel], as in AnimateDiffSDXLPipeline.

8 occurrences

Expected types for unet: (<class 'diffusers.models.unets.unet_2d_condition.UNet2DConditionModel'>,),
got <class 'diffusers.models.unets.unet_motion_model.UNetMotionModel'>.

7. Subclasses not checked in is_valid_type

This should be correct, so changed is_valid_type to use isinstance to also allow subsclasses. This way, you can annotate with a parent class.

4 occurrences

Expected types for bert: (<class 'transformers.modeling_utils.PreTrainedModel'>,),
got <class 'transformers.models.clip.modeling_clip.CLIPTextModel'>.

8. Type hinting with AutoTokenizer and AutoModel on SanaPipeliness

The proper base classes are PreTrainedModel and PreTrainedTokenizerBase.

7 occurrences

Expected types for tokenizer: (<class 'transformers.models.auto.tokenization_auto.AutoTokenizer'>,),
got <class 'transformers.models.gemma.tokenization_gemma.GemmaTokenizer'>.

3 occurences

Expected types for tokenizer: (<class 'transformers.models.auto.tokenization_auto.AutoTokenizer'>,),
got <class 'transformers.models.gemma.tokenization_gemma_fast.GemmaTokenizerFast'>.

3 occurrences

Expected types for text_encoder: (<class 'transformers.models.auto.modeling_auto.AutoModel'>,),
got <class 'transformers.models.gemma.modeling_gemma.GemmaForCausalLM'>.

3 occurrences

Expected types for text_encoder: (<class 'transformers.models.auto.modeling_auto.AutoModelForCausalLM'>,),
got <class 'transformers.models.gemma2.modeling_gemma2.Gemma2ForCausalLM'>.

4 occurrences

Expected types for text_encoder: (<class 'transformers.models.auto.modeling_auto.AutoModelForCausalLM'>,),
got <class 'transformers.models.gemma2.modeling_gemma2.Gemma2Model'>.

tokenizer: AutoTokenizer,
text_encoder: AutoModelForCausalLM,

tokenizer: AutoTokenizer,
text_encoder: AutoModelForCausalLM,

9. Interchanged use of CLIPTextModel and CLIPTextModelWithProjection

Just swapped with the intended type.

4 occurrences

Expected types for text_encoder: (<class 'transformers.models.clip.modeling_clip.CLIPTextModelWithProjection'>,),
got <class 'transformers.models.clip.modeling_clip.CLIPTextModel'>.

text_encoder: CLIPTextModelWithProjection,

8 occurrences

Expected types for text_encoder: (<class 'transformers.models.clip.modeling_clip.CLIPTextModel'>,),
got <class 'transformers.models.clip.modeling_clip.CLIPTextModelWithProjection'>.

4 occurrences

Expected types for prior_text_encoder: (<class 'transformers.models.clip.modeling_clip.CLIPTextModel'>,),
got <class 'transformers.models.clip.modeling_clip.CLIPTextModelWithProjection'>.

@guiyrt
Copy link
Contributor Author

guiyrt commented Feb 12, 2025

Found another warning related to custom pipelines, this time on "hf-internal-testing/diffusers-dummy-pipeline". The fix is having the correct type hinting there.

4 occurrences

Expected types for unet: (<class 'inspect._empty'>,),
got <class 'diffusers.models.unets.unet_2d.UNet2DModel'>

def test_run_custom_pipeline(self):
pipeline = DiffusionPipeline.from_pretrained(
"google/ddpm-cifar10-32", custom_pipeline="hf-internal-testing/diffusers-dummy-pipeline"
)
pipeline = pipeline.to(torch_device)
images, output_str = pipeline(num_inference_steps=2, output_type="np")
assert images[0].shape == (1, 32, 32, 3)
# compare output to https://huggingface.co/hf-internal-testing/diffusers-dummy-pipeline/blob/main/pipeline.py#L102
assert output_str == "This is a test"

@guiyrt
Copy link
Contributor Author

guiyrt commented Feb 12, 2025

I opened PRs on the hf-internal-testing repos with warnings:
Replacing deprecated CLIPFeatureExtractor for CLIPImageProcessor 1
Replacing deprecated CLIPFeatureExtractor for CLIPImageProcessor 2
Added type annotations for pipeline init args

The last one is regarding arguments with no annotations, which might be common on custom pipelines? To keep warnings relevant, it might be a good idea to skip type checking if there is no type annotated for a given argument.

@hlky
Copy link
Collaborator

hlky commented Feb 12, 2025

  1. Maybe we just don't check tokenizer
  2. If it's an internal testing checkpoint the warning is not important
  3. If it's expected it's ok
  4. Thanks
  5. Custom pipeline so not important
  6. Thanks
  7. Annotations of parent class PreTrainedModel are not that useful, changing to correct type be left as a todo
  8. As above
  9. Thanks

@guiyrt
Copy link
Contributor Author

guiyrt commented Feb 13, 2025

  1. Maybe we just don't check tokenizer

tokenizer is now skipped

  1. Annotations of parent class PreTrainedModel are not that useful, changing to correct type be left as a todo

This to-do includes:

  • LDMTextToImagePipeline: Supposedly the types in the docs are LDMBertModel and BertTokenizer, but text_encoder and tokenizer used in the tests are CLIPTextModel and CLIPTokenizer.
  • LuminaText2ImgPipeline and Lumina2Text2ImgPipeline: Docs mention T5, but tests use Gemma.
  • SanaPAGPipelineand SanaPipeline: From what I understood and the tests, these use GemmaTokenizer[Fast] and Gemma2Model for SanaPipeline tests and Gemma2CausalLM for SanaPAGPipeline, so I annotated text_encoder with Gemma2PreTrainedModel type, as it works for both.

@guiyrt
Copy link
Contributor Author

guiyrt commented Feb 13, 2025

@hlky can we move the functions get_detailed_type and is_valid_type to some utils file? Maybe a new typing_utils.py or something? I don't think they should stay in the middle of DiffusionPipeline::from_pretrained

@hlky
Copy link
Collaborator

hlky commented Feb 13, 2025

LDMTextToImagePipeline I'm not sure, could be that it supports both types or incorrect type hint.

LuminaText2ImgPipeline Yeah Lumina is Gemma, docstring/type hint would have been copied from some other pipeline.

SanaPAGPipeline and SanaPipeline is Gemma, not sure if it should be Gemma2Model or Gemma2ForCausalLM though and I'm assuming it's supposed to be the same for both. Using Gemma2PreTrainedModel should be ok.

@yiyixuxu WDYT about typing_utils.py? Might be some other code that could be moved there, these functions probably won't be used elsewhere though. IMO I don't mind the functions being in from_pretrained, if I'm working on from_pretrained I've got all the context of those functions immediately available.

@guiyrt
Copy link
Contributor Author

guiyrt commented Feb 13, 2025

LuminaText2ImgPipeline Yeah Lumina is Gemma, docstring/type hint would have been copied from some other pipeline.

I updated the docs. But if I got it right, Lumina v1 uses GemmaModel and v2 uses Gemma2Model, however the FastTests of both used GemmaForCausalLM. For Lumina v1 we can annotate as GemmaPreTrainedModel, but if we annotated Lumina v2 with Gemma2PreTrainedModel it would produce a warning for the tests. So on top of updating the type annotations, I also updated the tests for Lumina v2 to use Gemma2ForCausalLM, it was easy because there is no comparison with expected output hardcoded. Ran it locally and passed :)

@guiyrt
Copy link
Contributor Author

guiyrt commented Feb 13, 2025

Tests failed due to network issues I think. I noticed yesterday very slow download speeds from the hub, anything you are aware?

@hlky
Copy link
Collaborator

hlky commented Feb 13, 2025

Just temporary issues, happens sometimes. Thanks for all the iterations on this @guiyrt, should be good to go after @yiyixuxu's comments on whether to add typing_utils.py.

@hlky
Copy link
Collaborator

hlky commented Feb 20, 2025

Gentle ping @yiyixuxu

@yiyixuxu
Copy link
Collaborator

thanks for the PR @guiyrt @hlky
it can to into this util file I think! https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/pipeline_loading_utils.py

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants