Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

VLM: compile compatibility #35724

Open
wants to merge 11 commits into
base: main
Choose a base branch
from

Conversation

zucchini-nlp
Copy link
Member

@zucchini-nlp zucchini-nlp commented Jan 16, 2025

What does this PR do?

As per title, adds flags in VLMs when needed, removes test skips and makes sure VLMs are compile compatible. Also for BLIP models adds new cache format in OPT which is one of backbones. Now all official BLIP models can support static cache and thus compile

NOTE:

  • all VLMs have dynamic control in prepare_inputs_for_generation and thus skip test_compile_forward which compiles the model for pre-fill phase. But the test for decoding stage compile is green therefore I'm leaving the flag as True
  • Otherwise tests with -k compile_forward and -k static_ were run for all models and are passing. Some models needed to turn the flag off, which shouldn't be there because MoE cannot do compile currently (dynamic control flow)
  • Regarding executorch which I also checked, the model can be exported and run a forward pass. But the generation won't work and probably would need smth similar to what we do when exporting VLMs in ONNX. Still need to dig more into that in later PRs

How to run compile and export for VLMs:

import requests
from PIL import Image

import torch
from transformers import AutoProcessor, LlavaForConditionalGeneration
from transformers.generation import GenerationConfig
from transformers.cache_utils import StaticCache
from transformers.integrations.executorch import (
    TorchExportableModuleWithStaticCache,
    convert_and_export_with_cache,
)

model_id = "llava-hf/llava-interleave-qwen-0.5b-hf"
model = LlavaForConditionalGeneration.from_pretrained(
    model_id, 
    torch_dtype="float16", 
    device_map="cuda:0",
)
processor = AutoProcessor.from_pretrained(model_id)

conversation = [
    {

      "role": "user",
      "content": [
          {"type": "text", "text": "What are these?"},
          {"type": "image"},
        ],
    },
]
prompt = processor.apply_chat_template(conversation, add_generation_prompt=True)
image_file = "http://images.cocodataset.org/val2017/000000039769.jpg"
raw_image = Image.open(requests.get(image_file, stream=True).raw)
inputs = processor(images=raw_image, text=prompt, return_tensors='pt').to(0, torch.float16)

# Run with static cache which compiles the forward in decoding phase for you
output = model.generate(**inputs, max_new_tokens=20, cache_implementation="static")
print(processor.decode(output[0][2:], skip_special_tokens=True))



# Try to export with `torch.export`. NOTE: TorchExportableModuleWithStaticCache is not ready for VLMs
# and as mentioned above, VLMs might need to export 3 different modules as in ONNX. One for text embedding,
# one for vision backbone and one for the LM backbone with simple decoding token-by-token
max_generation_length = 1000
model = LlavaForConditionalGeneration.from_pretrained(
    model_id,
    device_map="cuda:0",
    torch_dtype="float16",
    attn_implementation="sdpa",
    generation_config=GenerationConfig(
        use_cache=True,
        cache_implementation="static",
        max_length=max_generation_length,
        cache_config={
            "batch_size": 1,
            "max_cache_len": max_generation_length,
        },
    ),
)

# Adapted from `TorchExportableModuleWithStaticCache` with minor changes
class TorchExportableModuleWithStaticCacheForVLM(torch.nn.Module):
    def __init__(self, model):
        super().__init__()
        self.model = model
        self.static_cache = StaticCache(
            config=self.model.config.get_text_config(),
            batch_size=self.model.generation_config.cache_config.batch_size,
            max_cache_len=self.model.generation_config.cache_config.max_cache_len,
            dtype=self.model.dtype,
            device=model.device,
        )
        self.is_causal = any(("CausalLM" in arch or "ConditionalGeneration" in arch) for arch in self.model.config.architectures)
        if self.is_causal:
            causal_mask = torch.triu(
                torch.full(
                    (
                        self.model.generation_config.cache_config.batch_size,
                        1,
                        self.static_cache.max_cache_len,
                        self.static_cache.max_cache_len
                    ),
                    fill_value=torch.finfo(self.model.dtype).min,
                    dtype=self.model.dtype,
                    device=model.device,
                )
            )
            self.register_buffer("mask", causal_mask, persistent=False)

    def forward(
        self,
        input_ids: torch.Tensor,
        cache_position: torch.Tensor,
        pixel_values: torch.Tensor,
    ):
        _, seqlen = input_ids.shape
        attn_mask = self.mask[:, :, cache_position, :] if self.is_causal else None
        outs = self.model(
            input_ids=input_ids,
            attention_mask=attn_mask,
            position_ids=cache_position.unsqueeze(0),
            pixel_values=pixel_values,
            cache_position=cache_position,
            past_key_values=self.static_cache,
            use_cache=True,
        )
        return outs.logits

cache_position = torch.arange(inputs.input_ids.shape[1], dtype=torch.long, device=model.device)
export_inputs = {"input_ids": inputs.input_ids, "cache_position": cache_position, "pixel_values": inputs.pixel_values}

with torch.no_grad():
    exported_program = torch.export.export(
        TorchExportableModuleWithStaticCacheForVLM(model),
        args=(),
        kwargs=export_inputs,
        strict=True,
    )

torch.export.save(exported_program, "exported_llava.pt2")
exported_program = torch.export.load("exported_llava.pt2")
out = exported_program.module().forward(
    input_ids=inputs.input_ids,
    pixel_values=inputs.pixel_values,
    cache_position=cache_position,
)

Comment on lines +162 to +165
if past_key_value is not None:
if not isinstance(past_key_value, EncoderDecoderCache):
curr_past_key_value = past_key_value
else:
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I dont know why but OPT model works as decoder-only but the attention is written as cross-attention (not used anywhere in codebase). So we need to support somehow BC while using the new DynamicCache

As a workaround I simply added a check on cache instance. Another possibility is to accept and return only the correct cache (self or cross attn) but that means all encoder-decoder models will need a change thus breaking BC

@zucchini-nlp zucchini-nlp changed the title [WIP] VLM: compile compatibility VLM: compile compatibility Jan 16, 2025
@zucchini-nlp zucchini-nlp changed the title VLM: compile compatibility [WIP] VLM: compile compatibility Jan 16, 2025
Comment on lines -40 to -41
ignore_index (`int`, *optional*, defaults to -100):
The ignore index for the loss function.
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I believe this can be removed as it is not used anymore when merging inputs. We can also deprecate properly, but i don't think anyone uses it

Comment on lines +2069 to +2078
half_batch_size = self.model_tester.batch_size // 2
inputs_dict_1 = {k: v[:half_batch_size, ...] for k, v in inputs_dict.items() if "head_mask" not in k}
inputs_dict_2 = {
k: v[half_batch_size : half_batch_size * 2, ...]
for k, v in inputs_dict.items()
if "head_mask" not in k
}
self.assertTrue(
inputs_dict_1[model_class.main_input_name].shape == inputs_dict_2[model_class.main_input_name].shape
)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

some models cannot generate only from input_ids, so we pass the whole dict except for the head_mask (anyway we're removing head mask soon )

@@ -83,14 +83,14 @@ def __init__(
moe_intermediate_size=4,
moe_num_experts=4,
moe_topk=2,
num_attention_heads=20,
num_attention_heads=8,
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Aria had hidden size of 32 and 20 heads, which caused by division problems in tests when inferring head_dim

@zucchini-nlp zucchini-nlp changed the title [WIP] VLM: compile compatibility VLM: compile compatibility Jan 17, 2025
@zucchini-nlp
Copy link
Member Author

Ready for review failing test is flaky otherwise everything is passing on my end, including slow test for compile/StaticCache

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

1 participant