Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Misc][Kernel]: Add GPTQAllSpark Quantization #12931

Merged
merged 5 commits into from
Mar 1, 2025
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 16 additions & 0 deletions CMakeLists.txt
100755 → 100644
Original file line number Diff line number Diff line change
@@ -317,6 +317,22 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
" in CUDA target architectures")
endif()

# Only build AllSpark kernels if we are building for at least some compatible archs.
cuda_archs_loose_intersection(ALLSPARK_ARCHS "8.0;8.6;8.7;8.9" "${CUDA_ARCHS}")
if (ALLSPARK_ARCHS)
set(ALLSPARK_SRCS
"csrc/quantization/gptq_allspark/allspark_repack.cu"
"csrc/quantization/gptq_allspark/allspark_qgemm_w8a16.cu")
set_gencode_flags_for_srcs(
SRCS "${ALLSPARK_SRCS}"
CUDA_ARCHS "${ALLSPARK_ARCHS}")
list(APPEND VLLM_EXT_SRC "${ALLSPARK_SRCS}")
message(STATUS "Building AllSpark kernels for archs: ${ALLSPARK_ARCHS}")
else()
message(STATUS "Not building AllSpark kernels as no compatible archs found"
" in CUDA target architectures")
endif()

# The cutlass_scaled_mm kernels for Hopper (c3x, i.e. CUTLASS 3.x) require
# CUDA 12.0 or later (and only work on Hopper, 9.0a for now).
cuda_archs_loose_intersection(SCALED_MM_3X_ARCHS "9.0a" "${CUDA_ARCHS}")
47 changes: 45 additions & 2 deletions benchmarks/kernels/benchmark_marlin.py
Original file line number Diff line number Diff line change
@@ -10,6 +10,8 @@
from vllm.model_executor.layers.quantization.gptq_marlin_24 import (
GPTQ_MARLIN_24_MAX_PARALLEL, GPTQ_MARLIN_24_MIN_THREAD_N,
GPTQ_MARLIN_24_SUPPORTED_GROUP_SIZES, GPTQ_MARLIN_24_SUPPORTED_QUANT_TYPES)
from vllm.model_executor.layers.quantization.utils.allspark_utils import (
ALLSPARK_AMPERE_M_CUBLAS_THRESHOLD, ALLSPARK_SUPPORTED_QUANT_TYPES)
from vllm.model_executor.layers.quantization.utils.marlin_utils import (
GPTQ_MARLIN_MAX_PARALLEL, GPTQ_MARLIN_MIN_THREAD_N,
MARLIN_SUPPORTED_GROUP_SIZES, query_marlin_supported_quant_types)
@@ -18,12 +20,12 @@
from vllm.model_executor.layers.quantization.utils.marlin_utils_test_24 import (
marlin_24_quantize)
from vllm.model_executor.layers.quantization.utils.quant_utils import (
gptq_pack, gptq_quantize_weights, sort_weights)
gptq_pack, gptq_quantize_weights, quantize_weights, sort_weights)
from vllm.scalar_type import ScalarType
from vllm.utils import FlexibleArgumentParser

DEFAULT_MODELS = ["meta-llama/Llama-2-7b-hf/TP1"]
DEFAULT_BATCH_SIZES = [1, 16, 32, 64, 128, 256, 512]
DEFAULT_BATCH_SIZES = [1, 16, 32, 64, 128, 256, 512, 1024, 2048, 4096, 8192]

ACT_ORDER_OPTS = [False, True]
K_FULL_OPTS = [False, True]
@@ -81,6 +83,27 @@ def bench_run(results: List[benchmark.Measurement], model: str,
GPTQ_MARLIN_24_MAX_PARALLEL)
marlin_zp = torch.zeros_like(marlin_s, dtype=torch.int)

# AllSpark W8A16 quant
as_supported_case = (quant_type in ALLSPARK_SUPPORTED_QUANT_TYPES
and group_size == -1 and not act_order and is_k_full)
if as_supported_case:
properties = torch.cuda.get_device_properties(b.device.index)
sm_count = properties.multi_processor_count
sm_version = properties.major * 10 + properties.minor

supported_arch = (sm_version >= 80 and sm_version < 90)
as_supported_case = as_supported_case and supported_arch
if supported_arch:
has_zp = False
w_ref, qw, s, zp = quantize_weights(b, quant_type, group_size,
has_zp)
qw = qw.to(torch.uint8)

qw_reorder, s_reorder, zp_reorder = \
ops.allspark_repack_weight(
qw, s, zp, has_zp)
CUBLAS_M_THRESHOLD = ALLSPARK_AMPERE_M_CUBLAS_THRESHOLD

globals = {
# Gen params
"quant_type": quant_type,
@@ -109,10 +132,19 @@ def bench_run(results: List[benchmark.Measurement], model: str,
# GPTQ params
"q_w_gptq": q_w_gptq,
"repack_sort_indices": repack_sort_indices,
# AllSpark W8A16 params
"qw_reorder": qw_reorder if as_supported_case else None,
"s_reorder": s_reorder if as_supported_case else None,
"zp_reorder": zp_reorder if as_supported_case else None,
"sm_count": sm_count if as_supported_case else None,
"sm_version": sm_version if as_supported_case else None,
"CUBLAS_M_THRESHOLD":
CUBLAS_M_THRESHOLD if as_supported_case else None,
# Kernels
"gptq_marlin_gemm": ops.gptq_marlin_gemm,
"gptq_marlin_24_gemm": ops.gptq_marlin_24_gemm,
"gptq_marlin_repack": ops.gptq_marlin_repack,
"allspark_w8a16_gemm": ops.allspark_w8a16_gemm,
}

min_run_time = 1
@@ -172,6 +204,17 @@ def bench_run(results: List[benchmark.Measurement], model: str,
description="gptq_marlin_repack",
).blocked_autorange(min_run_time=min_run_time))

if as_supported_case:
results.append(
benchmark.Timer(
stmt=
"output = allspark_w8a16_gemm(a, qw_reorder, s_reorder, zp_reorder, size_n, group_size, sm_count, sm_version, CUBLAS_M_THRESHOLD, False, True)", # noqa: E501
globals=globals,
label=label,
sub_label=sub_label,
description="allspark_w8a16_gemm_fp32",
).blocked_autorange(min_run_time=min_run_time))


def main(args):
print("Benchmarking models:")
1,008 changes: 1,008 additions & 0 deletions csrc/quantization/gptq_allspark/allspark_qgemm_w8a16.cu

Large diffs are not rendered by default.

163 changes: 163 additions & 0 deletions csrc/quantization/gptq_allspark/allspark_repack.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,163 @@
#include "allspark_utils.cuh"
#include <torch/all.h>
#include "core/registration.h"

namespace allspark {

// Rearrange B to facilitate Ampere Tensor Core load data
// reorder B from (K, N) to (N_32align / 4, K * 4)
// K % 16 == 0, N % 16 == 0, N_32align % 32 == 0
template <typename FType>
__global__ void __launch_bounds__(128)
rearrange_kn_weight_as_n32k16_order_ldg16_kernel(
const uint8_t* B, const FType* B_scale, const FType* B_zero,
uint8_t* B_result, FType* B_scale_result, FType* B_zero_result,
const int K, const int N, const int N_32align) {
const int lane_id = threadIdx.x % 32;
const int warp_id = threadIdx.x / 32;

if (blockIdx.x != gridDim.x - 1) {
// Load B
// per block process 64(k) * 128(n) B elements
// per warp process 16(k) * 128 B elements
const int src_row_base_idx =
blockIdx.x * 64 + warp_id * 16 + ((lane_id % 8) / 2) * 2;
const int src_col_idx =
blockIdx.y * 128 + (lane_id / 8) * 32 + (lane_id % 2) * 16;
uint8_t B_frag[4][16];
#pragma unroll
for (int i = 0; i < 4; ++i) {
int src_row_idx = src_row_base_idx + (i / 2) * 8 + (i % 2);
int src_offset = src_row_idx * N + src_col_idx;
bool guard = src_row_idx < K && src_col_idx < N;
ldg128_cg_0(*reinterpret_cast<uint32_t*>(B_frag[i]),
*(reinterpret_cast<uint32_t*>(B_frag[i]) + 1),
*(reinterpret_cast<uint32_t*>(B_frag[i]) + 2),
*(reinterpret_cast<uint32_t*>(B_frag[i]) + 3), B + src_offset,
guard);
}

// reorder B
uint8_t B_reorder_frag[8][8];
#pragma unroll
for (int i = 0; i < 4; ++i) {
#pragma unroll
for (int j = 0; j < 16; ++j) {
int dst_i = j % 8;
int dst_j = i + (j / 8) * 4;
B_reorder_frag[dst_i][dst_j] = B_frag[i][j];
}
}

// Store B
const int dst_row_base_idx = blockIdx.y * (128 / 4) + (lane_id / 8) * 8;
const int dst_col_idx =
blockIdx.x * (64 * 4) + warp_id * 64 + (lane_id % 8) * 8;
for (int i = 0; i < 8; ++i) {
int dst_row_idx = dst_row_base_idx + i;
int dst_offset = dst_row_idx * K * 4 + dst_col_idx;
bool guard = (dst_row_base_idx < N_32align / 4) && (dst_col_idx < K * 4);
if (guard) {
*reinterpret_cast<int2*>(B_result + dst_offset) =
*reinterpret_cast<int2*>(B_reorder_frag[i]);
}
}
} else {
// Load B_scale and B_zero
FType b_scale_reg, b_zero_reg;
int src_offset = blockIdx.y * 128 + threadIdx.x;
ldg16_cg_0(b_scale_reg, B_scale + src_offset, src_offset < N);
if (B_zero != nullptr)
ldg16_cg_0(b_zero_reg, B_zero + src_offset, src_offset < N);
int dst_offset =
blockIdx.y * 128 + warp_id * 32 + (lane_id % 8) * 4 + lane_id / 8;
if (dst_offset < N_32align) {
B_scale_result[dst_offset] = b_scale_reg;
if (B_zero != nullptr) B_zero_result[dst_offset] = b_zero_reg;
}
}
}

template <typename FType>
void rearrange_kn_weight_as_n32k16_order_ldg16(
const uint8_t* B, const FType* B_scale, const FType* B_zero,
uint8_t* B_result, FType* B_scale_result, FType* B_zero_result,
const int64_t K, const int64_t N, const int64_t N_32align,
cudaStream_t stream) {
if (N % 16 != 0 || K % 16 != 0) {
std::cerr << "Now only support N and K is multiples of 16" << std::endl;
}
const int BLOCK = 128;
int grid_x = (K + 64 - 1) / 64 + 1;
int grid_y = (N + 128 - 1) / 128;
dim3 grid(grid_x, grid_y);

rearrange_kn_weight_as_n32k16_order_ldg16_kernel<FType>
<<<grid, BLOCK, 0, stream>>>(B, B_scale, B_zero, B_result, B_scale_result,
B_zero_result, K, N, N_32align);
}
} // namespace allspark

void rearrange_kn_weight_as_n32k16_order(
torch::Tensor const& b_qweight, torch::Tensor const& b_scales,
c10::optional<torch::Tensor> const& b_zeros, bool has_zp,
torch::Tensor& b_qweight_reorder, torch::Tensor& b_scales_reorder,
c10::optional<torch::Tensor> const& b_zeros_reorder, const int64_t K,
const int64_t N, const int64_t N_32align) {
// Verify device and strides
TORCH_CHECK(b_qweight.device().is_cuda(), "b_qweight is not on GPU");
TORCH_CHECK(b_qweight.is_contiguous(), "b_qweight is not contiguous");

TORCH_CHECK(b_scales.device().is_cuda(), "b_scales is not on GPU");
TORCH_CHECK(b_scales.is_contiguous(), "b_scales is not contiguous");

TORCH_CHECK(b_qweight_reorder.device().is_cuda(),
"b_qweight_reorder is not on GPU");
TORCH_CHECK(b_qweight_reorder.is_contiguous(),
"b_qweight_reorder is not contiguous");

TORCH_CHECK(b_scales_reorder.device().is_cuda(),
"b_scales_reorder is not on GPU");
TORCH_CHECK(b_scales_reorder.is_contiguous(),
"b_scales_reorder is not contiguous");

if (has_zp) {
TORCH_CHECK(b_zeros.value().device().is_cuda(), "b_zeros is not on GPU");
TORCH_CHECK(b_zeros.value().is_contiguous(), "b_zeros is not contiguous");

TORCH_CHECK(b_zeros_reorder.value().device().is_cuda(),
"b_zeros_reorder is not on GPU");
TORCH_CHECK(b_zeros_reorder.value().is_contiguous(),
"b_zeros_reorder is not contiguous");
}

const uint8_t* matB = reinterpret_cast<const uint8_t*>(b_qweight.data_ptr());
const void* b_scale = b_scales.data_ptr();
const void* b_zero = has_zp ? b_zeros.value().data_ptr() : nullptr;

uint8_t* matB_reorder =
reinterpret_cast<uint8_t*>(b_qweight_reorder.data_ptr());
void* b_scale_reorder = b_scales_reorder.data_ptr();
void* b_zero_reorder = has_zp ? b_zeros_reorder.value().data_ptr() : nullptr;

cudaStream_t stream = at::cuda::getCurrentCUDAStream();
if (b_scales.dtype() == at::ScalarType::Half) {
allspark::rearrange_kn_weight_as_n32k16_order_ldg16<__half>(
matB, reinterpret_cast<const __half*>(b_scale),
reinterpret_cast<const __half*>(b_zero), matB_reorder,
reinterpret_cast<__half*>(b_scale_reorder),
reinterpret_cast<__half*>(b_zero_reorder), K, N, N_32align, stream);
} else if (b_scales.dtype() == at::ScalarType::BFloat16) {
allspark::rearrange_kn_weight_as_n32k16_order_ldg16<__nv_bfloat16>(
matB, reinterpret_cast<const __nv_bfloat16*>(b_scale),
reinterpret_cast<const __nv_bfloat16*>(b_zero), matB_reorder,
reinterpret_cast<__nv_bfloat16*>(b_scale_reorder),
reinterpret_cast<__nv_bfloat16*>(b_zero_reorder), K, N, N_32align,
stream);
}
}

TORCH_LIBRARY_IMPL_EXPAND(TORCH_EXTENSION_NAME, CUDA, m) {
m.impl("rearrange_kn_weight_as_n32k16_order",
&rearrange_kn_weight_as_n32k16_order);
}
408 changes: 408 additions & 0 deletions csrc/quantization/gptq_allspark/allspark_utils.cuh

Large diffs are not rendered by default.

19 changes: 19 additions & 0 deletions csrc/torch_bindings.cpp
Original file line number Diff line number Diff line change
@@ -447,6 +447,25 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
"Tensor!? azp) -> ()");
ops.impl("dynamic_scaled_int8_quant", torch::kCUDA,
&dynamic_scaled_int8_quant);

#ifndef USE_ROCM
// reorder weight for AllSpark Ampere W8A16 Fused Gemm kernel
ops.def(
"rearrange_kn_weight_as_n32k16_order(Tensor b_qweight, Tensor b_scales, "
"Tensor? b_zeros, "
"bool has_zp, Tensor! b_qweight_reorder, Tensor! b_scales_reorder, "
"Tensor!? b_zeros_reorder, "
"int K, int N, int N_32align) -> ()");
// conditionally compiled so impl in source file

// AllSpark quantization ops
ops.def(
"allspark_w8a16_gemm(Tensor a, Tensor b_qweight, Tensor b_scales, "
"Tensor? b_qzeros, "
"SymInt n, SymInt group_size, SymInt sm_count, SymInt sm_version, SymInt "
"CUBLAS_M_THRESHOLD, bool has_zp, bool n32k16_reorder) -> Tensor");
// conditionally compiled so impl in source file
#endif
}

TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _cache_ops), cache_ops) {
100 changes: 100 additions & 0 deletions tests/kernels/test_allspark_gemm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
# SPDX-License-Identifier: Apache-2.0
import pytest
import torch

from tests.kernels.utils import DEFAULT_OPCHECK_TEST_UTILS, opcheck
from vllm import _custom_ops as ops
from vllm.model_executor.layers.quantization.utils.allspark_utils import (
ALLSPARK_AMPERE_K_ALIGN, ALLSPARK_AMPERE_M_CUBLAS_THRESHOLD,
ALLSPARK_AMPERE_N_ALIGN)
from vllm.model_executor.layers.quantization.utils.quant_utils import (
quantize_weights)
from vllm.platforms import current_platform
from vllm.scalar_type import scalar_types


def is_gptq_allspark_supported(min_capability: int,
max_capability: int) -> bool:
if not current_platform.is_cuda():
return False

capability = current_platform.get_device_capability()
assert capability is not None

return capability.to_int() >= min_capability \
and capability.to_int() <= max_capability


MNK_FACTORS = [
(1, 4, 8),
(13, 17, 67),
(26, 37, 13),
(48, 16, 24),
(67, 13, 88),
(257, 13, 11),
(658, 13, 11),
(1033, 9, 17),
]

DTYPES = [torch.float16, torch.bfloat16]
HAS_ZP_OPTS = [False, True]


def compute_max_diff(output, output_ref):
return torch.mean(torch.abs(output - output_ref)) / torch.mean(
torch.abs(output_ref))


def rand_data(shape, dtype=torch.float16):
return torch.randn(shape, dtype=dtype, device="cuda")


@pytest.mark.skipif(
not is_gptq_allspark_supported(80, 89),
reason="AllSpark Ampere kernel is not supported on this GPU type.")
@pytest.mark.parametrize("mnk_factors", MNK_FACTORS)
@pytest.mark.parametrize("group_size", [-1])
@pytest.mark.parametrize("has_zp", HAS_ZP_OPTS)
@pytest.mark.parametrize("dtype", DTYPES)
def test_gptq_allspark_gemm_ampere(mnk_factors, group_size, has_zp, dtype):
m_factor, n_factor, k_factor = mnk_factors
m = m_factor
n = n_factor * ALLSPARK_AMPERE_N_ALIGN
k = k_factor * ALLSPARK_AMPERE_K_ALIGN

input = rand_data((m, k), dtype=dtype)
weight = rand_data((k, n), dtype=dtype)

# Quantize (and apply act_order if provided)
w_ref, qw, s, zp = quantize_weights(weight, scalar_types.uint8b128,
group_size, has_zp)

qw = qw.to(torch.uint8)
if has_zp:
zp = zp.to(dtype)
properties = torch.cuda.get_device_properties(qw.device.index)
sm_count = properties.multi_processor_count
sm_version = properties.major * 10 + properties.minor

n_32align = (n + 32 - 1) // 32 * 32

qw_reorder, s_reorder, zp_reorder = ops.allspark_repack_weight(
qw, s, zp, has_zp)
opcheck(torch.ops._C.rearrange_kn_weight_as_n32k16_order,
(qw, s, zp, has_zp, qw_reorder, s_reorder, zp_reorder, k, n,
n_32align))

opcheck(torch.ops._C.allspark_w8a16_gemm,
(input, qw_reorder, s_reorder, zp_reorder, n, group_size, sm_count,
sm_version, ALLSPARK_AMPERE_M_CUBLAS_THRESHOLD, has_zp, True),
test_utils=DEFAULT_OPCHECK_TEST_UTILS)
output = ops.allspark_w8a16_gemm(input, qw_reorder, s_reorder, zp_reorder,
n, group_size, sm_count, sm_version,
ALLSPARK_AMPERE_M_CUBLAS_THRESHOLD,
has_zp, True)

output_ref = torch.matmul(input, w_ref)
torch.cuda.synchronize()
max_diff = compute_max_diff(output, output_ref)

assert max_diff < 0.04
2 changes: 0 additions & 2 deletions tests/quantization/test_compressed_tensors.py
Original file line number Diff line number Diff line change
@@ -215,8 +215,6 @@ def check_model(model):
assert qkv_proj.scheme.group_size == (-1
if group is None else group)

assert qkv_proj.weight_packed.dtype is torch.int32
assert qkv_proj.weight_scale.dtype is torch.float16
assert qkv_proj.scheme.pack_factor == pack_factor

llm.apply_model(check_model)
77 changes: 77 additions & 0 deletions vllm/_custom_ops.py
Original file line number Diff line number Diff line change
@@ -404,6 +404,22 @@ def machete_prepack_B_fake(
memory_format=torch.contiguous_format)


if hasattr(torch.ops._C, "allspark_w8a16_gemm"):

@register_fake("_C::allspark_w8a16_gemm")
def _allspark_w8a16_gemm_fake(a: torch.Tensor, b_qweight: torch.Tensor,
b_scales: torch.Tensor,
b_qzeros: Optional[torch.Tensor],
n: torch.SymInt, group_size: torch.SymInt,
sm_count: torch.SymInt,
sm_version: torch.SymInt,
CUBLAS_M_THRESHOLD: torch.SymInt,
has_zp: bool,
n32k16_reorder: bool) -> torch.Tensor:
m = a.size(0)
return torch.empty((m, n), device=a.device, dtype=a.dtype)


if hasattr(torch.ops._C, "ggml_dequantize"):

@register_fake("_C::ggml_dequantize")
@@ -881,6 +897,67 @@ def scaled_fp8_quant(
return output, scale


# gptq allspark
def allspark_repack_weight(
qweight: torch.Tensor,
scale: torch.Tensor,
zero_point: Optional[torch.Tensor] = None,
has_zp: bool = False
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Rearrange qweight, scale, and zero_point(if asymmetric) to n32k16 format
for Ampere W8A16 Fused Gemm kernel
Args:
qweight: uint8 weight tensor, original k x n format.
scale: fp16/bf16 weight scale tensor, 1 x n format.
zero_point: fp16/bf16 weight zero_point tensor, 1 x n format.
Must be provided for asymmetric quantization.
has_zp: if use symmetric quantization, has_zp = False.
if use asymmetric quantization, has_zp = True.
Returns:
Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]] :
rearranged weight, scale, and optionally zero_point.
"""
K = qweight.shape[0]
N = qweight.shape[1]
N_32align = (N + 32 - 1) // 32 * 32

qweight_reorder = torch.empty((N_32align, K),
device=qweight.device,
dtype=qweight.dtype)
scale_reorder = torch.empty((1, N_32align),
device=scale.device,
dtype=scale.dtype)
zero_point_reorder = None
if has_zp:
assert zero_point is not None, (
"zero_point must be provided for asymmetric quantization.")
zero_point_reorder = torch.empty((1, N_32align),
device=zero_point.device,
dtype=zero_point.dtype)

torch.ops._C.rearrange_kn_weight_as_n32k16_order(
qweight, scale, zero_point, has_zp, qweight_reorder, scale_reorder,
zero_point_reorder, K, N, N_32align)

return qweight_reorder, scale_reorder, zero_point_reorder


def allspark_w8a16_gemm(a: torch.Tensor, b_qweight: torch.Tensor,
b_scales: torch.Tensor,
b_qzeros: Optional[torch.Tensor], n: int,
group_size: int, sm_count: int, sm_version: int,
CUBLAS_M_THRESHOLD: int, has_zp: bool,
n32k16_reorder: bool) -> torch.Tensor:

return torch.ops._C.allspark_w8a16_gemm(a, b_qweight, b_scales, b_qzeros,
n, group_size, sm_count,
sm_version, CUBLAS_M_THRESHOLD,
has_zp, n32k16_reorder)


# int8
def scaled_int8_quant(
input: torch.Tensor,
Original file line number Diff line number Diff line change
@@ -3,6 +3,8 @@
from typing import List, Optional, Type

import vllm.envs as envs
from vllm.model_executor.layers.quantization.kernels.mixed_precision.allspark import ( # noqa: E501
AllSparkLinearKernel)
from vllm.model_executor.layers.quantization.kernels.mixed_precision.exllama import ( # noqa: E501
ExllamaLinearKernel)
from vllm.model_executor.layers.quantization.kernels.mixed_precision.machete import ( # noqa: E501
@@ -16,6 +18,7 @@
# in priority/performance order (when available)
_POSSIBLE_KERNELS: List[Type[MPLinearKernel]] = [
MacheteLinearKernel,
AllSparkLinearKernel,
MarlinLinearKernel,
ExllamaLinearKernel,
]
Original file line number Diff line number Diff line change
@@ -0,0 +1,115 @@
# SPDX-License-Identifier: Apache-2.0

from typing import Optional, Tuple

import torch

from vllm import _custom_ops as ops
from vllm.model_executor.layers.quantization.utils import replace_parameter
from vllm.model_executor.layers.quantization.utils.allspark_utils import (
ALLSPARK_AMPERE_M_CUBLAS_THRESHOLD, check_allspark_supported_dtype_shape)
from vllm.model_executor.parameter import (BasevLLMParameter,
permute_param_layout_)

from .MPLinearKernel import MPLinearKernel, MPLinearLayerConfig


class AllSparkLinearKernel(MPLinearKernel):

@classmethod
def get_min_capability(cls) -> int:
return 80

@classmethod
def can_implement(cls,
c: MPLinearLayerConfig) -> Tuple[bool, Optional[str]]:
if c.has_g_idx:
return False, "Act reordering currently not supported by AllSpark"

if c.zero_points:
return False, "Zero points currently not supported by AllSpark"

return check_allspark_supported_dtype_shape(
c.partition_weight_shape[0], # in_features
c.partition_weight_shape[1], # out_features
c.group_size,
c.weight_type,
c.act_type)

# note assumes that
# `weight_packed` is: {input_dim = 0, output_dim = 1, packed_dim = 0}
# `weight_scale` is: {input_dim = 0, output_dim = 1}
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
device = getattr(layer, self.w_q_name).device
c = self.config

# prepare the parameters required for the kernel
properties = torch.cuda.get_device_properties(device.index)
sm_count = properties.multi_processor_count
sm_version = properties.major * 10 + properties.minor
gemm_args = {}
gemm_args['sm_count'] = sm_count
gemm_args['sm_version'] = sm_version

self.gemm_args = gemm_args

# transform param weight, scale
old_weight_param = getattr(layer, self.w_q_name)
old_scale_param = getattr(layer, self.w_s_name)

assert isinstance(old_weight_param, BasevLLMParameter)
permute_param_layout_(old_weight_param,
input_dim=0,
output_dim=1,
packed_dim=0)

assert isinstance(old_scale_param, BasevLLMParameter)
permute_param_layout_(old_scale_param, input_dim=0, output_dim=1)

# unpack weight from K / 4 x N int32 to K x N uint8
new_weight_param = torch.nn.Parameter(old_weight_param.data,
requires_grad=False)
new_weight_param.data = new_weight_param.data.t().contiguous().view(
dtype=torch.uint8)
new_weight_param.data = new_weight_param.data.t().contiguous()

new_scale_param = torch.nn.Parameter(old_scale_param.data,
requires_grad=False)

# reorder K x N weight as N32K16 format for Ampere W8A16
new_weight_param.data, new_scale_param.data, _ = \
ops.allspark_repack_weight(
new_weight_param.data, new_scale_param.data, None,
c.zero_points)

replace_parameter(layer, self.w_q_name, new_weight_param.data)
replace_parameter(layer, self.w_s_name, new_scale_param.data)

def apply_weights(self,
layer: torch.nn.Module,
x: torch.Tensor,
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
c = self.config
gemm_args = self.gemm_args
w_q, w_s, _, _ = self._get_weight_params(layer)

reshaped_x = x.reshape(-1, x.shape[-1])
out_shape = x.shape[:-1] + (c.partition_weight_shape[1], )

output = ops.allspark_w8a16_gemm(
a=reshaped_x,
b_qweight=w_q,
b_scales=w_s,
b_qzeros=None,
n=c.partition_weight_shape[1],
group_size=c.group_size,
sm_count=gemm_args['sm_count'],
sm_version=gemm_args['sm_version'],
CUBLAS_M_THRESHOLD=ALLSPARK_AMPERE_M_CUBLAS_THRESHOLD,
has_zp=c.zero_points,
n32k16_reorder=True)

if bias is not None:
output.add_(bias) # In-place add

return output.reshape(out_shape)
51 changes: 51 additions & 0 deletions vllm/model_executor/layers/quantization/utils/allspark_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
# SPDX-License-Identifier: Apache-2.0

import torch

from vllm.platforms import current_platform
from vllm.scalar_type import ScalarType, scalar_types

ALLSPARK_AMPERE_M_CUBLAS_THRESHOLD = 1024
ALLSPARK_SUPPORTED_QUANT_TYPES = [scalar_types.uint8b128]
ALLSPARK_AMPERE_N_ALIGN = 16
ALLSPARK_AMPERE_K_ALIGN = 16


def check_allspark_supported_dtype_shape(input_size_per_partition: int,
output_size_per_partition: int,
group_size: int,
weight_dtype: ScalarType,
act_dtype: torch.dtype):
capability_tuple = current_platform.get_device_capability()
device_capability = (-1 if capability_tuple is None else
capability_tuple.to_int())

# For Ampere GPU
if device_capability >= 80 and device_capability < 90:
if group_size != -1:
return False, \
"For Ampere GPU, AllSpark does not support group_size "\
f"= {group_size}. Only group_size = -1 are supported."

if weight_dtype not in ALLSPARK_SUPPORTED_QUANT_TYPES:
return False, "For Ampere GPU, AllSpark does not support "\
f"quant type ({weight_dtype}). Only quant type "\
f"({ALLSPARK_SUPPORTED_QUANT_TYPES}) are supported."

if input_size_per_partition % ALLSPARK_AMPERE_K_ALIGN != 0 \
or output_size_per_partition % ALLSPARK_AMPERE_N_ALIGN != 0:
return False, \
"AllSpark needs input_size_per_partition % "\
f"{ALLSPARK_AMPERE_K_ALIGN} = 0 and "\
f"output_size_per_partition % {ALLSPARK_AMPERE_N_ALIGN} = 0 "\
"for Ampere GPU optimized kernels."

if act_dtype != torch.float16 and act_dtype != torch.bfloat16:
return False, \
"AllSpark only supports act_dtype = float16 or bfloat16,"\
f"for Ampere GPU, but got act_dtype = {act_dtype}."
else:
return False, "AllSpark currently does not support "\
f"device_capability = {device_capability}."

return True, None