From e1c9bd0c1c34047e743cbeaf30b5bf0b649560fb Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Aleksandar=20Samard=C5=BEi=C4=87?= Date: Tue, 4 Feb 2025 21:01:48 +0100 Subject: [PATCH] Add CUTLASS-based row-wise scaled sparse FP8 kernel --- setup.py | 2 + ...st_rowwise_scaled_linear_sparse_cutlass.py | 188 +++++++ .../rowwise_scaled_linear_cutlass_s4s4.cu | 1 + .../rowwise_scaled_linear_cutlass_s8s4.cu | 1 + .../rowwise_scaled_linear_sparse_cutlass.cuh | 481 ++++++++++++++++++ ...wwise_scaled_linear_sparse_cutlass_f8f8.cu | 42 ++ ...to_sparse_semi_structured_cutlass_sm9x.cuh | 161 ++++++ ..._sparse_semi_structured_cutlass_sm9x_f8.cu | 32 ++ torchao/ops.py | 78 ++- 9 files changed, 985 insertions(+), 1 deletion(-) create mode 100644 test/test_rowwise_scaled_linear_sparse_cutlass.py create mode 100644 torchao/csrc/cuda/rowwise_scaled_linear_sparse_cutlass/rowwise_scaled_linear_sparse_cutlass.cuh create mode 100644 torchao/csrc/cuda/rowwise_scaled_linear_sparse_cutlass/rowwise_scaled_linear_sparse_cutlass_f8f8.cu create mode 100644 torchao/csrc/cuda/to_sparse_semi_structured_cutlass_sm9x/to_sparse_semi_structured_cutlass_sm9x.cuh create mode 100644 torchao/csrc/cuda/to_sparse_semi_structured_cutlass_sm9x/to_sparse_semi_structured_cutlass_sm9x_f8.cu diff --git a/setup.py b/setup.py index 67a8d2e576..508cba525b 100644 --- a/setup.py +++ b/setup.py @@ -256,6 +256,7 @@ def get_extensions(): if use_cuda and not IS_WINDOWS: use_cutlass = True cutlass_dir = os.path.join(third_party_path, "cutlass") + cutlass_util_include_dir = os.path.join(cutlass_dir, "tools", "util", "include") cutlass_include_dir = os.path.join(cutlass_dir, "include") cutlass_extensions_include_dir = os.path.join(cwd, extensions_cuda_dir) if use_cutlass: @@ -263,6 +264,7 @@ def get_extensions(): [ "-DTORCHAO_USE_CUTLASS", "-I" + cutlass_include_dir, + "-I" + cutlass_util_include_dir, "-I" + cutlass_extensions_include_dir, ] ) diff --git a/test/test_rowwise_scaled_linear_sparse_cutlass.py b/test/test_rowwise_scaled_linear_sparse_cutlass.py new file mode 100644 index 0000000000..324a332eb0 --- /dev/null +++ b/test/test_rowwise_scaled_linear_sparse_cutlass.py @@ -0,0 +1,188 @@ +import itertools +import random + +import pytest +import torch +from torch.testing._internal.common_cuda import SM90OrLater + +from torchao.dtypes import ( + Float8Layout, + to_affine_quantized_floatx, +) +from torchao.ops import ( + rowwise_scaled_linear_sparse_cutlass_f8f8, + to_sparse_semi_structured_cutlass_sm9x_f8, +) + + +X_W_DTYPES = [(torch.float16, torch.float16), (torch.bfloat16, torch.bfloat16)] +XQ_WQ_DTYPES = [ + (torch.float8_e5m2, torch.float8_e4m3fn), + (torch.float8_e4m3fn, torch.float8_e4m3fn), +] +BATCH_SIZE = [1, 4, 8, 16, 32, 64] +SIZE_MNK = [ + (2, 512, 128), + (3, 2048, 2048), + (4, 3584, 640), + (13, 8704, 8576), + (26, 18944, 1664), + (67, 6656, 1408), +] +USE_BIAS = [False, True] +BIAS_DTYPE = [torch.float16] +TEST_PARAMS = list( + itertools.product( + X_W_DTYPES, + XQ_WQ_DTYPES, + BATCH_SIZE, + SIZE_MNK, + USE_BIAS, + BIAS_DTYPE, + ) +) + + +# FIXME: remove this! +X_W_DTYPES = [(torch.float16, torch.float16)] +XQ_WQ_DTYPES = [(torch.float8_e5m2, torch.float8_e4m3fn)] +BATCH_SIZE = [1] +SIZE_MNK = [(32, 64, 128)] +USE_BIAS = [True] +BIAS_DTYPE = [torch.float16] +TEST_PARAMS = list( + itertools.product( + X_W_DTYPES, + XQ_WQ_DTYPES, + BATCH_SIZE, + SIZE_MNK, + USE_BIAS, + BIAS_DTYPE, + ) +) + + +def rand_sparse_semi_structured(r, c, dtype, device, choice=None): + pattern = "2by4" if dtype != torch.float32 else "1by2" + if pattern == "1by2": + ksparse = 2 + choices = [[0, 1], [1, 0]] + elif pattern == "2by4": + ksparse = 4 + choices = [ + [1, 1, 0, 0], + [1, 0, 1, 0], + [1, 0, 0, 1], + [0, 1, 1, 0], + [0, 1, 0, 1], + [0, 0, 1, 1], + ] + assert c % ksparse == 0 + mask_entries = [choice or random.choice(choices) for i in range(r * c // ksparse)] + mask = torch.tensor(mask_entries, dtype=torch.bool).view(r, c).to(device) + dense = torch.randn(r, c, dtype=dtype, device=device) + dense[dense == 0] = 1 # To prevent zeros except where mask applied. + dense = dense.masked_fill(~mask, 0) + return dense + + +def run_test_for_op( + op, + x_dtype, + w_dtype, + xq_dtype, + wq_dtype, + batch_size, + size_mnk, + use_bias, + bias_dtype, +): + size_m, size_n, size_k = size_mnk + + x = torch.randn((batch_size, size_m, size_k), dtype=x_dtype, device="cuda") + w = rand_sparse_semi_structured(size_n, size_k, dtype=w_dtype, device="cuda") + bias = torch.rand((size_n,), dtype=bias_dtype, device="cuda") if use_bias else None + + block_size = [1] * (x.dim() - 1) + [x.shape[-1]] + x_aqt = to_affine_quantized_floatx( + input_float=x, + target_dtype=xq_dtype, + block_size=block_size, + _layout=Float8Layout(mm_config=None), + ) + xq, xq_scales, zero_points = x_aqt.tensor_impl.get_plain() + assert zero_points is None + + block_size = [1] * (w.dim() - 1) + [w.shape[-1]] + w_aqt = to_affine_quantized_floatx( + input_float=w, + target_dtype=wq_dtype, + block_size=block_size, + _layout=Float8Layout(mm_config=None), + ) + wq, wq_scales, zero_points = w_aqt.tensor_impl.get_plain() + assert zero_points is None + wq_sp, wq_sp_meta = to_sparse_semi_structured_cutlass_sm9x_f8(wq) + wq_sp_scales = wq_scales + + xq_2d = xq.view(-1, xq.shape[-1]) + size_m_2d = xq_2d.shape[0] + output_ref = ( + (xq_2d.float() @ wq.float().T) + * xq_scales.view(size_m_2d, 1) + * wq_scales.view(1, size_n) + ) + if bias is not None: + output_ref += bias + output_ref = output_ref.to(x.dtype).reshape(x.shape[:-1] + (size_n,)) + + fn_inputs = (xq, xq_scales, wq_sp, wq_sp_meta, wq_sp_scales, bias) + try: + output = op(*fn_inputs) + except NotImplementedError: + pytest.xfail("operator not implemented") + + # FIXME: remove this! + d_ref = output_ref + d = output + print( + f"Sum of relative errors d vs. d_ref : {torch.sum(torch.abs(d - d_ref) / torch.abs(d_ref)).item():8.2f}" + ) + print() + d_ref = d_ref.flatten().to(torch.float32) + d = d.flatten().to(torch.float32) + topk = 10 + print(f"Top {topk} relative errors d vs. d_ref :") + print(" d_ref d") + print("------------+------------") + values, indices = torch.topk(torch.abs(d - d_ref) / torch.abs(d_ref), topk) + for index in indices: + print(f"{d_ref[index].item():12.5e} {d[index].item():12.5e}") + print() + + torch.testing.assert_close(output, output_ref, rtol=1e-2, atol=1e-3) + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") +@pytest.mark.skipif(not SM90OrLater, reason="FP8 is only supported on H100+ devices") +@pytest.mark.parametrize( + "x_w_dtypes, xq_wq_dtypes, batch_size, size_mnk, use_bias, bias_dtype", + TEST_PARAMS, +) +def test_rowwise_scaled_liner_sparse_cutlass_f8f8( + x_w_dtypes, + xq_wq_dtypes, + batch_size, + size_mnk, + use_bias, + bias_dtype, +): + run_test_for_op( + rowwise_scaled_linear_sparse_cutlass_f8f8, + *x_w_dtypes, + *xq_wq_dtypes, + batch_size, + size_mnk, + use_bias, + bias_dtype, + ) diff --git a/torchao/csrc/cuda/rowwise_scaled_linear_cutlass/rowwise_scaled_linear_cutlass_s4s4.cu b/torchao/csrc/cuda/rowwise_scaled_linear_cutlass/rowwise_scaled_linear_cutlass_s4s4.cu index e455b7bdf2..09a4a7d7fe 100644 --- a/torchao/csrc/cuda/rowwise_scaled_linear_cutlass/rowwise_scaled_linear_cutlass_s4s4.cu +++ b/torchao/csrc/cuda/rowwise_scaled_linear_cutlass/rowwise_scaled_linear_cutlass_s4s4.cu @@ -1,3 +1,4 @@ +#include #include #include "rowwise_scaled_linear_cutlass.cuh" diff --git a/torchao/csrc/cuda/rowwise_scaled_linear_cutlass/rowwise_scaled_linear_cutlass_s8s4.cu b/torchao/csrc/cuda/rowwise_scaled_linear_cutlass/rowwise_scaled_linear_cutlass_s8s4.cu index 680822ca7f..fc1b2951c7 100644 --- a/torchao/csrc/cuda/rowwise_scaled_linear_cutlass/rowwise_scaled_linear_cutlass_s8s4.cu +++ b/torchao/csrc/cuda/rowwise_scaled_linear_cutlass/rowwise_scaled_linear_cutlass_s8s4.cu @@ -1,3 +1,4 @@ +#include #include #include "rowwise_scaled_linear_cutlass.cuh" diff --git a/torchao/csrc/cuda/rowwise_scaled_linear_sparse_cutlass/rowwise_scaled_linear_sparse_cutlass.cuh b/torchao/csrc/cuda/rowwise_scaled_linear_sparse_cutlass/rowwise_scaled_linear_sparse_cutlass.cuh new file mode 100644 index 0000000000..a8afc8f33b --- /dev/null +++ b/torchao/csrc/cuda/rowwise_scaled_linear_sparse_cutlass/rowwise_scaled_linear_sparse_cutlass.cuh @@ -0,0 +1,481 @@ +#pragma once + +#include +#include +#include +#include + +#if defined(TORCHAO_USE_CUTLASS) && !defined(_WIN32) && \ + defined(CUDA_VERSION) && (CUDA_VERSION >= 12020) +#define BUILD_ROWWISE_SCALED_LINEAR_SPARSE_CUTLASS +#endif + +#if defined(BUILD_ROWWISE_SCALED_LINEAR_SPARSE_CUTLASS) +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "cutlass_extensions/common.h" +#endif + +#define OPERATOR_NAME "rowwise_scaled_linear_sparse_cutlass" + +namespace torchao { + +#if defined(BUILD_ROWWISE_SCALED_LINEAR_SPARSE_CUTLASS) +template< + typename DtypeXq, + typename DtypeWq, + typename DtypeY, + typename UseBias, + typename DtypeBias, + typename DtypeXScale, + typename DtypeWScale, + typename TileShape, + typename ClusterShape> +void rowwise_scaled_linear_sparse_kernel_cutlass_sm9x( + const at::Tensor& Xq, const at::Tensor& X_scale, const at::Tensor& Wq, + const at::Tensor& W_meta, const at::Tensor& W_scale, const at::Tensor& bias, + at::Tensor& Y) { + // For CUTLASS, sparsified tensor must be the first operand, thus + // the result will be calculated as: + // ((Wq @ Xq.T) * W_scale * X_scale.T + bias.T).T + + using SmArch = cutlass::arch::Sm90; + + // Use CUTLASS naming conventions for naming datatypes. + using ElementA = DtypeWq; + using ElementB = DtypeXq; + using ElementD = DtypeY; + using ElementAScale = DtypeWScale; + using ElementBScale = DtypeXScale; + using ElementBias = DtypeBias; + + using LayoutTagA = cutlass::layout::RowMajor; + using LayoutTagB = cutlass::layout::ColumnMajor; + using LayoutTagD = cutlass::layout::ColumnMajor; + + constexpr auto AlignmentA = 128 / cutlass::sizeof_bits::value; + constexpr auto AlignmentB = 128 / cutlass::sizeof_bits::value; + constexpr auto AlignmentD = 128 / cutlass::sizeof_bits::value; + + // TODO: use different accumulator datatype if inputs are not float. + using ElementAccumulator = float; + using ElementCompute = float; + + using ProblemShape = cute::Shape; + + // TODO: consider KernelTmaWarpSpecializedPingpongFP8FastAccum. + using KernelSchedule = cutlass::gemm::KernelTmaWarpSpecializedPingpong; + using EpilogueSchedule = cutlass::epilogue::TmaWarpSpecialized; + + constexpr auto RoundStyle = cutlass::FloatRoundStyle::round_to_nearest; + using Accum = cutlass::epilogue::fusion::Sm90AccFetch; + using AScale = + cutlass::epilogue::fusion::Sm90ColBroadcast<0, TileShape, ElementAScale>; + using ApplyAScale = cutlass::epilogue::fusion::Sm90EVT< + cutlass::epilogue::fusion::Sm90Compute< + cutlass::multiplies, ElementCompute, ElementCompute, RoundStyle>, + Accum, + AScale>; + using BScale = + cutlass::epilogue::fusion::Sm90RowBroadcast<0, TileShape, ElementBScale>; + using ApplyBScale = cutlass::epilogue::fusion::Sm90EVT< + cutlass::epilogue::fusion::Sm90Compute< + cutlass::multiplies, ElementCompute, ElementCompute, RoundStyle>, + ApplyAScale, + BScale>; + using BiasScalar = + cutlass::epilogue::fusion::Sm90ScalarBroadcast; + using BiasTensor = + cutlass::epilogue::fusion::Sm90ColBroadcast<0, TileShape, ElementBias>; + using Bias = std::conditional_t; + using ApplyBias = cutlass::epilogue::fusion::Sm90EVT< + cutlass::epilogue::fusion::Sm90Compute< + cutlass::plus, ElementCompute, ElementCompute, RoundStyle>, + ApplyBScale, + Bias>; + using EVT = ApplyBias; + + using CollectiveEpilogue = + typename cutlass::epilogue::collective::CollectiveBuilder< + SmArch, cutlass::arch::OpClassSparseTensorOp, + TileShape, ClusterShape, + cutlass::epilogue::collective::EpilogueTileAuto, + ElementAccumulator, ElementCompute, + ElementD, LayoutTagD, AlignmentD, + ElementD, LayoutTagD, AlignmentD, + EpilogueSchedule, + EVT>::CollectiveOp; + using CollectiveMainloop = + typename cutlass::gemm::collective::CollectiveBuilder< + SmArch, cutlass::arch::OpClassSparseTensorOp, + ElementA, LayoutTagA, AlignmentA, + ElementB, LayoutTagB, AlignmentB, + ElementAccumulator, + TileShape, ClusterShape, + cutlass::gemm::collective::StageCountAutoCarveout< + static_cast(sizeof(typename CollectiveEpilogue::SharedStorage))>, + KernelSchedule>::CollectiveOp; + using GemmKernel = enable_3x_kernel_for_sm90_or_later< + cutlass::gemm::kernel::GemmUniversal< + ProblemShape, + CollectiveMainloop, + CollectiveEpilogue>>; + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + + + using StrideA = cutlass::gemm::TagToStrideA_t; + using StrideB = typename Gemm::GemmKernel::StrideB; + using StrideD = typename Gemm::GemmKernel::StrideD; + using StrideE = StrideA; + using ElementE = typename Gemm::GemmKernel::CollectiveMainloop::ElementE; + using SparseConfig = + typename Gemm::GemmKernel::CollectiveMainloop::SparseConfig; + + const int m = Wq.size(0); + const int n = Xq.size(0); + const int k = Xq.size(1); + + // FIXME: validate these checks. + /* + // Check for current CUTLASS limitations w.r.t. alignments. + TORCH_CHECK(k % AlignmentA == 0, OPERATOR_NAME, + " : Number of columns of tensor A must be divisible by ", + AlignmentA); + TORCH_CHECK(k % AlignmentB == 0, OPERATOR_NAME, + " : Number of columns of tensor B must be divisible by ", + AlignmentB); + TORCH_CHECK(n % AlignmentD == 0, OPERATOR_NAME, + " : Number of columns of tensor Y must be divisible by ", + AlignmentD); + */ + + ProblemShape problem_shape(m, n, k, 1); + const auto layout_A = SparseConfig::fill_layoutA(problem_shape); + const auto layout_E = SparseConfig::fill_layoutE(problem_shape); + const auto stride_B = + cutlass::make_cute_packed_stride(StrideB{}, cute::make_shape(n, k, 1)); + const auto stride_D = + cutlass::make_cute_packed_stride(StrideD{}, cute::make_shape(m, n, 1)); + + typename Gemm::Arguments arguments{ + cutlass::gemm::GemmUniversalMode::kGemm, + problem_shape, + { + (ElementA*)Wq.data_ptr(), layout_A, (ElementB*)Xq.data_ptr(), stride_B, + (ElementE*)W_meta.data_ptr(), layout_E + }, + { + {}, + (ElementD*)Y.data_ptr(), stride_D, (ElementD*)Y.data_ptr(), stride_D + } + }; + + const typename AScale::Arguments A_scale_arguments{ + (ElementAScale*)W_scale.data_ptr(), + ElementAScale(1), + {cute::_1{}, cute::_0{}, cute::_0{}} + }; + const typename BScale::Arguments B_scale_arguments{ + (ElementBScale*)X_scale.data_ptr(), + ElementBScale(1), + {cute::_0{}, cute::_1{}, cute::_0{}} + }; + const auto bias_arguments{ + [&]() -> typename Bias::Arguments { + if constexpr (UseBias::value) { + return { + (ElementBias*)bias.data_ptr(), + ElementBias(0), + {cute::_1{}, cute::_0{}, cute::_0{}} + }; + } else { + return {ElementBias(0)}; + } + }() + }; + arguments.epilogue.thread = { + { + { + {}, // Accum + A_scale_arguments, // AScale + {} // ApplyAScale + }, + B_scale_arguments, // TensorBScale + {}, // ApplyBScale + }, + bias_arguments, // Bias + {} // ApplyBiass + }; + + Gemm gemm_op; + + cutlass::Status status; + + // Verify that GEMM operation with given arguments can be performed + // by CUTLASS. + status = gemm_op.can_implement(arguments); + CUTLASS_STATUS_CHECK(status, OPERATOR_NAME); + + // Allocate workspace for CUTLASS mixed datatypes GEMM kernel. + const auto workspace_size = Gemm::get_workspace_size(arguments); + auto workspace = Xq.new_empty({(int64_t)workspace_size}, + at::TensorOptions().dtype(at::kByte)); + + // Initialize CUTLASS mixed datatypes GEMM object. + status = gemm_op.initialize(arguments, workspace.data_ptr(), + at::cuda::getCurrentCUDAStream()); + CUTLASS_STATUS_CHECK(status, OPERATOR_NAME); + + // Perform mixed datatypes GEMM operation. + status = gemm_op.run(at::cuda::getCurrentCUDAStream()); + CUTLASS_STATUS_CHECK(status, OPERATOR_NAME); + + C10_CUDA_KERNEL_LAUNCH_CHECK(); +} + +template +static void select_config( + const at::Tensor& Xq, const at::Tensor& X_scale, const at::Tensor& Wq, + const at::Tensor& W_meta, const at::Tensor& W_scale, const at::Tensor& bias, + at::Tensor& Y) { + const auto dprops = at::cuda::getCurrentDeviceProperties(); + const auto is_sm9x = dprops->major == 9; + + if (is_sm9x) { + if constexpr ((std::is_same::value && + std::is_same::value) || + (std::is_same::value && + std::is_same::value)) { + // TODO: add some tuning + using TileShape = cute::Shape; + using ClusterShape = cute::Shape; + rowwise_scaled_linear_sparse_kernel_cutlass_sm9x< + DtypeXq, DtypeWq, Types..., TileShape, ClusterShape>( + Xq, X_scale, Wq, W_meta, W_scale, bias, Y); + return; + } + } + + TORCH_CHECK(false, OPERATOR_NAME, + " : Operator not supported on SM", dprops->major, ".", + dprops->minor, " for given operands"); +} + +template +static void +dispatch_on_bias( + const at::Tensor& Xq, const at::Tensor& X_scale, const at::Tensor& Wq, + const at::Tensor& W_meta, const at::Tensor& W_scale, const at::Tensor& bias, + at::Tensor& Y) { + if (bias.numel() == 0) { + using UseBias = std::false_type; + using DtypeBias = DtypeY; + select_config( + Xq, X_scale, Wq, W_meta, W_scale, bias, Y); + return; + } + + using UseBias = std::true_type; + if (bias.scalar_type() == at::ScalarType::Half) { + using DtypeBias = cutlass::half_t; + select_config( + Xq, X_scale, Wq, W_meta, W_scale, bias, Y); + return; + } else if (bias.scalar_type() == at::ScalarType::BFloat16) { + using DtypeBias = cutlass::bfloat16_t; + select_config( + Xq, X_scale, Wq, W_meta, W_scale, bias, Y); + return; + } + + TORCH_CHECK(false, OPERATOR_NAME, + " : Operator not supported for datatype ", bias.scalar_type(), + " for bias"); +} + +template + static void +dispatch_on_X_scale_and_W_scale( + const at::Tensor& Xq, const at::Tensor& X_scale, const at::Tensor& Wq, + const at::Tensor& W_meta, const at::Tensor& W_scale, const at::Tensor& bias, + at::Tensor& Y) { + TORCH_CHECK(Y.scalar_type() == X_scale.scalar_type(), + OPERATOR_NAME, " : Operator not supported for Y datatype ", + Y.scalar_type(), " as it's different from the first ", + " operand scale datatype ", X_scale.scalar_type()); + + if (X_scale.scalar_type() == at::ScalarType::Half && + W_scale.scalar_type() == at::ScalarType::Half) { + using DtypeXScale = cutlass::half_t; + using DtypeWScale = cutlass::half_t; + using DtypeY = cutlass::half_t; + dispatch_on_bias(Xq, X_scale, Wq, W_meta, W_scale, bias, Y); + return; + } else if (X_scale.scalar_type() == at::ScalarType::BFloat16 && + W_scale.scalar_type() == at::ScalarType::BFloat16) { + using DtypeXScale = cutlass::bfloat16_t; + using DtypeWScale = cutlass::bfloat16_t; + using DtypeY = cutlass::bfloat16_t; + dispatch_on_bias(Xq, X_scale, Wq, W_meta, W_scale, bias, Y); + return; + } + + TORCH_CHECK(false, OPERATOR_NAME, + " : Operator not supported for combination of datatypes ", + X_scale.scalar_type(), " for first operand scale and ", + W_scale.scalar_type(), " for second operand scale"); +} + +template +void +rowwise_scaled_linear_sparse_cutlass_check_inputs( + const at::Tensor& Xq, const at::Tensor& X_scale, const at::Tensor& Wq, + const at::Tensor& W_meta, const at::Tensor& W_scale, + const at::Tensor& bias) { + // Validate metadata datatype. + TORCH_CHECK(W_meta.dtype() == at::kByte, OPERATOR_NAME, + " : Expected Wq meta argument to be of torch.uint8 datatype got ", + Wq.dtype()); + + // Validate layouts of arguments. + TORCH_CHECK(Xq.dim() >= 2, OPERATOR_NAME, + " : Expected Xq argument to be 2D or higher-dimensional tensor, " + " got ", Xq.dim(), " dims"); + TORCH_CHECK(Xq.layout() == at::Layout::Strided, OPERATOR_NAME, + " : Expected Xq argument to be strided, got layout ", + Xq.layout()); + TORCH_CHECK(X_scale.dim() == Xq.dim() - 1, OPERATOR_NAME, + " : Expected Xq scale argument to be ", Xq.dim() - 1, + "D tensor, got ", X_scale.dim(), " dims"); + TORCH_CHECK(X_scale.layout() == at::Layout::Strided, OPERATOR_NAME, + " : Expected Xq scale argument to be strided, got layout ", + X_scale.layout()); + TORCH_CHECK(Wq.dim() == 2, OPERATOR_NAME, + " : Expected Wq argument to be 2D tensor, got ", Wq.dim(), + " dims"); + TORCH_CHECK(Wq.layout() == at::Layout::Strided, OPERATOR_NAME, + " : Expected Wq argument to be strided, got layout ", + Wq.layout()); + TORCH_CHECK(W_meta.dim() == 2, OPERATOR_NAME, + " : Expected Wq meta argument to be 2D tensor, got ", + W_meta.dim(), " dims"); + TORCH_CHECK(W_meta.layout() == at::Layout::Strided, OPERATOR_NAME, + " : Expected Wq meta argument to be strided, got layout ", + W_meta.layout()); + TORCH_CHECK(W_scale.dim() == 1 || W_scale.dim() == 2, OPERATOR_NAME, + " : Expected Wq scale argument to be 1D or 2D tensor, ", + "got ", W_scale.dim(), " dims"); + TORCH_CHECK(W_scale.layout() == at::Layout::Strided, OPERATOR_NAME, + " : Expected Wq scale argument to be strided, got layout ", + W_scale.layout()); + if (bias.numel() > 0) { + TORCH_CHECK(bias.dim() == 1, OPERATOR_NAME, + " : Expected bias argument to be 1D tensor, got ", bias.dim(), + " dims"); + TORCH_CHECK(bias.layout() == at::Layout::Strided, OPERATOR_NAME, + " : Expected bias argument to be strided, got layout ", + bias.layout()); + } + + // Validate sizes of arguments. + const auto Xq_sizes = Xq.sizes().vec(); + TORCH_CHECK(Xq_sizes.back() == 2 * Wq.size(1), OPERATOR_NAME, + " : Expected Xq argument to have ", 2 * Wq.size(1), + " columns, but got ", Xq_sizes.back()); + const auto X_scale_sizes = X_scale.sizes().vec(); + for (auto i = 0; i < X_scale_sizes.size(); ++i) + TORCH_CHECK(X_scale_sizes[i] == Xq_sizes[i], OPERATOR_NAME, + " : Expected Xq scale argument size at position ", i, " to be ", + Xq_sizes[i], ", but got ", X_scale_sizes[i]); + TORCH_CHECK(Wq.size(1) % 8 == 0, OPERATOR_NAME, + " : Expected Wq argument to have number of columns divisible by ", + " 8, got ", Wq.size(1)); + TORCH_CHECK(W_meta.size(0) == Wq.size(0), OPERATOR_NAME, + " : Expected Wq meta argument to have ", Wq.size(0), + " rows, got ", W_meta.numel(), " rows"); + TORCH_CHECK(W_meta.size(1) == Wq.size(1) / 4, OPERATOR_NAME, + " : Expected Wq meta argument to hold ", Wq.size(1) / 4, + " bytes per row to encode sparsity of Wq argument, got ", + W_meta.size(1), " bytes"); + TORCH_CHECK(W_scale.numel() == Wq.size(0), OPERATOR_NAME, + " : Expected Wq scale argument to have ", Wq.size(0), + " elements, got ", W_scale.numel(), " elements"); + if (bias.numel() > 0) { + TORCH_CHECK(bias.numel() == Wq.size(0), OPERATOR_NAME, + " : Expected bias argument to have ", Wq.size(0), + " elements, got ", bias.numel(), " elements"); + } + + // Validate strides of arguments. + const auto Xq_strides = Xq.strides(); + TORCH_CHECK(Xq_strides[Xq_strides.size() - 1] == 1, OPERATOR_NAME, + " : Expected Xq argument in row-major layout"); + auto Xq_stride_expected = Xq_strides[Xq_strides.size() - 2]; + for (int i = Xq_strides.size() - 3; i >= 0; --i) { + Xq_stride_expected *= Xq_sizes[i + 1]; + TORCH_CHECK(Xq_strides[i] == Xq_stride_expected, OPERATOR_NAME, + " : Expected Xq argument in row-major layout"); + } + TORCH_CHECK(X_scale.is_contiguous(), OPERATOR_NAME, + " : Expected Xq scale argument to be contiguous"); + const auto Wq_strides = Wq.strides(); + TORCH_CHECK(Wq_strides[0] >= 1 && Wq_strides[1] == 1, OPERATOR_NAME, + " : Expected Wq argument in row-major layout"); + const auto W_meta_strides = W_meta.strides(); + TORCH_CHECK(W_meta_strides[0] >= 1 && W_meta_strides[1] == 1, OPERATOR_NAME, + " : Expected Wq meta argument in row-major layout"); + TORCH_CHECK(W_scale.is_contiguous(), OPERATOR_NAME, + " : Expected Wq scale argument to be contiguous"); + if (bias.numel() > 0) { + const auto bias_strides = bias.strides(); + TORCH_CHECK(bias_strides[0] == 1, OPERATOR_NAME, + " : Expected bias argument to be contiguous"); + } +} +#endif + +template +at::Tensor +rowwise_scaled_linear_sparse_cutlass( + const at::Tensor& Xq, const at::Tensor& X_scale, const at::Tensor& Wq, + const at::Tensor& W_meta, const at::Tensor& W_scale, + const at::Tensor& bias) { +#if defined(BUILD_ROWWISE_SCALED_LINEAR_SPARSE_CUTLASS) + // Check inputs. + rowwise_scaled_linear_sparse_cutlass_check_inputs( + Xq, X_scale, Wq, W_meta, W_scale, bias); + + // Squash the input tensors as appropriate. + const auto Xq_sizes = Xq.sizes().vec(); + const auto Xq_2d = Xq.reshape({-1, Xq_sizes.back()}); + const auto X_scale_1d = X_scale.reshape({-1}); + const auto W_scale_1d = W_scale.reshape({-1}); + + // Create result tensor. + at::Tensor Y = X_scale.new_empty({Xq_2d.size(0), Wq.size(0)}); + + // Dispatch to appropriate kernel template. + dispatch_on_X_scale_and_W_scale( + Xq_2d, X_scale_1d, Wq, W_meta, W_scale_1d, bias, Y); + + // Reshape and return Y tensor. + auto Y_sizes = Xq_sizes; + Y_sizes.back() = Wq.size(0); + return Y.reshape(Y_sizes); +#else + TORCH_CHECK_NOT_IMPLEMENTED(false, OPERATOR_NAME); + return at::Tensor{}; +#endif +} + +} // namespace torchao diff --git a/torchao/csrc/cuda/rowwise_scaled_linear_sparse_cutlass/rowwise_scaled_linear_sparse_cutlass_f8f8.cu b/torchao/csrc/cuda/rowwise_scaled_linear_sparse_cutlass/rowwise_scaled_linear_sparse_cutlass_f8f8.cu new file mode 100644 index 0000000000..00cfa9a48f --- /dev/null +++ b/torchao/csrc/cuda/rowwise_scaled_linear_sparse_cutlass/rowwise_scaled_linear_sparse_cutlass_f8f8.cu @@ -0,0 +1,42 @@ +#include +#include + +#include "rowwise_scaled_linear_sparse_cutlass.cuh" + +namespace torchao { + +at::Tensor +rowwise_scaled_linear_sparse_cutlass_f8f8( + const at::Tensor& Xq, const at::Tensor& X_scale, const at::Tensor& Wq, + const at::Tensor& W_meta, const at::Tensor& W_scale, + const at::Tensor& bias) { + // Validate input datatypes. + TORCH_CHECK( + (Xq.dtype() == at::kFloat8_e5m2 && Wq.dtype() == at::kFloat8_e4m3fn) || + (Xq.dtype() == at::kFloat8_e4m3fn && Wq.dtype() == at::kFloat8_e4m3fn), + __func__, " : The input datatypes combination ", Xq.dtype(), + " for Xq and ", Wq.dtype(), " for Wq is not supported"); + + // Dispatch to appropriate kernel template. + if (Xq.dtype() == at::kFloat8_e5m2 && Wq.dtype() == at::kFloat8_e4m3fn) { + using DtypeXq = cutlass::float_e5m2_t; + using DtypeWq = cutlass::float_e4m3_t; + return rowwise_scaled_linear_sparse_cutlass( + Xq, X_scale, Wq, W_meta, W_scale, bias); + } else if (Xq.dtype() == at::kFloat8_e4m3fn && + Wq.dtype() == at::kFloat8_e4m3fn) { + using DtypeXq = cutlass::float_e4m3_t; + using DtypeWq = cutlass::float_e4m3_t; + return rowwise_scaled_linear_sparse_cutlass( + Xq, X_scale, Wq, W_meta, W_scale, bias); + } + + return at::Tensor{}; +} + +TORCH_LIBRARY_IMPL(torchao, CUDA, m) { + m.impl("torchao::rowwise_scaled_linear_sparse_cutlass_f8f8", + &rowwise_scaled_linear_sparse_cutlass_f8f8); +} + +} // namespace torchao diff --git a/torchao/csrc/cuda/to_sparse_semi_structured_cutlass_sm9x/to_sparse_semi_structured_cutlass_sm9x.cuh b/torchao/csrc/cuda/to_sparse_semi_structured_cutlass_sm9x/to_sparse_semi_structured_cutlass_sm9x.cuh new file mode 100644 index 0000000000..b38dd871e2 --- /dev/null +++ b/torchao/csrc/cuda/to_sparse_semi_structured_cutlass_sm9x/to_sparse_semi_structured_cutlass_sm9x.cuh @@ -0,0 +1,161 @@ +#pragma once + +#include + +#include +#include +#include +#include + +#if defined(TORCHAO_USE_CUTLASS) && !defined(_WIN32) && \ + defined(CUDA_VERSION) && (CUDA_VERSION >= 12020) +#define BUILD_TO_SPARSE_SEMI_STRUCTURED_CUTLASS_SM9X +#endif + +#if defined(BUILD_TO_SPARSE_SEMI_STRUCTURED_CUTLASS_SM9X) +#include +#include +#include +#include +#include +#include +#include + +#include "cutlass_extensions/common.h" +#endif + +#define OPERATOR_NAME "to_sparse_semi_structured_cutlass_sm9x" + +namespace torchao { + +#if defined(BUILD_TO_SPARSE_SEMI_STRUCTURED_CUTLASS_SM9X) +template +std::tuple +to_sparse_semi_structured_kernel_cutlass_sm9x(const at::Tensor& W) { + // The kernel doesn't check, but assumes instead, that the input + // tensor is a structured sparse tensor. + + static_assert(std::is_same_v || + std::is_same_v); + + using SmArch = cutlass::arch::Sm90; + + using ProblemShape = cute::Shape; + + using LayoutTagW = cutlass::layout::RowMajor; + using StrideW = cutlass::gemm::TagToStrideA_t; + + using DtypeMeta = unsigned char; + using SparseConfig = cutlass::Sm90GemmSparseConfig< + cute::sparse_elem<2, DtypeW>, + cute::GMMA::Major::K, + cute::sparse_elem<8, unsigned char>, + cute::_128>; + + using CompressorUtility = + cutlass::transform::kernel::StructuredSparseCompressorUtility< + ProblemShape, DtypeW, LayoutTagW, SparseConfig>; + using CompressorKernel = + cutlass::transform::kernel::StructuredSparseCompressor< + ProblemShape, DtypeW, LayoutTagW, SparseConfig, SmArch>; + using Compressor = + cutlass::transform::device::TransformUniversalAdapter; + + const int m = W.size(0); + const int k = W.size(1); + + // FIXME: check the 64 number! + ProblemShape problem_shape(m, 1, k, 1); + + StrideW stride_W = + cutlass::make_cute_packed_stride(StrideW{}, cute::make_shape(m, k, 1)); + CompressorUtility compressor_utility(problem_shape, stride_W); + int k_compressed = compressor_utility.get_tensorA_k_physical(); + int m_meta = compressor_utility.get_metadata_m_physical(); + int k_meta = compressor_utility.get_metadata_k_physical(); + + // Create result tensors. + at::Tensor W_compressed = W.new_empty({m, k_compressed}); + at::Tensor W_meta = + W.new_empty({m_meta, k_meta}, at::TensorOptions().dtype(at::kByte)); + + cutlass::KernelHardwareInfo hw_info; + hw_info.device_id = 0; + hw_info.sm_count = + cutlass::KernelHardwareInfo::query_device_multiprocessor_count( + hw_info.device_id); + typename Compressor::Arguments arguments{ + problem_shape, + { + (DtypeW*)W.data_ptr(), stride_W, (DtypeW*)W_compressed.data_ptr(), + (DtypeMeta*)W_meta.data_ptr() + }, + {hw_info}}; + + Compressor compressor_op; + + cutlass::Status status; + + // Verify that compression operation with given arguments can be + // performed by CUTLASS. + status = compressor_op.can_implement(arguments); + CUTLASS_STATUS_CHECK(status, OPERATOR_NAME); + + // Allocate workspace for the compressor. + const auto workspace_size = Compressor::get_workspace_size(arguments); + auto workspace = W.new_empty({(int64_t)workspace_size}, + at::TensorOptions().dtype(at::kByte)); + + // Initialize compressor. + status = compressor_op.initialize(arguments, workspace.data_ptr(), + at::cuda::getCurrentCUDAStream()); + CUTLASS_STATUS_CHECK(status, OPERATOR_NAME); + + // Perform compression. + status = compressor_op.run(at::cuda::getCurrentCUDAStream()); + CUTLASS_STATUS_CHECK(status, OPERATOR_NAME); + + C10_CUDA_KERNEL_LAUNCH_CHECK(); + + return std::make_tuple(W_compressed, W_meta); +} + +template +void +to_sparse_semi_structured_cutlass_sm9x_check_inputs(const at::Tensor& W) { + // Validate the input tensor layout. + TORCH_CHECK(W.dim() == 2, OPERATOR_NAME, + " : Expected W argument to be 2D tensor, got ", W.dim(), + " dims"); + TORCH_CHECK(W.layout() == at::Layout::Strided, OPERATOR_NAME, + " : Expected W argument to be strided, got layout ",W.layout()); + + // Validate the input tensor shape. + const auto W_sizes = W.sizes().vec(); + TORCH_CHECK(W_sizes[1] % 8 == 0, OPERATOR_NAME, + " : Expected number of columns of the W argument to be divisible", + "by 8, got ", W_sizes[1], " columns"); + + // Validate the input tensor strides. + const auto W_strides = W.strides(); + TORCH_CHECK(W_strides[1] == 1, OPERATOR_NAME, + " : Expected W argument in row-major layout"); +} +#endif + +template +std::tuple +to_sparse_semi_structured_cutlass_sm9x(const at::Tensor& W) { +#if defined(BUILD_TO_SPARSE_SEMI_STRUCTURED_CUTLASS_SM9X) + // Check inputs. + to_sparse_semi_structured_cutlass_sm9x_check_inputs(W); + + // Call the kernel. + return to_sparse_semi_structured_kernel_cutlass_sm9x(W); +#else + TORCH_CHECK_NOT_IMPLEMENTED(false, OPERATOR_NAME); + return std::make_tuple(at::Tensor{}, at::Tensor{}); +#endif +} + +} // namespace torchao diff --git a/torchao/csrc/cuda/to_sparse_semi_structured_cutlass_sm9x/to_sparse_semi_structured_cutlass_sm9x_f8.cu b/torchao/csrc/cuda/to_sparse_semi_structured_cutlass_sm9x/to_sparse_semi_structured_cutlass_sm9x_f8.cu new file mode 100644 index 0000000000..6154f5d9cf --- /dev/null +++ b/torchao/csrc/cuda/to_sparse_semi_structured_cutlass_sm9x/to_sparse_semi_structured_cutlass_sm9x_f8.cu @@ -0,0 +1,32 @@ +#include +#include + +#include "to_sparse_semi_structured_cutlass_sm9x.cuh" + +namespace torchao { + +std::tuple +to_sparse_semi_structured_cutlass_sm9x_f8(const at::Tensor& W) { + // Validate input datatypes. + TORCH_CHECK(W.dtype() == at::kFloat8_e5m2 || W.dtype() == at::kFloat8_e4m3fn, + __func__, " : The input datatype ", W.dtype(), + " is not supported"); + + // Dispatch to appropriate kernel template. + if (W.dtype() == at::kFloat8_e5m2) { + using DtypeW = cutlass::float_e5m2_t; + return to_sparse_semi_structured_cutlass_sm9x(W); + } else if (W.dtype() == at::kFloat8_e4m3fn) { + using DtypeW = cutlass::float_e4m3_t; + return to_sparse_semi_structured_cutlass_sm9x(W); + } + + return std::tuple(at::Tensor{}, at::Tensor{}); +} + +TORCH_LIBRARY_IMPL(torchao, CUDA, m) { + m.impl("torchao::to_sparse_semi_structured_cutlass_sm9x_f8", + &to_sparse_semi_structured_cutlass_sm9x_f8); +} + +} // namespace torchao diff --git a/torchao/ops.py b/torchao/ops.py index 8b573876f2..97692a5810 100644 --- a/torchao/ops.py +++ b/torchao/ops.py @@ -25,7 +25,12 @@ lib.define( "rowwise_scaled_linear_cutlass_s8s4(Tensor input, Tensor input_scale, Tensor weight, Tensor weight_scale, Tensor bias) -> Tensor" ) - +lib.define( + "rowwise_scaled_linear_sparse_cutlass_f8f8(Tensor input, Tensor input_scale, Tensor weight, Tensor weight_meta, Tensor weight_scale, Tensor bias) -> Tensor" +) +lib.define( + "to_sparse_semi_structured_cutlass_sm9x_f8(Tensor weight) -> (Tensor, Tensor)" +) def register_custom_op(name): def decorator(func): @@ -590,5 +595,76 @@ def _( weight: Tensor, weight_scale: Tensor, bias: Tensor, +) -> Tensor: + # No checks here, as detailed checks are performed by the + # operator itself. + + return input_scale.new_empty(*input.shape[:-1], weight.shape[0]) + + +def rowwise_scaled_linear_sparse_cutlass_f8f8( + input: Tensor, + input_scale: Tensor, + weight: Tensor, + weight_meta: Tensor, + weight_scale: Tensor, + bias: Tensor, +) -> Tensor: + """ + CUTLASS-based row-wise scaled F8F8 linear operator, for sparsified weight case. + Args: + input: quantized input tensor, in row-major layout. + input_scale: scale factors for input tensor, has to be tensor of the same shape as the input tensor, minus the last dimension. + weight: sparsified quantized weight matrix, in row-major layout. + weight_meta: sparsify metadata for weight tensor. + weight_scale: scale factors for weight tensor, one value per row of weight matrix (thus also tensor of the same shape as the weight tensor, minus the last dimension). + bias: a vector of size equal to number of rows of weight tensor, or None. + Returns: + output: result tensor, in row-major layout. + """ + + return torch.ops.torchao.rowwise_scaled_linear_sparse_cutlass_f8f8.default( + input, input_scale, weight, weight_meta, weight_scale, bias + ) + + +@register_custom_op("torchao::rowwise_scaled_linear_sparse_cutlass_f8f8") +def _( + input: Tensor, + input_scale: Tensor, + weight: Tensor, + weight_meta: Tensor, + weight_scale: Tensor, + bias: Tensor, ) -> Tensor: return input_scale.new_empty(*input.shape[:-1], weight.shape[0]) + + +def to_sparse_semi_structured_cutlass_sm9x_f8( + weight: Tensor, +) -> (Tensor, Tensor): + """ + CUTLASS-based converted from sparsified input tensor to corresponding compressed tensor, along with corresponding metadata tensor. + Args: + weight: input tensor, in row-major layout. + Returns: + weight_compressed: compressed weight tensor, with sparsity eliminated, in row-major layout. + weight_meta: metadata tensor, describing the sparsity structure of the input tensor, also in row-major layout. + """ + + return torch.ops.torchao.to_sparse_semi_structured_cutlass_sm9x_f8.default( + weight + ) + + +@register_custom_op("torchao::to_sparse_semi_structured_cutlass_sm9x_f8") +def _( + weight: Tensor, +) -> (Tensor, Tensor): + # No checks here, as detailed checks are performed by the + # operator itself. + + return ( + weight.new_empty(weight[0], weight[1] // 2), + weight.new_empty(weight[0], weight[1] // 8, dtype=torch.char), + )