Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[V1] Clarify input processing and multimodal feature caching logic #13211

Merged
merged 7 commits into from
Feb 13, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 8 additions & 8 deletions vllm/v1/engine/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from vllm.v1.core.scheduler import Scheduler
from vllm.v1.engine import (EngineCoreOutputs, EngineCoreRequest,
EngineCoreRequestType)
from vllm.v1.engine.mm_input_mapper import MMInputMapperServer
from vllm.v1.engine.mm_input_cache import MMInputCacheServer
from vllm.v1.executor.abstract import Executor
from vllm.v1.request import Request, RequestStatus
from vllm.v1.serial_utils import MsgpackDecoder, MsgpackEncoder
Expand Down Expand Up @@ -65,7 +65,7 @@ def __init__(
log_stats=self.log_stats,
)

self.mm_input_mapper_server = MMInputMapperServer(
self.mm_input_cache_server = MMInputCacheServer(
vllm_config.model_config)

def _initialize_kv_caches(self,
Expand Down Expand Up @@ -97,13 +97,13 @@ def add_request(self, request: EngineCoreRequest):
"""Add request to the scheduler."""

if request.mm_hashes is not None:
# Here, if hash exists for an image, then it will be fetched
# from the cache, else it will be added to the cache.
# Note that the cache here is mirrored with the client side of the
# MM mapper, so anything that has a hash must have a HIT cache
# entry here as well.
# Here, if hash exists for a multimodal input, then it will be
# fetched from the cache, else it will be added to the cache.
# Note that the cache here is mirrored with the client cache, so
# anything that has a hash must have a HIT cache entry here
# as well.
assert request.mm_inputs is not None
request.mm_inputs = self.mm_input_mapper_server.process_inputs(
request.mm_inputs = self.mm_input_cache_server.get_and_update(
request.mm_inputs, request.mm_hashes)

req = Request.from_engine_core_request(request)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,18 @@

logger = init_logger(__name__)

# The idea of MM preprocessor caching is based on having a client and a server,
# where the client executes in the frontend process (=P0) and the server in the
# core process (=P1).
# The idea of multimodal preprocessing caching is based on having a client and
# a server, where the client executes in the frontend process (=P0) and the
# server in the core process (=P1).
#
# -- Client: Executes the MM mapper and performs caching of the results.
# -- Server: Performs caching of the results
# -- Client:
# - Apply legacy input_mapper (if one exists) to generate MultiModalKwargs.
# - Perform caching of the generated MultiModalKwargs.
# - This client can be deprecated once all mutimodal models migrate to use
# merged preprocessor with built-in caching functionality.
#
# -- Server:
# - Perform caching of the received MultiModalKwargs.
#
# The caching for both client and server is mirrored/similar, and this allows us
# to avoid the serialization of "mm_inputs" (like pixel values) between
Expand All @@ -27,7 +33,9 @@
MM_CACHE_SIZE = 256


class MMInputMapperClient:
# TODO(ywang96): Deprecate this class once all multimodal models migrate to use
# merged preprocessor with built-in caching functionality.
class MMInputCacheClient:

def __init__(
self,
Expand All @@ -54,7 +62,8 @@ def cache_hit_ratio(self, steps):
logger.debug("MMInputMapper: cache_hit_ratio = %.2f ",
self.mm_cache_hits / self.mm_cache_total)

# TODO: Support modalities beyond image.
# NOTE: process_inputs only supports image inputs since all multimodal
# models with other modalities have migrated to use merged preprocessor.
def process_inputs(
self,
mm_data: MultiModalDataDict,
Expand Down Expand Up @@ -95,7 +104,7 @@ def process_inputs(
# Reuse precomputed input (for merged preprocessor)
mm_input = precomputed_mm_inputs[input_id]
else:
# Apply MM mapper
# Apply legacy input_mapper
mm_input = self.multi_modal_input_mapper(
{"image": [image_inputs[input_id]]},
mm_processor_kwargs=mm_processor_kwargs,
Expand All @@ -114,13 +123,13 @@ def process_inputs(
return ret_inputs


class MMInputMapperServer:
class MMInputCacheServer:

def __init__(self, model_config):
self.use_cache = not model_config.disable_mm_preprocessor_cache
self.mm_cache = LRUCache[str, MultiModalKwargs](MM_CACHE_SIZE)

def process_inputs(
def get_and_update(
self,
mm_inputs: List[Optional[MultiModalKwargs]],
mm_hashes: List[str],
Expand Down
20 changes: 14 additions & 6 deletions vllm/v1/engine/processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from vllm.sampling_params import SamplingParams
from vllm.transformers_utils.tokenizer_group import BaseTokenizerGroup
from vllm.v1.engine import EngineCoreRequest
from vllm.v1.engine.mm_input_mapper import MMInputMapperClient
from vllm.v1.engine.mm_input_cache import MMInputCacheClient


class Processor:
Expand Down Expand Up @@ -46,7 +46,7 @@ def __init__(
model_config)

# Multi-modal (huggingface) input mapper
self.mm_input_mapper_client = MMInputMapperClient(model_config)
self.mm_input_cache_client = MMInputCacheClient(model_config)

# Multi-modal hasher (for images)
self.use_hash = (not model_config.disable_mm_preprocessor_cache) or \
Expand Down Expand Up @@ -106,16 +106,24 @@ def process_inputs(
assert priority == 0, "vLLM V1 does not support priority at the moment."
assert trace_headers is None, "vLLM V1 does not support tracing yet."

# Process inputs.
# Process inputs, which includes:
# 1. Tokenize text prompt, with LoRA request if one exists.
# 2. For multimodal models with a merged preprocessor, preprocess
# multimodal data and expand prompt token ids accordingly.
# 3. Apply prompt adapter to prompt token ids if one exists.
preprocessed_inputs = self.input_preprocessor.preprocess(
prompt,
request_id=request_id,
lora_request=lora_request,
prompt_adapter_request=prompt_adapter_request,
)
eos_token_id = self.input_preprocessor.get_eos_token_id(lora_request)

# Process prompt and prompt token ids.
# Only applicable to multimodal models with legacy input processor.
processed_inputs = self.input_processor(preprocessed_inputs)

self._validate_model_inputs(processed_inputs)
eos_token_id = self.input_preprocessor.get_eos_token_id(lora_request)

if is_encoder_decoder_inputs(processed_inputs):
decoder_inputs = SingletonInputsAdapter(
Expand Down Expand Up @@ -200,8 +208,8 @@ def process_inputs(
key=lambda mm_input: modality_order_dict[list(
mm_input.modalities)[0]])

# Apply mm input cache update (and input mapper if necessary).
sorted_mm_inputs = self.mm_input_mapper_client.process_inputs(
# Apply mm input cache update and legacy input mapper if one exists.
sorted_mm_inputs = self.mm_input_cache_client.process_inputs(
mm_data=decoder_mm_data,
mm_hashes=sorted_mm_hashes,
mm_processor_kwargs=decoder_inputs.mm_processor_kwargs,
Expand Down
7 changes: 4 additions & 3 deletions vllm/v1/worker/gpu_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
from vllm.v1.attention.backends.flash_attn import (FlashAttentionBackend,
FlashAttentionMetadata)
from vllm.v1.core.encoder_cache_manager import compute_encoder_budget
from vllm.v1.engine.mm_input_mapper import MMInputMapperClient
from vllm.v1.engine.mm_input_cache import MMInputCacheClient
from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig,
KVCacheSpec)
from vllm.v1.outputs import LogprobsTensors, ModelRunnerOutput
Expand Down Expand Up @@ -94,9 +94,10 @@ def __init__(
self.mm_registry = MULTIMODAL_REGISTRY
self.uses_mrope = model_config.uses_mrope

# NOTE: Initialized input mapper is only used for processing dummy
# NOTE: Initialized client is only used for processing dummy
# multimodal data into multimodal kwargs for GPU memory profiling.
self.mm_input_mapper_profiling = MMInputMapperClient(self.model_config)
# Only applicable to multimodal models with legacy input mapper.
self.mm_input_mapper_profiling = MMInputCacheClient(self.model_config)
self.mm_input_mapper_profiling.use_cache = False

encoder_compute_budget, encoder_cache_size = compute_encoder_budget(
Expand Down