Skip to content

Commit

Permalink
Add CUTLASS-based row-wise scaled sparse FP8 kernel
Browse files Browse the repository at this point in the history
  • Loading branch information
alexsamardzic committed Feb 7, 2025
1 parent 867a91f commit e1c9bd0
Show file tree
Hide file tree
Showing 9 changed files with 985 additions and 1 deletion.
2 changes: 2 additions & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -256,13 +256,15 @@ def get_extensions():
if use_cuda and not IS_WINDOWS:
use_cutlass = True
cutlass_dir = os.path.join(third_party_path, "cutlass")
cutlass_util_include_dir = os.path.join(cutlass_dir, "tools", "util", "include")
cutlass_include_dir = os.path.join(cutlass_dir, "include")
cutlass_extensions_include_dir = os.path.join(cwd, extensions_cuda_dir)
if use_cutlass:
extra_compile_args["nvcc"].extend(
[
"-DTORCHAO_USE_CUTLASS",
"-I" + cutlass_include_dir,
"-I" + cutlass_util_include_dir,
"-I" + cutlass_extensions_include_dir,
]
)
Expand Down
188 changes: 188 additions & 0 deletions test/test_rowwise_scaled_linear_sparse_cutlass.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,188 @@
import itertools
import random

import pytest
import torch
from torch.testing._internal.common_cuda import SM90OrLater

from torchao.dtypes import (
Float8Layout,
to_affine_quantized_floatx,
)
from torchao.ops import (
rowwise_scaled_linear_sparse_cutlass_f8f8,
to_sparse_semi_structured_cutlass_sm9x_f8,
)


X_W_DTYPES = [(torch.float16, torch.float16), (torch.bfloat16, torch.bfloat16)]
XQ_WQ_DTYPES = [
(torch.float8_e5m2, torch.float8_e4m3fn),
(torch.float8_e4m3fn, torch.float8_e4m3fn),
]
BATCH_SIZE = [1, 4, 8, 16, 32, 64]
SIZE_MNK = [
(2, 512, 128),
(3, 2048, 2048),
(4, 3584, 640),
(13, 8704, 8576),
(26, 18944, 1664),
(67, 6656, 1408),
]
USE_BIAS = [False, True]
BIAS_DTYPE = [torch.float16]
TEST_PARAMS = list(
itertools.product(
X_W_DTYPES,
XQ_WQ_DTYPES,
BATCH_SIZE,
SIZE_MNK,
USE_BIAS,
BIAS_DTYPE,
)
)


# FIXME: remove this!
X_W_DTYPES = [(torch.float16, torch.float16)]
XQ_WQ_DTYPES = [(torch.float8_e5m2, torch.float8_e4m3fn)]
BATCH_SIZE = [1]
SIZE_MNK = [(32, 64, 128)]
USE_BIAS = [True]
BIAS_DTYPE = [torch.float16]
TEST_PARAMS = list(
itertools.product(
X_W_DTYPES,
XQ_WQ_DTYPES,
BATCH_SIZE,
SIZE_MNK,
USE_BIAS,
BIAS_DTYPE,
)
)


def rand_sparse_semi_structured(r, c, dtype, device, choice=None):
pattern = "2by4" if dtype != torch.float32 else "1by2"
if pattern == "1by2":
ksparse = 2
choices = [[0, 1], [1, 0]]
elif pattern == "2by4":
ksparse = 4
choices = [
[1, 1, 0, 0],
[1, 0, 1, 0],
[1, 0, 0, 1],
[0, 1, 1, 0],
[0, 1, 0, 1],
[0, 0, 1, 1],
]
assert c % ksparse == 0
mask_entries = [choice or random.choice(choices) for i in range(r * c // ksparse)]
mask = torch.tensor(mask_entries, dtype=torch.bool).view(r, c).to(device)
dense = torch.randn(r, c, dtype=dtype, device=device)
dense[dense == 0] = 1 # To prevent zeros except where mask applied.
dense = dense.masked_fill(~mask, 0)
return dense


def run_test_for_op(
op,
x_dtype,
w_dtype,
xq_dtype,
wq_dtype,
batch_size,
size_mnk,
use_bias,
bias_dtype,
):
size_m, size_n, size_k = size_mnk

x = torch.randn((batch_size, size_m, size_k), dtype=x_dtype, device="cuda")
w = rand_sparse_semi_structured(size_n, size_k, dtype=w_dtype, device="cuda")
bias = torch.rand((size_n,), dtype=bias_dtype, device="cuda") if use_bias else None

block_size = [1] * (x.dim() - 1) + [x.shape[-1]]
x_aqt = to_affine_quantized_floatx(
input_float=x,
target_dtype=xq_dtype,
block_size=block_size,
_layout=Float8Layout(mm_config=None),
)
xq, xq_scales, zero_points = x_aqt.tensor_impl.get_plain()
assert zero_points is None

block_size = [1] * (w.dim() - 1) + [w.shape[-1]]
w_aqt = to_affine_quantized_floatx(
input_float=w,
target_dtype=wq_dtype,
block_size=block_size,
_layout=Float8Layout(mm_config=None),
)
wq, wq_scales, zero_points = w_aqt.tensor_impl.get_plain()
assert zero_points is None
wq_sp, wq_sp_meta = to_sparse_semi_structured_cutlass_sm9x_f8(wq)
wq_sp_scales = wq_scales

xq_2d = xq.view(-1, xq.shape[-1])
size_m_2d = xq_2d.shape[0]
output_ref = (
(xq_2d.float() @ wq.float().T)
* xq_scales.view(size_m_2d, 1)
* wq_scales.view(1, size_n)
)
if bias is not None:
output_ref += bias
output_ref = output_ref.to(x.dtype).reshape(x.shape[:-1] + (size_n,))

fn_inputs = (xq, xq_scales, wq_sp, wq_sp_meta, wq_sp_scales, bias)
try:
output = op(*fn_inputs)
except NotImplementedError:
pytest.xfail("operator not implemented")

# FIXME: remove this!
d_ref = output_ref
d = output
print(
f"Sum of relative errors d vs. d_ref : {torch.sum(torch.abs(d - d_ref) / torch.abs(d_ref)).item():8.2f}"
)
print()
d_ref = d_ref.flatten().to(torch.float32)
d = d.flatten().to(torch.float32)
topk = 10
print(f"Top {topk} relative errors d vs. d_ref :")
print(" d_ref d")
print("------------+------------")
values, indices = torch.topk(torch.abs(d - d_ref) / torch.abs(d_ref), topk)
for index in indices:
print(f"{d_ref[index].item():12.5e} {d[index].item():12.5e}")
print()

torch.testing.assert_close(output, output_ref, rtol=1e-2, atol=1e-3)


@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
@pytest.mark.skipif(not SM90OrLater, reason="FP8 is only supported on H100+ devices")
@pytest.mark.parametrize(
"x_w_dtypes, xq_wq_dtypes, batch_size, size_mnk, use_bias, bias_dtype",
TEST_PARAMS,
)
def test_rowwise_scaled_liner_sparse_cutlass_f8f8(
x_w_dtypes,
xq_wq_dtypes,
batch_size,
size_mnk,
use_bias,
bias_dtype,
):
run_test_for_op(
rowwise_scaled_linear_sparse_cutlass_f8f8,
*x_w_dtypes,
*xq_wq_dtypes,
batch_size,
size_mnk,
use_bias,
bias_dtype,
)
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
#include <cutlass/cutlass.h>
#include <torch/library.h>

#include "rowwise_scaled_linear_cutlass.cuh"
Expand Down
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
#include <cutlass/cutlass.h>
#include <torch/library.h>

#include "rowwise_scaled_linear_cutlass.cuh"
Expand Down
Loading

0 comments on commit e1c9bd0

Please sign in to comment.