Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[VLM] Implement merged multimodal processor for Mllama #11427

Merged
merged 56 commits into from
Feb 13, 2025
Merged
Show file tree
Hide file tree
Changes from 49 commits
Commits
Show all changes
56 commits
Select commit Hold shift + click to select a range
99657bb
draft
Isotr0py Dec 20, 2024
76185c5
fix profiling
Isotr0py Dec 20, 2024
f3ca433
draft
Isotr0py Dec 21, 2024
1da6712
fix processing
Isotr0py Dec 21, 2024
7402e62
refactor
Isotr0py Dec 21, 2024
0fa61d7
Merge branch 'vllm-project:main' into enc-dec-processor
Isotr0py Dec 21, 2024
f9ff072
cleanup
Isotr0py Dec 21, 2024
256f85e
cleanup
Isotr0py Dec 21, 2024
5708658
refactor
Isotr0py Dec 22, 2024
00cf0ef
fix
Isotr0py Dec 22, 2024
c3cd4a0
Merge branch 'vllm-project:main' into enc-dec-processor
Isotr0py Dec 23, 2024
cd780bc
refactor
Isotr0py Dec 24, 2024
ba22a20
code format
Isotr0py Dec 24, 2024
2f2584a
cleanup
Isotr0py Dec 24, 2024
a65ddce
cleanup
Isotr0py Dec 24, 2024
f819d51
refactor
Isotr0py Dec 25, 2024
804f76b
fix text-only inputs
Isotr0py Dec 25, 2024
a715a97
cleanup
Isotr0py Dec 25, 2024
c96fd21
fix text enc-dec model
Isotr0py Dec 25, 2024
4fd4204
fix a typo
Isotr0py Dec 25, 2024
97f420d
Merge remote-tracking branch 'upstream/main' into enc-dec-processor
Isotr0py Feb 4, 2025
d2ab070
reorganize
Isotr0py Feb 6, 2025
4ba1bbc
fix mm_data
Isotr0py Feb 6, 2025
3f6ebec
refactor hf_processor calling
Isotr0py Feb 6, 2025
6928719
Merge branch 'vllm-project:main' into enc-dec-processor
Isotr0py Feb 6, 2025
6088118
fix typo
Isotr0py Feb 6, 2025
407988d
Merge branch 'vllm-project:main' into enc-dec-processor
Isotr0py Feb 7, 2025
6e5cb54
clean up
Isotr0py Feb 7, 2025
3c336d9
fix multi-image profiling
Isotr0py Feb 7, 2025
bd12b30
cleanup
Isotr0py Feb 7, 2025
d47cacb
fix multiimage inference
Isotr0py Feb 7, 2025
0f775c4
cleanup
Isotr0py Feb 7, 2025
5dacf4d
fix processor test
Isotr0py Feb 8, 2025
6bdaf45
add processor test
Isotr0py Feb 8, 2025
862f410
update comment
Isotr0py Feb 8, 2025
9a7c742
Update vllm/inputs/preprocess.py
Isotr0py Feb 9, 2025
b241b6b
fix processor test
Isotr0py Feb 9, 2025
1b13d33
refactor token_per_chunk func
Isotr0py Feb 9, 2025
a7a6589
refactor create_encoder_prompt and remove hardcoded image_token_id
Isotr0py Feb 9, 2025
ea44a41
fix inference
Isotr0py Feb 9, 2025
1b2bd03
make mypy happy
Isotr0py Feb 9, 2025
9572f1e
fix explicit input
Isotr0py Feb 9, 2025
5eafc05
make mypy happy
Isotr0py Feb 9, 2025
709017b
fix text-only explicit prompt
Isotr0py Feb 10, 2025
611d1d7
fix tokens only inputs
Isotr0py Feb 10, 2025
35bcb5e
make mypy happy
Isotr0py Feb 10, 2025
98e47ce
make mypy happy
Isotr0py Feb 10, 2025
3567c1c
mypy again
Isotr0py Feb 10, 2025
5551eae
refactor preprocess func for explicit prompt
Isotr0py Feb 10, 2025
edb99ba
add explicit and implicit_prompt test
Isotr0py Feb 11, 2025
209253a
fix broken cross attn mask test
Isotr0py Feb 11, 2025
3528e61
ooops
Isotr0py Feb 11, 2025
f66ba41
fix entrypoint test
Isotr0py Feb 12, 2025
7850f07
fix whisper
Isotr0py Feb 12, 2025
2621d4a
linting
Isotr0py Feb 12, 2025
62215b8
Merge branch 'vllm-project:main' into enc-dec-processor
Isotr0py Feb 12, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 11 additions & 2 deletions tests/models/multimodal/processing/test_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,14 @@ def _test_processing_correctness(
partial(random_audio, rng, min_len=512, max_len=1024, sr=16000),
}

tokenizer_encode_kwargs = {}
if model_config.hf_config.model_type == "mllama":
# For Mllama, tokenizer will always add bos_token at the beginning of
# prompt by default, causing hf_processor outputs incorrect token ids.
# So we need use `add_special_tokens=False` here to leave bos_token
# to be added by the processor.
tokenizer_encode_kwargs = {"add_special_tokens": False}

for batch_idx in range(num_batches):
mm_data = {
k:
Expand Down Expand Up @@ -122,7 +130,7 @@ def _test_processing_correctness(
f"Failed ({batch_idx=}, {prompt=}, {mm_data=})")

baseline_tokenized_result = baseline_processor.apply(
tokenizer.encode(prompt),
tokenizer.encode(prompt, **tokenizer_encode_kwargs),
mm_data=mm_data,
hf_processor_mm_kwargs={},
)
Expand All @@ -131,7 +139,7 @@ def _test_processing_correctness(
f"Failed ({batch_idx=}, {prompt=}, {mm_data=})")

cached_tokenized_result = cached_processor.apply(
tokenizer.encode(prompt),
tokenizer.encode(prompt, **tokenizer_encode_kwargs),
mm_data=mm_data,
hf_processor_mm_kwargs={},
)
Expand All @@ -154,6 +162,7 @@ def _test_processing_correctness(
"llava-hf/llava-v1.6-mistral-7b-hf",
"llava-hf/LLaVA-NeXT-Video-7B-hf",
"llava-hf/llava-onevision-qwen2-0.5b-ov-hf",
"meta-llama/Llama-3.2-11B-Vision-Instruct",
"TIGER-Lab/Mantis-8B-siglip-llama3",
"mistral-community/pixtral-12b",
"openbmb/MiniCPM-o-2_6",
Expand Down
86 changes: 79 additions & 7 deletions vllm/inputs/preprocess.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,16 @@
# SPDX-License-Identifier: Apache-2.0

import asyncio
from typing import List, Mapping, Optional, Union
from typing import List, Mapping, Optional, Tuple, Union, cast

from typing_extensions import assert_never

from vllm.config import ModelConfig
from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry
from vllm.multimodal.inputs import MultiModalDataDict, MultiModalInputs
from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalEncDecInputs,
MultiModalInputs)
from vllm.prompt_adapter.request import PromptAdapterRequest
from vllm.transformers_utils.tokenizer_group import BaseTokenizerGroup

Expand Down Expand Up @@ -486,6 +487,51 @@ def _build_enc_dec_llm_inputs(
decoder=decoder_inputs,
)

def _separate_enc_dec_inputs_from_mm_processor_outputs(
self,
inputs: SingletonInputs,
decoder_inputs_to_override: Optional[SingletonInputs] = None,
) -> Tuple[SingletonInputs, SingletonInputs]:
"""
For encoder/decoder models only:
Separate Encoder/Decoder inputs from a MultiModalEncDecInputs
"""
encoder_inputs: SingletonInputs
decoder_inputs: SingletonInputs
if inputs["type"] == "multimodal":
# Multimodal data inputs
assert ("encoder_prompt" in inputs
and "encoder_prompt_token_ids" in inputs)
inputs = cast(MultiModalEncDecInputs, inputs)
encoder_inputs = token_inputs(
prompt=inputs["encoder_prompt"],
prompt_token_ids=inputs["encoder_prompt_token_ids"],
)
if decoder_inputs_to_override is not None:
decoder_inputs = MultiModalInputs(
type="multimodal",
prompt=decoder_inputs_to_override.get("prompt", ""),
prompt_token_ids=decoder_inputs_to_override[
"prompt_token_ids"],
mm_kwargs=inputs["mm_kwargs"],
mm_placeholders=inputs["mm_placeholders"],
)
else:
decoder_inputs = MultiModalInputs(
type="multimodal",
prompt=inputs["prompt"],
prompt_token_ids=inputs["prompt_token_ids"],
mm_kwargs=inputs["mm_kwargs"],
mm_placeholders=inputs["mm_placeholders"],
)
DarkLight1337 marked this conversation as resolved.
Show resolved Hide resolved
elif inputs["type"] == "token":
# Text-only inputs
encoder_inputs = token_inputs(prompt="", prompt_token_ids=[])
decoder_inputs = decoder_inputs_to_override or inputs
else:
assert_never(inputs) # type: ignore[arg-type]
return encoder_inputs, decoder_inputs

def _process_encoder_decoder_prompt(
self,
prompt: PromptType,
Expand Down Expand Up @@ -530,21 +576,33 @@ def _process_encoder_decoder_prompt(
prompt["encoder_prompt"],
request_id=request_id,
)

if (decoder_input := prompt["decoder_prompt"]) is None:
decoder_inputs = None
else:
decoder_inputs = self._prompt_to_llm_inputs(
decoder_input,
request_id=request_id,
)
# For multimodal model, override decoder prompt from processor
# with explicit decoder prompt.
if self._can_process_multimodal():
encoder_inputs, decoder_inputs = (
self._separate_enc_dec_inputs_from_mm_processor_outputs(
encoder_inputs, decoder_inputs))
else:
encoder_inputs = self._prompt_to_llm_inputs(
inputs = self._prompt_to_llm_inputs(
prompt,
request_id=request_id,
)
if self._can_process_multimodal():
# Encoder-Decoder Multimodal model
encoder_inputs, decoder_inputs = (
self._separate_enc_dec_inputs_from_mm_processor_outputs(
inputs))
else:
encoder_inputs = inputs

decoder_inputs = None
decoder_inputs = None

return self._build_enc_dec_llm_inputs(encoder_inputs, decoder_inputs)

Expand Down Expand Up @@ -574,13 +632,27 @@ async def _process_encoder_decoder_prompt_async(

encoder_inputs, decoder_inputs = await asyncio.gather(
encoder_task, decoder_task)

# For multimodal model, override decoder prompt from processor
# with explicit decoder prompt.
if self._can_process_multimodal():
encoder_inputs, decoder_inputs = (
self._separate_enc_dec_inputs_from_mm_processor_outputs(
encoder_inputs, decoder_inputs))
else:
encoder_inputs = await self._prompt_to_llm_inputs_async(
inputs = await self._prompt_to_llm_inputs_async(
prompt,
request_id=request_id,
)
if self._can_process_multimodal():
# Encoder-Decoder Multimodal model
encoder_inputs, decoder_inputs = (
self._separate_enc_dec_inputs_from_mm_processor_outputs(
inputs))
Isotr0py marked this conversation as resolved.
Show resolved Hide resolved
else:
encoder_inputs = inputs

decoder_inputs = None
decoder_inputs = None

return self._build_enc_dec_llm_inputs(encoder_inputs, decoder_inputs)

Expand Down
3 changes: 2 additions & 1 deletion vllm/inputs/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -350,7 +350,8 @@ def dummy_data_for_profiling(
)
processor = mm_registry.create_processor(model_config, tokenizer)
profiler = MultiModalProfiler(processor)
dummy_data = profiler.get_dummy_data(seq_len)
dummy_data = profiler.get_dummy_data(
seq_len, is_encoder_data=is_encoder_data)
else:
model_cls, _ = get_model_architecture(model_config)
if is_encoder_data:
Expand Down
Loading