-
-
Notifications
You must be signed in to change notification settings - Fork 5.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Hardware][TPU] Multi-LoRA implementation for the TPU backend #12623
base: main
Are you sure you want to change the base?
Conversation
…ter loading a LoRA adapter. Signed-off-by: Oleg Mosalov <[email protected]>
Signed-off-by: Akshat Tripathi <[email protected]>
Signed-off-by: Akshat Tripathi <[email protected]>
Signed-off-by: Akshat Tripathi <[email protected]>
Signed-off-by: Akshat Tripathi <[email protected]>
Signed-off-by: Akshat Tripathi <[email protected]>
Signed-off-by: Akshat Tripathi <[email protected]>
…` to be called with infinities Signed-off-by: Akshat Tripathi <[email protected]>
Signed-off-by: Akshat Tripathi <[email protected]>
Signed-off-by: Akshat Tripathi <[email protected]>
Signed-off-by: Akshat Tripathi <[email protected]>
… the adapter and its weights are loaded. Signed-off-by: Oleg Mosalov <[email protected]>
Signed-off-by: Oleg Mosalov <[email protected]>
Signed-off-by: Oleg Mosalov <[email protected]>
Signed-off-by: Akshat Tripathi <[email protected]>
Signed-off-by: Akshat Tripathi <[email protected]>
Signed-off-by: Akshat Tripathi <[email protected]>
Signed-off-by: Akshat Tripathi <[email protected]>
Signed-off-by: Akshat Tripathi <[email protected]>
Signed-off-by: Akshat Tripathi <[email protected]>
Signed-off-by: Akshat Tripathi <[email protected]>
Signed-off-by: Akshat Tripathi <[email protected]>
Signed-off-by: Akshat Tripathi <[email protected]>
Signed-off-by: Akshat Tripathi <[email protected]>
Signed-off-by: Akshat Tripathi <[email protected]>
Signed-off-by: Akshat Tripathi <[email protected]>
Signed-off-by: Akshat Tripathi <[email protected]>
Signed-off-by: Akshat Tripathi <[email protected]>
Signed-off-by: Akshat Tripathi <[email protected]>
Signed-off-by: Akshat Tripathi <[email protected]>
👋 Hi! Thank you for contributing to the vLLM project. Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging. To run CI, PR reviewers can do one of these:
🚀 |
Signed-off-by: Akshat Tripathi <[email protected]>
Signed-off-by: Akshat Tripathi <[email protected]>
vllm/lora/layers.py
Outdated
if current_platform.is_tpu(): | ||
# Because nan_to_num_ doesn't work with actual -inf values on TPU | ||
neg_inf = torch.finfo(lora_logits.dtype).min | ||
pos_inf = torch.finfo(lora_logits.dtype).max | ||
else: | ||
neg_inf = float("-inf") | ||
pos_inf = float("inf") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
these if-else conditions will make vLLM hard to maintain.
file an issue with torch-xla ? or abstract this as part of an utility function ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
abstract this as part of an utility function
sounds good
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah that sounds good I can abstract it away, it was only a problem for that nan_to_num()
function though, -inf works properly elsewhere.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
abstract it away as a short-term solution is fine.
it would better if we can create an issue in torch-xla repo, as a longer-term solution.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ok, I've made the issue here: pytorch/xla#8674
vllm/lora/ops/xla_ops/lora_ops.py
Outdated
@@ -0,0 +1,58 @@ | |||
import torch | |||
|
|||
from ..torch_ops import bgmv_expand, bgmv_expand_slice, bgmv_shrink |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It seems the TPU ops are still using PyTorch operators, is it necessary to add the below ops?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The sgmv
ops are slightly different here because I'm using repeat_interleave
with a static size rather than a dynamic tensor, which reduces the compile time quite a bit because torch_xla can't lower the dynamic version properly.
|
||
# The platforms that are compatible with the PyTorch-native implementation can | ||
# inherit this class | ||
class PunicaWrapperTPU(PunicaWrapperBase): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why not directly inherit from PunicaWrapperCPU
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I thought about it, but this code is going to change very soon as I add in the Pallas kernels
Signed-off-by: Akshat Tripathi <[email protected]>
Signed-off-by: Akshat Tripathi <[email protected]>
Signed-off-by: Akshat Tripathi <[email protected]>
Signed-off-by: Akshat Tripathi <[email protected]>
It looks like the |
cc @lsy323 to take a pass |
…because xla doesn't allow partial updates Signed-off-by: Akshat Tripathi <[email protected]>
Signed-off-by: Akshat Tripathi <[email protected]>
Signed-off-by: Akshat Tripathi <[email protected]>
Signed-off-by: Akshat Tripathi <[email protected]>
This PR adds a Multi-LoRA implementation that works on the TPU backend, extending the work done in #11100.
Currently this uses pytorch operations for the Punica kernels, but I am going to put up a PR with Pallas kernels soon.