Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add CUTLASS-based row-wise scaled sparse FP8 kernel #1671

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -256,13 +256,15 @@ def get_extensions():
if use_cuda and not IS_WINDOWS:
use_cutlass = True
cutlass_dir = os.path.join(third_party_path, "cutlass")
cutlass_util_include_dir = os.path.join(cutlass_dir, "tools", "util", "include")
cutlass_include_dir = os.path.join(cutlass_dir, "include")
cutlass_extensions_include_dir = os.path.join(cwd, extensions_cuda_dir)
if use_cutlass:
extra_compile_args["nvcc"].extend(
[
"-DTORCHAO_USE_CUTLASS",
"-I" + cutlass_include_dir,
"-I" + cutlass_util_include_dir,
"-I" + cutlass_extensions_include_dir,
]
)
Expand Down
188 changes: 188 additions & 0 deletions test/test_rowwise_scaled_linear_sparse_cutlass.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,188 @@
import itertools
import random

import pytest
import torch
from torch.testing._internal.common_cuda import SM90OrLater

from torchao.dtypes import (
Float8Layout,
to_affine_quantized_floatx,
)
from torchao.ops import (
rowwise_scaled_linear_sparse_cutlass_f8f8,
to_sparse_semi_structured_cutlass_sm9x_f8,
)


X_W_DTYPES = [(torch.float16, torch.float16), (torch.bfloat16, torch.bfloat16)]
XQ_WQ_DTYPES = [
(torch.float8_e5m2, torch.float8_e4m3fn),
(torch.float8_e4m3fn, torch.float8_e4m3fn),
]
BATCH_SIZE = [1, 4, 8, 16, 32, 64]
SIZE_MNK = [
(2, 512, 128),
(3, 2048, 2048),
(4, 3584, 640),
(13, 8704, 8576),
(26, 18944, 1664),
(67, 6656, 1408),
]
USE_BIAS = [False, True]
BIAS_DTYPE = [torch.float16]
TEST_PARAMS = list(
itertools.product(
X_W_DTYPES,
XQ_WQ_DTYPES,
BATCH_SIZE,
SIZE_MNK,
USE_BIAS,
BIAS_DTYPE,
)
)


# FIXME: remove this!
X_W_DTYPES = [(torch.float16, torch.float16)]
XQ_WQ_DTYPES = [(torch.float8_e5m2, torch.float8_e4m3fn)]
BATCH_SIZE = [1]
SIZE_MNK = [(32, 64, 128)]
USE_BIAS = [True]
BIAS_DTYPE = [torch.float16]
TEST_PARAMS = list(
itertools.product(
X_W_DTYPES,
XQ_WQ_DTYPES,
BATCH_SIZE,
SIZE_MNK,
USE_BIAS,
BIAS_DTYPE,
)
)


def rand_sparse_semi_structured(r, c, dtype, device, choice=None):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

FYI: I think we have a variant of this in here:

def create_semi_structured_tensor(r, c, dtype):

pattern = "2by4" if dtype != torch.float32 else "1by2"
if pattern == "1by2":
ksparse = 2
choices = [[0, 1], [1, 0]]
elif pattern == "2by4":
ksparse = 4
choices = [
[1, 1, 0, 0],
[1, 0, 1, 0],
[1, 0, 0, 1],
[0, 1, 1, 0],
[0, 1, 0, 1],
[0, 0, 1, 1],
]
assert c % ksparse == 0
mask_entries = [choice or random.choice(choices) for i in range(r * c // ksparse)]
mask = torch.tensor(mask_entries, dtype=torch.bool).view(r, c).to(device)
dense = torch.randn(r, c, dtype=dtype, device=device)
dense[dense == 0] = 1 # To prevent zeros except where mask applied.
dense = dense.masked_fill(~mask, 0)
return dense


def run_test_for_op(
op,
x_dtype,
w_dtype,
xq_dtype,
wq_dtype,
batch_size,
size_mnk,
use_bias,
bias_dtype,
):
size_m, size_n, size_k = size_mnk

x = torch.randn((batch_size, size_m, size_k), dtype=x_dtype, device="cuda")
w = rand_sparse_semi_structured(size_n, size_k, dtype=w_dtype, device="cuda")
bias = torch.rand((size_n,), dtype=bias_dtype, device="cuda") if use_bias else None

block_size = [1] * (x.dim() - 1) + [x.shape[-1]]
x_aqt = to_affine_quantized_floatx(
input_float=x,
target_dtype=xq_dtype,
block_size=block_size,
_layout=Float8Layout(mm_config=None),
)
xq, xq_scales, zero_points = x_aqt.tensor_impl.get_plain()
assert zero_points is None

block_size = [1] * (w.dim() - 1) + [w.shape[-1]]
w_aqt = to_affine_quantized_floatx(
input_float=w,
target_dtype=wq_dtype,
block_size=block_size,
_layout=Float8Layout(mm_config=None),
)
wq, wq_scales, zero_points = w_aqt.tensor_impl.get_plain()
assert zero_points is None
wq_sp, wq_sp_meta = to_sparse_semi_structured_cutlass_sm9x_f8(wq)
wq_sp_scales = wq_scales

xq_2d = xq.view(-1, xq.shape[-1])
size_m_2d = xq_2d.shape[0]
output_ref = (
(xq_2d.float() @ wq.float().T)
* xq_scales.view(size_m_2d, 1)
* wq_scales.view(1, size_n)
)
if bias is not None:
output_ref += bias
output_ref = output_ref.to(x.dtype).reshape(x.shape[:-1] + (size_n,))

fn_inputs = (xq, xq_scales, wq_sp, wq_sp_meta, wq_sp_scales, bias)
try:
output = op(*fn_inputs)
except NotImplementedError:
pytest.xfail("operator not implemented")

# FIXME: remove this!
d_ref = output_ref
d = output
print(
f"Sum of relative errors d vs. d_ref : {torch.sum(torch.abs(d - d_ref) / torch.abs(d_ref)).item():8.2f}"
)
print()
d_ref = d_ref.flatten().to(torch.float32)
d = d.flatten().to(torch.float32)
topk = 10
print(f"Top {topk} relative errors d vs. d_ref :")
print(" d_ref d")
print("------------+------------")
values, indices = torch.topk(torch.abs(d - d_ref) / torch.abs(d_ref), topk)
for index in indices:
print(f"{d_ref[index].item():12.5e} {d[index].item():12.5e}")
print()

torch.testing.assert_close(output, output_ref, rtol=1e-2, atol=1e-3)


@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
@pytest.mark.skipif(not SM90OrLater, reason="FP8 is only supported on H100+ devices")
@pytest.mark.parametrize(
"x_w_dtypes, xq_wq_dtypes, batch_size, size_mnk, use_bias, bias_dtype",
TEST_PARAMS,
)
def test_rowwise_scaled_liner_sparse_cutlass_f8f8(
x_w_dtypes,
xq_wq_dtypes,
batch_size,
size_mnk,
use_bias,
bias_dtype,
):
run_test_for_op(
rowwise_scaled_linear_sparse_cutlass_f8f8,
*x_w_dtypes,
*xq_wq_dtypes,
batch_size,
size_mnk,
use_bias,
bias_dtype,
)
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
#include <cutlass/cutlass.h>
#include <torch/library.h>

#include "rowwise_scaled_linear_cutlass.cuh"
Expand Down
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
#include <cutlass/cutlass.h>
#include <torch/library.h>

#include "rowwise_scaled_linear_cutlass.cuh"
Expand Down
Loading
Loading