-
Notifications
You must be signed in to change notification settings - Fork 1.1k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[BUG] Mixed Precision Gemm Correctness Regression in Cutlass 3.7/3.8 #2070
Comments
@jwfromm Thanks for submitting this bug. Could you first try reverting every change in the PR, except for the cutlass 3.7 and mixed precision api update (i.e., remove I will try to clone and build the repo and reproduce the issue. If there's any tips to speed up the build, I'd greatly appreciate it. Thanks. |
@jwfromm I am able to reproduce the issue. Looks like when the data type of scale/zero is not the same as the activation (fp16 and bf16 in your original code base), the result will always be incorrect. I tried checking out cutlass 3.7, removing the
Plz see this commit for more info. The scales/zeros data type in your PR are also the same as activation but the results are incorrect so I guess there's some other bugs introduced by the other changes. Thanks again for submitting the bug. I will fix this edge case asap. |
@jwfromm I have located the source for the bug. To fix the issue, in the The fix will be included in the cutlass 3.8 tag |
Describe the bug
Since Cutlass 3.7, mixed input dtype GEMMs are producing less accurate outputs than they were in Cutlass 3.6. The loss of accuracy is substantial and makes using mixed input impractical for real use-cases.
Specifically, we have a collection of mixed input GEMMs in FBGEMM that work well on Cutlass 3.6. While these kernels compile fine with newer versions of cutlass (after small api updates), they produce garbage outputs.
Directly copying example 55's BF16 x I4 Gemm example produces slightly better results, but the outputs are still much less accurate than the 3.6 equivalents.
Steps/Code to reproduce bug
We use this benchmarking script to measure the performance and accuracy of kernels. The script can be run with these sample arguments:
This will produce an output like this:
The sim metric is an L1 distance from the BF16 output. After updating to cutlass 3.7, copying example 55, and running the same script we get:
Which has a clearly less correct output. The updated version of the kernel can be found at this PR
Expected behavior
The accuracy of mixed input kernels should not have changed due to updates.
Environment details (please complete the following information):
cuda 12.4 driver version 535.154.05 on Linux system with 8X H100 GPUs.
The text was updated successfully, but these errors were encountered: