Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Attention] MLA with chunked prefill #12639

Merged
merged 44 commits into from
Feb 21, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
44 commits
Select commit Hold shift + click to select a range
4267344
chunked mla
LucasWilkinson Feb 1, 2025
2821aed
add gather cache kernel
LucasWilkinson Feb 5, 2025
dc00371
wip
LucasWilkinson Feb 6, 2025
f50719b
wip running
LucasWilkinson Feb 6, 2025
3d2e770
more cleanup
LucasWilkinson Feb 6, 2025
ea19198
better defaults
LucasWilkinson Feb 6, 2025
d116752
increase MLA gpu_memory_utilization default
LucasWilkinson Feb 6, 2025
c3ad988
wip
LucasWilkinson Feb 7, 2025
ca1b07d
wip fix tensor on wrong device
LucasWilkinson Feb 11, 2025
396f4db
wip
LucasWilkinson Feb 12, 2025
b4a900e
finally :/
LucasWilkinson Feb 12, 2025
04c6042
working!
LucasWilkinson Feb 13, 2025
d0925bb
clean-up
LucasWilkinson Feb 13, 2025
0e173be
delete files
LucasWilkinson Feb 13, 2025
25bd9cb
cleanup
LucasWilkinson Feb 13, 2025
b73fb74
cleanup
LucasWilkinson Feb 13, 2025
829ce2b
relocate merge_attn_states
LucasWilkinson Feb 13, 2025
dee34f7
add comments
LucasWilkinson Feb 13, 2025
b28f99a
comment fixes
LucasWilkinson Feb 13, 2025
54ae713
minor fixes
LucasWilkinson Feb 13, 2025
3db8ab6
remove no-longer necessary changes
LucasWilkinson Feb 13, 2025
04644e3
clean-up
LucasWilkinson Feb 13, 2025
4398787
fix tp
LucasWilkinson Feb 13, 2025
d73f9ff
review comments
LucasWilkinson Feb 14, 2025
f4da0b6
minor fix
LucasWilkinson Feb 14, 2025
50a53aa
fix wrong device, increase workspace, enable cuda-graphs
LucasWilkinson Feb 14, 2025
e0a758e
minor changes
LucasWilkinson Feb 14, 2025
3c800bb
add comment
simon-mo Feb 14, 2025
a79ee4c
fix assert
LucasWilkinson Feb 15, 2025
1c59597
extra workspace allocation during profile run
LucasWilkinson Feb 15, 2025
1137f76
rename
LucasWilkinson Feb 15, 2025
920ecc6
fix illegal memory access
LucasWilkinson Feb 17, 2025
0547a94
Merge remote-tracking branch 'origin/main' into lwilkinson/chunked-mla
LucasWilkinson Feb 18, 2025
b665575
format
LucasWilkinson Feb 18, 2025
3a0ae51
format
LucasWilkinson Feb 18, 2025
28464b5
mypy pass
LucasWilkinson Feb 18, 2025
609267b
Merge branch 'main' into lwilkinson/chunked-mla
tlrmchlsmth Feb 19, 2025
dfb3ada
fix basic model test
LucasWilkinson Feb 19, 2025
9ca182b
attempt to fix AMD build
LucasWilkinson Feb 19, 2025
d325935
attempt 2 fix amd build
LucasWilkinson Feb 19, 2025
6394a8a
Merge remote-tracking branch 'origin/main' into lwilkinson/chunked-mla
LucasWilkinson Feb 20, 2025
f17599e
Merge remote-tracking branch 'origin/main' into lwilkinson/chunked-mla
LucasWilkinson Feb 21, 2025
c5fbdaa
Merge remote-tracking branch 'origin/main' into lwilkinson/chunked-mla
LucasWilkinson Feb 21, 2025
10c4e54
Merge remote-tracking branch 'origin/main' into lwilkinson/chunked-mla
LucasWilkinson Feb 21, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions csrc/cache.h
Original file line number Diff line number Diff line change
Expand Up @@ -39,3 +39,10 @@ void concat_and_cache_mla(torch::Tensor& kv_c, torch::Tensor& k_pe,
// Just for unittest
void convert_fp8(torch::Tensor& dst_cache, torch::Tensor& src_cache,
const double scale, const std::string& kv_cache_dtype);

void gather_cache(
torch::Tensor const& src_cache, // [NUM_BLOCKS, BLOCK_SIZE, ENTRIES...]
torch::Tensor const& dst, // [TOT_TOKENS, ENTRIES...]
torch::Tensor const& block_table, // [BATCH, BLOCK_INDICES]
torch::Tensor const& cu_seq_lens, // [BATCH+1]
int64_t batch_size, std::optional<torch::Tensor> seq_starts = std::nullopt);
159 changes: 159 additions & 0 deletions csrc/cache_kernels.cu
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h>

#include "cuda_utils.h"
#include "cuda_compat.h"
#include "dispatch_utils.h"

Expand Down Expand Up @@ -570,3 +571,161 @@ void convert_fp8(torch::Tensor& dst_cache, torch::Tensor& src_cache,
TORCH_CHECK(false, "Unsupported data type: ", kv_cache_dtype);
}
}

namespace vllm {

// grid is launched with dimensions (batch, num_splits)
template <typename scalar_t>
__global__ void gather_cache(
const scalar_t* __restrict__ src_cache, // [NUM_BLOCKS, BLOCK_SIZE,
// ENTRIES...]
scalar_t* __restrict__ dst, // [TOT_TOKENS, ENTRIES...]
const int32_t* __restrict__ block_table, // [BATCH, BLOCK_INDICES]
const int32_t* __restrict__ cu_seq_lens, // [BATCH+1]
const int32_t block_size, const int32_t entry_size,
const int64_t block_table_stride, const int64_t cache_block_stride,
const int64_t cache_entry_stride, const int64_t dst_entry_stride,
const int32_t* __restrict__ seq_starts) { // Optional: starting offsets per
// batch

const int64_t bid = blockIdx.x; // Batch ID
const int32_t num_splits = gridDim.y;
const int32_t split = blockIdx.y;
const int32_t seq_start = cu_seq_lens[bid];
const int32_t seq_end = cu_seq_lens[bid + 1];
const int32_t seq_len = seq_end - seq_start;
const int32_t tot_blocks = cuda_utils::ceil_div(seq_len, block_size);
const int32_t split_blocks = cuda_utils::ceil_div(tot_blocks, num_splits);

const int32_t split_start = split * split_blocks;
const int32_t split_end = min((split + 1) * split_blocks, tot_blocks);

const bool is_active_split = (split_start < tot_blocks);
const bool is_last_split = (split_end == tot_blocks);

if (!is_active_split) return;

int32_t full_blocks_end = split_end;
int32_t partial_block_size = 0;

// Adjust the pointer for the block_table for this batch.
// If seq_starts is provided, compute an offset based on (seq_starts[bid] /
// page_size)
const int32_t batch_offset = bid * block_table_stride;
int32_t offset = 0;
if (seq_starts != nullptr) {
offset = seq_starts[bid] / block_size;
}
const int32_t* batch_block_table = block_table + batch_offset + offset;

// Adjust dst pointer based on the cumulative sequence lengths.
dst += seq_start * dst_entry_stride;

if (is_last_split) {
partial_block_size = seq_len % block_size;
if (partial_block_size) full_blocks_end -= 1;
}

auto copy_entry = [&](const scalar_t* __restrict__ _src,
scalar_t* __restrict__ _dst) {
for (int i = threadIdx.x; i < entry_size; i += blockDim.x)
_dst[i] = _src[i];
};

for (int pid = split_start; pid < full_blocks_end; ++pid) {
auto block_id = batch_block_table[pid];
auto block_start_ptr = src_cache + block_id * cache_block_stride;
auto block_dst_ptr = dst + pid * block_size * dst_entry_stride;
for (int eid = 0; eid < block_size; ++eid) {
copy_entry(block_start_ptr + eid * cache_entry_stride,
block_dst_ptr + eid * dst_entry_stride);
}
}

if (partial_block_size) {
auto block_id = batch_block_table[full_blocks_end];
auto block_start_ptr = src_cache + block_id * cache_block_stride;
auto block_dst_ptr = dst + full_blocks_end * block_size * dst_entry_stride;
for (int eid = 0; eid < partial_block_size; ++eid) {
copy_entry(block_start_ptr + eid * cache_entry_stride,
block_dst_ptr + eid * dst_entry_stride);
}
}
}

} // namespace vllm

// Macro to dispatch the kernel based on the data type.
#define CALL_GATHER_CACHE(CPY_DTYPE) \
vllm::gather_cache<CPY_DTYPE><<<grid, block, 0, stream>>>( \
reinterpret_cast<CPY_DTYPE*>(src_cache.data_ptr()), \
reinterpret_cast<CPY_DTYPE*>(dst.data_ptr()), \
block_table.data_ptr<int32_t>(), cu_seq_lens.data_ptr<int32_t>(), \
block_size, entry_size, block_table_stride, cache_block_stride, \
cache_entry_stride, dst_entry_stride, seq_starts_ptr);

// Gather sequences from the cache into the destination tensor.
// - cu_seq_lens contains the cumulative sequence lengths for each batch
// - block_table contains the cache block indices for each sequence
// - Optionally, seq_starts (if provided) offsets the starting block index by
// (seq_starts[bid] / page_size)
void gather_cache(
torch::Tensor const& src_cache, // [NUM_BLOCKS, BLOCK_SIZE, ENTRIES...]
torch::Tensor const& dst, // [TOT_TOKENS, ENTRIES...]
torch::Tensor const& block_table, // [BATCH, BLOCK_INDICES]
torch::Tensor const& cu_seq_lens, // [BATCH+1]
int64_t batch_size,
std::optional<torch::Tensor> seq_starts = std::nullopt) {
at::cuda::OptionalCUDAGuard device_guard(src_cache.device());
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();

int32_t block_size = src_cache.size(1);
int32_t entry_size = src_cache.flatten(2, -1).size(2);

TORCH_CHECK(block_table.dtype() == torch::kInt32,
"block_table must be int32");
TORCH_CHECK(cu_seq_lens.dtype() == torch::kInt32,
"cu_seq_lens must be int32");
if (seq_starts.has_value()) {
TORCH_CHECK(seq_starts.value().dtype() == torch::kInt32,
"seq_starts must be int32");
}

TORCH_CHECK(src_cache.device() == dst.device(),
"src_cache and dst must be on the same device");
TORCH_CHECK(src_cache.device() == block_table.device(),
"src_cache and block_table must be on the same device");
TORCH_CHECK(src_cache.device() == cu_seq_lens.device(),
"src_cache and cu_seq_lens must be on the same device");
if (seq_starts.has_value()) {
TORCH_CHECK(src_cache.device() == seq_starts.value().device(),
"src_cache and seq_starts must be on the same device");
}
Comment on lines +694 to +703
Copy link
Collaborator

@tlrmchlsmth tlrmchlsmth Feb 14, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Future work: this is generally useful across all kernels and we should probably factor these checks out into a helper function. Something like this:

namespace detail {
// Get device from either Tensor or optional<Tensor>
inline std::optional<torch::Device> get_device(const torch::Tensor& tensor) {
    return tensor.device();
}

inline std::optional<torch::Device> get_device(const std::optional<torch::Tensor>& maybe_tensor) {
    return maybe_tensor.has_value() ? std::optional(maybe_tensor.value().device()) 
                                  : std::nullopt;
}
} // namespace detail

template <typename First, typename... Rest>
void check_same_device(const First& first, const Rest&... rest) {
    auto first_device = detail::get_device(first);
    if (!first_device.has_value()) return;
    
    ([&](const auto& tensor) {
        auto device = detail::get_device(tensor);
        if (device.has_value()) {
            TORCH_CHECK(*device == *first_device, "All tensors must be on the same device");
        }
    }(rest), ...);
}

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

agreed, that would be nice, ill work on a separated PR


int64_t block_table_stride = block_table.stride(0);
int64_t cache_block_stride = src_cache.stride(0);
int64_t cache_entry_stride = src_cache.stride(1);
int64_t dst_entry_stride = dst.stride(0);

// Decide on the number of splits based on the batch size.
int num_splits = batch_size > 128 ? 2 : batch_size > 64 ? 4 : 16;
dim3 grid(batch_size, num_splits);
dim3 block(1024);

TORCH_CHECK(src_cache.dtype() == dst.dtype(),
"src_cache and dst must have the same dtype");

const int dtype_bits = src_cache.element_size() * 8;
const int32_t* seq_starts_ptr =
seq_starts.has_value() ? seq_starts.value().data_ptr<int32_t>() : nullptr;

if (dtype_bits == 32) {
CALL_GATHER_CACHE(uint32_t);
} else if (dtype_bits == 16) {
CALL_GATHER_CACHE(uint16_t);
} else if (dtype_bits == 8) {
CALL_GATHER_CACHE(uint8_t);
} else {
TORCH_CHECK(false, "Unsupported data type width: ", dtype_bits);
}
}
5 changes: 0 additions & 5 deletions csrc/core/math.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,3 @@ inline constexpr uint32_t next_pow_2(uint32_t const num) {
if (num <= 1) return num;
return 1 << (CHAR_BIT * sizeof(num) - __builtin_clz(num - 1));
}

template <typename T>
inline constexpr std::enable_if_t<std::is_integral_v<T>, T> ceil_div(T a, T b) {
return (a + b - 1) / b;
}
22 changes: 18 additions & 4 deletions csrc/cuda_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,14 @@

#include <stdio.h>

#if defined(__CUDACC__) || defined(_NVHPC_CUDA)
#define HOST_DEVICE_INLINE __forceinline__ __host__ __device__
#define DEVICE_INLINE __forceinline__ __device__
#define HOST_INLINE __forceinline__ __host__
#if defined(__HIPCC__)
#define HOST_DEVICE_INLINE __host__ __device__
#define DEVICE_INLINE __device__
#define HOST_INLINE __host__
#elif defined(__CUDACC__) || defined(_NVHPC_CUDA)
#define HOST_DEVICE_INLINE __host__ __device__ __forceinline__
#define DEVICE_INLINE __device__ __forceinline__
#define HOST_INLINE __host__ __forceinline__
#else
#define HOST_DEVICE_INLINE inline
#define DEVICE_INLINE inline
Expand All @@ -25,3 +29,13 @@
int64_t get_device_attribute(int64_t attribute, int64_t device_id);

int64_t get_max_shared_memory_per_block_device_attribute(int64_t device_id);

namespace cuda_utils {

template <typename T>
HOST_DEVICE_INLINE constexpr std::enable_if_t<std::is_integral_v<T>, T>
ceil_div(T a, T b) {
return (a + b - 1) / b;
}

}; // namespace cuda_utils
5 changes: 3 additions & 2 deletions csrc/quantization/cutlass_w8a8/scaled_mm_c3x.cu
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
#include <cudaTypedefs.h>
#include "c3x/scaled_mm_kernels.hpp"

#include "core/math.hpp"
#include "cuda_utils.h"

/*
This file defines quantized GEMM operations using the CUTLASS 3.x API, for
Expand Down Expand Up @@ -33,7 +33,8 @@ void cutlass_scaled_mm_sm90(torch::Tensor& c, torch::Tensor const& a,
auto make_group_shape = [](torch::Tensor const& x,
torch::Tensor const& s) -> GroupShape {
TORCH_CHECK(s.dim() == 2, "cutlass_scaled_mm group scales must be 2D");
return {ceil_div(x.size(0), s.size(0)), ceil_div(x.size(1), s.size(1))};
return {cuda_utils::ceil_div(x.size(0), s.size(0)),
cuda_utils::ceil_div(x.size(1), s.size(1))};
};

GroupShape a_scale_group_shape = make_group_shape(a, a_scales);
Expand Down
6 changes: 6 additions & 0 deletions csrc/torch_bindings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -493,6 +493,12 @@ TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _cache_ops), cache_ops) {
"convert_fp8(Tensor! dst_cache, Tensor src_cache, float scale, "
"str kv_cache_dtype) -> ()");
cache_ops.impl("convert_fp8", torch::kCUDA, &convert_fp8);

// Gather cache blocks from src_cache to dst.
cache_ops.def(
"gather_cache(Tensor src_cache, Tensor! dst, Tensor block_table, "
"Tensor cu_seq_lens, int batch_size, Tensor? seq_starts) -> ()");
cache_ops.impl("gather_cache", torch::kCUDA, &gather_cache);
}

TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _cuda_utils), cuda_utils) {
Expand Down
75 changes: 73 additions & 2 deletions tests/kernels/test_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -682,8 +682,6 @@ def test_swap_blocks_mla(
torch.ops._C_cache_ops.swap_blocks,
(src_cache, dst_cache, block_mapping_tensor),
test_utils=DEFAULT_OPCHECK_TEST_UTILS,
cond=(kv_lora_rank == KV_LORA_RANKS[0]
and qk_rope_head_dim == QK_ROPE_HEAD_DIMS[0]),
)

ops.swap_blocks(src_cache, dst_cache, block_mapping_tensor)
Expand All @@ -694,3 +692,76 @@ def test_swap_blocks_mla(
dst_cache[dst].cpu(),
msg=f"Block {src} from src should have been swapped to block "
f"{dst} in dst_cache.")


@pytest.mark.parametrize("kv_lora_rank", [512])
@pytest.mark.parametrize("qk_rope_head_dim", [64])
@pytest.mark.parametrize("block_size", [16])
@pytest.mark.parametrize("num_blocks", [1024])
@pytest.mark.parametrize("max_seq_len", [512])
@pytest.mark.parametrize("batch_size", [8])
@pytest.mark.parametrize("dtype", [torch.float32])
@pytest.mark.parametrize("kv_cache_dtype",
["auto"]) # You can also test "fp8" if needed.
@pytest.mark.parametrize("align_cache", [True, False])
@pytest.mark.parametrize("device", CUDA_DEVICES)
@torch.inference_mode()
def test_gather_cache_mla(kv_lora_rank, qk_rope_head_dim, block_size,
num_blocks, max_seq_len, batch_size, dtype,
kv_cache_dtype, align_cache, device):
entry_size = kv_lora_rank + qk_rope_head_dim
src_cache = _create_mla_cache(num_blocks, block_size, entry_size, dtype,
kv_cache_dtype, device, align_cache)
_fill_mla_cache(src_cache, kv_cache_dtype=kv_cache_dtype)

seq_len_tensor = torch.randint(0,
max_seq_len + 1, (batch_size, ),
device=device)

total_tokens = seq_len_tensor.sum()
cu_seq_lens = torch.empty((batch_size + 1),
dtype=torch.int32,
device=device)
cu_seq_lens[0] = 0
cu_seq_lens[1:] = seq_len_tensor.cumsum(dim=0).to(dtype=torch.int32)
print("seq_len_tensor", seq_len_tensor)

tot_blocks_tensor = (seq_len_tensor + block_size - 1) // block_size
block_table = torch.empty((batch_size, num_blocks),
dtype=torch.int32,
device=device)

for b in range(batch_size):
perm = torch.randperm(num_blocks, device=device)
block_table[b, :] = perm

dst = torch.zeros((total_tokens, entry_size),
dtype=src_cache.dtype,
device=device)

expected_batches = []
for b in range(batch_size):
s = seq_len_tensor[b]
if s == 0:
continue
tot = tot_blocks_tensor[b]
blocks = block_table[b, :tot].tolist()

gathered_rows = []
for i in range(tot - 1):
gathered_rows.append(src_cache[blocks[i]])
remaining = s - (tot - 1) * block_size
gathered_rows.append(src_cache[blocks[-1], :remaining, :])

batch_expected = torch.cat(gathered_rows, dim=0)
expected_batches.append(batch_expected)
expected = torch.cat(expected_batches, dim=0)

opcheck(
torch.ops._C_cache_ops.gather_cache,
(src_cache, dst, block_table, cu_seq_lens, batch_size, None),
test_utils=DEFAULT_OPCHECK_TEST_UTILS,
)

ops.gather_cache(src_cache, dst, block_table, cu_seq_lens, batch_size)
torch.testing.assert_close(dst, expected)
10 changes: 10 additions & 0 deletions vllm/_custom_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -1099,6 +1099,16 @@ def convert_fp8(output: torch.Tensor,
torch.ops._C_cache_ops.convert_fp8(output, input, scale, kv_dtype)


def gather_cache(src_cache: torch.Tensor,
dst: torch.Tensor,
block_table: torch.Tensor,
cu_seq_lens: torch.Tensor,
batch_size: int,
seq_starts: Optional[torch.Tensor] = None) -> None:
torch.ops._C_cache_ops.gather_cache(src_cache, dst, block_table,
cu_seq_lens, batch_size, seq_starts)


def get_device_attribute(attribute: int, device: int) -> int:
return torch.ops._C_cuda_utils.get_device_attribute(attribute, device)

Expand Down
12 changes: 4 additions & 8 deletions vllm/attention/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,16 +4,12 @@
AttentionMetadata,
AttentionMetadataBuilder,
AttentionState, AttentionType)
from vllm.attention.backends.utils import get_flash_attn_version
from vllm.attention.layer import Attention
from vllm.attention.selector import get_attn_backend

__all__ = [
"Attention",
"AttentionBackend",
"AttentionMetadata",
"AttentionType",
"AttentionMetadataBuilder",
"Attention",
"AttentionState",
"get_attn_backend",
"Attention", "AttentionBackend", "AttentionMetadata", "AttentionType",
"AttentionMetadataBuilder", "Attention", "AttentionState",
"get_attn_backend", "get_flash_attn_version"
]
Loading