Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: implement SM-Constrained GEMM API #744

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion csrc/flashinfer_gemm_ops.cu
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,12 @@ void bmm_fp8(at::Tensor A, at::Tensor B, at::Tensor D, at::Tensor A_scale, at::T
void CutlassSegmentGEMM(at::Tensor workspace_buffer, at::Tensor all_problems, at::Tensor x_ptr,
at::Tensor w_ptr, at::Tensor y_ptr, at::Tensor x_ld, at::Tensor w_ld,
at::Tensor y_ld, at::Tensor empty_x_data, bool weight_column_major,
int64_t cuda_stream);
std::vector<int64_t> plan_info_vec, int64_t cuda_stream);

std::vector<int64_t> CutlassSegmentGEMMPlan(unsigned int num_ctas);

PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("cutlass_segment_gemm", &CutlassSegmentGEMM, "Cutlass Segment GEMM");
m.def("cutlass_segment_gemm_plan", &CutlassSegmentGEMMPlan, "Cutlass Segment GEMM Plan");
m.def("bmm_fp8", &bmm_fp8, "BMM FP8");
}
27 changes: 27 additions & 0 deletions csrc/flashinfer_gemm_sm90_ops.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
/*
* Copyright (c) 2025 by FlashInfer team.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "pytorch_extension_utils.h"

void CutlassSegmentGEMMSM90(at::Tensor float_workspace_buffer, at::Tensor int_workspace_buffer,
at::Tensor all_problems, at::Tensor x_ptr, at::Tensor w_ptr,
at::Tensor y_ptr, at::Tensor x_stride, at::Tensor weight_stride,
at::Tensor y_stride, at::Tensor empty_x_data, bool weight_column_major,
std::vector<int64_t> plan_info_vec, int64_t cuda_stream);

PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("cutlass_segment_gemm_sm90", &CutlassSegmentGEMMSM90,
"Cutlass Segment GEMM operator for SM90");
}
4 changes: 3 additions & 1 deletion csrc/flashinfer_ops.cu
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,8 @@ void bmm_fp8(at::Tensor A, at::Tensor B, at::Tensor D, at::Tensor A_scale, at::T
void CutlassSegmentGEMM(at::Tensor workspace_buffer, at::Tensor all_problems, at::Tensor x_ptr,
at::Tensor w_ptr, at::Tensor y_ptr, at::Tensor x_ld, at::Tensor w_ld,
at::Tensor y_ld, at::Tensor empty_x_data, bool weight_column_major,
int64_t cuda_stream);
std::vector<int64_t> plan_info_vec, int64_t cuda_stream);
std::vector<int64_t> CutlassSegmentGEMMPlan(unsigned int num_ctas);

//========== norm ==========

Expand Down Expand Up @@ -219,6 +220,7 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {

// gemm
m.def("bmm_fp8", &bmm_fp8, "BMM FP8");
m.def("cutlass_segment_gemm_plan", &CutlassSegmentGEMMPlan, "Cutlass Segment GEMM plan");
m.def("cutlass_segment_gemm", &CutlassSegmentGEMM, "Cutlass Segment GEMM operator");

// norm
Expand Down
2 changes: 1 addition & 1 deletion csrc/flashinfer_ops_sm90.cu
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ void CutlassSegmentGEMMSM90(at::Tensor float_workspace_buffer, at::Tensor int_wo
at::Tensor all_problems, at::Tensor x_ptr, at::Tensor w_ptr,
at::Tensor y_ptr, at::Tensor x_stride, at::Tensor weight_stride,
at::Tensor y_stride, at::Tensor empty_x_data, bool weight_column_major,
int64_t cuda_stream);
std::vector<int64_t> plan_info_vec, int64_t cuda_stream);

void single_prefill_with_kv_cache_sm90(
at::Tensor q, at::Tensor k, at::Tensor v, at::Tensor o, std::optional<at::Tensor> maybe_lse,
Expand Down
15 changes: 13 additions & 2 deletions csrc/group_gemm.cu
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
* limitations under the License.
*/
#include <flashinfer/gemm/group_gemm.cuh>
#include <flashinfer/gemm/scheduler.cuh>

#include "pytorch_extension_utils.h"

Expand All @@ -23,7 +24,9 @@ using namespace flashinfer::group_gemm;
void CutlassSegmentGEMM(at::Tensor workspace_buffer, at::Tensor all_problems, at::Tensor x_ptr,
at::Tensor w_ptr, at::Tensor y_ptr, at::Tensor x_ld, at::Tensor w_ld,
at::Tensor y_ld, at::Tensor empty_x_data, bool weight_column_major,
int64_t cuda_stream) {
std::vector<int64_t> plan_info_vec, int64_t cuda_stream) {
GemmPlanInfo plan_info;
plan_info.FromVector(plan_info_vec);
unsigned int batch_size = x_ptr.size(0);

cudaStream_t stream = reinterpret_cast<cudaStream_t>(cuda_stream);
Expand All @@ -32,9 +35,17 @@ void CutlassSegmentGEMM(at::Tensor workspace_buffer, at::Tensor all_problems, at
auto status = CutlassSegmentGEMMRun<cutlass_t>(
workspace_buffer.data_ptr(), workspace_buffer.element_size() * workspace_buffer.size(0),
all_problems.data_ptr(), batch_size, x_ptr.data_ptr(), w_ptr.data_ptr(), y_ptr.data_ptr(),
x_ld.data_ptr(), w_ld.data_ptr(), y_ld.data_ptr(), weight_column_major, stream);
x_ld.data_ptr(), w_ld.data_ptr(), y_ld.data_ptr(), weight_column_major, plan_info.num_ctas,
stream);
TORCH_CHECK(status == cudaSuccess,
"Failed to run CutlassSegmentGEMM: ", cudaGetErrorString(status));
return true;
});
}

std::vector<int64_t> CutlassSegmentGEMMPlan(unsigned int num_ctas) {
GemmPlanInfo plan_info;
cudaError_t status = GemmPlan(num_ctas, plan_info);
TORCH_CHECK(status == cudaSuccess, "GemmPlan failed with error: ", cudaGetErrorString(status));
return plan_info.ToVector();
}
8 changes: 6 additions & 2 deletions csrc/group_gemm_sm90.cu
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
* limitations under the License.
*/
#include <flashinfer/gemm/group_gemm_sm90.cuh>
#include <flashinfer/gemm/scheduler.cuh>

#include "pytorch_extension_utils.h"

Expand All @@ -24,7 +25,9 @@ void CutlassSegmentGEMMSM90(at::Tensor float_workspace_buffer, at::Tensor int_wo
at::Tensor all_problems, at::Tensor x_ptr, at::Tensor w_ptr,
at::Tensor y_ptr, at::Tensor x_stride, at::Tensor weight_stride,
at::Tensor y_stride, at::Tensor empty_x_data, bool weight_column_major,
int64_t cuda_stream) {
std::vector<int64_t> plan_info_vec, int64_t cuda_stream) {
GemmPlanInfo plan_info;
plan_info.FromVector(plan_info_vec);
unsigned int batch_size = x_ptr.size(0);
auto device = float_workspace_buffer.device();

Expand All @@ -37,7 +40,8 @@ void CutlassSegmentGEMMSM90(at::Tensor float_workspace_buffer, at::Tensor int_wo
int_workspace_buffer.data_ptr(),
int_workspace_buffer.element_size() * int_workspace_buffer.size(0), all_problems.data_ptr(),
batch_size, x_ptr.data_ptr(), w_ptr.data_ptr(), y_ptr.data_ptr(), x_stride.data_ptr(),
weight_stride.data_ptr(), y_stride.data_ptr(), weight_column_major, stream);
weight_stride.data_ptr(), y_stride.data_ptr(), weight_column_major, plan_info.num_ctas,
stream);
TORCH_CHECK(status == cudaSuccess,
"Failed to run CutlassSegmentGEMM: ", cudaGetErrorString(status));
return true;
Expand Down
35 changes: 34 additions & 1 deletion flashinfer/gemm.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
"""

from types import SimpleNamespace
from typing import Optional
from typing import List, Optional

import torch
import triton
Expand Down Expand Up @@ -104,6 +104,7 @@ def cutlass_segment_gemm(
y: torch.Tensor,
empty_x_data: torch.Tensor,
weight_column_major: bool,
plan_info_vec: List[int],
) -> None:
with x_data.device as device:
module.cutlass_segment_gemm(
Expand All @@ -117,6 +118,7 @@ def cutlass_segment_gemm(
y_ld,
empty_x_data,
weight_column_major,
plan_info_vec,
get_cuda_stream(device),
)

Expand All @@ -139,6 +141,7 @@ def _fake_cutlass_segment_gemm(
# Register the module
_gemm_module = SimpleNamespace(
bmm_fp8=bmm_fp8,
plan=module.cutlass_segment_gemm_plan,
cutlass_segment_gemm=cutlass_segment_gemm,
)

Expand Down Expand Up @@ -181,6 +184,7 @@ def cutlass_segment_gemm_sm90(
y: torch.Tensor,
empty_x_data: torch.Tensor,
weight_column_major: bool,
plan_info_vec: List[int],
) -> None:
with x_data.device as device:
module.cutlass_segment_gemm_sm90(
Expand All @@ -195,6 +199,7 @@ def cutlass_segment_gemm_sm90(
y_stride,
empty_x_data,
weight_column_major,
plan_info_vec,
get_cuda_stream(device),
)

Expand All @@ -212,6 +217,7 @@ def _fake_cutlass_segment_gemm_sm90(
y: torch.Tensor,
empty_x_data: torch.Tensor,
weight_column_major: bool,
plan_info_vec: List[int],
) -> None:
pass

Expand Down Expand Up @@ -444,6 +450,8 @@ class SegmentGEMMWrapper:
>>> x = torch.randn(10, 128, device="cuda", dtype=torch.float16)
>>> # create weight tensor with 4 weights, each with 128 input and 256 output channels, column major
>>> weights = torch.randn(4, 256, 128, device="cuda", dtype=torch.float16)
>>> # set the number of CTAs to 64
>>> segment_gemm.plan(64)
>>> # compute the segment GEMM
>>> y = segment_gemm.run(x, weights, 4, True, seg_lens=seq_lens)
>>> y.shape
Expand Down Expand Up @@ -512,6 +520,29 @@ def reset_workspace_buffer(
self._float_workspace_buffer = float_workspace_buffer
self._int_workspace_buffer = int_workspace_buffer

def plan(self, num_ctas: int = 0) -> None:
r"""Plan gemm for given num_ctas.

Parameters
----------
num_ctas: int
The number of CTAs to run gemm kernel. If equal to 0 or greater than
the number of CTAs on device, it will be set to the number of CTAs on device.


Note
----
The :meth:`plan` method should be called before any :meth:`run`.

The :meth:`plan` method cannot be used in Cuda Graph or in ``torch.compile``.
"""
if num_ctas < 0:
raise ValueError("Num_ctas should be greater than or equal to 0.")

self._plan_info = get_gemm_module().plan(
num_ctas,
)

def run(
self,
x: torch.Tensor,
Expand Down Expand Up @@ -629,6 +660,7 @@ def run(
y, # for torch compile mutates_args
empty_x_data, # for kernel type dispatch
weight_column_major,
self._plan_info,
)
case "sm80":
(
Expand Down Expand Up @@ -660,6 +692,7 @@ def run(
y,
empty_x_data,
weight_column_major,
self._plan_info,
)
case _:
raise ValueError(f"Unsupported gemm backend: {backend}")
Expand Down
10 changes: 5 additions & 5 deletions include/flashinfer/gemm/group_gemm.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ template <typename DType>
cudaError_t CutlassSegmentGEMMRun(void* workspace_buffer, size_t workspace_buffer_size_in_bytes,
void* all_problems, unsigned int batch_size, void* x, void* w,
void* y, void* x_ld, void* w_ld, void* y_ld,
bool weight_column_major, cudaStream_t stream) {
bool weight_column_major, int num_ctas, cudaStream_t stream) {
using cutlass::epilogue::thread::LinearCombination;
using cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle;
DISPATCH_WEIGHT_LAYOUT(weight_column_major, WEIGHT_LAYOUT, {
Expand Down Expand Up @@ -69,10 +69,10 @@ cudaError_t CutlassSegmentGEMMRun(void* workspace_buffer, size_t workspace_buffe
using GemmGrouped = cutlass::gemm::device::GemmGrouped<GemmKernel>;
typename GemmGrouped::Arguments args(
reinterpret_cast<cutlass::gemm::GemmCoord*>(all_problems), (int)batch_size,
/*threadblock_count=*/4, epilogue_op, static_cast<DType**>(x), static_cast<DType**>(w),
static_cast<DType**>(y), static_cast<DType**>(y), reinterpret_cast<int64_t*>(x_ld),
reinterpret_cast<int64_t*>(w_ld), reinterpret_cast<int64_t*>(y_ld),
reinterpret_cast<int64_t*>(y_ld));
/*threadblock_count=*/num_ctas, epilogue_op, static_cast<DType**>(x),
static_cast<DType**>(w), static_cast<DType**>(y), static_cast<DType**>(y),
reinterpret_cast<int64_t*>(x_ld), reinterpret_cast<int64_t*>(w_ld),
reinterpret_cast<int64_t*>(y_ld), reinterpret_cast<int64_t*>(y_ld));

GemmGrouped gemm;
auto status = gemm.initialize(args, nullptr, stream);
Expand Down
5 changes: 2 additions & 3 deletions include/flashinfer/gemm/group_gemm_sm90.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ cudaError_t CutlassSegmentGEMMSM90Run(void* float_buffer, size_t float_buffer_si
void* int_buffer, size_t int_buffer_size_in_bytes,
void* all_problems, unsigned int batch_size, void* x, void* w,
void* y, void* x_stride, void* w_stride, void* y_stride,
bool weight_column_major, cudaStream_t stream) {
bool weight_column_major, int num_ctas, cudaStream_t stream) {
auto compute_capacity = GetCudaComputeCapability();
if (compute_capacity.first < 9) {
std::cerr << "CutlassSegmentGEMMSM90Run requires compute capability of at least 9.0"
Expand Down Expand Up @@ -121,8 +121,7 @@ cudaError_t CutlassSegmentGEMMSM90Run(void* float_buffer, size_t float_buffer_si

cutlass::KernelHardwareInfo hw_info;
cudaGetDevice(&hw_info.device_id);
hw_info.sm_count =
cutlass::KernelHardwareInfo::query_device_multiprocessor_count(hw_info.device_id);
hw_info.sm_count = num_ctas;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I tried nsys profiler and it turns out this value can't control the number of SMs this kernel used.
A more fundamental approach might be using green context.


typename Gemm::EpilogueOutputOp::Params params;
params =
Expand Down
64 changes: 64 additions & 0 deletions include/flashinfer/gemm/scheduler.cuh
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
/*
* Copyright (c) 2025 by FlashInfer team.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef FLASHINFER_GEMM_SCHEDULER_CUH_
#define FLASHINFER_GEMM_SCHEDULER_CUH_

#include <cuda_runtime_api.h>

#include <algorithm>
#include <cstddef>
#include <cstdint>
#include <sstream>
#include <vector>

#include "../utils.cuh"

namespace flashinfer {

struct GemmPlanInfo {
int64_t num_ctas;

GemmPlanInfo() : num_ctas(0) {}

// convert GemmPlanInfo to std::vector<int64_t>
std::vector<int64_t> ToVector() const { return {num_ctas}; }

// From std::vector<int64_t> to GemmPlanInfo
void FromVector(const std::vector<int64_t>& vec) {
if (vec.size() != 1) {
std::ostringstream err_msg;
err_msg << "GemmPlanInfo::FromVector: vec.size() should be 1, but got " << vec.size();
FLASHINFER_ERROR(err_msg.str());
}
num_ctas = vec[0];
}
};

inline cudaError_t GemmPlan(uint32_t num_ctas, GemmPlanInfo& plan_info) {
int dev_id = 0;
int num_sms = 0;
FLASHINFER_CUDA_CALL(cudaGetDevice(&dev_id));
FLASHINFER_CUDA_CALL(cudaDeviceGetAttribute(&num_sms, cudaDevAttrMultiProcessorCount, dev_id));
if (num_ctas > 0 && num_ctas < num_sms) {
plan_info.num_ctas = num_ctas;
} else {
plan_info.num_ctas = num_sms;
}
return cudaSuccess;
}

} // namespace flashinfer
#endif // FLASHINFER_GEMM_SCHEDULER_CUH_
3 changes: 3 additions & 0 deletions tests/test_group_gemm.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("device", CUDA_DEVICES)
@pytest.mark.parametrize("backend", ["auto", "sm90", "sm80"])
@pytest.mark.parametrize("num_ctas", [0, 4, 16, 64])
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What's the expected behavior of num_ctas=0?

def test_segment_gemm(
batch_size,
num_rows_per_batch,
Expand All @@ -43,6 +44,7 @@ def test_segment_gemm(
dtype,
device,
backend,
num_ctas,
):
if batch_size * num_rows_per_batch > 8192:
pytest.skip("batch_size * num_rows_per_batch too large for test.")
Expand All @@ -64,6 +66,7 @@ def test_segment_gemm(
weight = torch.randn(batch_size, d_out, d_in, dtype=dtype).to(device)
else:
weight = torch.randn(batch_size, d_in, d_out, dtype=dtype).to(device)
segment_gemm.plan(num_ctas)
y = segment_gemm.run(
x,
weight,
Expand Down