Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

F8I4 Grouped Gemm Optimization for Sparse M #3854

Closed
wants to merge 1 commit into from
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
75 changes: 58 additions & 17 deletions fbgemm_gpu/experimental/gen_ai/bench/quantize_bench.py
Original file line number Diff line number Diff line change
@@ -32,6 +32,29 @@ def __init__(self, *args, **kwargs):
from .quantize_ops import get_quantize_ops, QuantizeOpBase


def generate_group_tensor(G, M):
"""
Generate a tensor with G elements whose integer elements sum to A.

Args:
G (int): Number of elements in the tensor.
M (int): Sum of the elements in the tensor.

Returns:
torch.Tensor: A tensor with G elements whose integer elements sum to M.
"""

# First, we generate a random tensor with G elements
random_tensor = torch.rand(G)
# Then, we normalize this tensor so it sums up to 1
normalized_tensor = random_tensor / random_tensor.sum()
# Finally, we multiply this tensor by M and round to the nearest integer
output_tensor = torch.round(normalized_tensor * M).to(torch.int64)
# Adjust the last element to ensure the sum is exactly M
output_tensor[-1] += M - output_tensor.sum()
return output_tensor.tolist()


def set_amd_env_vars() -> None:
print("Setting environment variables for AMD GPU performance")
os.environ["DISABLE_ADDMM_HIP_LT"] = "0"
@@ -177,9 +200,10 @@ def benchmark_grouped(
output = [o[: m[i]] for i, o in enumerate(output)]
# Compare the quantize op output to reference as a sanity check.
for i in range(num_groups):
metrics.sim += float(
torch.mean(torch.pow(output[i] - out_ref[i], 2)).item()
)
if m[i] > 0:
metrics.sim += float(
torch.mean(torch.pow(output[i] - out_ref[i], 2)).item()
)
for _ in range(num_iters):
# Now perform benchmark.
if bench_quantize:
@@ -205,17 +229,16 @@ def benchmark_grouped(
metrics.tflops += (
2 * b[i] * m[i] * n[i] * k[i] / (ms_runtime / 1e3) / 1e12
)
metrics.gbps += (
(
quantized_vals[0][i][: m[i]].numel()
* quantized_vals[0][i][: m[i]].element_size()
+ quantized_vals[1][i].numel()
* quantized_vals[1][i].element_size()
+ output[i].numel() * output[i].element_size()
if m[i] > 0:
metrics.gbps += (
(
b[i] * m[i] * k[i] * quantized_vals[0][0].element_size()
+ b[i] * n[i] * k[i] * quantized_vals[1][0].element_size()
+ b[i] * m[i] * n[i] * output[0].element_size()
)
/ (ms_runtime / 1e3)
/ 1e9
)
/ (ms_runtime / 1e3)
/ 1e9
)
metrics.ms += ms_runtime
metrics.ms /= num_iters
metrics.tflops /= num_iters
@@ -411,10 +434,22 @@ def main(args: Any):
# When groups is provided transform shapes into grouped format.
if args.groups:
groups = int(args.groups)
MNK = [
[[b] * groups, [m] * groups, [n] * groups, [k] * groups]
for b, m, n, k in MNK
]
if args.total_M:
M = generate_group_tensor(groups, int(args.total_M))
MNK = [
[
[b] * groups,
generate_group_tensor(groups, int(args.total_M)),
[n] * groups,
[k] * groups,
]
for b, _, n, k in MNK
]
else:
MNK = [
[[b] * groups, [m] * groups, [n] * groups, [k] * groups]
for b, m, n, k in MNK
]

# Iterate over shapes and benchmark.
benchmark_results = []
@@ -512,6 +547,12 @@ def invoke_main() -> None:
default=None,
help="If set with grouped mode, repeat input shapes this many times.",
)
parser.add_argument(
"--total_M",
default=None,
help="If set, Adjusts the M values to sum to this number. "
"This can help simulate real grouped workloads.",
)
parser.add_argument(
"--no_cuda_graph",
default=False,
Original file line number Diff line number Diff line change
@@ -73,36 +73,55 @@ __global__ void set_kernel_args(
auto group_index = blockIdx.x * blockDim.x + threadIdx.x;
// If this is a valid group, write kernel args to device.
if (group_index < G) {
// First get the M value for this group.
// Since we are only writing a subset of the groups to kernel args,
// we need to start by initializing a counter and setting other groups
// to empty problems.
__shared__ int non_zero_counter;
// Initialize counter and problem memory for this group.
if (group_index == 0) {
non_zero_counter = 0;
}
// We set the problem shapes to empty by default to skip over
// these groups.
problem_shape_ptr[group_index] = ProblemShape(0, 0, 0);
// Sync threads to make sure state is shared across the block.
__syncthreads();

// Now check if this is a non-zero group.
int M = M_sizes[group_index];
// Compute offset into tensor where this group begins.
int offset_M = 0;
// Compute cumulative sum of prior groups to find offset.
for (int i = 0; i < group_index; i++) {
offset_M += M_sizes[i];
// Only proceed if so.
if (M > 0) {
// Get the non-zero index for this group atomically.
int non_zero_idx = atomicAdd(&non_zero_counter, 1);
// Compute offset into tensor where this group begins.
int offset_M = 0;
// Compute cumulative sum of prior groups to find offset.
for (int i = 0; i < group_index; i++) {
offset_M += M_sizes[i];
}
// Set the problem shape for this group.
problem_shape_ptr[non_zero_idx] = ProblemShape(N, M, K);
// Set pointer to xq.
xq_ptr[non_zero_idx] = xq + (offset_M * K);
// Set pointer to wq, dividing by two as wq is packed into bytes.
wq_ptr[non_zero_idx] = wq + (group_index * N * K / 2);
// Set scale pointers.
x_scale_ptr[non_zero_idx] = x_scale + offset_M;
w_scale_ptr[non_zero_idx] = w_scale + (group_index * N);
w_scale_group_ptr[non_zero_idx] =
w_scale_group + (group_index * N * num_scale_groups);
// Set output pointer.
output_ptr[non_zero_idx] = output + (offset_M * N);
// Set stride pointers.
stride_a_ptr[non_zero_idx] = cutlass::make_cute_packed_stride(
StrideA{}, cute::make_shape(M, K, 1));
stride_b_ptr[non_zero_idx] = cute::tile_to_shape(
LayoutAtomQuant{}, cute::make_shape(N, K, cute::Int<1>{}));
stride_c_ptr[non_zero_idx] = cutlass::make_cute_packed_stride(
StrideC{}, cute::make_shape(N, M, 1));
stride_s_ptr[non_zero_idx] = cutlass::make_cute_packed_stride(
StrideS{}, cute::make_shape(N, num_scale_groups, 1));
}
// Set the problem shape for this group.
problem_shape_ptr[group_index] = ProblemShape(N, M, K);
// Set pointer to xq.
xq_ptr[group_index] = xq + (offset_M * K);
// Set pointer to wq, dividing by two as wq is packed into bytes.
wq_ptr[group_index] = wq + (group_index * N * K / 2);
// Set scale pointers.
x_scale_ptr[group_index] = x_scale + offset_M;
w_scale_ptr[group_index] = w_scale + (group_index * N);
w_scale_group_ptr[group_index] =
w_scale_group + (group_index * N * num_scale_groups);
// Set output pointer.
output_ptr[group_index] = output + (offset_M * N);
// Set stride pointers.
stride_a_ptr[group_index] =
cutlass::make_cute_packed_stride(StrideA{}, cute::make_shape(M, K, 1));
stride_b_ptr[group_index] = cute::tile_to_shape(
LayoutAtomQuant{}, cute::make_shape(N, K, cute::Int<1>{}));
stride_c_ptr[group_index] =
cutlass::make_cute_packed_stride(StrideC{}, cute::make_shape(N, M, 1));
stride_s_ptr[group_index] = cutlass::make_cute_packed_stride(
StrideS{}, cute::make_shape(N, num_scale_groups, 1));
}
}

@@ -118,6 +137,8 @@ void _f8i4bf16_shuffled_grouped(
// Get basic shape information.
int G = M_sizes.size(0);
// XQ is shape [total_M, K]
int total_M = XQ.size(0);
int kernel_groups = std::min(G, total_M);
int K = XQ.size(-1);
// WQ is shape [G, N, K/2]
int N = WQ.size(1);
@@ -394,7 +415,7 @@ void _f8i4bf16_shuffled_grouped(
// Define GEMM arguments.
typename GemmShuffled::Arguments arguments{
cutlass::gemm::GemmUniversalMode::kGrouped,
{G, problem_shape_ptr, nullptr},
{kernel_groups, problem_shape_ptr, nullptr},
{wq_ptr,
stride_b_ptr,
xq_ptr,