Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update bf16i4 gemm with new cutlass version #3630

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 8 additions & 3 deletions fbgemm_gpu/experimental/gen_ai/bench/quantize_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -972,8 +972,8 @@ def _int4_row_quantize(

# Cutlass expects column major layout for scale and zero point,
# so we transpose here and make them contiguous.
scales = scales.view(x.shape[0], -1).t().contiguous()
zeros = zeros.view(x.shape[0], -1).t().contiguous()
scales = scales.view(x.shape[0], -1)
zeros = zeros.view(x.shape[0], -1)

return out, scales, zeros

Expand Down Expand Up @@ -1030,7 +1030,12 @@ def quantize(self, x, w):
wq, w_scale, w_zp = self._int4_row_quantize(w)
# Pack int4 values together.
wq = self._pack_int4(wq)
return x.to(torch.bfloat16), wq, w_scale, w_zp
return (
x.to(torch.bfloat16),
wq,
w_scale.to(torch.bfloat16),
w_zp.to(torch.bfloat16),
)

def compute(self, x, wq, w_scale, w_zp):
return torch.ops.fbgemm.bf16i4bf16_rowwise(x, wq, w_scale, w_zp)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,7 @@ template <
int TBS_M,
int TBS_N,
int TBS_K,
bool PONG,
typename WEIGHT_SCALE_DTYPE>
bool PONG>
at::Tensor bf16i4bf16_rowwise_impl(
at::Tensor X, // BF16
at::Tensor WQ, // INT4
Expand All @@ -42,42 +41,54 @@ at::Tensor bf16i4bf16_rowwise_impl(
int M = X.size(0);
int N = WQ.size(0);
int K = X.size(1);

int num_groups = w_scale.size(0);
int scale_k = w_scale.size(1);

TORCH_CHECK(X.is_cuda() && X.is_contiguous());
TORCH_CHECK(WQ.is_cuda() && WQ.is_contiguous());
TORCH_CHECK(w_scale.is_cuda() && w_scale.is_contiguous());
TORCH_CHECK(w_zp.is_cuda() && w_zp.is_contiguous());
TORCH_CHECK(K >= num_groups && K % num_groups == 0);
TORCH_CHECK(K >= scale_k && K % scale_k == 0);

int group_size = K / num_groups;
int group_size = K / scale_k;

auto Y = at::empty({M, N}, X.options().dtype(at::kBFloat16));

using ElementInputA = cutlass::bfloat16_t;
using LayoutInputA = cutlass::layout::ColumnMajor;
constexpr int AlignmentInputA =
using MmaType = cutlass::bfloat16_t;
using QuantType = cutlass::int4b_t;
// TODO Is this really needed?
constexpr int TileShapeK = 128 * 8 / cutlass::sizeof_bits<MmaType>::value;

using ElementA = MmaType;
using LayoutA = cutlass::layout::RowMajor;
constexpr int AlignmentA =
128 /
cutlass::sizeof_bits<
ElementInputA>::value; // Memory access granularity/alignment of A
// matrix in units of elements (up to 16 bytes)
ElementA>::value; // Memory access granularity/alignment of A
// matrix in units of elements (up to 16 bytes)

using ElementInputB = cutlass::int4b_t;
using LayoutInputB = cutlass::layout::RowMajor;
constexpr int AlignmentInputB =
using ElementB = QuantType;
using LayoutB = cutlass::layout::ColumnMajor;
constexpr int AlignmentB =
128 /
cutlass::sizeof_bits<
ElementInputB>::value; // Memory access granularity/alignment of B
// matrix in units of elements (up to 16 bytes)
ElementB>::value; // Memory access granularity/alignment of B
// matrix in units of elements (up to 16 bytes)

using ElementScale = WEIGHT_SCALE_DTYPE;
using ElementZeroPoint = WEIGHT_SCALE_DTYPE;
using ElementComputeEpilogue = float;
// We transpose and swap inputs.
using LayoutA_Transpose =
typename cutlass::layout::LayoutTranspose<LayoutA>::type;
using LayoutB_Transpose =
typename cutlass::layout::LayoutTranspose<LayoutB>::type;

using LayoutScale = cutlass::layout::RowMajor;

using ElementScale = MmaType;
using ElementZero = MmaType;
using ElementCompute = float;
using ElementAccumulator = float;

using ElementOutput = cutlass::bfloat16_t;
using LayoutOutput = cutlass::layout::ColumnMajor;
using LayoutOutput = cutlass::layout::RowMajor;
constexpr int AlignmentOutput =
128 /
cutlass::sizeof_bits<
Expand All @@ -90,21 +101,25 @@ at::Tensor bf16i4bf16_rowwise_impl(
using TileShape = cute::Shape<
cute::Int<TB_M>,
cute::Int<TB_N>,
cute::Int<TB_K>>; // Threadblock-level
// tile size
cute::Int<TileShapeK>>; // Threadblock-level
// tile size
using ClusterShape = cute::Shape<
cute::Int<TBS_M>,
cute::Int<TBS_N>,
cute::Int<TBS_K>>; // Shape of the
// threadblocks in a
// cluster
using DefaultSchedule = cutlass::gemm::KernelTmaWarpSpecializedMixedInput;
using PongSchedule =
cutlass::gemm::KernelTmaWarpSpecializedPingpongMixedInput;
using EpilogueSchedule = cutlass::epilogue::TmaWarpSpecialized;
using DefaultSchedule = cutlass::gemm::KernelTmaWarpSpecializedCooperative;
using PongSchedule = cutlass::gemm::KernelTmaWarpSpecializedPingpong;
using DefaultEpiSchedule = cutlass::epilogue::TmaWarpSpecializedCooperative;
using PongEpiSchedule = cutlass::epilogue::TmaWarpSpecialized;
using EpilogueTileType = cutlass::epilogue::collective::EpilogueTileAuto;
using MainLoopSchedule =
using KernelSchedule =
cute::conditional_t<PONG, PongSchedule, DefaultSchedule>;
// TODO Possible that only cooperative schedule works.
using EpilogueSchedule = DefaultEpiSchedule;
// using EpilogueSchedule =
// cute::conditional_t<PONG, PongEpiSchedule, DefaultEpiSchedule>;

using CollectiveEpilogue =
typename cutlass::epilogue::collective::CollectiveBuilder<
Expand All @@ -115,67 +130,73 @@ at::Tensor bf16i4bf16_rowwise_impl(
EpilogueTileType,
ElementAccumulator,
ElementAccumulator,
// Transpose layout of D here since we use explicit swap + transpose
// the void type for C tells the builder to allocate 0 smem for the C
// matrix. We can enable this if beta == 0 by changing ElementC to
// void below.
ElementOutput,
LayoutOutput,
typename cutlass::layout::LayoutTranspose<LayoutOutput>::type,
AlignmentOutput,
ElementOutput,
LayoutOutput,
typename cutlass::layout::LayoutTranspose<LayoutOutput>::type,
AlignmentOutput,
EpilogueSchedule>::CollectiveOp;
EpilogueSchedule // This is the only epi supporting the required swap
// + transpose.
>::CollectiveOp;

using CollectiveMainloop =
using CollectiveMainloopScaleWithZeroPoint =
typename cutlass::gemm::collective::CollectiveBuilder<
ArchTag,
OperatorClass,
cute::tuple<ElementInputB, ElementScale, ElementZeroPoint>,
LayoutInputB,
AlignmentInputB,
ElementInputA,
LayoutInputA,
AlignmentInputA,
cute::tuple<ElementB, ElementScale, ElementZero>,
LayoutB_Transpose,
AlignmentB,
ElementA,
LayoutA_Transpose,
AlignmentA,
ElementAccumulator,
TileShape,
ClusterShape,
cutlass::gemm::collective::StageCountAutoCarveout<static_cast<int>(
sizeof(typename CollectiveEpilogue::SharedStorage))>,
MainLoopSchedule>::CollectiveOp;
KernelSchedule>::CollectiveOp;

using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
cute::Shape<int, int, int>,
CollectiveMainloop,
cute::Shape<int, int, int, int>, // Indicates ProblemShape
CollectiveMainloopScaleWithZeroPoint,
CollectiveEpilogue>;

using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;

using StrideInputA = typename Gemm::GemmKernel::StrideA;
using StrideInputB = typename Gemm::GemmKernel::StrideB;
using StrideA = typename Gemm::GemmKernel::StrideA;
using StrideB = typename Gemm::GemmKernel::StrideB;
using StrideOutput = typename Gemm::GemmKernel::StrideC;
using StrideS = typename CollectiveMainloop::StrideScale;
using StrideS = typename CollectiveMainloopScaleWithZeroPoint::StrideScale;

StrideInputA stride_a = cutlass::make_cute_packed_stride(
StrideInputA{}, cute::make_shape(M, K, 1));
StrideInputB stride_b = cutlass::make_cute_packed_stride(
StrideInputB{}, cute::make_shape(N, K, 1));
StrideA stride_A =
cutlass::make_cute_packed_stride(StrideA{}, cute::make_shape(M, K, 1));
StrideB stride_B =
cutlass::make_cute_packed_stride(StrideB{}, cute::make_shape(N, K, 1));
StrideOutput stride_output = cutlass::make_cute_packed_stride(
StrideOutput{}, cute::make_shape(N, M, 1));
StrideS stride_S = cutlass::make_cute_packed_stride(
StrideS{}, cute::make_shape(N, num_groups, 1));
StrideS{}, cute::make_shape(N, scale_k, 1));

typename Gemm::Arguments arguments{
cutlass::gemm::GemmUniversalMode::kGemm,
{N, M, K},
{reinterpret_cast<ElementInputB*>(WQ.data_ptr()),
stride_b,
reinterpret_cast<ElementInputA*>(X.data_ptr()),
stride_a,
{N, M, K, 1},
{reinterpret_cast<ElementB*>(WQ.data_ptr()),
stride_B,
reinterpret_cast<ElementA*>(X.data_ptr()),
stride_A,
reinterpret_cast<ElementScale*>(w_scale.data_ptr()),
stride_S,
group_size,
reinterpret_cast<ElementZeroPoint*>(w_zp.data_ptr())},
{{1.0, 0.0},
(ElementOutput*)Y.data_ptr<at::BFloat16>(),
reinterpret_cast<ElementZero*>(w_zp.data_ptr())},
{{},
reinterpret_cast<ElementOutput*>(Y.data_ptr()),
stride_output,
(ElementOutput*)Y.data_ptr<at::BFloat16>(),
reinterpret_cast<ElementOutput*>(Y.data_ptr()),
stride_output}};

Gemm gemm;
Expand Down Expand Up @@ -211,43 +232,21 @@ at::Tensor bf16i4bf16_rowwise_impl(
return Y;
}

template <typename WEIGHT_SCALE_DTYPE>
at::Tensor dispatch_bf16i4bf16_rowwise_kernel(
at::Tensor X, // BF16
at::Tensor WQ, // INT4
at::Tensor w_scale,
at::Tensor w_zp) {
KernelMode kernel = get_kernel_mode(X, WQ);
if (kernel == KernelMode::Small) {
return bf16i4bf16_rowwise_impl<
64,
128,
128,
2,
1,
1,
true,
WEIGHT_SCALE_DTYPE>(X, WQ, w_scale, w_zp);
return bf16i4bf16_rowwise_impl<64, 128, 128, 2, 1, 1, true>(
X, WQ, w_scale, w_zp);
} else if (kernel == KernelMode::Large) {
return bf16i4bf16_rowwise_impl<
128,
128,
128,
2,
1,
1,
true,
WEIGHT_SCALE_DTYPE>(X, WQ, w_scale, w_zp);
return bf16i4bf16_rowwise_impl<128, 128, 128, 2, 1, 1, true>(
X, WQ, w_scale, w_zp);
} else {
return bf16i4bf16_rowwise_impl<
128,
128,
128,
2,
1,
1,
false,
WEIGHT_SCALE_DTYPE>(X, WQ, w_scale, w_zp);
return bf16i4bf16_rowwise_impl<128, 128, 128, 2, 1, 1, false>(
X, WQ, w_scale, w_zp);
}
}

Expand All @@ -258,23 +257,10 @@ at::Tensor bf16i4bf16_rowwise(
at::Tensor w_zp) {
// Check datatypes.
TORCH_CHECK(
(w_scale.dtype() == at::kFloat && w_zp.dtype() == at::kFloat) ||
(w_scale.dtype() == at::kHalf && w_zp.dtype() == at::kHalf) ||
(w_scale.dtype() == at::kBFloat16 && w_zp.dtype() == at::kBFloat16),
"Weight scale and zero point tensors must be float32, bfloat16, or float16, and dtype of weight scale and zero point tensors must be the same .");

if (w_scale.dtype() == at::kFloat) {
return dispatch_bf16i4bf16_rowwise_kernel<float>(X, WQ, w_scale, w_zp);
} else if (w_scale.dtype() == at::kHalf) {
return dispatch_bf16i4bf16_rowwise_kernel<cutlass::half_t>(
X, WQ, w_scale, w_zp);
} else if (w_scale.dtype() == at::kBFloat16) {
return dispatch_bf16i4bf16_rowwise_kernel<cutlass::bfloat16_t>(
X, WQ, w_scale, w_zp);
} else {
throw std::runtime_error(
"Weight scale and zero point data type not supported in bf16i4bf16_rowwise");
}
(w_scale.dtype() == at::kBFloat16 && w_zp.dtype() == at::kBFloat16),
"Weight scale and zero point tensors must be bfloat16 and dtype of weight scale and zero point tensors must be the same.");

return dispatch_bf16i4bf16_rowwise_kernel(X, WQ, w_scale, w_zp);
}

#else
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -102,9 +102,8 @@ at::Tensor bf16i4bf16_rowwise_batched_impl(
cute::Int<TBS_K>>; // Shape of the
// threadblocks in a
// cluster
using DefaultSchedule = cutlass::gemm::KernelTmaWarpSpecializedMixedInput;
using PongSchedule =
cutlass::gemm::KernelTmaWarpSpecializedPingpongMixedInput;
using DefaultSchedule = cutlass::gemm::KernelTmaWarpSpecialized;
using PongSchedule = cutlass::gemm::KernelTmaWarpSpecializedPingpong;
using EpilogueSchedule = cutlass::epilogue::TmaWarpSpecialized;
using EpilogueTileType = cutlass::epilogue::collective::EpilogueTileAuto;
using MainLoopSchedule =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -92,9 +92,8 @@ at::Tensor f8i4bf16_rowwise_impl(
cute::Int<TBS_K>>; // Shape of the
// threadblocks in a
// cluster
using DefaultSchedule = cutlass::gemm::KernelTmaWarpSpecializedMixedInput;
using PongSchedule =
cutlass::gemm::KernelTmaWarpSpecializedPingpongMixedInput;
using DefaultSchedule = cutlass::gemm::KernelTmaWarpSpecialized;
using PongSchedule = cutlass::gemm::KernelTmaWarpSpecializedPingpong;
using EpilogueSchedule = cutlass::epilogue::TmaWarpSpecialized;
using EpilogueTileType = cutlass::epilogue::collective::EpilogueTileAuto;
using MainLoopSchedule =
Expand Down
Loading