Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Hardware][TPU] Multi-LoRA implementation for the TPU backend #12623

Open
wants to merge 40 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
40 commits
Select commit Hold shift + click to select a range
2fc505c
A simple test to compare named_modules for a base model before and af…
mosalov Nov 20, 2024
e351842
Added non-triton SGMV and BGMV ops (not kernels yet)
Akshat-Tripathi Nov 20, 2024
4628132
Made a copy of the layer tests for the TPU. TODO: DRY it out
Akshat-Tripathi Nov 20, 2024
2054570
Removed extra print
Akshat-Tripathi Nov 21, 2024
4ca792a
Made some minor shape-based fixes to the kernels
Akshat-Tripathi Nov 22, 2024
14a8f7d
Added basic lora execution code
Akshat-Tripathi Nov 22, 2024
cb94436
Replaced einsums with matmuls+reshaping for better xla compilation
Akshat-Tripathi Nov 25, 2024
da1fff9
Replaced inf/-inf with max/min since XLA doesn't allow `nan_to_num_()…
Akshat-Tripathi Nov 25, 2024
c668c97
Added lora config to `_dummy_run()`
Akshat-Tripathi Nov 25, 2024
e165aee
Changed torch._dynamo config
Akshat-Tripathi Nov 25, 2024
d031d89
Quick patch to allow non lora code to run
Akshat-Tripathi Nov 25, 2024
6a15233
Updated the test for loading a LoRA adapter, now it better shows when…
mosalov Nov 22, 2024
72fe7e0
Better wording.
mosalov Nov 22, 2024
a554f9e
Added arg_parser to test_load_lora_adapter.py.
mosalov Dec 16, 2024
dcbc952
Minor fixes
Akshat-Tripathi Jan 17, 2025
58d571e
Replaced einsums with matmuls to allow xla compilation
Akshat-Tripathi Jan 22, 2025
17247dd
Removed xla ops for torch ops
Akshat-Tripathi Jan 23, 2025
9341f0f
Removed old debug log points
Akshat-Tripathi Jan 23, 2025
825c965
Fixed bgmv/sgmv shape error
Akshat-Tripathi Jan 23, 2025
d7899ce
Fixed lora batching crash in warmup
Akshat-Tripathi Jan 23, 2025
3448072
Fixed shape issue in add_lora_linear()
Akshat-Tripathi Jan 23, 2025
e27e6f6
Fixed dynamic lora tensor shapes
Akshat-Tripathi Jan 23, 2025
f14fc34
Fixed lora_input preparation for actual execution
Akshat-Tripathi Jan 23, 2025
194f0c9
Fixed wrong model bug
Akshat-Tripathi Jan 24, 2025
acaab93
Moved if statements outside of for loops in PunicaWrapperTPU
Akshat-Tripathi Jan 24, 2025
a39bbff
Added early exits to PunicaWrapperTPU lora functions
Akshat-Tripathi Jan 28, 2025
13d46cf
Added torch ops for tpu (Static prefill sizes)
Akshat-Tripathi Jan 30, 2025
d44f4c1
XLA bgmv operations are now imported from the default torch_ops
Akshat-Tripathi Jan 30, 2025
5d669c2
Removed TODOs
Akshat-Tripathi Jan 31, 2025
81ed389
Merge branch 'main' into multi_lora_tpu
Akshat-Tripathi Jan 31, 2025
7f8694d
Removed old code
Akshat-Tripathi Jan 31, 2025
7d9723a
Linting
Akshat-Tripathi Jan 31, 2025
b7cadb9
Fixed import error
Akshat-Tripathi Feb 3, 2025
effc49d
lint
Akshat-Tripathi Feb 4, 2025
3452eba
Abstracted out infinity values
Akshat-Tripathi Feb 4, 2025
aafc92f
Merge branch 'main' into multi_lora_tpu
Akshat-Tripathi Feb 4, 2025
7cfcdb0
Moved and modified bgmv ops from the cpu backend to the tpu backend, …
Akshat-Tripathi Feb 7, 2025
a8df33f
Removed total_size for linting
Akshat-Tripathi Feb 7, 2025
bfb3770
Reverted changes to torch_ops
Akshat-Tripathi Feb 7, 2025
ce49855
Lint
Akshat-Tripathi Feb 7, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion tests/lora/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ def dist_init():
temp_file = tempfile.mkstemp()[1]

backend = "nccl"
if current_platform.is_cpu():
if current_platform.is_cpu() or current_platform.is_tpu():
backend = "gloo"

init_distributed_environment(world_size=1,
Expand Down
13 changes: 9 additions & 4 deletions vllm/lora/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -1076,15 +1076,20 @@ def _get_logits(
torch.matmul(self.embeddings_tensors,
hidden_states.T,
out=lora_logits[:-1])
lora_logits[-1] = float("-inf")

neg_inf, pos_inf = current_platform.get_infinity_values(
lora_logits.dtype)

lora_logits[-1] = neg_inf
lora_logits = lora_logits.mT
indices_padded = self.punica_wrapper.sampler_indices_padded

lora_logits = (lora_logits.reshape(
lora_logits.shape[0] * lora_logits.shape[1],
lora_logits.shape[2],
).index_select(0, indices_padded).nan_to_num_(nan=float("-inf"),
posinf=float("inf"),
neginf=float("-inf")))
).index_select(0, indices_padded).nan_to_num_(nan=neg_inf,
posinf=pos_inf,
neginf=neg_inf))

# HPU needs special handling to prune out dummy samples.
if current_platform.is_hpu():
Expand Down
15 changes: 15 additions & 0 deletions vllm/lora/ops/xla_ops/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
# SPDX-License-Identifier: Apache-2.0

from vllm.lora.ops.xla_ops.lora_ops import bgmv_expand # noqa: F401
from vllm.lora.ops.xla_ops.lora_ops import (bgmv_expand_slice, bgmv_shrink,
sgmv_expand, sgmv_expand_slice,
sgmv_shrink)

__all__ = [
"bgmv_expand",
"bgmv_expand_slice",
"bgmv_shrink",
"sgmv_expand",
"sgmv_expand_slice",
"sgmv_shrink",
]
142 changes: 142 additions & 0 deletions vllm/lora/ops/xla_ops/lora_ops.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,142 @@
# SPDX-License-Identifier: Apache-2.0

import torch


def sgmv_expand(inputs: torch.Tensor,
lora_b_weights: torch.Tensor,
output_tensor: torch.Tensor,
b_seq_start_loc: torch.Tensor,
seq_len_tensor: torch.Tensor,
lora_indices_tensor: torch.Tensor,
batches: int,
max_seq_length: int,
token_nums: int,
add_inputs: bool = False):
exploded_indices = torch.repeat_interleave(lora_indices_tensor,
inputs.size(0))

bgmv_expand(inputs, lora_b_weights, output_tensor, exploded_indices,
add_inputs)


def bgmv_expand(inputs: torch.Tensor,
lora_b_weights: torch.Tensor,
output_tensor: torch.Tensor,
lora_indices_tensor: torch.Tensor,
add_inputs: bool = True):
selected_loras = lora_b_weights[lora_indices_tensor].to(
dtype=output_tensor.dtype)
if len(selected_loras.shape) == 4:
selected_loras = selected_loras.squeeze(dim=1)
inputs = inputs.to(dtype=output_tensor.dtype)
# outputs = torch.einsum("bi, boi -> bo", inputs, selected_loras)
batch_size, output_size, input_size = selected_loras.shape
outputs = (selected_loras @ inputs.reshape(
(batch_size, input_size, 1))).reshape((batch_size, output_size))

limit = output_tensor.shape[0]
if outputs.shape[0] == 1 and output_tensor.shape[0] != 1:
limit = 1

outputs = torch.cat(
(outputs,
torch.zeros((batch_size, output_tensor.shape[1] - outputs.shape[1]),
device=outputs.device)),
dim=1)

if add_inputs:
output_tensor += outputs[:limit, :]
else:
output_tensor = outputs[:limit, :]


def sgmv_shrink(
inputs: torch.Tensor,
lora_a_weights: torch.Tensor,
output_tensor: torch.Tensor,
b_seq_start_loc: torch.Tensor,
seq_len_tensor: torch.Tensor,
lora_indices_tensor: torch.Tensor,
batches: int,
max_seq_length: int,
token_nums: int,
scaling: float,
):
exploded_indices = torch.repeat_interleave(lora_indices_tensor,
inputs.size(0))

bgmv_shrink(inputs, lora_a_weights, output_tensor, exploded_indices,
scaling)


def bgmv_shrink(inputs: torch.Tensor,
lora_b_weights: torch.Tensor,
output_tensor: torch.Tensor,
lora_indices_tensor: torch.Tensor,
scaling: float = 1.0):
selected_loras = lora_b_weights[lora_indices_tensor].to(
dtype=output_tensor.dtype)
if len(selected_loras.shape) == 4:
selected_loras = selected_loras.squeeze(dim=1)
inputs = inputs.to(dtype=output_tensor.dtype)
# outputs = torch.einsum("bi, boi -> bo", inputs, selected_loras)
batch_size, output_size, input_size = selected_loras.shape
outputs = (selected_loras @ inputs.reshape(
(batch_size, input_size, 1))).reshape((batch_size, output_size))

output_tensor = scaling * outputs


def sgmv_expand_slice(inputs: torch.Tensor,
lora_b_weights: torch.Tensor,
output_tensor: torch.Tensor,
b_seq_start_loc: torch.Tensor,
seq_len_tensor: torch.Tensor,
lora_indices_tensor: torch.Tensor,
batches: int,
max_seq_length: int,
token_nums: int,
slice_offset: int,
slice_size: int,
add_inputs: bool = False):
exploded_indices = torch.repeat_interleave(lora_indices_tensor,
inputs.size(0))

bgmv_expand_slice(inputs, lora_b_weights, output_tensor, exploded_indices,
slice_offset, slice_size, add_inputs)


def bgmv_expand_slice(inputs: torch.Tensor,
lora_b_weights: torch.Tensor,
output_tensor: torch.Tensor,
lora_indices_tensor: torch.Tensor,
slice_offset: int,
slice_size: int,
add_inputs: bool = True):
selected_loras = lora_b_weights[lora_indices_tensor].to(
dtype=output_tensor.dtype)

inputs = inputs.to(dtype=output_tensor.dtype)

if len(selected_loras.shape) == 4:
selected_loras = selected_loras.squeeze(dim=1)

batch_size, output_size, input_size = selected_loras.shape

outputs = (selected_loras @ inputs.reshape(
(batch_size, input_size, 1))).reshape((batch_size, output_size))

outputs = torch.cat((
torch.zeros((batch_size, slice_offset), device=outputs.device),
outputs,
torch.zeros(
(batch_size, output_tensor.shape[1] - (slice_offset + slice_size)),
device=outputs.device),
),
dim=1)

if add_inputs:
output_tensor += outputs
else:
output_tensor = outputs
Loading