Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Hardware][TPU] Multi-LoRA implementation for the TPU backend #12623

Open
wants to merge 40 commits into
base: main
Choose a base branch
from

Conversation

Akshat-Tripathi
Copy link
Contributor

@Akshat-Tripathi Akshat-Tripathi commented Jan 31, 2025

This PR adds a Multi-LoRA implementation that works on the TPU backend, extending the work done in #11100.

Currently this uses pytorch operations for the Punica kernels, but I am going to put up a PR with Pallas kernels soon.

mosalov and others added 30 commits January 24, 2025 14:56
…ter loading a LoRA adapter.

Signed-off-by: Oleg Mosalov <[email protected]>
Signed-off-by: Akshat Tripathi <[email protected]>
Signed-off-by: Akshat Tripathi <[email protected]>
…` to be called with infinities

Signed-off-by: Akshat Tripathi <[email protected]>
Signed-off-by: Akshat Tripathi <[email protected]>
Signed-off-by: Akshat Tripathi <[email protected]>
… the adapter and its weights are loaded.

Signed-off-by: Oleg Mosalov <[email protected]>
Signed-off-by: Oleg Mosalov <[email protected]>
Signed-off-by: Akshat Tripathi <[email protected]>
Signed-off-by: Akshat Tripathi <[email protected]>
Signed-off-by: Akshat Tripathi <[email protected]>
Signed-off-by: Akshat Tripathi <[email protected]>
Signed-off-by: Akshat Tripathi <[email protected]>
Signed-off-by: Akshat Tripathi <[email protected]>
Signed-off-by: Akshat Tripathi <[email protected]>
Copy link

👋 Hi! Thank you for contributing to the vLLM project.
Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of those by going to your fastcheck build on Buildkite UI (linked in the PR checks section) and unblock them. If you do not have permission to unblock, ping simon-mo or khluu to add you in our Buildkite org.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can do one of these:

  • Add ready label to the PR
  • Enable auto-merge.

🚀

Signed-off-by: Akshat Tripathi <[email protected]>
Signed-off-by: Akshat Tripathi <[email protected]>
Comment on lines 1081 to 1087
if current_platform.is_tpu():
# Because nan_to_num_ doesn't work with actual -inf values on TPU
neg_inf = torch.finfo(lora_logits.dtype).min
pos_inf = torch.finfo(lora_logits.dtype).max
else:
neg_inf = float("-inf")
pos_inf = float("inf")
Copy link
Contributor

@liangfu liangfu Jan 31, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

these if-else conditions will make vLLM hard to maintain.

file an issue with torch-xla ? or abstract this as part of an utility function ?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

abstract this as part of an utility function sounds good

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah that sounds good I can abstract it away, it was only a problem for that nan_to_num() function though, -inf works properly elsewhere.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

abstract it away as a short-term solution is fine.

it would better if we can create an issue in torch-xla repo, as a longer-term solution.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok, I've made the issue here: pytorch/xla#8674

@@ -0,0 +1,58 @@
import torch

from ..torch_ops import bgmv_expand, bgmv_expand_slice, bgmv_shrink
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems the TPU ops are still using PyTorch operators, is it necessary to add the below ops?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The sgmv ops are slightly different here because I'm using repeat_interleave with a static size rather than a dynamic tensor, which reduces the compile time quite a bit because torch_xla can't lower the dynamic version properly.


# The platforms that are compatible with the PyTorch-native implementation can
# inherit this class
class PunicaWrapperTPU(PunicaWrapperBase):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why not directly inherit from PunicaWrapperCPU ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I thought about it, but this code is going to change very soon as I add in the Pallas kernels

Signed-off-by: Akshat Tripathi <[email protected]>
Signed-off-by: Akshat Tripathi <[email protected]>
Signed-off-by: Akshat Tripathi <[email protected]>
@Akshat-Tripathi
Copy link
Contributor Author

It looks like the Async Engine, Inputs, Utils, Worker Test is failing on multimodal inputs, which is WIP right now.
The TPU test seems to be failing on non lora code. Do these tests pass on main? I'm wondering if they're linked to this PR or something else

@miladm
Copy link
Collaborator

miladm commented Feb 7, 2025

cc @lsy323 to take a pass

…because xla doesn't allow partial updates

Signed-off-by: Akshat Tripathi <[email protected]>
Signed-off-by: Akshat Tripathi <[email protected]>
@miladm miladm requested review from lsy323 and removed request for liangfu February 7, 2025 19:02
Signed-off-by: Akshat Tripathi <[email protected]>
Signed-off-by: Akshat Tripathi <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants