diff --git a/vllm/v1/engine/core.py b/vllm/v1/engine/core.py index e19680355584c..4642ac1778ed0 100644 --- a/vllm/v1/engine/core.py +++ b/vllm/v1/engine/core.py @@ -20,7 +20,7 @@ from vllm.v1.core.scheduler import Scheduler from vllm.v1.engine import (EngineCoreOutputs, EngineCoreRequest, EngineCoreRequestType) -from vllm.v1.engine.mm_input_mapper import MMInputMapperServer +from vllm.v1.engine.mm_input_cache import MMInputCacheServer from vllm.v1.executor.abstract import Executor from vllm.v1.request import Request, RequestStatus from vllm.v1.serial_utils import MsgpackDecoder, MsgpackEncoder @@ -65,7 +65,7 @@ def __init__( log_stats=self.log_stats, ) - self.mm_input_mapper_server = MMInputMapperServer( + self.mm_input_cache_server = MMInputCacheServer( vllm_config.model_config) def _initialize_kv_caches(self, @@ -102,13 +102,13 @@ def add_request(self, request: EngineCoreRequest): """Add request to the scheduler.""" if request.mm_hashes is not None: - # Here, if hash exists for an image, then it will be fetched - # from the cache, else it will be added to the cache. - # Note that the cache here is mirrored with the client side of the - # MM mapper, so anything that has a hash must have a HIT cache - # entry here as well. + # Here, if hash exists for a multimodal input, then it will be + # fetched from the cache, else it will be added to the cache. + # Note that the cache here is mirrored with the client cache, so + # anything that has a hash must have a HIT cache entry here + # as well. assert request.mm_inputs is not None - request.mm_inputs = self.mm_input_mapper_server.process_inputs( + request.mm_inputs = self.mm_input_cache_server.get_and_update( request.mm_inputs, request.mm_hashes) req = Request.from_engine_core_request(request) diff --git a/vllm/v1/engine/mm_input_mapper.py b/vllm/v1/engine/mm_input_cache.py similarity index 82% rename from vllm/v1/engine/mm_input_mapper.py rename to vllm/v1/engine/mm_input_cache.py index 83a0d9db161d2..e1b6679c284b4 100644 --- a/vllm/v1/engine/mm_input_mapper.py +++ b/vllm/v1/engine/mm_input_cache.py @@ -10,12 +10,18 @@ logger = init_logger(__name__) -# The idea of MM preprocessor caching is based on having a client and a server, -# where the client executes in the frontend process (=P0) and the server in the -# core process (=P1). +# The idea of multimodal preprocessing caching is based on having a client and +# a server, where the client executes in the frontend process (=P0) and the +# server in the core process (=P1). # -# -- Client: Executes the MM mapper and performs caching of the results. -# -- Server: Performs caching of the results +# -- Client: +# - Apply legacy input_mapper (if one exists) to generate MultiModalKwargs. +# - Perform caching of the generated MultiModalKwargs. +# - This client can be deprecated once all mutimodal models migrate to use +# merged preprocessor with built-in caching functionality. +# +# -- Server: +# - Perform caching of the received MultiModalKwargs. # # The caching for both client and server is mirrored/similar, and this allows us # to avoid the serialization of "mm_inputs" (like pixel values) between @@ -27,7 +33,9 @@ MM_CACHE_SIZE = 256 -class MMInputMapperClient: +# TODO(ywang96): Deprecate this class once all multimodal models migrate to use +# merged preprocessor with built-in caching functionality. +class MMInputCacheClient: def __init__( self, @@ -54,7 +62,8 @@ def cache_hit_ratio(self, steps): logger.debug("MMInputMapper: cache_hit_ratio = %.2f ", self.mm_cache_hits / self.mm_cache_total) - # TODO: Support modalities beyond image. + # NOTE: process_inputs only supports image inputs since all multimodal + # models with other modalities have migrated to use merged preprocessor. def process_inputs( self, mm_data: MultiModalDataDict, @@ -95,7 +104,7 @@ def process_inputs( # Reuse precomputed input (for merged preprocessor) mm_input = precomputed_mm_inputs[input_id] else: - # Apply MM mapper + # Apply legacy input_mapper mm_input = self.multi_modal_input_mapper( {"image": [image_inputs[input_id]]}, mm_processor_kwargs=mm_processor_kwargs, @@ -114,13 +123,13 @@ def process_inputs( return ret_inputs -class MMInputMapperServer: +class MMInputCacheServer: def __init__(self, model_config): self.use_cache = not model_config.disable_mm_preprocessor_cache self.mm_cache = LRUCache[str, MultiModalKwargs](MM_CACHE_SIZE) - def process_inputs( + def get_and_update( self, mm_inputs: List[Optional[MultiModalKwargs]], mm_hashes: List[str], diff --git a/vllm/v1/engine/processor.py b/vllm/v1/engine/processor.py index 70876b03a8236..b7eee5a39972b 100644 --- a/vllm/v1/engine/processor.py +++ b/vllm/v1/engine/processor.py @@ -17,7 +17,7 @@ from vllm.sampling_params import SamplingParams from vllm.transformers_utils.tokenizer_group import BaseTokenizerGroup from vllm.v1.engine import EngineCoreRequest -from vllm.v1.engine.mm_input_mapper import MMInputMapperClient +from vllm.v1.engine.mm_input_cache import MMInputCacheClient class Processor: @@ -46,7 +46,7 @@ def __init__( model_config) # Multi-modal (huggingface) input mapper - self.mm_input_mapper_client = MMInputMapperClient(model_config) + self.mm_input_cache_client = MMInputCacheClient(model_config) # Multi-modal hasher (for images) self.use_hash = (not model_config.disable_mm_preprocessor_cache) or \ @@ -106,16 +106,24 @@ def process_inputs( assert priority == 0, "vLLM V1 does not support priority at the moment." assert trace_headers is None, "vLLM V1 does not support tracing yet." - # Process inputs. + # Process inputs, which includes: + # 1. Tokenize text prompt, with LoRA request if one exists. + # 2. For multimodal models with a merged preprocessor, preprocess + # multimodal data and expand prompt token ids accordingly. + # 3. Apply prompt adapter to prompt token ids if one exists. preprocessed_inputs = self.input_preprocessor.preprocess( prompt, request_id=request_id, lora_request=lora_request, prompt_adapter_request=prompt_adapter_request, ) + eos_token_id = self.input_preprocessor.get_eos_token_id(lora_request) + + # Process prompt and prompt token ids. + # Only applicable to multimodal models with legacy input processor. processed_inputs = self.input_processor(preprocessed_inputs) + self._validate_model_inputs(processed_inputs) - eos_token_id = self.input_preprocessor.get_eos_token_id(lora_request) if is_encoder_decoder_inputs(processed_inputs): decoder_inputs = SingletonInputsAdapter( @@ -200,8 +208,8 @@ def process_inputs( key=lambda mm_input: modality_order_dict[list( mm_input.modalities)[0]]) - # Apply mm input cache update (and input mapper if necessary). - sorted_mm_inputs = self.mm_input_mapper_client.process_inputs( + # Apply mm input cache update and legacy input mapper if one exists. + sorted_mm_inputs = self.mm_input_cache_client.process_inputs( mm_data=decoder_mm_data, mm_hashes=sorted_mm_hashes, mm_processor_kwargs=decoder_inputs.mm_processor_kwargs, diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 5d8da7545f03f..fa4bd81a28dd8 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -27,7 +27,7 @@ from vllm.v1.attention.backends.flash_attn import (FlashAttentionBackend, FlashAttentionMetadata) from vllm.v1.core.encoder_cache_manager import compute_encoder_budget -from vllm.v1.engine.mm_input_mapper import MMInputMapperClient +from vllm.v1.engine.mm_input_cache import MMInputCacheClient from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig, KVCacheSpec) from vllm.v1.outputs import LogprobsTensors, ModelRunnerOutput @@ -95,9 +95,10 @@ def __init__( self.mm_registry = MULTIMODAL_REGISTRY self.uses_mrope = model_config.uses_mrope - # NOTE: Initialized input mapper is only used for processing dummy + # NOTE: Initialized client is only used for processing dummy # multimodal data into multimodal kwargs for GPU memory profiling. - self.mm_input_mapper_profiling = MMInputMapperClient(self.model_config) + # Only applicable to multimodal models with legacy input mapper. + self.mm_input_mapper_profiling = MMInputCacheClient(self.model_config) self.mm_input_mapper_profiling.use_cache = False encoder_compute_budget, encoder_cache_size = compute_encoder_budget(