Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Attention] WIP MLA with chunked prefill #12639

Open
wants to merge 8 commits into
base: main
Choose a base branch
from

Conversation

LucasWilkinson
Copy link
Contributor

@LucasWilkinson LucasWilkinson commented Feb 1, 2025

Merge #12807 first

Note this implementation uses alot of runtime memory due to up-projecting the full context, may need to turn down --gpu-memory-utilization

More benchmarking needed to know if this should be on by default (due to the memory concerns im leaning towards no)

Shout to @pathorn for the assistance with this PR

Copy link

github-actions bot commented Feb 1, 2025

👋 Hi! Thank you for contributing to the vLLM project.
Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of those by going to your fastcheck build on Buildkite UI (linked in the PR checks section) and unblock them. If you do not have permission to unblock, ping simon-mo or khluu to add you in our Buildkite org.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can do one of these:

  • Add ready label to the PR
  • Enable auto-merge.

🚀

@LucasWilkinson LucasWilkinson changed the title [Attention] WIP MLA with chunked prefill [WIP][Attention] WIP MLA with chunked prefill Feb 1, 2025
Copy link

mergify bot commented Feb 6, 2025

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @LucasWilkinson.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify bot added the needs-rebase label Feb 6, 2025
@LucasWilkinson LucasWilkinson force-pushed the lwilkinson/chunked-mla branch 2 times, most recently from 463e453 to c542cc4 Compare February 6, 2025 05:24
@mergify mergify bot added v1 and removed needs-rebase labels Feb 6, 2025
@LucasWilkinson LucasWilkinson changed the title [WIP][Attention] WIP MLA with chunked prefill [Attention] WIP MLA with chunked prefill Feb 6, 2025
@LucasWilkinson LucasWilkinson marked this pull request as ready for review February 6, 2025 05:49
Copy link

mergify bot commented Feb 7, 2025

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @LucasWilkinson.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify bot added the needs-rebase label Feb 7, 2025
LucasWilkinson and others added 7 commits February 7, 2025 16:42
Signed-off-by: Lucas Wilkinson <[email protected]>
Signed-off-by: Lucas Wilkinson <[email protected]>
Signed-off-by: Lucas Wilkinson <[email protected]>
Co-authored-by: Patrick Horn <[email protected]>
Signed-off-by: Lucas Wilkinson <[email protected]>
Signed-off-by: Lucas Wilkinson <[email protected]>
Signed-off-by: Lucas Wilkinson <[email protected]>
@mergify mergify bot removed the needs-rebase label Feb 7, 2025
Comment on lines +1128 to +1135
# Default to `gpu_memory_utilization` of 0.9 if not specified
gpu_memory_utilization = self.gpu_memory_utilization if \
self.gpu_memory_utilization is not None else 0.9
# For models using MLA and chunked prefill, lower the default to 0.85
# to account for the extra memory required to up-project the MLA cache
if self.gpu_memory_utilization is None and \
(self.enable_chunked_prefill and model_config.use_mla):
gpu_memory_utilization = 0.85
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My understanding of gpu_memory_utilization is that all of vLLM's memory usage including weights, activations, kv cache, and any extra space needed for MLA should fit within this budget.

A user is explicitly specifying a gpu_memory_utilization, they wouldn't want an MLA model to exceed that limit. I think a better way to handle the extra memory util due to MLA could be in the worker's determine_num_available_blocks method.

Do you know why the profile_run doesn't already account for this memory footprint?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you know why the profile_run doesn't already account for this memory footprint?

because it depends on the context len in cache for each seq in the request and we dont profile with max context len requests

im trying to cap the amount of memory used by chunking contexts longer than a certain amount

Comment on lines +22 to 24
logger = logging.getLogger(__name__)

logger = logging.getLogger(__name__)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixup duplicate logger = logging.getLogger(__name__)

Comment on lines +16 to +24

namespace cuda_utils {

template <typename T>
HOST_DEVICE_INLINE constexpr std::enable_if_t<std::is_integral_v<T>, T>
ceil_div(T a, T b) {
return (a + b - 1) / b;
}

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is already in csrc/core/math.hpp without the HOST_DEVICE_INLINE. Does it make sense for it to be one function?

Signed-off-by: Lucas Wilkinson <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants