Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Numerical stability issue in recent commits since 0.2.0 #805

Open
rchardx opened this issue Feb 11, 2025 · 5 comments
Open

Numerical stability issue in recent commits since 0.2.0 #805

rchardx opened this issue Feb 11, 2025 · 5 comments
Labels
bug Something isn't working priority: high

Comments

@rchardx
Copy link

rchardx commented Feb 11, 2025

Environment: CUDA 12.6, Hopper architecture.

Recent commits have significantly impacted the numerical stability of Attention. This can be observed in the logs, where different commits show considerable differences in their results when compared to the float version reference implementation.

One concern I have is that we're observing an increasing trend in these diffs, which might indicate potential underlying issues.
Another issue is that FA3 template produces NaNs in the results after prefilling.

We kindly request developers to pay attention to this aspect during future updates.

main commit: 054778

W20250211 04:30:59.612421 33250 test_flashinfer_prefill.cu:338] batch_size: 6, num_qo_heads: 8, num_kv_heads: 1, head_dim: 128, num_mismatches: 80
W20250211 04:30:59.612534 33250 test_flashinfer_prefill.cu:343] diff: 1.281738e-03, idx: 4052, o_host: -6.591797e-03, o_ref: -7.873535e-03
W20250211 04:30:59.612591 33250 test_flashinfer_prefill.cu:343] diff: 1.220703e-03, idx: 4064, o_host: -1.226807e-02, o_ref: -1.348877e-02
W20250211 04:30:59.612596 33250 test_flashinfer_prefill.cu:343] diff: 1.220703e-03, idx: 1636, o_host: 4.223633e-02, o_ref: 4.345703e-02
W20250211 04:30:59.612600 33250 test_flashinfer_prefill.cu:343] diff: 1.220703e-03, idx: 1606, o_host: 2.429199e-02, o_ref: 2.551270e-02
W20250211 04:30:59.612604 33250 test_flashinfer_prefill.cu:343] diff: 1.220703e-03, idx: 1553, o_host: 3.198242e-02, o_ref: 3.320312e-02
....../tests/kernel/test_flashinfer_prefill.cu:348: Failure
Expected equality of these values:
  num_mismatches
    Which is: 80
  0

main commit: 956910

W20250211 04:20:27.498323 31762 test_flashinfer_prefill.cu:338] batch_size: 6, num_qo_heads: 8, num_kv_heads: 1, head_dim: 128, num_mismatches: 297
W20250211 04:20:27.498430 31762 test_flashinfer_prefill.cu:343] diff: 2.929688e-03, idx: 5820, o_host: -1.513672e-01, o_ref: -1.484375e-01
W20250211 04:20:27.498482 31762 test_flashinfer_prefill.cu:343] diff: 2.441406e-03, idx: 5840, o_host: 9.277344e-02, o_ref: 9.033203e-02
W20250211 04:20:27.498489 31762 test_flashinfer_prefill.cu:343] diff: 2.258301e-03, idx: 5766, o_host: 1.594543e-03, o_ref: -6.637573e-04
W20250211 04:20:27.498495 31762 test_flashinfer_prefill.cu:343] diff: 1.953125e-03, idx: 5847, o_host: 7.080078e-02, o_ref: 6.884766e-02
W20250211 04:20:27.498500 31762 test_flashinfer_prefill.cu:343] diff: 1.953125e-03, idx: 2983, o_host: -8.007812e-02, o_ref: -8.203125e-02
....../tests/kernel/test_flashinfer_prefill.cu:348: Failure
Expected equality of these values:
  num_mismatches
    Which is: 297
  0

main commit:9f5fbe

W20250211 05:20:10.827390 38669 test_flashinfer_prefill.cu:460] batch_size: 4, num_qo_heads: 28, num_kv_heads: 4, head_dim: 128, num_mismatches: 25
W20250211 05:20:10.827448 38669 test_flashinfer_prefill.cu:465] diff: 0.000793457, idx: 256058
W20250211 05:20:10.827464 38669 test_flashinfer_prefill.cu:465] diff: 0.000793457, idx: 253146
W20250211 05:20:10.827469 38669 test_flashinfer_prefill.cu:465] diff: 0.000793457, idx: 252594
W20250211 05:20:10.827476 38669 test_flashinfer_prefill.cu:465] diff: 0.000747681, idx: 219744
W20250211 05:20:10.827483 38669 test_flashinfer_prefill.cu:465] diff: 0.000717163, idx: 217179
....../tests/kernel/test_flashinfer_prefill.cu:468: Failure
Expected equality of these values:
  num_mismatches
    Which is: 25
  0
@yzh119 yzh119 added priority: high bug Something isn't working labels Feb 11, 2025
@yzh119
Copy link
Collaborator

yzh119 commented Feb 11, 2025

Hi @rchardx thanks for raising this up.

From 956910 to 054778, I believe #801 changes the numerical stability (and it indeed increase the numerical stability).
From 9f5fbe to 956910, seems something weird happened in this period that degrades the numerical stability, we should look into this.

Another issue is that FA3 template produces NaNs in the results after prefilling.

I remembered that you mentioned filling kv-cache will all zeros will resolve the issue, which indicates we might have loaded V (through SM80 TiledCopy) without filling oob values with all zeros. I'll take a look.

@yzh119
Copy link
Collaborator

yzh119 commented Feb 11, 2025

The long-term solution is to create regression test for kernel correctness.

@rchardx
Copy link
Author

rchardx commented Feb 12, 2025

Hi @rchardx thanks for raising this up.

From 956910 to 9f5fbe , I believe #801 changes the numerical stability (and it indeed increase the numerical stability). From 9f5fbe to 956910, seems something weird happened in this period that degrades the numerical stability, we should look into this.

Another issue is that FA3 template produces NaNs in the results after prefilling.

I remembered that you mentioned filling kv-cache will all zeros will resolve the issue, which indicates we might have loaded V (through SM80 TiledCopy) without filling oob values with all zeros. I'll take a look.

Firstly we thank FlashInfer team for the great works all along.

Filling kv-cache with all zeros will somehow mitigate the issue for the first few request, but the results will still contain NaNs for later requests.
This indeed requires a full investigation.

@yzh119
Copy link
Collaborator

yzh119 commented Feb 13, 2025

Regarding the kernel correctness, would you mind sharing me the testcase test_flashinfer_prefill.cu? I can't reproduce it with the existing python/c++ unittests.

For the fa3 nan issue, can we schedule a meeting for this?

@rchardx
Copy link
Author

rchardx commented Feb 13, 2025

Regarding the kernel correctness, would you mind sharing me the testcase test_flashinfer_prefill.cu? I can't reproduce it with the existing python/c++ unittests.

Sure.
rchardx@af31796

❯ ./test_batch_prefill --gtest_filter=\*BatchPagedPrefillKernelCorrectnessTestOneHotBF16\*
Running main() from /cpfs/2926428ee2463e44/user/huli/flashinfer/3rdparty/googletest/googletest/src/gtest_main.cc
Note: Google Test filter = *BatchPagedPrefillKernelCorrectnessTestOneHotBF16*
[==========] Running 1 test from 1 test suite.
[----------] Global test environment set-up.
[----------] 1 test from FlashInferCorrectnessTest
[ RUN      ] FlashInferCorrectnessTest.BatchPagedPrefillKernelCorrectnessTestOneHotBF16
request_idx=0, page_size=16, num_qo_heads=8, num_kv_heads=1, q_len=21, kv_len=21, head_dim=128, causal=1, pos_encoding_mode=None, result_accuracy=0.852446
/cpfs/2926428ee2463e44/user/huli/flashinfer/src/test_batch_prefill.cu:148: Failure
Expected: (result_accuracy) > (0.99), actual: 0.852446079 vs 0.99
Result correctness test failed.

request_idx=1, page_size=16, num_qo_heads=8, num_kv_heads=1, q_len=20, kv_len=1024, head_dim=128, causal=1, pos_encoding_mode=None, result_accuracy=0.998096
request_idx=2, page_size=16, num_qo_heads=8, num_kv_heads=1, q_len=40, kv_len=8072, head_dim=128, causal=1, pos_encoding_mode=None, result_accuracy=1
request_idx=3, page_size=16, num_qo_heads=8, num_kv_heads=1, q_len=4, kv_len=30, head_dim=128, causal=1, pos_encoding_mode=None, result_accuracy=0.86499
/cpfs/2926428ee2463e44/user/huli/flashinfer/src/test_batch_prefill.cu:148: Failure
Expected: (result_accuracy) > (0.99), actual: 0.864990234 vs 0.99
Result correctness test failed.

request_idx=4, page_size=16, num_qo_heads=8, num_kv_heads=1, q_len=8, kv_len=27, head_dim=128, causal=1, pos_encoding_mode=None, result_accuracy=0.851318
/cpfs/2926428ee2463e44/user/huli/flashinfer/src/test_batch_prefill.cu:148: Failure
Expected: (result_accuracy) > (0.99), actual: 0.851318359 vs 0.99
Result correctness test failed.

request_idx=5, page_size=16, num_qo_heads=8, num_kv_heads=1, q_len=99, kv_len=999, head_dim=128, causal=1, pos_encoding_mode=None, result_accuracy=0.997781
request_idx=6, page_size=16, num_qo_heads=8, num_kv_heads=1, q_len=0, kv_len=0, head_dim=128, causal=1, pos_encoding_mode=None, result_accuracy=1
request_idx=7, page_size=16, num_qo_heads=8, num_kv_heads=1, q_len=0, kv_len=0, head_dim=128, causal=1, pos_encoding_mode=None, result_accuracy=1
request_idx=8, page_size=16, num_qo_heads=8, num_kv_heads=1, q_len=0, kv_len=0, head_dim=128, causal=1, pos_encoding_mode=None, result_accuracy=1
[  FAILED  ] FlashInferCorrectnessTest.BatchPagedPrefillKernelCorrectnessTestOneHotBF16 (4716 ms)
[----------] 1 test from FlashInferCorrectnessTest (4716 ms total)

[----------] Global test environment tear-down
[==========] 1 test from 1 test suite ran. (4716 ms total)
[  PASSED  ] 0 tests.
[  FAILED  ] 1 test, listed below:
[  FAILED  ] FlashInferCorrectnessTest.BatchPagedPrefillKernelCorrectnessTestOneHotBF16

 1 FAILED TEST

For the fa3 nan issue, can we schedule a meeting for this?

Yes. I'm available on weekdays from UTC 2:00 to 13:00.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working priority: high
Projects
None yet
Development

No branches or pull requests

2 participants