Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[RFC]: Reimplement and separate beam search on top of vLLM core #8306

Open
1 task done
youkaichao opened this issue Sep 9, 2024 · 6 comments
Open
1 task done

[RFC]: Reimplement and separate beam search on top of vLLM core #8306

youkaichao opened this issue Sep 9, 2024 · 6 comments
Labels

Comments

@youkaichao
Copy link
Member

Motivation.

A rework of #6226

After discussing further with the community, we find that the common use case for beam search is:

  1. throughput oriented
  2. mainly offline batch inference
  3. use one beam search parameter for all the prompts in the batch

After discussing with many contributors, we find:

because beam search is a search algorithm, it conflicts with all the rest sampling algorithm. As a result, many features in vllm already directly assert beam search is not used, e.g.

assert len(input_seq_group_metadata.seq_data) == 1, (
"Beam search "
"not supported in speculative decoding")

assert len(seqs) == 1, (
"Beam search not supported in multi-step decoding.")
seq = seqs[0]

keeping beam-search as-is in the codebase, will not benefit current beam search user, as no optimization will target at better beam search performance. What's worse, very few developers understand beam search. Keeping beam-search as-is will not only increase the bugs for beam search as the codebase evolves, but also increase the maintenance cost of all contributors.

in search of a win-win solution, on behalf of the vllm team, I propose to separate and reimplement beam search on top of the vllm core code.

to be specific, we can:

  1. remove beam search logic from the scheduler
  2. add an LLM.beam_search interface, that calls the engine to generate 1 tokens with logprobs every step, and maintain beam-search logic only in the LLM.beam_search function.
  3. add a beam search emulator over commonly used openai api server, which internally calls the generation endpoint to generate one step with logprobs, and maintain beam-search logic only in the emulator.

From the initial discussion, one concern is the efficiency of such implementation, as the request will come and go again and again from the vllm core's perspective. It should be solvable in two-folds:

  1. turning on prefix caching can reuse computation from the last step so that we don't need to recompute the kv cache of prompt again and again.
  2. after separating beam search and the vllm core, they can be optimized individually. The simplified code will be much easier to optimize.

vLLM is a community project, and we'd like to not only seek opinions from beam-search users, but also seek contributions from beam-search users. Your help is truly needed to shape the future of beam-search support in vLLM.

Proposed Change.

summary of the change: implement beam-search on top of vllm core and add wrappers for users. remove beam-search from the vllm core (scheduler).

Feedback Period.

1 week, from 9/9 to 9/15 (both inclusive)

CC List.

@hrsmanian @zhouyuan @lanking520 @nightflight-dk @HeegonJin @SemMulder @darabos @DhruvaBansal00 @tmostak @physicsrob @YooSungHyun @denadai2 @sjmielke @Reichenbachian @AaronFriel @hinnefe2 @mflaxman10
@WoosukKwon @zhuohan123 @simon-mo

Any Other Things.

No response

Before submitting a new issue...

  • Make sure you already searched for relevant issues, and asked the chatbot living at the bottom right corner of the documentation page, which can answer lots of frequently asked questions.
@youkaichao youkaichao added the RFC label Sep 9, 2024
@simon-mo simon-mo changed the title [RFC]: Reimplement and separate beam search on top of vllm core [RFC]: Reimplement and separate beam search on top of vLLM core Sep 9, 2024
@AaronFriel
Copy link

AaronFriel commented Sep 9, 2024

after separating beam search and the vllm core, they can be optimized individually. The simplified code will be much easier to optimize.

This is a good goal to work toward, as ensuring that API interfaces (OpenAI, beam search, or otherwise) can efficiently and reliably schedule new sequences benefits all consumers.

turning on prefix caching can reuse computation from the last step so that we don't need to recompute the kv cache of prompt again and again.

The flood of vLLM notifications is hard to keep up with, so I may be out of date. My understanding was that prefix caching was not precise and was block based, resulting in some amount of excess computation. Is there an issue to allow APIs to specify the "prefix length" that should be cached?

This new approach could see performance degrade when the sequence length approaches a multiple of the KV block length, if each arm of the beam search schedules a new sequence and must prefill O(kv_block_size) tokens plus decoding O(1) tokens. Ideally both would be O(1) with a hint to allow beam search to cache the entire prefix.

@youkaichao
Copy link
Member Author

youkaichao commented Sep 9, 2024

My understanding was that prefix caching was not precise and was block based, resulting in some amount of excess computation

we can set block size to 1 for the vLLM instance when we use beam search, then we don't have to waste any computation.

@simon-mo
Copy link
Collaborator

simon-mo commented Sep 9, 2024

There are also some alternative implementation of this by moving this functionality to a special class of Worker or Executor, which can be configured when beam search is turned on for any engine that needs it.

@AaronFriel
Copy link

@youkaichao How well does the KV cache handle a block size of 1, in terms of compute or memory overhead?

@youkaichao
Copy link
Member Author

@AaronFriel I don't think setting block size of 1 will affect performance a lot. But we need to test and measure the impact.

@youkaichao
Copy link
Member Author

@simon-mo can you explain more? What special functions / interfaces would these new Worker or Executor need?

@youkaichao youkaichao pinned this issue Sep 12, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

No branches or pull requests

3 participants