Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Question about cascade inference #789

Open
sleepwalker2017 opened this issue Feb 5, 2025 · 11 comments
Open

Question about cascade inference #789

sleepwalker2017 opened this issue Feb 5, 2025 · 11 comments

Comments

@sleepwalker2017
Copy link

https://flashinfer.ai/2024/02/02/cascade-inference.html

Hi, I notice this blog posted a year ago.

I wonder what situation does the Evaluations part refer to.

Is it for prefill stage ? or decoding stage? Or for both phase?

@yzh119
Copy link
Collaborator

yzh119 commented Feb 5, 2025

It only refers to decode attention kernel, not end-to-end results.

@sleepwalker2017
Copy link
Author

sleepwalker2017 commented Feb 6, 2025

It only refers to decode attention kernel, not end-to-end results.

Thank you.

Is this optimization mainly aimed at the decoding stage?

How is the benefit for the prefill stage?

@yzh119
Copy link
Collaborator

yzh119 commented Feb 6, 2025

Is this optimization mainly aimed at the decoding stage?

Yes, and it doesn't work for attention variants such as MLA (even for decoding), which exhibit very high operational intensity (128) in decoding stage.

@sleepwalker2017
Copy link
Author

sleepwalker2017 commented Feb 10, 2025

Is this optimization mainly aimed at the decoding stage?

Yes, and it doesn't work for attention variants such as MLA (even for decoding), which exhibit very high operational intensity (128) in decoding stage.

Does it has any benefit for prefill MHA? Seems the effect is little? because if the sequence is long enough, the kernel is compute bound?

@sleepwalker2017
Copy link
Author

Is this optimization mainly aimed at the decoding stage?

Yes, and it doesn't work for attention variants such as MLA (even for decoding), which exhibit very high operational intensity (128) in decoding stage.

How about if a lot of requests sharing the same long prefix needs to do prefill? Seems it can save a lot of computation because the shared part only needs to compute once.

I notice in sglang, this feature is not used. In that case, sglang will process one prefill request first and then process the rest.

@yzh119
Copy link
Collaborator

yzh119 commented Feb 10, 2025

Does it has any benefit for prefill MHA? Seems the effect is little? because if the sequence is long enough, the kernel is compute bound?

Depending on the query length, once query length (operational intensity) reaches the ridge point of GPU's roofline model (usually not large, ~300) , the benefit is gone.

@yzh119
Copy link
Collaborator

yzh119 commented Feb 10, 2025

How about if a lot of requests sharing the same long prefix needs to do prefill? Seems it can save a lot of computation because the shared part only needs to compute once.

I notice in sglang, this feature is not used. In that case, sglang will process one prefill request first and then process the rest.

I think prefix-caching has already done such optimizations. Did you enable prefix-caching in sglang?

@sleepwalker2017
Copy link
Author

How about if a lot of requests sharing the same long prefix needs to do prefill? Seems it can save a lot of computation because the shared part only needs to compute once.

I notice in sglang, this feature is not used. In that case, sglang will process one prefill request first and then process the rest.

I think prefix-caching has already done such optimizations. Did you enable prefix-caching in sglang?

Yes, sglang prefix caching is enabled.

I mean in the prefill stage, when there are 4 sequences in the same batch, and they share the same long prefix.
In that case, sglang will do prefill for one request and cache the kv cache, and then do prefill for the other 3 requests, thus the computation for the shared prefix will be saved.

@sleepwalker2017
Copy link
Author

I have another question, the cascade api launches 5 kernels for a batch decoding.

  • stage 1: 2 kernels for unique parts: attention + merge
  • stage 2: 2 for shared parts: attention + merge
  • stage 3: merge shared and unique results.

My question is, is this implementation efficient enough?

These two stages are executed sequentially. Will this lead to insufficient SM occupancy?

Also, can these two stages be fused and executed in one kernel ?
The fusion is not done because it is too complex to implement?
If it is too difficult to implement, can multiple streams be used to make the two stages execute at the same time?
It seems that they are independent stages.

Sorry for so many questions. Hope for your reply. Thank you! @yzh119

@yzh119
Copy link
Collaborator

yzh119 commented Feb 13, 2025

My question is, is this implementation efficient enough?

Apparently not, and we should optimize it, actually all of them can be fused into a single kernel, if you have interest, I can guide you to implement this (I don't have enough bandwidth at this moment).

@sleepwalker2017
Copy link
Author

sleepwalker2017 commented Feb 13, 2025

My question is, is this implementation efficient enough?

Apparently not, and we should optimize it, actually all of them can be fused into a single kernel, if you have interest, I can guide you to implement this (I don't have enough bandwidth at this moment).

I'm very glad to do that !

I just take a glance at the BatchPrefillWithPagedKVCacheKernel kernel, seems the code is clear.
(I still need some time to fully understand it. )

The main idea seems to be as follows:

  1. calculate the needed grid_size and block size for both parts
  2. do kernel level synchronization, seems to use "cooperative groups"?
  3. merge partial results after that

Is that the right way to do that?

You can give some instructions when you're free. Thank you.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants