Skip to content

Commit 3ef4aa9

Browse files
authored
Refine vllm_quickstart doc (#11199)
* refine doc * refine
1 parent 744042d commit 3ef4aa9

File tree

3 files changed

+36
-3
lines changed

3 files changed

+36
-3
lines changed

docker/llm/serving/cpu/docker/benchmark_vllm_throughput.py

+29-2
Original file line numberDiff line numberDiff line change
@@ -76,9 +76,13 @@ def run_vllm(
7676
enable_prefix_caching: bool,
7777
gpu_memory_utilization: float = 0.9,
7878
load_in_low_bit: str = "sym_int4",
79+
max_num_batched_tokens: int = 10450,
7980
) -> float:
8081
from vllm import SamplingParams
8182
from ipex_llm.vllm.cpu.engine import IPEXLLMClass as LLM
83+
warm_prompt = "hi " * (1024 - 1)
84+
warm_requests = [(warm_prompt, 1024, 1024)
85+
for _ in range(8)]
8286
llm = LLM(model=model,
8387
tokenizer=tokenizer,
8488
quantization=quantization,
@@ -94,6 +98,22 @@ def run_vllm(
9498
enable_prefix_caching=enable_prefix_caching,
9599
load_in_low_bit=load_in_low_bit)
96100

101+
for prompt, _, output_len in warm_requests:
102+
sampling_params = SamplingParams(
103+
n=n,
104+
temperature=0.0 if use_beam_search else 1.0,
105+
top_p=1.0,
106+
use_beam_search=use_beam_search,
107+
ignore_eos=True,
108+
max_tokens=output_len,
109+
)
110+
llm._add_request(
111+
prompt=prompt,
112+
prompt_token_ids=None,
113+
sampling_params=sampling_params,
114+
)
115+
llm._run_engine(use_tqdm=True)
116+
97117
# Add the requests to the engine.
98118
for prompt, _, output_len in requests:
99119
sampling_params = SamplingParams(
@@ -216,7 +236,9 @@ def main(args: argparse.Namespace):
216236
args.tensor_parallel_size, args.seed, args.n, args.use_beam_search,
217237
args.trust_remote_code, args.dtype, args.max_model_len,
218238
args.enforce_eager, args.kv_cache_dtype, args.device,
219-
args.enable_prefix_caching, args.gpu_memory_utilization, args.load_in_low_bit)
239+
args.enable_prefix_caching, args.gpu_memory_utilization, args.load_in_low_bit,
240+
args.max_num_batched_tokens)
241+
220242
elif args.backend == "hf":
221243
assert args.tensor_parallel_size == 1
222244
elapsed_time = run_hf(requests, args.model, tokenizer, args.n,
@@ -320,9 +342,14 @@ def main(args: argparse.Namespace):
320342
parser.add_argument(
321343
"--load-in-low-bit",
322344
type=str,
323-
choices=["sym_int4", "fp8", "fp16"],
345+
choices=["sym_int4", "fp6", "fp8", "fp16"],
324346
default="sym_int4",
325347
help="Low-bit format quantization with IPEX-LLM")
348+
parser.add_argument('--max-num-batched-tokens',
349+
type=int,
350+
default=10450,
351+
help='maximum number of batched tokens per iteration')
352+
326353
args = parser.parse_args()
327354
if args.tokenizer is None:
328355
args.tokenizer = args.model

docker/llm/serving/xpu/docker/Dockerfile

+1-1
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ RUN cd /llm &&\
1212
# Install ipex-llm[serving] only will update ipex_llm source code without updating
1313
# bigdl-core-xe, which will lead to problems
1414
apt-get update && \
15-
apt-get install -y libfabric-dev wrk && \
15+
apt-get install -y libfabric-dev wrk libaio-dev && \
1616
pip install --pre --upgrade ipex-llm[xpu,serving] && \
1717
pip install transformers==4.37.0 gradio==4.19.2 && \
1818
# Install vLLM-v2 dependencies

docs/readthedocs/source/doc/LLM/Quickstart/vLLM_quickstart.md

+6
Original file line numberDiff line numberDiff line change
@@ -134,6 +134,12 @@ You can tune the service using these four arguments:
134134
3. `--max-num-batched-token`: Maximum number of batched tokens per iteration.
135135
4. `--max-num-seq`: Maximum number of sequences per iteration. Default: 256
136136

137+
For longer input prompt, we would suggest to use `--max-num-batched-token` to restrict the service. The reason behind this logic is that the `peak GPU memory usage` will appear when generating first token. By using `--max-num-batched-token`, we can restrict the input size when generating first token.
138+
139+
`--max-num-seqs` will restrict the generation for both first token and rest token. It will restrict the maximum batch size to the value set by `--max-num-seqs`.
140+
141+
When out-of-memory error occurs, the most obvious solution is to reduce the `gpu-memory-utilization`. Other ways to resolve this error is to set `--max-num-batched-token` if peak memory occurs when generating first token or using `--max-num-seq` if peak memory occurs when generating rest tokens.
142+
137143
If the service have been booted successfully, the console will display messages similar to the following:
138144

139145
<a href="https://llm-assets.readthedocs.io/en/latest/_images/start-vllm-service.png" target="_blank">

0 commit comments

Comments
 (0)