Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature][Frontend] Add KVTransferParams for disaggregated prefill feature #12957

Open
wants to merge 14 commits into
base: main
Choose a base branch
from
50 changes: 50 additions & 0 deletions tests/kv_transfer/test_kv_transfer_params.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
# SPDX-License-Identifier: Apache-2.0
"""Tests for the KVTransferParams class.
"""
from vllm import KVTransferParams


def test_all_none():
# None should be allowed
KVTransferParams(prefix_prompt_ids=None,
kvcache_load_keys=None,
kvcache_store_keys=None)
KVTransferParams(prefix_prompt_ids=[None],
kvcache_load_keys=[None],
kvcache_store_keys=[None])

# Note(shangming): KVCache transfer may have different granularity,
# such as block-level, so the length of kvcache_load_keys and
# kvcache_store_keys has no strong correspondence with the length of
# prefix_prompt_ids.

# prefill node cases
KVTransferParams(prefix_prompt_ids=[1, 2, 3],
kvcache_load_keys=None,
kvcache_store_keys=["key1", "key2", "key3"])

KVTransferParams(prefix_prompt_ids=[None, [1, 2, 3]],
kvcache_load_keys=[None, None],
kvcache_store_keys=[None, ["key1"]])

# decode node cases
KVTransferParams(prefix_prompt_ids=[1, 2, 3],
kvcache_load_keys=["key1", "key2", "key3"],
kvcache_store_keys=None)
KVTransferParams(prefix_prompt_ids=[None, [1, 2, 3]],
kvcache_load_keys=[None, ["key1"]],
kvcache_store_keys=[None, None])

# prefix cache sharing cases
KVTransferParams(prefix_prompt_ids=[[1, 2, 3], [1, 2]],
kvcache_load_keys=[["key1", "key2", "key3"],
["key1", "key2"]],
kvcache_store_keys=[None, None])
KVTransferParams(prefix_prompt_ids=[[1, 2, 3], [4, 5, 6]],
kvcache_load_keys=[["key1", "key2", "key3"], None],
kvcache_store_keys=[None, ["key4", "key5", "key6"]])


if __name__ == "__main__":
import pytest
pytest.main([__file__])
2 changes: 2 additions & 0 deletions vllm/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from vllm.entrypoints.llm import LLM
from vllm.executor.ray_utils import initialize_ray_cluster
from vllm.inputs import PromptType, TextPrompt, TokensPrompt
from vllm.kv_transfer_params import KVTransferParams
from vllm.model_executor.models import ModelRegistry
from vllm.outputs import (ClassificationOutput, ClassificationRequestOutput,
CompletionOutput, EmbeddingOutput,
Expand Down Expand Up @@ -61,4 +62,5 @@
"AsyncEngineArgs",
"initialize_ray_cluster",
"PoolingParams",
"KVTransferParams",
]
1 change: 1 addition & 0 deletions vllm/core/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -1380,6 +1380,7 @@ def schedule(
computed_block_nums=common_computed_block_nums,
encoder_seq_data=encoder_seq_data,
cross_block_table=cross_block_table,
kv_transfer_params=seq_group.kv_transfer_params,
state=seq_group.state,
token_type_ids=seq_group.token_type_ids,
# `multi_modal_data` will only be present for the 1st comm
Expand Down
13 changes: 13 additions & 0 deletions vllm/engine/async_llm_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from vllm.executor.executor_base import ExecutorBase
from vllm.inputs import PromptType
from vllm.inputs.preprocess import InputPreprocessor
from vllm.kv_transfer_params import KVTransferParams
from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
from vllm.model_executor.guided_decoding import (
Expand Down Expand Up @@ -437,6 +438,7 @@ async def add_request_async(
lora_request: Optional[LoRARequest] = None,
trace_headers: Optional[Mapping[str, str]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
kv_transfer_params: Optional[KVTransferParams] = None,
priority: int = 0,
) -> None:
...
Expand All @@ -451,6 +453,7 @@ async def add_request_async(
lora_request: Optional[LoRARequest] = None,
trace_headers: Optional[Mapping[str, str]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
kv_transfer_params: Optional[KVTransferParams] = None,
priority: int = 0,
) -> None:
...
Expand All @@ -468,6 +471,7 @@ async def add_request_async(
lora_request: Optional[LoRARequest] = None,
trace_headers: Optional[Mapping[str, str]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
kv_transfer_params: Optional[KVTransferParams] = None,
priority: int = 0,
*,
inputs: Optional[PromptType] = None, # DEPRECATED
Expand Down Expand Up @@ -519,6 +523,7 @@ async def add_request_async(
lora_request=lora_request,
prompt_adapter_request=prompt_adapter_request,
trace_headers=trace_headers,
kv_transfer_params=kv_transfer_params,
priority=priority,
)

Expand Down Expand Up @@ -856,6 +861,7 @@ def add_request(
lora_request: Optional[LoRARequest] = None,
trace_headers: Optional[Mapping[str, str]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
kv_transfer_params: Optional[KVTransferParams] = None,
priority: int = 0,
) -> Coroutine[None, None, AsyncGenerator[Union[
RequestOutput, PoolingRequestOutput], None]]:
Expand All @@ -871,6 +877,7 @@ def add_request(
lora_request: Optional[LoRARequest] = None,
trace_headers: Optional[Mapping[str, str]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
kv_transfer_params: Optional[KVTransferParams] = None,
priority: int = 0,
) -> Coroutine[None, None, AsyncGenerator[Union[
RequestOutput, PoolingRequestOutput], None]]:
Expand All @@ -889,6 +896,7 @@ async def add_request(
lora_request: Optional[LoRARequest] = None,
trace_headers: Optional[Mapping[str, str]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
kv_transfer_params: Optional[KVTransferParams] = None,
priority: int = 0,
*,
inputs: Optional[PromptType] = None, # DEPRECATED
Expand Down Expand Up @@ -921,6 +929,7 @@ async def add_request(
lora_request=lora_request,
trace_headers=trace_headers,
prompt_adapter_request=prompt_adapter_request,
kv_transfer_params=kv_transfer_params,
priority=priority,
)

Expand All @@ -934,6 +943,7 @@ async def generate(
lora_request: Optional[LoRARequest] = None,
trace_headers: Optional[Mapping[str, str]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
kv_transfer_params: Optional[KVTransferParams] = None,
priority: int = 0,
) -> AsyncGenerator[RequestOutput, None]:
"""Generate outputs for a request.
Expand All @@ -951,6 +961,8 @@ async def generate(
trace_headers: OpenTelemetry trace headers.
prompt_adapter_request: Prompt Adapter request to use
for generation, if any.
kv_transfer_params: The KVCache transfer parameters to use
for generation, if any.
priority: The priority of the request.
Only applicable with priority scheduling.

Expand Down Expand Up @@ -1010,6 +1022,7 @@ async def generate(
lora_request=lora_request,
trace_headers=trace_headers,
prompt_adapter_request=prompt_adapter_request,
kv_transfer_params=kv_transfer_params,
priority=priority,
):
yield LLMEngine.validate_output(output, RequestOutput)
Expand Down
11 changes: 11 additions & 0 deletions vllm/engine/llm_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
PromptType, SingletonInputsAdapter)
from vllm.inputs.parse import is_encoder_decoder_inputs, is_token_prompt
from vllm.inputs.preprocess import InputPreprocessor
from vllm.kv_transfer_params import KVTransferParams
from vllm.logger import init_logger
from vllm.logits_process import get_bad_words_logits_processors
from vllm.lora.request import LoRARequest
Expand Down Expand Up @@ -551,6 +552,7 @@ def _add_processed_request(
lora_request: Optional[LoRARequest],
prompt_adapter_request: Optional[PromptAdapterRequest],
trace_headers: Optional[Mapping[str, str]] = None,
kv_transfer_params: Optional[KVTransferParams] = None,
priority: int = 0,
) -> Optional[SequenceGroup]:
"""Add a processed request to the engine's request pool.
Expand All @@ -566,6 +568,7 @@ def _add_processed_request(
lora_request=lora_request,
trace_headers=trace_headers,
prompt_adapter_request=prompt_adapter_request,
kv_transfer_params=kv_transfer_params,
priority=priority,
)
return None
Expand Down Expand Up @@ -601,6 +604,7 @@ def _add_processed_request(
trace_headers=trace_headers,
prompt_adapter_request=prompt_adapter_request,
encoder_seq=encoder_seq,
kv_transfer_params=kv_transfer_params,
priority=priority)
elif isinstance(params, PoolingParams):
seq_group = self._create_sequence_group_with_pooling(
Expand Down Expand Up @@ -639,6 +643,7 @@ def add_request(
lora_request: Optional[LoRARequest] = None,
trace_headers: Optional[Mapping[str, str]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
kv_transfer_params: Optional[KVTransferParams] = None,
priority: int = 0,
) -> None:
...
Expand All @@ -655,6 +660,7 @@ def add_request(
lora_request: Optional[LoRARequest] = None,
trace_headers: Optional[Mapping[str, str]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
kv_transfer_params: Optional[KVTransferParams] = None,
priority: int = 0,
) -> None:
...
Expand All @@ -672,6 +678,7 @@ def add_request(
lora_request: Optional[LoRARequest] = None,
trace_headers: Optional[Mapping[str, str]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
kv_transfer_params: Optional[KVTransferParams] = None,
priority: int = 0,
*,
inputs: Optional[PromptType] = None, # DEPRECATED
Expand All @@ -694,6 +701,7 @@ def add_request(
lora_request: The LoRA request to add.
trace_headers: OpenTelemetry trace headers.
prompt_adapter_request: The prompt adapter request to add.
kv_transfer_params: The KVCache transfer parameters to add.
priority: The priority of the request.
Only applicable with priority scheduling.

Expand Down Expand Up @@ -764,6 +772,7 @@ def add_request(
lora_request=lora_request,
prompt_adapter_request=prompt_adapter_request,
trace_headers=trace_headers,
kv_transfer_params=kv_transfer_params,
priority=priority,
)

Expand Down Expand Up @@ -798,6 +807,7 @@ def _create_sequence_group_with_sampling(
trace_headers: Optional[Mapping[str, str]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
encoder_seq: Optional[Sequence] = None,
kv_transfer_params: Optional[KVTransferParams] = None,
priority: int = 0,
) -> SequenceGroup:
"""Creates a SequenceGroup with SamplingParams."""
Expand Down Expand Up @@ -829,6 +839,7 @@ def _create_sequence_group_with_sampling(
trace_headers=trace_headers,
prompt_adapter_request=prompt_adapter_request,
encoder_seq=encoder_seq,
kv_transfer_params=kv_transfer_params,
priority=priority)

return seq_group
Expand Down
6 changes: 6 additions & 0 deletions vllm/engine/multiprocessing/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

from vllm import PoolingParams
from vllm.inputs import PromptType
from vllm.kv_transfer_params import KVTransferParams
from vllm.lora.request import LoRARequest
from vllm.outputs import RequestOutput
from vllm.prompt_adapter.request import PromptAdapterRequest
Expand All @@ -35,6 +36,7 @@ class RPCProcessRequest:
lora_request: Optional[LoRARequest] = None
trace_headers: Optional[Mapping[str, str]] = None
prompt_adapter_request: Optional[PromptAdapterRequest] = None
kv_transfer_params: Optional[KVTransferParams] = None
priority: int = 0

@overload
Expand All @@ -46,6 +48,7 @@ def __init__(
lora_request: Optional[LoRARequest] = None,
trace_headers: Optional[Mapping[str, str]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
kv_transfer_params: Optional[KVTransferParams] = None,
priority: int = 0,
) -> None:
...
Expand All @@ -61,6 +64,7 @@ def __init__(
lora_request: Optional[LoRARequest] = None,
trace_headers: Optional[Mapping[str, str]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
kv_transfer_params: Optional[KVTransferParams] = None,
priority: int = 0,
) -> None:
...
Expand All @@ -77,6 +81,7 @@ def __init__(
lora_request: Optional[LoRARequest] = None,
trace_headers: Optional[Mapping[str, str]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
kv_transfer_params: Optional[KVTransferParams] = None,
priority: int = 0,
*,
inputs: Optional[PromptType] = None, # DEPRECATED
Expand All @@ -94,6 +99,7 @@ def __init__(
self.lora_request = lora_request
self.trace_headers = trace_headers
self.prompt_adapter_request = prompt_adapter_request
self.kv_transfer_params = kv_transfer_params
self.priority = priority


Expand Down
11 changes: 10 additions & 1 deletion vllm/engine/multiprocessing/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
from vllm.envs import VLLM_RPC_TIMEOUT
from vllm.inputs import PromptType
from vllm.inputs.preprocess import InputPreprocessor
from vllm.kv_transfer_params import KVTransferParams
from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
from vllm.model_executor.layers.sampler import SamplerOutput
Expand Down Expand Up @@ -441,6 +442,7 @@ def generate(
lora_request: Optional[LoRARequest] = None,
trace_headers: Optional[Mapping[str, str]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
kv_transfer_params: Optional[KVTransferParams] = None,
priority: int = 0,
) -> AsyncGenerator[RequestOutput, None]:
...
Expand All @@ -456,6 +458,7 @@ def generate(
lora_request: Optional[LoRARequest] = None,
trace_headers: Optional[Mapping[str, str]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
kv_transfer_params: Optional[KVTransferParams] = None,
priority: int = 0,
) -> AsyncGenerator[RequestOutput, None]:
...
Expand All @@ -472,6 +475,7 @@ def generate(
lora_request: Optional[LoRARequest] = None,
trace_headers: Optional[Mapping[str, str]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
kv_transfer_params: Optional[KVTransferParams] = None,
priority: int = 0,
*,
inputs: Optional[PromptType] = None # DEPRECATED
Expand All @@ -491,6 +495,8 @@ def generate(
trace_headers: OpenTelemetry trace headers.
prompt_adapter_request: Prompt Adapter request to use
for generation, if any.
kv_transfer_params: The KVCache transfer parameters to use
for generation, if any.
priority: Priority of the request (lower means earlier handling).
Any priority other than 0 will lead to an error if the
scheduling policy is not "priority".
Expand All @@ -502,7 +508,8 @@ def generate(

return self._process_request(prompt, sampling_params, request_id,
lora_request, trace_headers,
prompt_adapter_request, priority)
prompt_adapter_request,
kv_transfer_params, priority)

@overload
def encode(
Expand Down Expand Up @@ -585,6 +592,7 @@ async def _process_request(
lora_request: Optional[LoRARequest] = None,
trace_headers: Optional[Mapping[str, str]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
kv_transfer_params: Optional[KVTransferParams] = None,
priority: int = 0,
) -> Union[AsyncGenerator[RequestOutput, None], AsyncGenerator[
PoolingRequestOutput, None]]:
Expand Down Expand Up @@ -638,6 +646,7 @@ async def _process_request(
lora_request=lora_request,
trace_headers=trace_headers,
prompt_adapter_request=prompt_adapter_request,
kv_transfer_params=kv_transfer_params,
priority=priority,
))

Expand Down
1 change: 1 addition & 0 deletions vllm/engine/multiprocessing/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -269,6 +269,7 @@ def _handle_process_request(self, request: RPCProcessRequest):
lora_request=request.lora_request,
trace_headers=request.trace_headers,
prompt_adapter_request=request.prompt_adapter_request,
kv_transfer_params=request.kv_transfer_params,
priority=request.priority)

if self.log_requests:
Expand Down
2 changes: 2 additions & 0 deletions vllm/engine/protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from vllm.inputs.data import PromptType, TokensPrompt
from vllm.inputs.parse import is_explicit_encoder_decoder_prompt
from vllm.inputs.preprocess import InputPreprocessor
from vllm.kv_transfer_params import KVTransferParams
from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
from vllm.model_executor.layers.sampler import SamplerOutput
Expand Down Expand Up @@ -55,6 +56,7 @@ def generate(
lora_request: Optional[LoRARequest] = None,
trace_headers: Optional[Mapping[str, str]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
kv_transfer_params: Optional[KVTransferParams] = None,
priority: int = 0,
) -> AsyncGenerator[RequestOutput, None]:
"""Generate outputs for a request."""
Expand Down
Loading