-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathmatmul.cu
365 lines (308 loc) · 15.7 KB
/
matmul.cu
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
#include "mma.cuh"
#include <cmath>
#include <stdio.h>
#include <assert.h>
#include <cstdint>
#include <cuda_bf16.h>
#define PRINT_IF(cond, ...) if (cond) printf(__VA_ARGS__);
__host__ __device__ constexpr int cdiv(int a, int b) { return (a + b - 1) / b; }
constexpr bool is_power_of_two(int x) { return x > 0 && (x & (x - 1)) == 0; } // https://stackoverflow.com/a/1804686
constexpr int WARP_SIZE = 32;
template <int BLOCK_SIZE, int HEIGHT, int WIDTH, typename T>
__device__ void load_b128(const T *in, int in_row_stride, T *out, int out_row_stride, int tid) {
// number of elements to do 128-bit/16-byte load
// e.g. FP32 -> 4 elements, BF16 -> 8 elements.
using load_type = uint4;
constexpr int num_elems = sizeof(load_type) / sizeof(T);
for (int idx = tid * num_elems; idx < HEIGHT * WIDTH; idx += BLOCK_SIZE * num_elems) {
const int row = idx / WIDTH;
const int col = idx % WIDTH;
load_type tmp = reinterpret_cast<const load_type *>(&in[row * in_row_stride + col])[0];
reinterpret_cast<load_type *>(&out[row * out_row_stride + col])[0] = tmp;
}
}
template <typename T> __device__ ushort f32_to_b16(float x);
template <> __device__ ushort f32_to_b16<half>(float x) { return __half_as_ushort(__float2half(x)); }
template <> __device__ ushort f32_to_b16<nv_bfloat16>(float x) { return __bfloat16_as_ushort(__float2bfloat16(x)); }
template <
int BLOCK_M, int BLOCK_N, int BLOCK_K,
int WARP_M, int WARP_N, int WARP_K,
int MMA_M, int MMA_N, int MMA_K,
bool PAD_SHMEM_A, bool PAD_SHMEM_B,
typename T>
__global__ void matmul_v1_kernel(const T *A, const T *B, T *C, int M, int N, int K) {
static_assert(BLOCK_M % WARP_M == 0);
static_assert(BLOCK_N % WARP_N == 0);
static_assert(BLOCK_K % WARP_K == 0);
static_assert(WARP_M % MMA_M == 0);
static_assert(WARP_N % MMA_N == 0);
static_assert(WARP_K % MMA_K == 0);
constexpr int BLOCK_SIZE = (BLOCK_M * BLOCK_N) / (WARP_M * WARP_N) * WARP_SIZE;
constexpr int NUM_MMA_M = WARP_M / MMA_M;
constexpr int NUM_MMA_N = WARP_N / MMA_N;
constexpr int NUM_MMA_K = WARP_K / MMA_K;
const int tid = threadIdx.x;
const int block_id = blockIdx.x;
const int warp_id = tid / WARP_SIZE;
const int lane_id = tid % WARP_SIZE;
const int num_blocks_per_row = cdiv(N, BLOCK_N);
const int block_id_m = block_id / num_blocks_per_row;
const int block_id_n = block_id % num_blocks_per_row;
const int offset_m = block_id_m * BLOCK_M;
const int offset_n = block_id_n * BLOCK_N;
constexpr int num_warps_per_row = BLOCK_N / WARP_N;
const int warp_id_m = warp_id / num_warps_per_row;
const int warp_id_n = warp_id % num_warps_per_row;
const int warp_tile_offset_m = warp_id_m * WARP_M;
const int warp_tile_offset_n = warp_id_n * WARP_N;
// A is row-major, B is column-major
A += offset_m * K;
B += offset_n * K;
// we can only pad 8 elements = 16 bytes to ensure 16-byte alignment required by ldmatrix
constexpr int A_shared_width = BLOCK_K + (PAD_SHMEM_A ? 8 : 0);
constexpr int B_shared_width = BLOCK_K + (PAD_SHMEM_B ? 8 : 0);
__shared__ T A_shared[BLOCK_M * A_shared_width];
__shared__ T B_shared[BLOCK_N * B_shared_width];
// 32-bit (4-byte) registers
constexpr int num_acc_regs = MMA_M * MMA_N / WARP_SIZE;
constexpr int num_A_regs = MMA_M * MMA_K * sizeof(T) / 4 / WARP_SIZE;
constexpr int num_B_regs = MMA_N * MMA_K * sizeof(T) / 4 / WARP_SIZE;
float acc[NUM_MMA_M][NUM_MMA_N][num_acc_regs] = {0.0f}; // for m16n8k8, each thread holds 4 output float
uint32_t A_reg[NUM_MMA_M][NUM_MMA_K][num_A_regs]; // each thread holds 2 input f16x2
uint32_t B_reg[NUM_MMA_N][NUM_MMA_K][num_B_regs]; // each thread holds 1 input f16x1
// first A and B warp-tile along BLOCK_K dim (we will iterate along BLOCK_K with step_size=WARP_K)
const T *A_warp_tile = reinterpret_cast<const T *>(A_shared) + warp_tile_offset_m * A_shared_width;
const T *B_warp_tile = reinterpret_cast<const T *>(B_shared) + warp_tile_offset_n * B_shared_width;
for (int block_k = 0; block_k < K; block_k += BLOCK_K) {
load_b128<BLOCK_SIZE, BLOCK_M, BLOCK_K>(A, K, A_shared, A_shared_width, tid);
load_b128<BLOCK_SIZE, BLOCK_N, BLOCK_K>(B, K, B_shared, B_shared_width, tid);
__syncthreads();
for (int warp_k = 0; warp_k < BLOCK_K; warp_k += WARP_K) {
// load data from shared memory to registers using ldmatrix
// https://docs.nvidia.com/cuda/parallel-thread-execution/#warp-level-matrix-instructions-ldmatrix
// convert generic address to .shared state space address expected by inline PTX
// thread 0 holds address of row 0
// thread 1 holds address of row 1, and so on
uint32_t A_tile_addr = cvta_shared(A_warp_tile + lane_id * A_shared_width + warp_k);
uint32_t B_tile_addr = cvta_shared(B_warp_tile + lane_id * B_shared_width + warp_k);
// load A to registers
// ldmatrix can only load 8x8 matrix. for 16x8 tile, we need to use x2
// works for both m16n8k8 and m16n8k16
for (int mma_tile_id_m = 0; mma_tile_id_m < NUM_MMA_M; mma_tile_id_m++)
for (int mma_tile_id_k = 0; mma_tile_id_k < NUM_MMA_K; mma_tile_id_k++) {
uint32_t A_local = A_tile_addr + (mma_tile_id_m * MMA_M * A_shared_width + mma_tile_id_k * MMA_K) * sizeof(T);
ldmatrix<num_A_regs>(A_reg[mma_tile_id_m][mma_tile_id_k], A_local);
}
// load B to registers
for (int mma_tile_id_n = 0; mma_tile_id_n < NUM_MMA_N; mma_tile_id_n++)
for (int mma_tile_id_k = 0; mma_tile_id_k < NUM_MMA_K; mma_tile_id_k++) {
uint32_t B_local = B_tile_addr + (mma_tile_id_n * MMA_N * B_shared_width + mma_tile_id_k * MMA_K) * sizeof(T);
ldmatrix<num_B_regs>(B_reg[mma_tile_id_n][mma_tile_id_k], B_local);
}
// call mma
// https://docs.nvidia.com/cuda/parallel-thread-execution/#warp-level-matrix-fragment-mma-1688
for (int mma_tile_id_m = 0; mma_tile_id_m < NUM_MMA_M; mma_tile_id_m++)
for (int mma_tile_id_n = 0; mma_tile_id_n < NUM_MMA_N; mma_tile_id_n++)
for (int mma_tile_id_k = 0; mma_tile_id_k < NUM_MMA_K; mma_tile_id_k++)
mma<MMA_M, MMA_N, MMA_K, T>(A_reg[mma_tile_id_m][mma_tile_id_k],
B_reg[mma_tile_id_n][mma_tile_id_k],
acc[mma_tile_id_m][mma_tile_id_n]);
}
__syncthreads();
A += BLOCK_K;
B += BLOCK_K;
}
const int C_offset_m = offset_m + warp_tile_offset_m;
const int C_offset_n = offset_n + warp_tile_offset_n;
C += C_offset_m * N + C_offset_n;
// check output layout here
// https://docs.nvidia.com/cuda/parallel-thread-execution/#mma-1688-c-f16-f32
// m16n8k16 has the same layout
const int a0_row = lane_id >> 2;
const int a0_col = (lane_id % 4) * 2;
C += a0_row * N + a0_col;
for (int mma_tile_id_m = 0; mma_tile_id_m < NUM_MMA_M; mma_tile_id_m++)
for (int mma_tile_id_n = 0; mma_tile_id_n < NUM_MMA_N; mma_tile_id_n++) {
T *C_local = C + mma_tile_id_m * MMA_M * N + mma_tile_id_n * MMA_N;
float *acc_frag = acc[mma_tile_id_m][mma_tile_id_n];
ushort2 tmp;
// write a0 and a1
tmp.x = f32_to_b16<T>(acc_frag[0]);
tmp.y = f32_to_b16<T>(acc_frag[1]);
reinterpret_cast<ushort2 *>(C_local)[0] = tmp;
// write a2 and a3
tmp.x = f32_to_b16<T>(acc_frag[2]);
tmp.y = f32_to_b16<T>(acc_frag[3]);
reinterpret_cast<ushort2 *>(C_local + 8 * N)[0] = tmp;
}
}
void matmul_v1a(const nv_bfloat16 *A, const nv_bfloat16 *B, nv_bfloat16 *C, int M, int N, int K) {
assert(is_power_of_two(M) && "M must be a power of 2");
assert(is_power_of_two(N) && "N must be a power of 2");
assert(is_power_of_two(K) && "K must be a power of 2");
const int BLOCK_M = 128, BLOCK_N = 128, BLOCK_K = 32;
const int WARP_M = 64, WARP_N = 64, WARP_K = 16;
const int MMA_M = 16, MMA_N = 8, MMA_K = 8;
const int BLOCK_SIZE = (BLOCK_M * BLOCK_N) / (WARP_M * WARP_N) * WARP_SIZE;
const int grid_size = cdiv(M * N, BLOCK_M * BLOCK_N);
matmul_v1_kernel<
BLOCK_M, BLOCK_N, BLOCK_K,
WARP_M, WARP_N, WARP_K,
MMA_M, MMA_N, MMA_K,
false, false><<<grid_size, BLOCK_SIZE>>>(A, B, C, M, N, K);
}
void matmul_v1b(const nv_bfloat16 *A, const nv_bfloat16 *B, nv_bfloat16 *C, int M, int N, int K) {
assert(is_power_of_two(M) && "M must be a power of 2");
assert(is_power_of_two(N) && "N must be a power of 2");
assert(is_power_of_two(K) && "K must be a power of 2");
const int BLOCK_M = 128, BLOCK_N = 128, BLOCK_K = 32;
const int WARP_M = 64, WARP_N = 64, WARP_K = 16;
const int MMA_M = 16, MMA_N = 8, MMA_K = 8;
const int BLOCK_SIZE = (BLOCK_M * BLOCK_N) / (WARP_M * WARP_N) * WARP_SIZE;
const int grid_size = cdiv(M * N, BLOCK_M * BLOCK_N);
matmul_v1_kernel<
BLOCK_M, BLOCK_N, BLOCK_K,
WARP_M, WARP_N, WARP_K,
MMA_M, MMA_N, MMA_K,
true, false><<<grid_size, BLOCK_SIZE>>>(A, B, C, M, N, K);
}
constexpr __device__ int log2_int(int x) { return x == 1 ? 0 : 1 + log2_int(x >> 1); }
// https://github.com/NVIDIA/cutlass/blob/main/include/cute/swizzle.hpp
template <int WIDTH, typename T>
__device__ int swizzle(int x) {
constexpr int num_elems = 16 / sizeof(T);
constexpr int stride = WIDTH / num_elems; // stride for 16-byte word.
constexpr int MBase = log2_int(num_elems); // we don't touch the first MBase bits because they belong to the same 16-byte row (8x 16-bit).
// TODO: seems like we have to add 1 to BBits? bug in logic?
constexpr int BBits = std::min(log2_int(stride), 3); // we permute BBits, which is the no. of non-overlapping bits between row index and 4-bank-group index.
constexpr int SShift = log2_int(stride); // relative difference from 4-bank-group index to row index.
constexpr int mask = ((1 << BBits) - 1) << MBase; // BBits 1s and MBase 0sa
if constexpr (BBits == 0) return x;
else return x ^ ((x >> SShift) & mask);
}
template <int BLOCK_SIZE, int HEIGHT, int WIDTH, typename T>
__device__ void load_shared_swizzle(const T *in, int in_row_stride, T *out, int tid) {
constexpr int num_elems = 16 / sizeof(T);
for (int idx = tid * num_elems; idx < HEIGHT * WIDTH; idx += BLOCK_SIZE * num_elems) {
const int row = idx / WIDTH;
const int col = idx % WIDTH;
uint4 tmp = reinterpret_cast<const uint4 *>(&in[row * in_row_stride + col])[0];
int swizzled_idx = swizzle<WIDTH, T>(row * WIDTH + col);
reinterpret_cast<uint4 *>(&out[swizzled_idx])[0] = tmp;
}
}
template <
int BLOCK_M, int BLOCK_N, int BLOCK_K,
int WARP_M, int WARP_N, int WARP_K,
int MMA_M, int MMA_N, int MMA_K,
typename T>
__global__ void matmul_v2_kernel(const T *A, const T *B, T *C, int M, int N, int K) {
static_assert(BLOCK_M % WARP_M == 0);
static_assert(BLOCK_N % WARP_N == 0);
static_assert(BLOCK_K % WARP_K == 0);
static_assert(WARP_M % MMA_M == 0);
static_assert(WARP_N % MMA_N == 0);
static_assert(WARP_K % MMA_K == 0);
constexpr int BLOCK_SIZE = (BLOCK_M * BLOCK_N) / (WARP_M * WARP_N) * WARP_SIZE;
constexpr int NUM_MMA_M = WARP_M / MMA_M;
constexpr int NUM_MMA_N = WARP_N / MMA_N;
constexpr int NUM_MMA_K = WARP_K / MMA_K;
const int tid = threadIdx.x;
const int block_id = blockIdx.x;
const int warp_id = tid / WARP_SIZE;
const int lane_id = tid % WARP_SIZE;
const int num_blocks_per_row = cdiv(N, BLOCK_N);
const int block_id_m = block_id / num_blocks_per_row;
const int block_id_n = block_id % num_blocks_per_row;
const int offset_m = block_id_m * BLOCK_M;
const int offset_n = block_id_n * BLOCK_N;
constexpr int num_warps_per_row = BLOCK_N / WARP_N;
const int warp_id_m = warp_id / num_warps_per_row;
const int warp_id_n = warp_id % num_warps_per_row;
const int warp_tile_offset_m = warp_id_m * WARP_M;
const int warp_tile_offset_n = warp_id_n * WARP_N;
// A is row-major, B is column-major
A += offset_m * K;
B += offset_n * K;
__shared__ T A_shared[BLOCK_M * BLOCK_K];
__shared__ T B_shared[BLOCK_N * BLOCK_K];
// 32-bit (4-byte) registers
constexpr int num_acc_regs = MMA_M * MMA_N / WARP_SIZE;
constexpr int num_A_regs = MMA_M * MMA_K * sizeof(T) / 4 / WARP_SIZE;
constexpr int num_B_regs = MMA_N * MMA_K * sizeof(T) / 4 / WARP_SIZE;
float acc[NUM_MMA_M][NUM_MMA_N][num_acc_regs] = {0.0f};
uint32_t A_reg[NUM_MMA_M][NUM_MMA_K][num_A_regs];
uint32_t B_reg[NUM_MMA_N][NUM_MMA_K][num_B_regs];
for (int block_k = 0; block_k < K; block_k += BLOCK_K)
{
load_shared_swizzle<BLOCK_SIZE, BLOCK_M, BLOCK_K>(A, K, A_shared, tid);
load_shared_swizzle<BLOCK_SIZE, BLOCK_N, BLOCK_K>(B, K, B_shared, tid);
__syncthreads();
for (int warp_k = 0; warp_k < BLOCK_K; warp_k += WARP_K)
{
// load A to registers
for (int mma_tile_id_m = 0; mma_tile_id_m < NUM_MMA_M; mma_tile_id_m++)
for (int mma_tile_id_k = 0; mma_tile_id_k < NUM_MMA_K; mma_tile_id_k++)
{
const int A_offset = (warp_tile_offset_m + lane_id + mma_tile_id_m * MMA_M) * BLOCK_K + (warp_k + mma_tile_id_k * MMA_K);
const T *A_local = reinterpret_cast<const T *>(A_shared) + swizzle<BLOCK_K, T>(A_offset);
ldmatrix<num_A_regs>(A_reg[mma_tile_id_m][mma_tile_id_k], cvta_shared(A_local));
}
// load B to registers
for (int mma_tile_id_n = 0; mma_tile_id_n < NUM_MMA_N; mma_tile_id_n++)
for (int mma_tile_id_k = 0; mma_tile_id_k < NUM_MMA_K; mma_tile_id_k++)
{
const int B_offset = (warp_tile_offset_n + lane_id + mma_tile_id_n * MMA_N) * BLOCK_K + (warp_k + mma_tile_id_k * MMA_K);
const T *B_local = reinterpret_cast<const T *>(B_shared) + swizzle<BLOCK_K, T>(B_offset);
ldmatrix<num_B_regs>(B_reg[mma_tile_id_n][mma_tile_id_k], cvta_shared(B_local));
}
// call mma
for (int mma_tile_id_m = 0; mma_tile_id_m < NUM_MMA_M; mma_tile_id_m++)
for (int mma_tile_id_n = 0; mma_tile_id_n < NUM_MMA_N; mma_tile_id_n++)
for (int mma_tile_id_k = 0; mma_tile_id_k < NUM_MMA_K; mma_tile_id_k++)
mma<MMA_M, MMA_N, MMA_K, T>(A_reg[mma_tile_id_m][mma_tile_id_k],
B_reg[mma_tile_id_n][mma_tile_id_k],
acc[mma_tile_id_m][mma_tile_id_n]);
}
__syncthreads();
A += BLOCK_K;
B += BLOCK_K;
}
const int C_offset_m = offset_m + warp_tile_offset_m;
const int C_offset_n = offset_n + warp_tile_offset_n;
C += C_offset_m * N + C_offset_n;
// check output layout here
// https://docs.nvidia.com/cuda/parallel-thread-execution/#mma-1688-c-f16-f32
const int a0_row = lane_id >> 2;
const int a0_col = (lane_id % 4) * 2;
C += a0_row * N + a0_col;
for (int mma_tile_id_m = 0; mma_tile_id_m < NUM_MMA_M; mma_tile_id_m++)
for (int mma_tile_id_n = 0; mma_tile_id_n < NUM_MMA_N; mma_tile_id_n++)
{
T *C_local = C + mma_tile_id_m * MMA_M * N + mma_tile_id_n * MMA_N;
float *acc_frag = acc[mma_tile_id_m][mma_tile_id_n];
ushort2 tmp;
// write a0 and a1
tmp.x = f32_to_b16<T>(acc_frag[0]);
tmp.y = f32_to_b16<T>(acc_frag[1]);
reinterpret_cast<ushort2 *>(C_local)[0] = tmp;
// write a2 and a3
tmp.x = f32_to_b16<T>(acc_frag[2]);
tmp.y = f32_to_b16<T>(acc_frag[3]);
reinterpret_cast<ushort2 *>(C_local + 8 * N)[0] = tmp;
}
}
void matmul_v2(const nv_bfloat16 *A, const nv_bfloat16 *B, nv_bfloat16 *C, int M, int N, int K) {
assert(is_power_of_two(M) && "M must be a power of 2");
assert(is_power_of_two(N) && "N must be a power of 2");
assert(is_power_of_two(K) && "K must be a power of 2");
const int BLOCK_M = 128, BLOCK_N = 128, BLOCK_K = 32;
const int WARP_M = 64, WARP_N = 64, WARP_K = 32;
const int MMA_M = 16, MMA_N = 8, MMA_K = 8;
const int BLOCK_SIZE = (BLOCK_M * BLOCK_N) / (WARP_M * WARP_N) * WARP_SIZE;
const int grid_size = cdiv(M * N, BLOCK_M * BLOCK_N);
matmul_v2_kernel<
BLOCK_M, BLOCK_N, BLOCK_K,
WARP_M, WARP_N, WARP_K,
MMA_M, MMA_N, MMA_K><<<grid_size, BLOCK_SIZE>>>(A, B, C, M, N, K);
}