-
Notifications
You must be signed in to change notification settings - Fork 213
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add CUTLASS-based row-wise scaled sparse FP8 kernel #1671
Draft
alexsamardzic
wants to merge
1
commit into
pytorch:main
Choose a base branch
from
alexsamardzic:rowwise-scaled-sparse-fp8-cutlass
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Draft
Changes from all commits
Commits
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,188 @@ | ||
import itertools | ||
import random | ||
|
||
import pytest | ||
import torch | ||
from torch.testing._internal.common_cuda import SM90OrLater | ||
|
||
from torchao.dtypes import ( | ||
Float8Layout, | ||
to_affine_quantized_floatx, | ||
) | ||
from torchao.ops import ( | ||
rowwise_scaled_linear_sparse_cutlass_f8f8, | ||
to_sparse_semi_structured_cutlass_sm9x_f8, | ||
) | ||
|
||
|
||
X_W_DTYPES = [(torch.float16, torch.float16), (torch.bfloat16, torch.bfloat16)] | ||
XQ_WQ_DTYPES = [ | ||
(torch.float8_e5m2, torch.float8_e4m3fn), | ||
(torch.float8_e4m3fn, torch.float8_e4m3fn), | ||
] | ||
BATCH_SIZE = [1, 4, 8, 16, 32, 64] | ||
SIZE_MNK = [ | ||
(2, 512, 128), | ||
(3, 2048, 2048), | ||
(4, 3584, 640), | ||
(13, 8704, 8576), | ||
(26, 18944, 1664), | ||
(67, 6656, 1408), | ||
] | ||
USE_BIAS = [False, True] | ||
BIAS_DTYPE = [torch.float16] | ||
TEST_PARAMS = list( | ||
itertools.product( | ||
X_W_DTYPES, | ||
XQ_WQ_DTYPES, | ||
BATCH_SIZE, | ||
SIZE_MNK, | ||
USE_BIAS, | ||
BIAS_DTYPE, | ||
) | ||
) | ||
|
||
|
||
# FIXME: remove this! | ||
X_W_DTYPES = [(torch.float16, torch.float16)] | ||
XQ_WQ_DTYPES = [(torch.float8_e5m2, torch.float8_e4m3fn)] | ||
BATCH_SIZE = [1] | ||
SIZE_MNK = [(32, 64, 128)] | ||
USE_BIAS = [True] | ||
BIAS_DTYPE = [torch.float16] | ||
TEST_PARAMS = list( | ||
itertools.product( | ||
X_W_DTYPES, | ||
XQ_WQ_DTYPES, | ||
BATCH_SIZE, | ||
SIZE_MNK, | ||
USE_BIAS, | ||
BIAS_DTYPE, | ||
) | ||
) | ||
|
||
|
||
def rand_sparse_semi_structured(r, c, dtype, device, choice=None): | ||
pattern = "2by4" if dtype != torch.float32 else "1by2" | ||
if pattern == "1by2": | ||
ksparse = 2 | ||
choices = [[0, 1], [1, 0]] | ||
elif pattern == "2by4": | ||
ksparse = 4 | ||
choices = [ | ||
[1, 1, 0, 0], | ||
[1, 0, 1, 0], | ||
[1, 0, 0, 1], | ||
[0, 1, 1, 0], | ||
[0, 1, 0, 1], | ||
[0, 0, 1, 1], | ||
] | ||
assert c % ksparse == 0 | ||
mask_entries = [choice or random.choice(choices) for i in range(r * c // ksparse)] | ||
mask = torch.tensor(mask_entries, dtype=torch.bool).view(r, c).to(device) | ||
dense = torch.randn(r, c, dtype=dtype, device=device) | ||
dense[dense == 0] = 1 # To prevent zeros except where mask applied. | ||
dense = dense.masked_fill(~mask, 0) | ||
return dense | ||
|
||
|
||
def run_test_for_op( | ||
op, | ||
x_dtype, | ||
w_dtype, | ||
xq_dtype, | ||
wq_dtype, | ||
batch_size, | ||
size_mnk, | ||
use_bias, | ||
bias_dtype, | ||
): | ||
size_m, size_n, size_k = size_mnk | ||
|
||
x = torch.randn((batch_size, size_m, size_k), dtype=x_dtype, device="cuda") | ||
w = rand_sparse_semi_structured(size_n, size_k, dtype=w_dtype, device="cuda") | ||
bias = torch.rand((size_n,), dtype=bias_dtype, device="cuda") if use_bias else None | ||
|
||
block_size = [1] * (x.dim() - 1) + [x.shape[-1]] | ||
x_aqt = to_affine_quantized_floatx( | ||
input_float=x, | ||
target_dtype=xq_dtype, | ||
block_size=block_size, | ||
_layout=Float8Layout(mm_config=None), | ||
) | ||
xq, xq_scales, zero_points = x_aqt.tensor_impl.get_plain() | ||
assert zero_points is None | ||
|
||
block_size = [1] * (w.dim() - 1) + [w.shape[-1]] | ||
w_aqt = to_affine_quantized_floatx( | ||
input_float=w, | ||
target_dtype=wq_dtype, | ||
block_size=block_size, | ||
_layout=Float8Layout(mm_config=None), | ||
) | ||
wq, wq_scales, zero_points = w_aqt.tensor_impl.get_plain() | ||
assert zero_points is None | ||
wq_sp, wq_sp_meta = to_sparse_semi_structured_cutlass_sm9x_f8(wq) | ||
wq_sp_scales = wq_scales | ||
|
||
xq_2d = xq.view(-1, xq.shape[-1]) | ||
size_m_2d = xq_2d.shape[0] | ||
output_ref = ( | ||
(xq_2d.float() @ wq.float().T) | ||
* xq_scales.view(size_m_2d, 1) | ||
* wq_scales.view(1, size_n) | ||
) | ||
if bias is not None: | ||
output_ref += bias | ||
output_ref = output_ref.to(x.dtype).reshape(x.shape[:-1] + (size_n,)) | ||
|
||
fn_inputs = (xq, xq_scales, wq_sp, wq_sp_meta, wq_sp_scales, bias) | ||
try: | ||
output = op(*fn_inputs) | ||
except NotImplementedError: | ||
pytest.xfail("operator not implemented") | ||
|
||
# FIXME: remove this! | ||
d_ref = output_ref | ||
d = output | ||
print( | ||
f"Sum of relative errors d vs. d_ref : {torch.sum(torch.abs(d - d_ref) / torch.abs(d_ref)).item():8.2f}" | ||
) | ||
print() | ||
d_ref = d_ref.flatten().to(torch.float32) | ||
d = d.flatten().to(torch.float32) | ||
topk = 10 | ||
print(f"Top {topk} relative errors d vs. d_ref :") | ||
print(" d_ref d") | ||
print("------------+------------") | ||
values, indices = torch.topk(torch.abs(d - d_ref) / torch.abs(d_ref), topk) | ||
for index in indices: | ||
print(f"{d_ref[index].item():12.5e} {d[index].item():12.5e}") | ||
print() | ||
|
||
torch.testing.assert_close(output, output_ref, rtol=1e-2, atol=1e-3) | ||
|
||
|
||
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") | ||
@pytest.mark.skipif(not SM90OrLater, reason="FP8 is only supported on H100+ devices") | ||
@pytest.mark.parametrize( | ||
"x_w_dtypes, xq_wq_dtypes, batch_size, size_mnk, use_bias, bias_dtype", | ||
TEST_PARAMS, | ||
) | ||
def test_rowwise_scaled_liner_sparse_cutlass_f8f8( | ||
x_w_dtypes, | ||
xq_wq_dtypes, | ||
batch_size, | ||
size_mnk, | ||
use_bias, | ||
bias_dtype, | ||
): | ||
run_test_for_op( | ||
rowwise_scaled_linear_sparse_cutlass_f8f8, | ||
*x_w_dtypes, | ||
*xq_wq_dtypes, | ||
batch_size, | ||
size_mnk, | ||
use_bias, | ||
bias_dtype, | ||
) |
1 change: 1 addition & 0 deletions
1
torchao/csrc/cuda/rowwise_scaled_linear_cutlass/rowwise_scaled_linear_cutlass_s4s4.cu
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,3 +1,4 @@ | ||
#include <cutlass/cutlass.h> | ||
#include <torch/library.h> | ||
|
||
#include "rowwise_scaled_linear_cutlass.cuh" | ||
|
1 change: 1 addition & 0 deletions
1
torchao/csrc/cuda/rowwise_scaled_linear_cutlass/rowwise_scaled_linear_cutlass_s8s4.cu
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,3 +1,4 @@ | ||
#include <cutlass/cutlass.h> | ||
#include <torch/library.h> | ||
|
||
#include "rowwise_scaled_linear_cutlass.cuh" | ||
|
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
FYI: I think we have a variant of this in here:
ao/torchao/sparsity/utils.py
Line 26 in cc6244c