Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: implement SM-Constrained GEMM API #744

Open
wants to merge 1 commit into
base: main
Choose a base branch
from

Conversation

lanchongyizu
Copy link

As requested in #591, this PR implements the plan function of GEMM
with num_ctas as an argument to specify the grid size.

@yzh119

Copy link
Collaborator

@yzh119 yzh119 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @lanchongyizu , great work! Can you also support this for SM80 template?

Reference: https://github.com/efeslab/Nanoflow/blob/22f0b48739d3a9ad1d8c82f956906b3bc58d519b/pipeline/include/cutlassGemmWrapperImpl.cuh#L92

For SM80 API, we might support setting a 3D tuple of num_ctas (on m, n, k dimension, correspondingly).

@@ -33,6 +33,7 @@
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("device", CUDA_DEVICES)
@pytest.mark.parametrize("backend", ["auto", "sm90", "sm80"])
@pytest.mark.parametrize("num_ctas", [0, 4, 16, 64])
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What's the expected behavior of num_ctas=0?

As requested in flashinfer-ai#591, this PR implements the `plan` function of GEMM
 with `num_ctas` as an argument to specify the grid size.
@yzh119 yzh119 force-pushed the sm_constrained_gemm branch from 841b423 to eac553b Compare January 24, 2025 02:28
@@ -121,8 +121,7 @@ cudaError_t CutlassSegmentGEMMSM90Run(void* float_buffer, size_t float_buffer_si

cutlass::KernelHardwareInfo hw_info;
cudaGetDevice(&hw_info.device_id);
hw_info.sm_count =
cutlass::KernelHardwareInfo::query_device_multiprocessor_count(hw_info.device_id);
hw_info.sm_count = num_ctas;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I tried nsys profiler and it turns out this value can't control the number of SMs this kernel used.
A more fundamental approach might be using green context.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants