-
Notifications
You must be signed in to change notification settings - Fork 207
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
feat: implement SM-Constrained GEMM API #744
base: main
Are you sure you want to change the base?
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hi @lanchongyizu , great work! Can you also support this for SM80 template?
For SM80 API, we might support setting a 3D tuple of num_ctas (on m, n, k dimension, correspondingly).
@@ -33,6 +33,7 @@ | |||
@pytest.mark.parametrize("dtype", DTYPES) | |||
@pytest.mark.parametrize("device", CUDA_DEVICES) | |||
@pytest.mark.parametrize("backend", ["auto", "sm90", "sm80"]) | |||
@pytest.mark.parametrize("num_ctas", [0, 4, 16, 64]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What's the expected behavior of num_ctas=0
?
As requested in flashinfer-ai#591, this PR implements the `plan` function of GEMM with `num_ctas` as an argument to specify the grid size.
841b423
to
eac553b
Compare
@@ -121,8 +121,7 @@ cudaError_t CutlassSegmentGEMMSM90Run(void* float_buffer, size_t float_buffer_si | |||
|
|||
cutlass::KernelHardwareInfo hw_info; | |||
cudaGetDevice(&hw_info.device_id); | |||
hw_info.sm_count = | |||
cutlass::KernelHardwareInfo::query_device_multiprocessor_count(hw_info.device_id); | |||
hw_info.sm_count = num_ctas; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I tried nsys profiler and it turns out this value can't control the number of SMs this kernel used.
A more fundamental approach might be using green context.
As requested in #591, this PR implements the
plan
function of GEMMwith
num_ctas
as an argument to specify the grid size.@yzh119