Skip to content

Latest commit

 

History

History
584 lines (488 loc) · 47 KB

blackwell_functionality.md

File metadata and controls

584 lines (488 loc) · 47 KB

Blackwell SM100 GEMMs

TLDR; jump to block scaled GEMM example

Blackwell SM100 introduces tcgen05.mma instructions. tcgen05.mma instructions support all legacy types (tfloat32_t, half_t, bfloat16_t, int8_t, uint8_t) and the new 4, 6, and 8-bits floating point datatypes with and without scale factors. This document explains the new tcgen05.mma instructions supported by CUTLASS and how one can leverage CUTLASS to create efficient SM100 GEMM kernels targeting these new mma instructions.

Blackwell SM100 has 7 new tcgen05.mma instructions. These instructions are 2x to 4x faster then Hopper Architecture's WGMMA instructions.

Ptx Instruction Throughput Notes
tcgen05.mma.cta_group::[1|2].kind::tf32 2x Hopper Tf32 Tensor Core MMA with A={tf32} x B={tf32} TN, NT, TT, NN layouts
tcgen05.mma.cta_group::[1|2].kind::f16 2x Hopper Fp16 Tensor Core MMA with A={f16} x B={f16} or A={bf16} x B={bf16} TN, NT, TT, NN layouts
tcgen05.mma.cta_group::[1|2].kind::i8 2x Hopper I8 Tensor Core MMA with A={i8} x B={i8} or A={u8} x B={u8} TN, NT, TT, NN layouts
tcgen05.mma.cta_group::[1|2].kind::f8f6f4 2x Hopper Fp8 Tensor Core Mixed precision MMA with A={f4,f6,f8} x B={f4,f6,f8} TN, NT, TT, NN layouts
tcgen05.mma.cta_group::[1|2].kind::mxf8f6f4.block_scale 2x Hopper Fp8 Tensor Core Block scaled mixed precision MMA with A={mxf4,mxf6,mxf8} x B={mxf4,mxf6,mxf8} with TN, NT, TT, NN layouts
tcgen05.mma.cta_group::[1|2].kind::mxf4.block_scale 4x Hopper Fp8 Tensor Core Block scaled MMA with A={mxf4} x B={mxf4} with TN layouts
tcgen05.mma.cta_group::[1|2].kind::mxf4nvf4.block_scale.scale_vec_size::[2X|4X] 4x Hopper Fp8 Tensor Core Block scaled MMA with A={mxf4} x B={mxf4} or A={nvf4} x B={nvf4} with TN layouts

For more detailed information see tcgen05.mma PTX documentation.

New in Blackwell SM100

Block Scaled GEMMs

Instructions with kind modifiers mxf8f6f4, mxf4, and nvf4mxf4 perform matrix multiplication operations with scale factors of the form $D = C +( A \times SFA) * (B \times SFB)$. Scale factors are applied to GEMM-K dimension such that every 16 or 32 elements of $A$ and $B$ matrices in K dimension have an associated scale factor. For example, an $M\times K$, $A$ matrix has an associated $M \times \lceil K/32 \rceil$ SFA matrix; and an $N\times K$ $B$, matrix has an associated $N \times \lceil K/32 \rceil$ SFB matrix. For block scaled GEMMs, an entry of output D matrix is $D_{ij} = C_{ij} + \sum_{k} (A_{i,k} \times SFA_{i,k/SV}) \times (B_{j,k}\times SFB_{j,k/SV})$, in index notation, we SV is the scale factor vector size (16 or 32). Further details can be found in PTX documentation on block scaling.

Blackwell Narrow Precision Data Types

Narrow-precision tcgen05.mma instructions can operate on several 4, 6, and 8-bit data types. Blackwell MMAs can operate on five different 8-bit floating point values, of which only two (float_ue8m0_t and float_ue4m3_t) can be used as scale factor data types. There are two 6-bit floating point types and one 4-bit floating point data type. See PTX documentation for narrow precision data types for details.

Blackwell Narrow Precision Data Types

Data Type Exponent Bits Mantissa Bits Signed Bit Size
float_e4m3_t 4 3 Yes 8
float_e5m2_t 5 2 Yes 8
float_e2m3_t 2 3 Yes 6
float_e3m2_t 3 2 Yes 6
float_e2m1_t 2 1 Yes 4
float_ue8m0_t1 8 0 No 8
float_ue4m3_t1 4 3 No 8

Block scaled MMAs use mx and nv types which are a pair of float8_t, float6_t, float4_t with 2 of the scale factor data types with a predetermined scale factor vector size. mx types follow OCP specification (see OCP Specification). The following types provided by CUTLASS can be used as inputs to collective builders to generate the block scaled kernels:

Blackwell Block Scaled Narrow Precision Data Types

Mx/Nv Data Type Scale Factor Type SF Vector Size OCP Compliant
mx_float8_t<Any F8type> float_ue8m0_t 32 Yes
mx_float6_t<Any F6Type> float_ue8m0_t 32 Yes
mx_float4_t float_ue8m0_t 32 Yes
nv_float4_t float_ue4m3_t 16 No

Layouts, Tensor Alignment Requirements to Target tcgen05.mma Instructions

Tables below list valid data type, and AB layout combinations. Note that the alignment is reported as number of elements. A and B matrix layouts are represented with T and N. T represents row-major layouts, and N represents column-major layouts. For instance, TN is row-major A matrix with column-major B matrix.

For legacy types (tf32, f16, bf16, i8 and u8) alignment requirements for A and B matrices are the same as in Hopper. All four layouts (TT, NN, NT, TT) are supported for all legacy data types.

Table 1: Valid Data Type, Alignment, and Layout Combinations For MMAs with Legacy Types

A Type B Type AB Layout A Alignment B Alignment Target tcgen05.mma.kind Unit Test
1 tfloat32_t tfloat32_t TN, NN, NT, TT 4 4 tf32
2 half_t half_t TN, NN, NT, TT 8 8 f16 Unit tests
3 bfloat16_t bfloat16_t TN, NN, NT, TT 8 8 f16 Similar to half_t unit tests
4 int8_t int8_t TN, NN, NT, TT 16 16 i8 Unit tests
5 uint8_t uint8_t TN, NN, NT, TT 16 16 i8 Similar to int8_t unit tests

For narrow precision Mmas, not all A/B type, and A/B layout combinations are supported by every tcgen05.mma instructions. Furthermore, tensor copy instructions for subbyte types impose additional alignment requirements while loading narrow-precision tensors from global memory to shared memory (see PTX doc for details).

Below tables list valid layout, and alignment values for each A and B data type combination and their target tcgen05.mma instructions supported by CUTLASS.

Table 2: Valid Data Type, Alignment, and Layout Combinations For Narrow Precision MMAs Without Block Scaling

A Type B Type AB Layout A Alignment B Alignment Target tcgen05.mma.kind Unit Test
1 float4_t float4_t TN, NN, NT, TT 128 128 f8f6f4 TN unit tests
NT unit tests
NN unit tests
TT unit tests
2 float4_t float6_t TN, NN, NT, TT 128 128 f8f6f4 TN unit tests
NT unit tests
NN unit tests
TT unit tests
3 float6_t float4_t TN, NN, NT, TT 128 128 f8f6f4 TN unit tests
NT unit tests
NN unit tests
TT unit tests
4 float4_t float8_t TN, NN, NT, TT 128 16 f8f6f4 TN unit tests
NT unit tests
5 float8_t float4_t TN, NN, NT, TT 16 128 f8f6f4 TN unit tests
NT unit tests
6 float6_t float6_t TN, NN, NT, TT 128 128 f8f6f4 TN unit tests
NT unit tests
NN unit tests
TT unit tests
7 float6_t float8_t TN, NN, NT, TT 128 16 f8f6f4 TN unit tests
NT unit tests
8 float8_t float6_t TN, NN, NT, TT 16 128 f8f6f4 TN unit tests
NT unit tests
9 float8_t float8_t TN, NN, NT, TT 16 16 f8f6f4 Unit tests

Table 3: Valid Data Type, Alignment, and Layout Combinations for Block Scaled Narrow Precision MMAs

A Type B Type AB Layout A Alignment B Alignment Target tcgen05.mma.kind Unit Test
1 nv_float4_t nv_float4_t TN 32 32 mxf4nvf4 TN unit tests
2 mx_float4_t mx_float4_t TN 32 32 mxf4, mxf4nvf4 TN unit tests
3 mx_float4_t mx_float4_t TN, NN, NT, TT 128 128 mxf8f6f4 NT unit tests
4 mx_float4_t mx_float6_t TN, NN, NT, TT 128 128 mxf8f6f4 TN unit tests
NT unit tests
5 mx_float6_t mx_float4_t TN, NN, NT, TT 128 128 mxf8f6f4 TN unit tests
NT unit tests
6 mx_float4_t mx_float8_t TN, NN, NT, TT 128 16 mxf8f6f4 TN unit tests
NT unit tests
7 mx_float8_t mx_float4_t TN, NN, NT, TT 16 128 mxf8f6f4 TN unit tests
NT unit tests
8 mx_float6_t mx_float6_t TN, NN, NT, TT 128 128 mxf8f6f4 TN unit tests
NT unit tests
9 mx_float6_t mx_float8_t TN, NN, NT, TT 128 16 mxf8f6f4 TN unit tests
NT unit tests
10 mx_float8_t mx_float6_t TN, NN, NT, TT 16 128 mxf8f6f4 TN unit tests
NT unit tests
11 mx_float8_t mx_float8_t TN, NN, NT, TT 16 16 mxf8f6f4 TN unit tests
NT unit tests

MMA tile shapes supported

The alignment restrictions also limit the options for Mma Tile Shapes. Tables below list the supported/valid MmaTileShape, Layout, and Dispatch Policy combinations for each row of Table 1, Table 2, and Table 3.

Table 4: Valid Tile Shapes and Dispatch Policies for lagacy types (All rows of Table 1)

1/2 SM Mma Tile Shape TN TT NT NN Dispatch Policy
1SM 64x64x(4*MMA-K) Y Y Y Y KernelTmaWarpSpecialized1SmSm100
1SM 64x128x(4*MMA-K) Y Y Y Y KernelTmaWarpSpecialized1SmSm100
1SM 64x192x(4*MMA-K) Y Y Y Y KernelTmaWarpSpecialized1SmSm100
1SM 64x256x(4*MMA-K) Y Y Y Y KernelTmaWarpSpecialized1SmSm100
1SM 128x64x(4*MMA-K) Y Y Y Y KernelTmaWarpSpecialized1SmSm100
1SM 128x128x(4*MMA-K) Y Y Y Y KernelTmaWarpSpecialized1SmSm100
1SM 128x192x(4*MMA-K) Y Y Y Y KernelTmaWarpSpecialized1SmSm100
1SM 128x256x(4*MMA-K) Y Y Y Y KernelTmaWarpSpecialized1SmSm100
2SM 128x64x(4*MMA-K) Y Y Y Y KernelTmaWarpSpecialized2SmSm100
2SM 128x128x(4*MMA-K) Y Y Y Y KernelTmaWarpSpecialized2SmSm100
2SM 128x192x(4*MMA-K) Y Y Y Y KernelTmaWarpSpecialized2SmSm100
2SM 128x256x(4*MMA-K) Y Y Y Y KernelTmaWarpSpecialized2SmSm100
2SM 256x64x(4*MMA-K) Y Y Y Y KernelTmaWarpSpecialized2SmSm100
2SM 256x128x(4*MMA-K) Y Y Y Y KernelTmaWarpSpecialized2SmSm100
2SM 256x192x(4*MMA-K) Y Y Y Y KernelTmaWarpSpecialized2SmSm100
2SM 256x256x(4*MMA-K) Y Y Y Y KernelTmaWarpSpecialized2SmSm100

Table 5: Valid Tile Shapes and Dispatch Policies for {float4_t, float6_t} x {float4_t, float6_t} (Rows 1,2,3,6 of Table 2)

1/2 SM Mma Tile Shape TN TT NT NN Dispatch Policy
1SM 64x64x128 Y N N N KernelTmaWarpSpecialized1SmSm100
1SM 64x128x128 Y Y N N KernelTmaWarpSpecialized1SmSm100
1SM 64x192x128 Y N N N KernelTmaWarpSpecialized1SmSm100
1SM 64x256x128 Y Y N N KernelTmaWarpSpecialized1SmSm100
1SM 128x64x128 Y N N Y KernelTmaWarpSpecialized1SmSm100
1SM 128x128x128 Y Y Y Y KernelTmaWarpSpecialized1SmSm100
1SM 128x192x128 Y N N Y KernelTmaWarpSpecialized1SmSm100
1SM 128x256x128 Y Y Y Y KernelTmaWarpSpecialized1SmSm100
2SM 128x64x128 Y N N N KernelTmaWarpSpecialized2SmSm100
2SM 128x128x128 Y N N N KernelTmaWarpSpecialized2SmSm100
2SM 128x192x128 Y N N N KernelTmaWarpSpecialized2SmSm100
2SM 128x256x128 Y Y N N KernelTmaWarpSpecialized2SmSm100
2SM 256x64x128 Y N N Y KernelTmaWarpSpecialized2SmSm100
2SM 256x128x128 Y N N Y KernelTmaWarpSpecialized2SmSm100
2SM 256x192x128 Y N N Y KernelTmaWarpSpecialized2SmSm100
2SM 256x256x128 Y Y Y Y KernelTmaWarpSpecialized2SmSm100

Table 6: Valid Tile Shapes and Dispatch Policies for float8_t x {float4_t, float6_t} (Rows 5,8 of Table 2)

1/2 SM Mma Tile Shape TN TT NT NN Dispatch Policy
1SM 64x64x128 Y N N Y KernelTmaWarpSpecialized1SmSm100
1SM 64x128x128 Y Y Y Y KernelTmaWarpSpecialized1SmSm100
1SM 64x192x128 Y N N Y KernelTmaWarpSpecialized1SmSm100
1SM 64x256x128 Y Y Y Y KernelTmaWarpSpecialized1SmSm100
1SM 128x64x128 Y N N Y KernelTmaWarpSpecialized1SmSm100
1SM 128x128x128 Y Y Y Y KernelTmaWarpSpecialized1SmSm100
1SM 128x192x128 Y N N Y KernelTmaWarpSpecialized1SmSm100
1SM 128x256x128 Y Y Y Y KernelTmaWarpSpecialized1SmSm100
2SM 128x64x128 Y N N Y KernelTmaWarpSpecialized2SmSm100
2SM 128x128x128 Y N N Y KernelTmaWarpSpecialized2SmSm100
2SM 128x192x128 Y N N Y KernelTmaWarpSpecialized2SmSm100
2SM 128x256x128 Y Y Y Y KernelTmaWarpSpecialized2SmSm100
2SM 256x64x128 Y N N Y KernelTmaWarpSpecialized2SmSm100
2SM 256x128x128 Y N N Y KernelTmaWarpSpecialized2SmSm100
2SM 256x192x128 Y N N Y KernelTmaWarpSpecialized2SmSm100
2SM 256x256x128 Y Y Y Y KernelTmaWarpSpecialized2SmSm100

Table 7: Valid Tile Shapes and Dispatch Policies for {float4_t, float6_t} x float8_t (Rows 4,7 of Table 2)

1/2 SM Mma Tile Shape TN TT NT NN Dispatch Policy
1SM 64x64x128 Y Y N N KernelTmaWarpSpecialized1SmSm100
1SM 64x128x128 Y Y N N KernelTmaWarpSpecialized1SmSm100
1SM 64x192x128 Y Y N N KernelTmaWarpSpecialized1SmSm100
1SM 64x256x128 Y Y N N KernelTmaWarpSpecialized1SmSm100
1SM 128x64x128 Y Y Y Y KernelTmaWarpSpecialized1SmSm100
1SM 128x128x128 Y Y Y Y KernelTmaWarpSpecialized1SmSm100
1SM 128x192x128 Y Y Y Y KernelTmaWarpSpecialized1SmSm100
1SM 128x256x128 Y Y Y Y KernelTmaWarpSpecialized1SmSm100
2SM 128x64x128 Y Y Y Y KernelTmaWarpSpecialized2SmSm100
2SM 128x128x128 Y Y N N KernelTmaWarpSpecialized2SmSm100
2SM 128x192x128 Y Y N N KernelTmaWarpSpecialized2SmSm100
2SM 128x256x128 Y Y N N KernelTmaWarpSpecialized2SmSm100
2SM 256x64x128 Y Y Y Y KernelTmaWarpSpecialized2SmSm100
2SM 256x128x128 Y Y Y Y KernelTmaWarpSpecialized2SmSm100
2SM 256x192x128 Y Y Y Y KernelTmaWarpSpecialized2SmSm100
2SM 256x256x128 Y Y Y Y KernelTmaWarpSpecialized2SmSm100

Table 8: Valid Tile Shapes and Dispatch Policies for float8_t x float8_t (Row 9 of Table 2)

1/2 SM Mma Tile Shape TN TT NT NN Dispatch Policy
1SM 64x64x128 Y Y Y Y KernelTmaWarpSpecialized1SmSm100
1SM 64x128x128 Y Y Y Y KernelTmaWarpSpecialized1SmSm100
1SM 64x192x128 Y Y Y Y KernelTmaWarpSpecialized1SmSm100
1SM 64x256x128 Y Y Y Y KernelTmaWarpSpecialized1SmSm100
1SM 128x64x128 Y Y Y Y KernelTmaWarpSpecialized1SmSm100
1SM 128x128x128 Y Y Y Y KernelTmaWarpSpecialized1SmSm100
1SM 128x192x128 Y Y Y Y KernelTmaWarpSpecialized1SmSm100
1SM 128x256x128 Y Y Y Y KernelTmaWarpSpecialized1SmSm100
2SM 128x64x128 Y Y Y Y KernelTmaWarpSpecialized2SmSm100
2SM 128x128x128 Y Y Y Y KernelTmaWarpSpecialized2SmSm100
2SM 128x192x128 Y Y Y Y KernelTmaWarpSpecialized2SmSm100
2SM 128x256x128 Y Y Y Y KernelTmaWarpSpecialized2SmSm100
2SM 256x64x128 Y Y Y Y KernelTmaWarpSpecialized2SmSm100
2SM 256x128x128 Y Y Y Y KernelTmaWarpSpecialized2SmSm100
2SM 256x192x128 Y Y Y Y KernelTmaWarpSpecialized2SmSm100
2SM 256x256x128 Y Y Y Y KernelTmaWarpSpecialized2SmSm100

Table 9: Valid Tile Shapes for nv_float4_t x nv_float4_t (Row 1 of Table 3)

1/2 SM Mma Tile Shape TN TT NT NN Dispatch Policy
1SM 128x128x256 Y N N N KernelTmaWarpSpecialized1SmNvf4Sm100
1SM 128x192x256 Y N N N KernelTmaWarpSpecialized1SmNvf4Sm100
1SM 128x256x256 Y N N N KernelTmaWarpSpecialized1SmNvf4Sm100
2SM 256x128x256 Y N N N KernelTmaWarpSpecialized2SmNvf4Sm100
2SM 256x192x256 Y N N N KernelTmaWarpSpecialized2SmNvf4Sm100
2SM 256x256x256 Y N N N KernelTmaWarpSpecialized2SmNvf4Sm100

Table 10: Valid Tile Shapes and Dispatch Policies for mx_float4_t x mx_float4_t (Row 2 of Table 3)

1/2 SM Mma Tile Shape TN TT NT NN Dispatch Policy
1SM 128x128x256 Y N N N KernelTmaWarpSpecialized1SmMxf4Sm100
1SM 128x192x256 Y N N N KernelTmaWarpSpecialized1SmMxf4Sm100
1SM 128x256x256 Y N N N KernelTmaWarpSpecialized1SmMxf4Sm100
2SM 256x128x256 Y N N N KernelTmaWarpSpecialized2SmMxf4Sm100
2SM 256x192x256 Y N N N KernelTmaWarpSpecialized2SmMxf4Sm100
2SM 256x256x256 Y N N N KernelTmaWarpSpecialized2SmMxf4Sm100

Table 11: Valid Tile Shapes and Dispatch Policies for mx_float4_t x mx_float4_t (Row 3 of Table 3)

1/2 SM Mma Tile Shape TN TT NT NN Dispatch Policy
1SM 128x128x128 Y Y Y Y KernelTmaWarpSpecialized1SmMxf8f6f4Sm100
1SM 128x192x128 Y N N Y KernelTmaWarpSpecialized1SmMxf8f6f4Sm100
1SM 128x256x128 Y Y Y Y KernelTmaWarpSpecialized1SmMxf8f6f4Sm100
2SM 256x128x128 Y N N Y KernelTmaWarpSpecialized2SmMxf8f6f4Sm100
2SM 256x192x128 Y N N Y KernelTmaWarpSpecialized2SmMxf8f6f4Sm100
2SM 256x256x128 Y Y Y Y KernelTmaWarpSpecialized2SmMxf8f6f4Sm100

Table 12: Valid Tile Shapes and Dispatch Policies for {mx_float4_t, mx_float6_t, mx_float8_t} x {mx_float4_t, mx_float6_t} (Rows 4, 5, 7, 8, 10 of Table 3)

1/2 SM Mma Tile Shape TN TT NT NN Dispatch Policy
1SM 128x128x128 Y Y Y Y KernelTmaWarpSpecialized1SmMxf8f6f4Sm100
1SM 128x192x128 Y N N Y KernelTmaWarpSpecialized1SmMxf8f6f4Sm100
1SM 128x256x128 Y Y Y Y KernelTmaWarpSpecialized1SmMxf8f6f4Sm100
2SM 256x128x128 Y N N Y KernelTmaWarpSpecialized2SmMxf8f6f4Sm100
2SM 256x192x128 Y N N Y KernelTmaWarpSpecialized2SmMxf8f6f4Sm100
2SM 256x256x128 Y Y Y Y KernelTmaWarpSpecialized2SmMxf8f6f4Sm100

Table 13: Valid Tile Shapes and Dispatch Policies for {mx_float4_t, mx_float6_t, mx_float8_t} x mx_float8_t (Rows 6, 9, 11 of Table 3)

1/2 SM Mma Tile Shape TN TT NT NN Dispatch Policy
1SM 128x128x128 Y Y Y Y KernelTmaWarpSpecialized1SmMxf8f6f4Sm100
1SM 128x192x128 Y Y Y Y KernelTmaWarpSpecialized1SmMxf8f6f4Sm100
1SM 128x256x128 Y Y Y Y KernelTmaWarpSpecialized1SmMxf8f6f4Sm100
2SM 256x128x128 Y Y Y Y KernelTmaWarpSpecialized2SmMxf8f6f4Sm100
2SM 256x192x128 Y Y Y Y KernelTmaWarpSpecialized2SmMxf8f6f4Sm100
2SM 256x256x128 Y Y Y Y KernelTmaWarpSpecialized2SmMxf8f6f4Sm100

Epilogue config supported

Table 14: Epilogue Dispatch Policy

1/2 SM Epilogue Dispatch Policy
1SM cutlass::epilogue::TmaWarpSpecialized1Sm
2SM cutlass::epilogue::TmaWarpSpecialized2Sm

Table 15: Epilogue PerSmTileShape_MNK

1/2 SM MMA tile Shape PerSmTileShape_MNK
1SM 64x64xMMA_TileShape_K 64x64xMMA_TileShape_K
1SM 64x128xMMA_TileShape_K 64x128xMMA_TileShape_K
1SM 64x192xMMA_TileShape_K 64x192xMMA_TileShape_K
1SM 64x256xMMA_TileShape_K 64x256xMMA_TileShape_K
1SM 128x64xMMA_TileShape_K 128x64xMMA_TileShape_K
1SM 128x128xMMA_TileShape_K 128x128xMMA_TileShape_K
1SM 128x192xMMA_TileShape_K 128x192xMMA_TileShape_K
1SM 128x256xMMA_TileShape_K 128x256xMMA_TileShape_K
2SM 128x64xMMA_TileShape_K 64x64xMMA_TileShape_K
2SM 128x128xMMA_TileShape_K 64x128xMMA_TileShape_K
2SM 128x192xMMA_TileShape_K 64x192xMMA_TileShape_K
2SM 128x256xMMA_TileShape_K 64x256xMMA_TileShape_K
2SM 256x64xMMA_TileShape_K 128x64xMMA_TileShape_K
2SM 256x128xMMA_TileShape_K 128x128xMMA_TileShape_K
2SM 256x192xMMA_TileShape_K 128x192xMMA_TileShape_K
2SM 256x256xMMA_TileShape_K 128x256xMMA_TileShape_K

MMA_TileShape_K is is generally 4 * MMA-Instruction-K. It depends on the config we defined in MMA tile shapes supported section.

Auto Kernel Dispatch Policies

In addition to direct dispatch policies listed above, the user can also use auto policies for both non-block scaled narrow-precision GEMMs, and block scaled narrow-precision GEMMs.

CUTLASS will do its best to find the most efficient kernel for given parameters, however, the preferred method for building these kernels is to use direct kernel dispatch policies shown in the above tables.

  • cutlass::gemm::collective::KernelScheduleAuto: For a given Mma Tile Size, data type and layout combinations choose instr kind (mxf8f6f4, mxf4, nvf4mxf4) and 1/2 SM tcgen05.mma.
  • KernelTmaWarpSpecialized1SmBlockScaledSm100: Use 1 SM tcgen05.mma instruction and choose instr kind (mxf8f6f4, mxf4, nvf4mxf4) automatically.
  • KernelTmaWarpSpecialized2SmBlockScaledSm100: Use 2 SM tcgen05.mma instruction and choose instr kind (mxf8f6f4, mxf4, nvf4mxf4) automatically.

Similarly for epilogues, we can use cutlass::epilogue::collective::EpilogueScheduleAuto.

Building a Block Scaled Kernel

For non-blockscaled dense GEMM refer to quick start page. An example dense GEMM can be found:

  1. Blackwell FP16 GEMM example.

Narrow precision and block scaled narrow precision kernels can be built using CUTLASS 3.x collective builder interface (as described in CUTLASS 3.0 GEMM API). However, special attention needs to be given to A and B matrix layouts, alignment requirements, and dispatch policies to obtain a functionally correct and performant kernel which are listed above.

Several examples of block scaled kernels can be found in examples/72_blackwell_narrow_precision_gemm directory:

  1. NVF4 Gemm with block scaling
  2. NVF4 Gemm with block scaling and NVF4 output matrix
  3. Mixed precision Nvf4 x Mxf8 GEMM with block scaling

Collective builder interface expects the same arguments as any other CUTLASS 3.x kernels as described here with a small difference for Collective MMA builder interface. As in all Blackwell kernels, the TileShape_MNK argument expects the MmaTileShape_MNK which is the tile shape needed by 1 or 2 SM tcgen05.mma instructions.

Let's consider building a block scaled GEMM where the A matrix is of type mx_float4_t and column-major (N), and the B matrix is of type mx_float4_t and row-major (T). We first need to describe the A and B tensors, and find the instruction that can support the selected A and B type and layout pair. Then, we will choose the performance parameters.

The skeleton C++ code is shown below:

  ///////////////////////////////////////////////////////////
  //                Mainloop Builder Setup
  ///////////////////////////////////////////////////////////
  
  ///////////////////////////////////////////
  // 1. Describe A and B tensors
  ///////////////////////////////////////////
  using ElementA       = // TBD
  constexpr int AlignA = // TBD
  using GmemLayoutA    = // TBD
  using ElementB       = // TBD
  constexpr int AlignB = // TBD
  using GmemLayoutB    = // TBD

  // Mma's accumulator type
  using ElementAccumulator = float;           // Always float for block scaled tcgen05.mma instructions

  //////////////////////////////////////////
  // 2. Choose Performance Parameters
  //////////////////////////////////////////

  // Tile and cluster shapes
  // Collective MMA takes tile shape of the MMA operation as input
  using KernelMainloopPolicy     = // TBD
  using MmaTileShape_MNK         = // TBD
  using ClusterShape_MNK         = // TBD

  using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder<
      cutlass::arch::Sm100, cutlass::arch::OpClassBlockScaledTensorOp,      // Arch and Tensorop spec
      ElementA, GmemLayoutA, AlignA,                                        // A tensor elem type, layout and alignment requirement
      ElementB, GmemLayoutB, AlignB,                                        // B tensor elem type, layout and alignment requirement
      ElementAccumulator,                                                   // Mma instruction accumulator type
      MmaTileShape_MNK, ClusterShape_MNK,                                   // Mma instruction tile shape, cluster shape
      // Epilogue's SMEM usage that needs to be subtracted from overall SMEM capacity 
      cutlass::gemm::collective::StageCountAutoCarveout<static_cast<int>(sizeof(typename CollectiveEpilogue::SharedStorage))>,
      KernelMainloopPolicy                                                  // Kernel schedule policy.
                                                                            // Auto or using targeted scheduling policy
    >::CollectiveOp;

From the valid type and layout combinations Table 3, we see that only row 3 can support mx_float4_txmx_float4_t combination with NT layout. As a result, we need to use the tcgen05.mma.kind:mxf8f6f4 instruction. Additionally, in order to use tcgen05.mma.kind:mxf8f6f4, we see that A and B tensors both should be 128-element aligned. Thus, we can describe A and B tensors as follows:

  ///////////////////////////////////////////////////////////
  //                Mainloop Builder Setup
  ///////////////////////////////////////////////////////////
  
  ///////////////////////////////////////////
  // 1. Describe A and B tensors
  ///////////////////////////////////////////
  using ElementA       = mx_float4_t;
  constexpr int AlignA = 128;
  using GmemLayoutA    = cutlass::layout::ColumnMajor;
  using ElementB       = mx_float4_t;
  constexpr int AlignB = 128;
  using GmemLayoutB    = cutlass::layout::RowMajor;

Next, we need to choose the performance parameters such as MmaTileShape_MNK, KernelMainloopPolicy, and ClusterShape_MNK.

MmaTileShape_MNK supported for mx_float4_txmx_float4_t with mxf8f6f4 are listed in Table 11. For NT layout, we see that 3 MmaTileShape_MNK are supported: 128x128x128, and 128x256x128 with 1SM instruction; and 256x256x128 with 2SM instruction. Let's say, we expect to get the best performance with 256x256x128 MMA tile shape for our GEMM problem. Then, we need to set the KernelMainloopPolicy to KernelTmaWarpSpecialized2SmMxf8f6f4Sm100. Now, we need to choose the ClusterShape_MNK. Since we have selected a 2SM mma instruction, ClusterShape_MNK should be compatible and its first mode should be a multiple of 2. ClusterShape_MNK = cute::Shape<_2, [_1|_2|_4], _1> or ClusterShape_MNK = cute::Shape<_4, [_1|_2|_4], _1> would be valid options. Let's choose cute::Shape<_4,_4,_1>. Our performance parameters looks like below:

  //////////////////////////////////////////
  // 2. Choose Performance Parameters
  //////////////////////////////////////////

  // Tile and cluster shapes
  // Collective MMA takes tile shape of the MMA operation as input
  using KernelMainloopPolicy     = cutlass::gemm::KernelTmaWarpSpecialized2SmMxf8f6f4Sm100;
  using MmaTileShape_MNK         = cute::Shape<_256,_256,_128>;
  using ClusterShape_MNK         = cute::Shape<_4,_4,_1>;

After we config the main-loop, let's setup the epilogue. A normal epilogue looks like below, we need to specify the output layout, datatype, alignment and PerSmTileShape_MNK, and let others to be default/auto.

PerSmTileShape_MNK should be deduced from the mainloop setup. For example, in above mainloop setup, the MmaTileShape_MNK is 256x256x128 and the KernelMainloopPolicy is 2sm policy. It means each CTA is doing (256 / 2sm) x 256 x 128 output, so the PerSmTileShape_MNK is 128x256x128. The possible PerSmTileShape_MNK is listed in Table 15

The epilogue scheduling policy is configurable, and it is common to set cutlass::epilogue::TmaWarpSpecialized2Sm to allow the epilogue builder to automatically select the appropriate policy. However, it can also be explicitly defined to use other policies based on the 1sm or 2sm MMA instruction. The available policies are listed in Table 14.

  // Describe C and D tensors
  using ElementC = cutlass::half_t;
  constexpr int AlignC = 8;
  using GmemLayoutC = cutlass::layout::RowMajor;
  using ElementD = cutlass::float_e2m1_t;
  constexpr int AlignD = 32;
  using GmemLayoutD = cutlass::layout::RowMajor;
  // Mma's accumulator type
  using ElementAccumulator = float;
  // Epilogue computation's precision type
  using ElementCompute = float;
  // Cluster size for multicast
  using ClusterShape_MNK = Shape<_4,_4,_1>;
  // Collective Epilogue takes the output tile shape for 1 CTA
  using PerSmTileShape_MNK = Shape<_128,_256,_128>;
  
  //
  // Construct CollectiveEpilogue
  //

  using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder<
      cutlass::arch::Sm100, cutlass::arch::OpClassBlockScaledTensorOp,      // Arch and Tensorop spec
      PerSmTileShape_MNK, ClusterShape_MNK,                                 // Epilogue tile shape, and cluster shape
      cutlass::epilogue::collective::EpilogueTileAuto,                      // Epilogue subtile shape. Auto will find a suitable tile shape
      ElementAccumulator, ElementCompute,                                   // Mma instr's accumulator type and compute precision for epilogue
      ElementC, GmemLayoutC, AlignC,                                        // C tensor description
      ElementD, GmemLayoutD, AlignD,                                        // D tensor description
      cutlass::epilogue::TmaWarpSpecialized2Sm                              // Epilogue schedule policy
    >::CollectiveOp;

If we want to let the epilogue generate mxf4/nvf4/mxf6/mxf8 (i.e. elements + block-scalefactor), we need to setup the epilogue fusion into the builder. First, we need to choose a SFDVectorSize indicates how many elements sharing the same block-scalefactor. Then, we need to choose ElementSFD and GmemLayoutSFD which indicates the output datatype and which output-dim is used to generate the block-scalefactor. Typically, GmemLayoutSFD would be same as the GmemLayoutD.

  //
  // Construct FusionOperation
  //
  constexpr int SFDVectorSize = 16;
  // Define the fusion operation applied during epilogue
  using FusionOperation = cutlass::epilogue::fusion::LinCombBlockScaleFactor<
      SFDVectorSize,
      ElementD, ElementCompute, 
      ElementSFD, GmemLayoutSFD,
      ElementC
    >;

  using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder<
      cutlass::arch::Sm100, cutlass::arch::OpClassBlockScaledTensorOp,      // Arch and Tensorop spec
      PerSmTileShape_MNK, ClusterShape_MNK,                                 // Epilogue tile shape, and cluster shape
      cutlass::epilogue::collective::EpilogueTileAuto,                      // Epilogue subtile shape. Auto will find a suitable tile shape
      ElementAccumulator, ElementCompute,                                   // Mma instr's accumulator type and compute precision for epilogue
      ElementC, GmemLayoutC, AlignC,                                        // C tensor description
      ElementD, GmemLayoutD, AlignD,                                        // D tensor description
      cutlass::epilogue::collective::EpilogueScheduleAuto                   // Epilogue schedule policy
      FusionOperation                                                       // <================================== Pass the fusion config into epilogue builder.
    >::CollectiveOp;

Above example made a gentle introduction to using the fusion operations in the epilogue. For more detailed example, see Blackwell GEMM with collective builder

Note that we have first discussed the CollectiveMainloop, then the CollectiveEpilogue for clarity. However, the CollectiveMainloop needs to know the SMEM utilization of the epilogue. Therefore, it needs to be setup before the CollectiveMainloop. See examples/72_blackwell_narrow_precision_gemm directory for full kernel and run setup.

Scale Factor Layouts

The scale factor layout consists of a 512B basic-block structure, as illustrated in the diagram below. Each block contains 128 M/N dimension and 4 scale factors (SF) along the K dimension. The byte order of the basic storage chunk is row-major, meaning that M0SF0 to M0SF3, M32SF0 to M32SF3, M64SF0 to M64SF3, and M96SF0 to M96SF3 are stored consecutively in GMEM.

/M128xK4_scalefactor_gmem.png

If the scale factor tensor exceeds M128xSF4, it indicates that there are multiple basic blocks along both the M and SFK dimensions. The arrangement of these basic blocks follows a K-major order. Here is a diagram illustrating the scenario where M equals 512 and the SFK is 16.

/narrow_precison_multiple_block_sf_layout.png

The creation of scale factor tensors' layouts are tedious. CUTLASS provides Sm100BlockScaledConfig to create these layouts easily (See sm100_blockscaled_layout.hpp). The interface to create SFA and SFB tensor layouts is as follows:

auto problem_shape = make_shape(M, N, K, L);
using SfConfig = Sm100BlockScaledConfig<SFVecSize>;

// SFA shape: ((32,4), ceil(M/128)), ((SFVecSize,4), ceil(K/4), L)
auto layout_sfa = SfConfig::tile_atom_to_shape_SFA(problem_shape);
// SFB shape: ((32,4), ceil(N/128)), ((SFVecSize,4), ceil(K/4), L)
auto layout_sfb = SfConfig::tile_atom_to_shape_SFB(problem_shape);

auto tensor_sfa = make_tensor(aptr, layout_sfa);
auto tensor_sfb = make_tensor(bptr, layout_sfb);
// Access SF for for element m,k of A tensor
auto val_a_mk = tensor_sfa(make_coord(m,k,0));

Copyright

Copyright (c) 2025 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. SPDX-License-Identifier: BSD-3-Clause

  Redistribution and use in source and binary forms, with or without
  modification, are permitted provided that the following conditions are met:

  1. Redistributions of source code must retain the above copyright notice, this
  list of conditions and the following disclaimer.

  2. Redistributions in binary form must reproduce the above copyright notice,
  this list of conditions and the following disclaimer in the documentation
  and/or other materials provided with the distribution.

  3. Neither the name of the copyright holder nor the names of its
  contributors may be used to endorse or promote products derived from
  this software without specific prior written permission.

  THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
  AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
  IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
  DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
  FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
  DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
  SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
  CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
  OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
  OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

Footnotes

  1. Only valid as scale factor data types. 2