Skip to content

Commit 51a0aae

Browse files
jwfrommfacebook-github-bot
authored andcommitted
F8I4 Grouped Gemm Optimization for Sparse M (pytorch#3854)
Summary: X-link: facebookresearch/FBGEMM#945 In cases where there are many groups, but few have a non-zero amount of routed tokens, it turns out we pay a high overhead. For example if a single token is routed to one of 128 experts, while the compute is the same as 1 token being routed to one expert the runtime is much lower. Presumably there are some kernel inefficiencies involved in looping over the empty groups. This diff changes how kernel arguments are set up so that we do grouped gemm over min(total_M, groups). This allows us to ignore many of the groups where no compute is required and improves performance in those cases considerably. As an example of the effect of this diff, when total_M is 1 and there are 128 groups, latency will be 3X smaller thanks to this change. Reviewed By: jiawenliu64 Differential Revision: D71510967
1 parent 8f1ee28 commit 51a0aae

File tree

2 files changed

+108
-46
lines changed

2 files changed

+108
-46
lines changed

fbgemm_gpu/experimental/gen_ai/bench/quantize_bench.py

+58-17
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,29 @@ def __init__(self, *args, **kwargs):
3232
from .quantize_ops import get_quantize_ops, QuantizeOpBase
3333

3434

35+
def generate_group_tensor(G, M):
36+
"""
37+
Generate a tensor with G elements whose integer elements sum to A.
38+
39+
Args:
40+
G (int): Number of elements in the tensor.
41+
M (int): Sum of the elements in the tensor.
42+
43+
Returns:
44+
torch.Tensor: A tensor with G elements whose integer elements sum to M.
45+
"""
46+
47+
# First, we generate a random tensor with G elements
48+
random_tensor = torch.rand(G)
49+
# Then, we normalize this tensor so it sums up to 1
50+
normalized_tensor = random_tensor / random_tensor.sum()
51+
# Finally, we multiply this tensor by M and round to the nearest integer
52+
output_tensor = torch.round(normalized_tensor * M).to(torch.int64)
53+
# Adjust the last element to ensure the sum is exactly M
54+
output_tensor[-1] += M - output_tensor.sum()
55+
return output_tensor.tolist()
56+
57+
3558
def set_amd_env_vars() -> None:
3659
print("Setting environment variables for AMD GPU performance")
3760
os.environ["DISABLE_ADDMM_HIP_LT"] = "0"
@@ -177,9 +200,10 @@ def benchmark_grouped(
177200
output = [o[: m[i]] for i, o in enumerate(output)]
178201
# Compare the quantize op output to reference as a sanity check.
179202
for i in range(num_groups):
180-
metrics.sim += float(
181-
torch.mean(torch.pow(output[i] - out_ref[i], 2)).item()
182-
)
203+
if m[i] > 0:
204+
metrics.sim += float(
205+
torch.mean(torch.pow(output[i] - out_ref[i], 2)).item()
206+
)
183207
for _ in range(num_iters):
184208
# Now perform benchmark.
185209
if bench_quantize:
@@ -205,17 +229,16 @@ def benchmark_grouped(
205229
metrics.tflops += (
206230
2 * b[i] * m[i] * n[i] * k[i] / (ms_runtime / 1e3) / 1e12
207231
)
208-
metrics.gbps += (
209-
(
210-
quantized_vals[0][i][: m[i]].numel()
211-
* quantized_vals[0][i][: m[i]].element_size()
212-
+ quantized_vals[1][i].numel()
213-
* quantized_vals[1][i].element_size()
214-
+ output[i].numel() * output[i].element_size()
232+
if m[i] > 0:
233+
metrics.gbps += (
234+
(
235+
b[i] * m[i] * k[i] * quantized_vals[0][0].element_size()
236+
+ b[i] * n[i] * k[i] * quantized_vals[1][0].element_size()
237+
+ b[i] * m[i] * n[i] * output[0].element_size()
238+
)
239+
/ (ms_runtime / 1e3)
240+
/ 1e9
215241
)
216-
/ (ms_runtime / 1e3)
217-
/ 1e9
218-
)
219242
metrics.ms += ms_runtime
220243
metrics.ms /= num_iters
221244
metrics.tflops /= num_iters
@@ -411,10 +434,22 @@ def main(args: Any):
411434
# When groups is provided transform shapes into grouped format.
412435
if args.groups:
413436
groups = int(args.groups)
414-
MNK = [
415-
[[b] * groups, [m] * groups, [n] * groups, [k] * groups]
416-
for b, m, n, k in MNK
417-
]
437+
if args.total_M:
438+
M = generate_group_tensor(groups, int(args.total_M))
439+
MNK = [
440+
[
441+
[b] * groups,
442+
generate_group_tensor(groups, int(args.total_M)),
443+
[n] * groups,
444+
[k] * groups,
445+
]
446+
for b, _, n, k in MNK
447+
]
448+
else:
449+
MNK = [
450+
[[b] * groups, [m] * groups, [n] * groups, [k] * groups]
451+
for b, m, n, k in MNK
452+
]
418453

419454
# Iterate over shapes and benchmark.
420455
benchmark_results = []
@@ -512,6 +547,12 @@ def invoke_main() -> None:
512547
default=None,
513548
help="If set with grouped mode, repeat input shapes this many times.",
514549
)
550+
parser.add_argument(
551+
"--total_M",
552+
default=None,
553+
help="If set, Adjusts the M values to sum to this number. "
554+
"This can help simulate real grouped workloads.",
555+
)
515556
parser.add_argument(
516557
"--no_cuda_graph",
517558
default=False,

fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/f8i4bf16_shuffled_grouped.cu

+50-29
Original file line numberDiff line numberDiff line change
@@ -73,36 +73,55 @@ __global__ void set_kernel_args(
7373
auto group_index = blockIdx.x * blockDim.x + threadIdx.x;
7474
// If this is a valid group, write kernel args to device.
7575
if (group_index < G) {
76-
// First get the M value for this group.
76+
// Since we are only writing a subset of the groups to kernel args,
77+
// we need to start by initializing a counter and setting other groups
78+
// to empty problems.
79+
__shared__ int non_zero_counter;
80+
// Initialize counter and problem memory for this group.
81+
if (group_index == 0) {
82+
non_zero_counter = 0;
83+
}
84+
// We set the problem shapes to empty by default to skip over
85+
// these groups.
86+
problem_shape_ptr[group_index] = ProblemShape(0, 0, 0);
87+
// Sync threads to make sure state is shared across the block.
88+
__syncthreads();
89+
90+
// Now check if this is a non-zero group.
7791
int M = M_sizes[group_index];
78-
// Compute offset into tensor where this group begins.
79-
int offset_M = 0;
80-
// Compute cumulative sum of prior groups to find offset.
81-
for (int i = 0; i < group_index; i++) {
82-
offset_M += M_sizes[i];
92+
// Only proceed if so.
93+
if (M > 0) {
94+
// Get the non-zero index for this group atomically.
95+
int non_zero_idx = atomicAdd(&non_zero_counter, 1);
96+
// Compute offset into tensor where this group begins.
97+
int offset_M = 0;
98+
// Compute cumulative sum of prior groups to find offset.
99+
for (int i = 0; i < group_index; i++) {
100+
offset_M += M_sizes[i];
101+
}
102+
// Set the problem shape for this group.
103+
problem_shape_ptr[non_zero_idx] = ProblemShape(N, M, K);
104+
// Set pointer to xq.
105+
xq_ptr[non_zero_idx] = xq + (offset_M * K);
106+
// Set pointer to wq, dividing by two as wq is packed into bytes.
107+
wq_ptr[non_zero_idx] = wq + (group_index * N * K / 2);
108+
// Set scale pointers.
109+
x_scale_ptr[non_zero_idx] = x_scale + offset_M;
110+
w_scale_ptr[non_zero_idx] = w_scale + (group_index * N);
111+
w_scale_group_ptr[non_zero_idx] =
112+
w_scale_group + (group_index * N * num_scale_groups);
113+
// Set output pointer.
114+
output_ptr[non_zero_idx] = output + (offset_M * N);
115+
// Set stride pointers.
116+
stride_a_ptr[non_zero_idx] = cutlass::make_cute_packed_stride(
117+
StrideA{}, cute::make_shape(M, K, 1));
118+
stride_b_ptr[non_zero_idx] = cute::tile_to_shape(
119+
LayoutAtomQuant{}, cute::make_shape(N, K, cute::Int<1>{}));
120+
stride_c_ptr[non_zero_idx] = cutlass::make_cute_packed_stride(
121+
StrideC{}, cute::make_shape(N, M, 1));
122+
stride_s_ptr[non_zero_idx] = cutlass::make_cute_packed_stride(
123+
StrideS{}, cute::make_shape(N, num_scale_groups, 1));
83124
}
84-
// Set the problem shape for this group.
85-
problem_shape_ptr[group_index] = ProblemShape(N, M, K);
86-
// Set pointer to xq.
87-
xq_ptr[group_index] = xq + (offset_M * K);
88-
// Set pointer to wq, dividing by two as wq is packed into bytes.
89-
wq_ptr[group_index] = wq + (group_index * N * K / 2);
90-
// Set scale pointers.
91-
x_scale_ptr[group_index] = x_scale + offset_M;
92-
w_scale_ptr[group_index] = w_scale + (group_index * N);
93-
w_scale_group_ptr[group_index] =
94-
w_scale_group + (group_index * N * num_scale_groups);
95-
// Set output pointer.
96-
output_ptr[group_index] = output + (offset_M * N);
97-
// Set stride pointers.
98-
stride_a_ptr[group_index] =
99-
cutlass::make_cute_packed_stride(StrideA{}, cute::make_shape(M, K, 1));
100-
stride_b_ptr[group_index] = cute::tile_to_shape(
101-
LayoutAtomQuant{}, cute::make_shape(N, K, cute::Int<1>{}));
102-
stride_c_ptr[group_index] =
103-
cutlass::make_cute_packed_stride(StrideC{}, cute::make_shape(N, M, 1));
104-
stride_s_ptr[group_index] = cutlass::make_cute_packed_stride(
105-
StrideS{}, cute::make_shape(N, num_scale_groups, 1));
106125
}
107126
}
108127

@@ -118,6 +137,8 @@ void _f8i4bf16_shuffled_grouped(
118137
// Get basic shape information.
119138
int G = M_sizes.size(0);
120139
// XQ is shape [total_M, K]
140+
int total_M = XQ.size(0);
141+
int kernel_groups = std::min(G, total_M);
121142
int K = XQ.size(-1);
122143
// WQ is shape [G, N, K/2]
123144
int N = WQ.size(1);
@@ -394,7 +415,7 @@ void _f8i4bf16_shuffled_grouped(
394415
// Define GEMM arguments.
395416
typename GemmShuffled::Arguments arguments{
396417
cutlass::gemm::GemmUniversalMode::kGrouped,
397-
{G, problem_shape_ptr, nullptr},
418+
{kernel_groups, problem_shape_ptr, nullptr},
398419
{wq_ptr,
399420
stride_b_ptr,
400421
xq_ptr,

0 commit comments

Comments
 (0)