Skip to content

Commit

Permalink
feat: implement SM-Constrained GEMM API
Browse files Browse the repository at this point in the history
As requested in #591, this PR implements the `plan` function of GEMM
 with `num_ctas` as an argument to specify the grid size.
  • Loading branch information
lanchongyizu committed Jan 21, 2025
1 parent a0e99a3 commit 841b423
Show file tree
Hide file tree
Showing 12 changed files with 163 additions and 16 deletions.
5 changes: 4 additions & 1 deletion csrc/flashinfer_gemm_ops.cu
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,12 @@ void bmm_fp8(at::Tensor A, at::Tensor B, at::Tensor D, at::Tensor A_scale, at::T
void CutlassSegmentGEMM(at::Tensor workspace_buffer, at::Tensor all_problems, at::Tensor x_ptr,
at::Tensor w_ptr, at::Tensor y_ptr, at::Tensor x_ld, at::Tensor w_ld,
at::Tensor y_ld, at::Tensor empty_x_data, bool weight_column_major,
int64_t cuda_stream);
std::vector<int64_t> plan_info_vec, int64_t cuda_stream);

std::vector<int64_t> CutlassSegmentGEMMPlan(unsigned int num_ctas);

PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("cutlass_segment_gemm", &CutlassSegmentGEMM, "Cutlass Segment GEMM");
m.def("cutlass_segment_gemm_plan", &CutlassSegmentGEMMPlan, "Cutlass Segment GEMM Plan");
m.def("bmm_fp8", &bmm_fp8, "BMM FP8");
}
27 changes: 27 additions & 0 deletions csrc/flashinfer_gemm_sm90_ops.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
/*
* Copyright (c) 2025 by FlashInfer team.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "pytorch_extension_utils.h"

void CutlassSegmentGEMMSM90(at::Tensor float_workspace_buffer, at::Tensor int_workspace_buffer,
at::Tensor all_problems, at::Tensor x_ptr, at::Tensor w_ptr,
at::Tensor y_ptr, at::Tensor x_stride, at::Tensor weight_stride,
at::Tensor y_stride, at::Tensor empty_x_data, bool weight_column_major,
std::vector<int64_t> plan_info_vec, int64_t cuda_stream);

PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("cutlass_segment_gemm_sm90", &CutlassSegmentGEMMSM90,
"Cutlass Segment GEMM operator for SM90");
}
4 changes: 3 additions & 1 deletion csrc/flashinfer_ops.cu
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,8 @@ void bmm_fp8(at::Tensor A, at::Tensor B, at::Tensor D, at::Tensor A_scale, at::T
void CutlassSegmentGEMM(at::Tensor workspace_buffer, at::Tensor all_problems, at::Tensor x_ptr,
at::Tensor w_ptr, at::Tensor y_ptr, at::Tensor x_ld, at::Tensor w_ld,
at::Tensor y_ld, at::Tensor empty_x_data, bool weight_column_major,
int64_t cuda_stream);
std::vector<int64_t> plan_info_vec, int64_t cuda_stream);
std::vector<int64_t> CutlassSegmentGEMMPlan(unsigned int num_ctas);

//========== norm ==========

Expand Down Expand Up @@ -223,6 +224,7 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {

// gemm
m.def("bmm_fp8", &bmm_fp8, "BMM FP8");
m.def("cutlass_segment_gemm_plan", &CutlassSegmentGEMMPlan, "Cutlass Segment GEMM plan");
m.def("cutlass_segment_gemm", &CutlassSegmentGEMM, "Cutlass Segment GEMM operator");

// norm
Expand Down
2 changes: 1 addition & 1 deletion csrc/flashinfer_ops_sm90.cu
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ void CutlassSegmentGEMMSM90(at::Tensor float_workspace_buffer, at::Tensor int_wo
at::Tensor all_problems, at::Tensor x_ptr, at::Tensor w_ptr,
at::Tensor y_ptr, at::Tensor x_stride, at::Tensor weight_stride,
at::Tensor y_stride, at::Tensor empty_x_data, bool weight_column_major,
int64_t cuda_stream);
std::vector<int64_t> plan_info_vec, int64_t cuda_stream);

void single_prefill_with_kv_cache_sm90(unsigned int mask_mode_code, at::Tensor q, at::Tensor k,
at::Tensor v,
Expand Down
15 changes: 13 additions & 2 deletions csrc/group_gemm.cu
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
* limitations under the License.
*/
#include <flashinfer/gemm/group_gemm.cuh>
#include <flashinfer/gemm/scheduler.cuh>

#include "pytorch_extension_utils.h"

Expand All @@ -23,7 +24,9 @@ using namespace flashinfer::group_gemm;
void CutlassSegmentGEMM(at::Tensor workspace_buffer, at::Tensor all_problems, at::Tensor x_ptr,
at::Tensor w_ptr, at::Tensor y_ptr, at::Tensor x_ld, at::Tensor w_ld,
at::Tensor y_ld, at::Tensor empty_x_data, bool weight_column_major,
int64_t cuda_stream) {
std::vector<int64_t> plan_info_vec, int64_t cuda_stream) {
GemmPlanInfo plan_info;
plan_info.FromVector(plan_info_vec);
unsigned int batch_size = x_ptr.size(0);

cudaStream_t stream = reinterpret_cast<cudaStream_t>(cuda_stream);
Expand All @@ -32,9 +35,17 @@ void CutlassSegmentGEMM(at::Tensor workspace_buffer, at::Tensor all_problems, at
auto status = CutlassSegmentGEMMRun<cutlass_t>(
workspace_buffer.data_ptr(), workspace_buffer.element_size() * workspace_buffer.size(0),
all_problems.data_ptr(), batch_size, x_ptr.data_ptr(), w_ptr.data_ptr(), y_ptr.data_ptr(),
x_ld.data_ptr(), w_ld.data_ptr(), y_ld.data_ptr(), weight_column_major, stream);
x_ld.data_ptr(), w_ld.data_ptr(), y_ld.data_ptr(), weight_column_major, plan_info.num_ctas,
stream);
TORCH_CHECK(status == cudaSuccess,
"Failed to run CutlassSegmentGEMM: ", cudaGetErrorString(status));
return true;
});
}

std::vector<int64_t> CutlassSegmentGEMMPlan(unsigned int num_ctas) {
GemmPlanInfo plan_info;
cudaError_t status = GemmPlan(num_ctas, plan_info);
TORCH_CHECK(status == cudaSuccess, "GemmPlan failed with error: ", cudaGetErrorString(status));
return plan_info.ToVector();
}
8 changes: 6 additions & 2 deletions csrc/group_gemm_sm90.cu
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
* limitations under the License.
*/
#include <flashinfer/gemm/group_gemm_sm90.cuh>
#include <flashinfer/gemm/scheduler.cuh>

#include "pytorch_extension_utils.h"

Expand All @@ -24,7 +25,9 @@ void CutlassSegmentGEMMSM90(at::Tensor float_workspace_buffer, at::Tensor int_wo
at::Tensor all_problems, at::Tensor x_ptr, at::Tensor w_ptr,
at::Tensor y_ptr, at::Tensor x_stride, at::Tensor weight_stride,
at::Tensor y_stride, at::Tensor empty_x_data, bool weight_column_major,
int64_t cuda_stream) {
std::vector<int64_t> plan_info_vec, int64_t cuda_stream) {
GemmPlanInfo plan_info;
plan_info.FromVector(plan_info_vec);
unsigned int batch_size = x_ptr.size(0);
auto device = float_workspace_buffer.device();

Expand All @@ -37,7 +40,8 @@ void CutlassSegmentGEMMSM90(at::Tensor float_workspace_buffer, at::Tensor int_wo
int_workspace_buffer.data_ptr(),
int_workspace_buffer.element_size() * int_workspace_buffer.size(0), all_problems.data_ptr(),
batch_size, x_ptr.data_ptr(), w_ptr.data_ptr(), y_ptr.data_ptr(), x_stride.data_ptr(),
weight_stride.data_ptr(), y_stride.data_ptr(), weight_column_major, stream);
weight_stride.data_ptr(), y_stride.data_ptr(), weight_column_major, plan_info.num_ctas,
stream);
TORCH_CHECK(status == cudaSuccess,
"Failed to run CutlassSegmentGEMM: ", cudaGetErrorString(status));
return true;
Expand Down
35 changes: 34 additions & 1 deletion flashinfer/gemm.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
"""

from types import SimpleNamespace
from typing import Optional
from typing import List, Optional

import torch
import triton
Expand Down Expand Up @@ -104,6 +104,7 @@ def cutlass_segment_gemm(
y: torch.Tensor,
empty_x_data: torch.Tensor,
weight_column_major: bool,
plan_info_vec: List[int],
) -> None:
with x_data.device as device:
module.cutlass_segment_gemm(
Expand All @@ -117,6 +118,7 @@ def cutlass_segment_gemm(
y_ld,
empty_x_data,
weight_column_major,
plan_info_vec,
get_cuda_stream(device),
)

Expand All @@ -139,6 +141,7 @@ def _fake_cutlass_segment_gemm(
# Register the module
_gemm_module = SimpleNamespace(
bmm_fp8=bmm_fp8,
plan=module.cutlass_segment_gemm_plan,
cutlass_segment_gemm=cutlass_segment_gemm,
)

Expand Down Expand Up @@ -181,6 +184,7 @@ def cutlass_segment_gemm_sm90(
y: torch.Tensor,
empty_x_data: torch.Tensor,
weight_column_major: bool,
plan_info_vec: List[int],
) -> None:
with x_data.device as device:
module.cutlass_segment_gemm_sm90(
Expand All @@ -195,6 +199,7 @@ def cutlass_segment_gemm_sm90(
y_stride,
empty_x_data,
weight_column_major,
plan_info_vec,
get_cuda_stream(device),
)

Expand All @@ -212,6 +217,7 @@ def _fake_cutlass_segment_gemm_sm90(
y: torch.Tensor,
empty_x_data: torch.Tensor,
weight_column_major: bool,
plan_info_vec: List[int],
) -> None:
pass

Expand Down Expand Up @@ -444,6 +450,8 @@ class SegmentGEMMWrapper:
>>> x = torch.randn(10, 128, device="cuda", dtype=torch.float16)
>>> # create weight tensor with 4 weights, each with 128 input and 256 output channels, column major
>>> weights = torch.randn(4, 256, 128, device="cuda", dtype=torch.float16)
>>> # set the number of CTAs to 64
>>> segment_gemm.plan(64)
>>> # compute the segment GEMM
>>> y = segment_gemm.run(x, weights, 4, True, seg_lens=seq_lens)
>>> y.shape
Expand Down Expand Up @@ -512,6 +520,29 @@ def reset_workspace_buffer(
self._float_workspace_buffer = float_workspace_buffer
self._int_workspace_buffer = int_workspace_buffer

def plan(self, num_ctas: int = 0) -> None:
r"""Plan gemm for given num_ctas.
Parameters
----------
num_ctas: int
The number of CTAs to run gemm kernel. If equal to 0 or greater than
the number of CTAs on device, it will be set to the number of CTAs on device.
Note
----
The :meth:`plan` method should be called before any :meth:`run`.
The :meth:`plan` method cannot be used in Cuda Graph or in ``torch.compile``.
"""
if num_ctas < 0:
raise ValueError("Num_ctas should be greater than or equal to 0.")

self._plan_info = get_gemm_module().plan(
num_ctas,
)

def run(
self,
x: torch.Tensor,
Expand Down Expand Up @@ -629,6 +660,7 @@ def run(
y, # for torch compile mutates_args
empty_x_data, # for kernel type dispatch
weight_column_major,
self._plan_info,
)
case "sm80":
(
Expand Down Expand Up @@ -660,6 +692,7 @@ def run(
y,
empty_x_data,
weight_column_major,
self._plan_info,
)
case _:
raise ValueError(f"Unsupported gemm backend: {backend}")
Expand Down
1 change: 1 addition & 0 deletions flashinfer/jit/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,7 @@ def load_cuda_ops(
"--threads",
"4",
"-use_fast_math",
"-DFLASHINFER_ENABLE_F16",
"-DFLASHINFER_ENABLE_BF16",
"-DFLASHINFER_ENABLE_FP8_E4M3",
"-DFLASHINFER_ENABLE_FP8_E5M2",
Expand Down
10 changes: 5 additions & 5 deletions include/flashinfer/gemm/group_gemm.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ template <typename DType>
cudaError_t CutlassSegmentGEMMRun(void* workspace_buffer, size_t workspace_buffer_size_in_bytes,
void* all_problems, unsigned int batch_size, void* x, void* w,
void* y, void* x_ld, void* w_ld, void* y_ld,
bool weight_column_major, cudaStream_t stream) {
bool weight_column_major, int num_ctas, cudaStream_t stream) {
using cutlass::epilogue::thread::LinearCombination;
using cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle;
DISPATCH_WEIGHT_LAYOUT(weight_column_major, WEIGHT_LAYOUT, {
Expand Down Expand Up @@ -69,10 +69,10 @@ cudaError_t CutlassSegmentGEMMRun(void* workspace_buffer, size_t workspace_buffe
using GemmGrouped = cutlass::gemm::device::GemmGrouped<GemmKernel>;
typename GemmGrouped::Arguments args(
reinterpret_cast<cutlass::gemm::GemmCoord*>(all_problems), (int)batch_size,
/*threadblock_count=*/4, epilogue_op, static_cast<DType**>(x), static_cast<DType**>(w),
static_cast<DType**>(y), static_cast<DType**>(y), reinterpret_cast<int64_t*>(x_ld),
reinterpret_cast<int64_t*>(w_ld), reinterpret_cast<int64_t*>(y_ld),
reinterpret_cast<int64_t*>(y_ld));
/*threadblock_count=*/num_ctas, epilogue_op, static_cast<DType**>(x),
static_cast<DType**>(w), static_cast<DType**>(y), static_cast<DType**>(y),
reinterpret_cast<int64_t*>(x_ld), reinterpret_cast<int64_t*>(w_ld),
reinterpret_cast<int64_t*>(y_ld), reinterpret_cast<int64_t*>(y_ld));

GemmGrouped gemm;
auto status = gemm.initialize(args, nullptr, stream);
Expand Down
5 changes: 2 additions & 3 deletions include/flashinfer/gemm/group_gemm_sm90.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ cudaError_t CutlassSegmentGEMMSM90Run(void* float_buffer, size_t float_buffer_si
void* int_buffer, size_t int_buffer_size_in_bytes,
void* all_problems, unsigned int batch_size, void* x, void* w,
void* y, void* x_stride, void* w_stride, void* y_stride,
bool weight_column_major, cudaStream_t stream) {
bool weight_column_major, int num_ctas, cudaStream_t stream) {
auto compute_capacity = GetCudaComputeCapability();
if (compute_capacity.first < 9) {
std::cerr << "CutlassSegmentGEMMSM90Run requires compute capability of at least 9.0"
Expand Down Expand Up @@ -121,8 +121,7 @@ cudaError_t CutlassSegmentGEMMSM90Run(void* float_buffer, size_t float_buffer_si

cutlass::KernelHardwareInfo hw_info;
cudaGetDevice(&hw_info.device_id);
hw_info.sm_count =
cutlass::KernelHardwareInfo::query_device_multiprocessor_count(hw_info.device_id);
hw_info.sm_count = num_ctas;

typename Gemm::EpilogueOutputOp::Params params;
params =
Expand Down
64 changes: 64 additions & 0 deletions include/flashinfer/gemm/scheduler.cuh
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
/*
* Copyright (c) 2025 by FlashInfer team.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef FLASHINFER_GEMM_SCHEDULER_CUH_
#define FLASHINFER_GEMM_SCHEDULER_CUH_

#include <cuda_runtime_api.h>

#include <algorithm>
#include <cstddef>
#include <cstdint>
#include <sstream>
#include <vector>

#include "../utils.cuh"

namespace flashinfer {

struct GemmPlanInfo {
int64_t num_ctas;

GemmPlanInfo() : num_ctas(0) {}

// convert GemmPlanInfo to std::vector<int64_t>
std::vector<int64_t> ToVector() const { return {num_ctas}; }

// From std::vector<int64_t> to GemmPlanInfo
void FromVector(const std::vector<int64_t>& vec) {
if (vec.size() != 1) {
std::ostringstream err_msg;
err_msg << "GemmPlanInfo::FromVector: vec.size() should be 1, but got " << vec.size();
FLASHINFER_ERROR(err_msg.str());
}
num_ctas = vec[0];
}
};

inline cudaError_t GemmPlan(uint32_t num_ctas, GemmPlanInfo& plan_info) {
int dev_id = 0;
int num_sms = 0;
FLASHINFER_CUDA_CALL(cudaGetDevice(&dev_id));
FLASHINFER_CUDA_CALL(cudaDeviceGetAttribute(&num_sms, cudaDevAttrMultiProcessorCount, dev_id));
if (num_ctas > 0 && num_ctas < num_sms) {
plan_info.num_ctas = num_ctas;
} else {
plan_info.num_ctas = num_sms;
}
return cudaSuccess;
}

} // namespace flashinfer
#endif // FLASHINFER_GEMM_SCHEDULER_CUH_
3 changes: 3 additions & 0 deletions tests/test_group_gemm.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("device", CUDA_DEVICES)
@pytest.mark.parametrize("backend", ["auto", "sm90", "sm80"])
@pytest.mark.parametrize("num_ctas", [0, 4, 16, 64])
def test_segment_gemm(
batch_size,
num_rows_per_batch,
Expand All @@ -43,6 +44,7 @@ def test_segment_gemm(
dtype,
device,
backend,
num_ctas,
):
if batch_size * num_rows_per_batch > 8192:
pytest.skip("batch_size * num_rows_per_batch too large for test.")
Expand All @@ -64,6 +66,7 @@ def test_segment_gemm(
weight = torch.randn(batch_size, d_out, d_in, dtype=dtype).to(device)
else:
weight = torch.randn(batch_size, d_in, d_out, dtype=dtype).to(device)
segment_gemm.plan(num_ctas)
y = segment_gemm.run(
x,
weight,
Expand Down

0 comments on commit 841b423

Please sign in to comment.