Skip to content

📚[WIP] FFPA: Yet antother Faster Flash Prefill Attention with O(1)⚡️GPU SRAM complexity for headdim > 256, 1.8x~3x↑🎉faster vs SDPA EA.

License

Notifications You must be signed in to change notification settings

DefTruth/ffpa-attn-mma

Repository files navigation

🤖FFPA: Yet antother Faster Flash Prefill Attention with O(1)⚡️GPU SRAM complexity for large headdim🐑

📚FFPA L1~L3 Design | 📈L20 ~1.9x↑🎉 | 📈A30 ~1.8x↑🎉 | 📈3080 ~2.9x↑🎉 | 📈4090 ~2.1x↑🎉

🤖FFPA: 1.8x~3x🎉faster vs SDPA EA with or without MMA Acc F32

🤖[WIP] FFPA: Yet antother Faster Flash Prefill Attention with O(1) SRAM complexity & O(d/4) or O(1) register complexity for large headdim (D > 256), almost 1.8x~3x 🎉 faster than SDPA EA with or without MMA Acc F32 on many devices: 📈L20 ~1.9x↑🎉, 📈A30 ~1.8x↑🎉, 📈3080 ~2.9x↑🎉, 📈4090 ~2.1x↑🎉. I have implemented not only FFPA Attention Algo: Fine-grained tiling for large headim but also FA-2 Attention Algo: Coarse-grained tiling for small headidm in this repo.

💡NOTE: This project is still in its early dev stages and now provides some kernels and benchmarks for reference. More features will be added in the future. (Welcome to 🌟👆🏻star this repo to support me ~)

©️Citations🎉🎉

@misc{ffpa-attn-mma@2025,
  title={FFPA: Yet another Faster Flash Prefill Attention for large headdim.},
  url={https://github.com/DefTruth/ffpa-attn-mma.git},
  note={Open-source software available at https://github.com/DefTruth/ffpa-attn-mma.git},
  author={DefTruth etc},
  year={2025}
}

📖 Contents

📖 FFPA L1~L3: FlashAttention + QKV Fine-grained Tiling at MMA level💡

We have extended FlashAttention for large headdim (D > 256) by implementing Fine-grained Tiling at the MMA level (GEMM style) for the Q@K^T and P@V matmul. This approach results in a constant SRAM usage of Br * 16 or Bc * 16 (Br = Bc) for Q, K, and V, leading to an overall SRAM complexity of O(2 * Br * 16) ≈ O(1) and a register complexity of O(d/4) or O(1). Consequently, this method allows us to extend headdim beyond 256 and achieve faster performance compared to SDPA with or without MMA Accumulation F32 (1.8x~3x 🎉 faster than SDPA EA).

We have named this new attention tiling technique FFPA: Faster Flash Prefill Attention. We have designed three (L1~L3) levels of FFPA based on SRAM and register complexity considerations. All levels will not introduce any additional VRAM requirements, ensuring that the HBM memory complexity remains same as FlashAttention. 👇

  • 📚L1: level 1, O(2xBrx16)≈O(1) SRAM complexity, ≈O(d/4) register complexity.
  • 📚L2: level 2, O(2xBrx16)≈O(1) SRAM complexity, ≈O(1) register complexity + Q@K^T recomputation.
  • 📚L3: level 3, O(2xBrx16)≈O(1) SRAM complexity, ≈O(1) register complexity + scaling O via HBM offloading.

By leveraging this approach, we can achieve better performance for large headdim (D > 256) through a balanced utilization of FlashAttention (which is not designed to support D > 256) and SDPA EA. Approximate SRAM and register complexity analysis for L1~L3 is as follows: (d=headdim, C,Br,Bc=Constant, Br=Bc) 👇

📚Complexity 📚FFPA L1 📚FFPA L2 📚FFPA L3 📚FA-2
SRAM O(2xBrx16)≈O(1) O(2xBrx16)≈O(1) O(2xBrx16)≈O(1) ≈O(3xBrxd), d↑
Register ≈O(d/4), d↑ O((Bc/16)x4+2C)≈O(1) O((Bc/16)x4+2C)≈O(1) ≈O(d/2), d↑
HBM ≈FA2≈O(Nd), O ≈FA2≈O(Nd), O ≈FA2≈O(Nd), O ≈O(Nd), O
Extra HBM ≈FA2≈O(N), m,l ≈FA2≈O(N), m,l ≈FA2≈O(N), m,l ≈O(N), m,l

📚👇Core Features🎉🎉: I have implemented FFPA L1~L3 using pure MMA PTX instructions, which supports many features such as Split-Q, SMEM Swizzle/Padding, QKV Multi-Stages(1~4), Tile MMAs/Warps, Mixed MMA F32/F16 Acc (Q@K^T MMA Acc F32 + P@V MMA Acc F16), Fully Shared QKV SMEM, Prefetch QKV g2s, Persist Q s2r/g2s, Fully QKV Fine-grained Tiling(GEMM style), Collective Store, etc.

📚Feature 📚Feature 📚Feature 📚Feature
✔️Tensor Cores ✔️Loop over N/D ✔️Tile Block(Br, Bc) ✔️MMA(m16n8k16)
✔️Split Q(FA-2) ✔️Pack LDST(128 bits) ✔️SMEM Swizzle/Pad ✔️Copy Async
✔️Tile MMA/Warp ✔️QKV Multi-Stages(1~4) ✔️Collective Store(Shfl) ✔️Prefetch QKV g2s
✔️QKV Fine-grained Tiling ✔️Shared QKV SMEM ✔️Mixed MMA Acc ✔️Persist Q s2r/g2s
template<
  const int kHeadDim,              // Headdim, 32~1024     
  const int kMmaAtomM,             // MMA Atom M, 16
  const int kMmaAtomN,             // MMA Atom N, 8
  const int kMmaAtomK,             // MMA Atom K, 16
  const int kMmaTileSeqLenQ,       // 4, more MMA(warp), M=16*4=64, Q@K^T=[Br(M), d(K)]@[d(K),  Bc(N)]  
  const int kMmaTileSeqLenK,       // 1, more MMA(warp), N=8*1 =8,  Q@K^T=[Br(M), d(K)]@[d(K),  Bc(N)]    
  const int kMmaTileSeqLenP,       // 4, more MMA(warp), M=16*4=64, P@V  =[Br(M),Bc(K)]@[Bc(K), d(N) ]
  const int kMmaTileHeadDimV,      // 1, more MMA(warp), N=8*1 =8,  P@V  =[Br(M),Bc(K)]@[Bc(K), d(N) ]       
  const int kWarpTileSeqLenQ,      // 1, more values, M, Br=64*1=64, matmul M 
  const int kWarpTileSeqLenK,      // 8, more values, N, Bc=8*8 =64, matmul N
  const int kWarpTileSeqLenP,      // 1, more values, M, Br=64*1=64, matmul M
  const int kWarpTileHeadDimV,     // 8, more values, N, d=8*(1|2|3|4|...)=8|...|32|64|96|128|...
  const int kMmaAccFloat32QK,      // 0/1, Q@K^T, 0 MMA Acc with fp16, 1 MMA Acc with fp32.
  const int kMmaAccFloat32PV,      // 0/1, P@V, 0 MMA Acc with fp16, 1 MMA Acc with fp32.
  const int kOStorageAccFloat32,   // 0/1, MMA Acc always be f32/f16, but O storage can be fp32 or half.
  const int kPrefetchQK,           // Prefetch QK at the Appropriate Time Point. 
  const int kPrefetchPV,           // Prefetch V at the Appropriate Time Point. 
  const int kShareSmemQKV,         // QKV share the same shared memory, reuse QK smem for V.
  const int kPersistQs2r,          // Persist load Q s2r for headdim  < 512, more registers, but still keep O(1) SRAM.
  const int kPersistQg2s,          // Persist load Q g2s for headdim <= 320, more SRAM, but still keep register usage.
  const int kStageQK,              // <= 4, may apply different multi stages policy for QK and V (<=4)
  const int kStagePV,              // <= 4, may apply different multi stages policy for QK and V (<=4)
  const int kPadQ,                 // Pad Q/K/V 0,8; 0 -> smem swizzle, > 0 -> padding
  const int kPadK,                 // Pad Q/K/V 0,8; 0 -> smem swizzle, > 0 -> padding
  const int kPadV                  // Pad Q/K/V 0,8; 0 -> smem swizzle, > 0 -> padding
> __global__ void // Q, K, V, O -> [B, H, N, D]
// FFPA Attention Algo: Fine-grained tiling at MMA level for large headdim (d>256), 
// which can achieve 1.8x~3x🎉 faster than SDPA EA with or without MMA Acc F32.
ffpa_mma_stages_split_q_L1_large_d_template(half* Q, half* K, half* V, half* O, ...); 
// FA-2 Attention Algo: Coarse-grained tiling at Attention level for small headdim (d<=256), 
// which can achieve 95%-105%🎉 performance as SDPA FA-2 BE with MMA Acc F32, and achieve
// almost 1.2x~1.4x🎉 faster than SDPA FA-2 via Mixed MMA Acc(Q@K^T F32 + P@V F16).
ffpa_mma_stages_split_q_L1_small_d_template(half* Q, half* K, half* V, half* O, ...); 

📖 Prerequisites

  • Python >= 3.10
  • PyTorch >= 2.4.0, CUDA >= 12.4
  • Recommended: PyTorch 2.5.1, CUDA 12.5
  • Docker: nvcr.io/nvidia/pytorch:24.10-py3

📖 Installation

The FFPA implemented in this repo can be install as a python library, namely, ffpa-attn library (optional).

git clone https://github.com/DefTruth/ffpa-attn-mma.git
# clone, then, run bash .dev/install.sh directly or run commands:
python3 setup.py bdist_wheel && cd dist && python3 -m pip install *.whl # pip uninstall ffpa-attn -y

📖 FFPA L1 (Level 1): Benchmark 🎉🎉

L1: level 1, O(2xBrx16)≈O(1) SRAM complexity, O(d/4) register complexity, the same GPU HBM memory complexity as FlashAttention. B=1, H=48, N=8192, D=320-1024(FA2 not supported 👀). (Notes, *=MMA Acc F32, ^=MMA Acc F16, Softmax Acc dtype is always be F32, T=TFLOPS, 👇Benchmark)

  • 📚 NVIDIA L20 (*=MMA Acc F32, ^=MMA Acc F16, T=TFLOPS, ~1.8x↑🎉)
Algorithm 320 384 448 512 576 640 704 768 832 896 960 1024
SDPA EA 56T 63T 58T 58T 55T 56T 54T 55T 54T 55T 54T 56T
FFPA L1* 102T 102T 103T 104T 103T 95T 95T 95T 95T 96T 95T 94T
Speedup 1.82x 1.62x 1.78x 1.79x 1.87x 1.7x 1.76x 1.73x 1.76x 1.75x 1.76x 1.68x
FFPA L1^ 104T 103T 103T 102T 104T 103T 102T 94T 94T 94T 100T 100T
Speedup 1.86x 1.63x 1.78x 1.76x 1.89x 1.84x 1.89x 1.71x 1.74x 1.71x 1.85x 1.79x
  • 📚 NVIDIA L20 (*=MMA Acc: QK F32 + PV F16, ^=MMA Acc F16, T=TFLOPS, ~1.9x↑🎉)
Algorithm 320 384 448 512 576 640 704 768 832 896 960 1024
SDPA EA 56T 64T 58T 58T 55T 56T 54T 55T 54T 55T 54T 56T
FFPA L1* 105T 102T 104T 103T 105T 95T 95T 94T 94T 94T 102T 101T
Speedup 1.88x 1.59x 1.79x 1.78x 1.91x 1.7x 1.76x 1.71x 1.74x 1.71x 1.89x 1.8x
FFPA L1^ 104T 103T 103T 102T 103T 103T 102T 94T 94T 94T 100T 100T
Speedup 1.86x 1.61x 1.78x 1.76x 1.87x 1.84x 1.89x 1.71x 1.74x 1.71x 1.85x 1.79x
  • 📚 NVIDIA A30 (*=MMA Acc F32, ^=MMA Acc F16, T=TFLOPS, ~1.8x↑🎉)
Algorithm 320 384 448 512 576 640 704 768 832 896 960 1024
SDPA EA 25T 25T 24T 24T 24T 24T 23T 22T 22T 22T 22T 18T
FFPA L1* 45T 44T 44T 43T 43T 38T 37T 37T 37T 36T 33T 32T
Speedup 1.8x 1.76x 1.83x 1.79x 1.79x 1.58x 1.61x 1.68x 1.68x 1.64x 1.5x 1.78x
FFPA L1^ 48T 46T 45T 43T 44T 44T 44T 38T 37T 36T 40T 34T
Speedup 1.92x 1.84x 1.88x 1.79x 1.83x 1.83x 1.91x 1.73x 1.68x 1.64x 1.82x 1.89x
  • 📚 NVIDIA A30 (*=MMA Acc: QK F32 + PV F16, ^=MMA Acc F16, T=TFLOPS, ~1.9x↑🎉)
Algorithm 320 384 448 512 576 640 704 768 832 896 960 1024
SDPA EA 25T 25T 24T 24T 24T 24T 23T 22T 22T 22T 22T 18T
FFPA L1* 48T 46T 46T 43T 44T 38T 38T 38T 37T 36T 40T 34T
Speedup 1.92x 1.84x 1.92x 1.79x 1.83x 1.58x 1.65x 1.73x 1.68x 1.64x 1.82x 1.89x
FFPA L1^ 48T 46T 45T 43T 44T 44T 44T 38T 37T 36T 39T 34T
Speedup 1.92x 1.84x 1.88x 1.79x 1.83x 1.83x 1.91x 1.73x 1.68x 1.64x 1.77x 1.89x
  • 📚 NVIDIA RTX 3080 Laptop (*=MMA Acc F32, ^=MMA Acc F16, T=TFLOPS, ~2.5x↑🎉)
Algorithm 320 384 448 512 576 640 704 768 832 896 960 1024
SDPA EA 13T 16T 11T 16T 15T 15T 15T 15T 14T 14T 14T 14T
FFPA L1* 33T 31T 30T 30T 30T 27T 27T 26T 26T 26T 26T 25T
Speedup 2.54x 1.94x 2.73x 1.88x 2.0x 1.8x 1.8x 1.73x 1.86x 1.86x 1.86x 1.79x
FFPA L1^ 43T 41T 39T 39T 39T 39T 39T 36T 34T 33T 31T 33T
Speedup 3.31x 2.56x 3.55x 2.44x 2.6x 2.6x 2.6x 2.4x 2.43x 2.36x 2.21x 2.36x
  • 📚 NVIDIA RTX 3080 Laptop (*=MMA Acc: QK F32 + PV F16, ^=MMA Acc F16, T=TFLOPS, ~2.9x↑🎉)
Algorithm 320 384 448 512 576 640 704 768 832 896 960 1024
SDPA EA 13T 15T 12T 15T 14T 15T 14T 14T 14T 14T 14T 14T
FFPA L1* 38T 36T 34T 35T 34T 31T 32T 31T 30T 28T 27T 27T
Speedup 2.92x 2.4x 2.83x 2.33x 2.43x 2.07x 2.29x 2.21x 2.14x 2.0x 1.93x 1.93x
FFPA L1^ 44T 41T 39T 39T 38T 39T 39T 36T 34T 32T 31T 33T
Speedup 3.38x 2.73x 3.25x 2.6x 2.71x 2.6x 2.79x 2.57x 2.43x 2.29x 2.21x 2.36x
  • 📚 NVIDIA RTX 4090 (*=MMA Acc F32, ^=MMA Acc F16, T=TFLOPS, ~1.8x↑🎉)
Algorithm 320 384 448 512 576 640 704 768 832 896 960 1024
SDPA EA 81T 94T 85T 85T 79T 81T 79T 80T 79T 80T 78T 78T
FFPA L1* 149T 150T 150T 150T 150T 140T 140T 140T 139T 139T 137T 134T
Speedup 1.84x 1.6x 1.76x 1.76x 1.9x 1.73x 1.77x 1.75x 1.76x 1.74x 1.76x 1.72x
FFPA L1^ 194T 194T 189T 191T 197T 188T 184T 180T 177T 172T 171T 171T
Speedup 2.4x 2.06x 2.22x 2.25x 2.49x 2.32x 2.33x 2.25x 2.24x 2.15x 2.19x 2.19x
  • 📚 NVIDIA RTX 4090 (*=MMA Acc: QK F32 + PV F16, ^=MMA Acc F16, T=TFLOPS, ~2.1x↑🎉)
Algorithm 320 384 448 512 576 640 704 768 832 896 960 1024
SDPA EA 82T 92T 85T 84T 78T 81T 79T 80T 78T 79T 77T 78T
FFPA L1* 176T 170T 171T 171T 171T 161T 160T 161T 160T 158T 165T 164T
Speedup 2.15x 1.85x 2.01x 2.04x 2.19x 1.99x 2.03x 2.01x 2.05x 2.0x 2.14x 2.1x
FFPA L1^ 200T 191T 189T 191T 188T 188T 186T 179T 175T 173T 172T 170T
Speedup 2.44x 2.08x 2.22x 2.27x 2.41x 2.32x 2.35x 2.24x 2.24x 2.19x 2.23x 2.18x

📖 Python Testing

👇You can test many custom FFPA kernels via Python and figure out the difference in their performance. The --gen-bench and --plot options help you generate a benchmark table in Markdown style and speedup bar plots on your device. Contributions of your benchmark tables and plots are welcome via a PR 🎉🎉.

  • 📚 case: B=1, H=48, N=8192, D=320(FA2 not supported)
# You can test on many devices, such as Volta, Ampere, Ada, Hopper, ...
cd tests && python3 test.py --B 1 --H 48 --N 8192 --show-all --D 320
  • 📚 case: Generate benchmark table and speedup bar plots on Your device.
cd tests && pip install matplotlib && python3 test.py --gen-bench --show-all --plot
  • 📚 case: Compare small headdim (d<=256, e.g 64), FFPA-L1 vs SDPA FA-2 BE.
export ENABLE_FFPA_FORCE_PV_F16=1 # Mixed Mma Acc (Q@K^T F32 + P@V F16).
# Enbale ffpa-attn small d kernel which using coarse-grained tiling method.
export ENABLE_FFPA_PERSIST_Q_G2S=1 && export ENABLE_FFPA_PERSIST_KV_G2S=1 
python3 test.py --B 1 --H 8 --N 8192 --show-all --D 64 # NVIDIA RTX 3080 Laptop
--------------------------B=1, H=8, N=8192, D=64, Warmup: 1, Iters: 5---------------------
                   (sdpa): ['0.00499344'], time:4.346418ms, TFLOPS:32.24 (+0.00 %)(~1.00x)
 (ffpa+acc+f32+L1+stage1): ['0.00500107'], time:3.538846ms, TFLOPS:39.59 (+22.82%)(~1.23x)
 (ffpa+acc+f32+L1+stage2): ['0.00500107'], time:3.539991ms, TFLOPS:39.58 (+0.00 %)(~1.23x)
 (ffpa+acc+f16+L1+stage1): ['0.00498199'], time:2.624893ms, TFLOPS:53.38 (+34.82%)(~1.66x)
 (ffpa+acc+f16+L1+stage2): ['0.00498199'], time:2.629899ms, TFLOPS:53.28 (+0.00 %)(~1.65x)
 (ffpa+acc+f32+L1+stage3): ['0.00500107'], time:3.535127ms, TFLOPS:39.64 (+0.00 %)(~1.23x)
 (ffpa+acc+f32+L1+stage4): ['0.00500107'], time:3.538227ms, TFLOPS:39.60 (+0.00 %)(~1.23x)
 (ffpa+acc+f16+L1+stage3): ['0.00498199'], time:2.627229ms, TFLOPS:53.33 (+0.00 %)(~1.65x)
 (ffpa+acc+f16+L1+stage4): ['0.00498199'], time:2.624702ms, TFLOPS:53.38 (+0.01 %)(~1.66x)
------------------------------------------------------------------------------------------

💡NOTE: Please check all configurable environment variables in env.py.

©️License

GNU General Public License v3.0

🎉Contribute

How to contribute? Wecome to star⭐️ this repo to support me👆🏻 ~

📖 References

About

📚[WIP] FFPA: Yet antother Faster Flash Prefill Attention with O(1)⚡️GPU SRAM complexity for headdim > 256, 1.8x~3x↑🎉faster vs SDPA EA.

Topics

Resources

License

Stars

Watchers

Forks

Packages

No packages published