-
Notifications
You must be signed in to change notification settings - Fork 299
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Showing
3 changed files
with
295 additions
and
0 deletions.
There are no files selected for viewing
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,295 @@ | ||
from typing import Tuple | ||
|
||
import torch | ||
import triton | ||
import triton.language as tl | ||
|
||
|
||
@triton.autotune( | ||
configs=[ | ||
triton.Config( | ||
# B_T, H_D (8192), D (2048) | ||
{"BLOCK_M": BLOCK_M, "BLOCK_N": BLOCK_N, "BLOCK_K": BLOCK_K}, | ||
num_stages=num_stages, | ||
num_warps=num_warps, | ||
) | ||
for BLOCK_M in [32] | ||
for BLOCK_N in [128, 256] | ||
for BLOCK_K in [128, 256] | ||
for num_stages in [2, 4, 8] | ||
for num_warps in [8] | ||
], | ||
key=["B_T", "D", "H_D"], | ||
) | ||
@triton.jit | ||
def fused_ffn_fwd( | ||
x_ptr, | ||
w13_ptr, | ||
w2_ptr, | ||
output_ptr, | ||
p_ptr, | ||
B_T, | ||
stride_xa, | ||
stride_xb, | ||
stride_w13a, | ||
stride_w13b, | ||
stride_w2a, | ||
stride_w2b, | ||
stride_oa, | ||
stride_ob, | ||
stride_pa, | ||
stride_pb, | ||
D: tl.constexpr, | ||
H_D: tl.constexpr, | ||
BLOCK_M: tl.constexpr, | ||
BLOCK_N: tl.constexpr, | ||
BLOCK_K: tl.constexpr, | ||
): | ||
pid_m = tl.program_id(axis=0) | ||
dtype = x_ptr.dtype.element_ty | ||
|
||
X_block_ptr = tl.make_block_ptr( | ||
base=x_ptr, | ||
shape=(B_T, D), | ||
strides=(stride_xa, stride_xb), | ||
offsets=(pid_m * BLOCK_M, 0), | ||
block_shape=(BLOCK_M, BLOCK_K), | ||
order=(1, 0), | ||
) | ||
O_block_ptr = tl.make_block_ptr( | ||
base=output_ptr, | ||
shape=(B_T, D), | ||
strides=(stride_oa, stride_ob), | ||
offsets=(pid_m * BLOCK_M, 0), | ||
block_shape=(BLOCK_M, BLOCK_K), | ||
order=(1, 0), | ||
) | ||
|
||
for start_n in range(0, H_D, BLOCK_N): | ||
P_block_ptr = tl.make_block_ptr( | ||
base=p_ptr, | ||
shape=(B_T, H_D), | ||
strides=(stride_pa, stride_pb), | ||
offsets=(pid_m * BLOCK_M, start_n), | ||
block_shape=(BLOCK_M, BLOCK_N), | ||
order=(1, 0), | ||
) | ||
w1t_bptr = tl.make_block_ptr( | ||
base=w13_ptr, | ||
shape=(D, H_D), | ||
strides=(stride_w13b, stride_w13a), | ||
offsets=(0, start_n), | ||
block_shape=(BLOCK_K, BLOCK_N), | ||
order=(0, 1), | ||
) | ||
w3t_bptr = tl.make_block_ptr( | ||
base=w13_ptr, | ||
shape=(D, H_D), | ||
strides=(stride_w13b, stride_w13a), | ||
offsets=(0, H_D + start_n), | ||
block_shape=(BLOCK_K, BLOCK_N), | ||
order=(0, 1), | ||
) | ||
w2_bptr = tl.make_block_ptr( | ||
base=w2_ptr, | ||
shape=(H_D, D), | ||
strides=(stride_w2a, stride_w2b), | ||
offsets=(0, 0), | ||
block_shape=(BLOCK_N, BLOCK_K), | ||
order=(1, 0), | ||
) | ||
|
||
x_bptr = X_block_ptr | ||
o_bptr = O_block_ptr | ||
acc_1 = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32) | ||
acc_3 = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32) | ||
# first GEMM | ||
w1t_bptr_inner = w1t_bptr | ||
w3t_bptr_inner = w3t_bptr | ||
w2_bptr_inner = w2_bptr | ||
for _ in range(0, D, BLOCK_K): | ||
x = tl.load(x_bptr) | ||
w1t = tl.load(w1t_bptr_inner) | ||
w3t = tl.load(w3t_bptr_inner) | ||
acc_1 = tl.dot(x, w1t, acc_1) | ||
acc_3 = tl.dot(x, w3t, acc_3) | ||
x_bptr = tl.advance(x_bptr, (0, BLOCK_K)) | ||
w1t_bptr_inner = tl.advance(w1t_bptr_inner, (BLOCK_K, 0)) | ||
w3t_bptr_inner = tl.advance(w3t_bptr_inner, (BLOCK_K, 0)) | ||
# acc_1 = acc_1.to(dtype).to(tl.float32) | ||
# acc_3 = acc_3.to(dtype).to(tl.float32) | ||
p = acc_1 * tl.sigmoid(acc_1) * acc_3 | ||
p = p.to(dtype) | ||
tl.store(P_block_ptr, p) | ||
# second GEMM | ||
for _ in range(0, BLOCK_K, BLOCK_K): | ||
w2 = tl.load(w2_bptr) | ||
o = tl.load(o_bptr) | ||
tl.store(o_bptr, (tl.dot(p, w2) + o).to(dtype)) | ||
w2_bptr_inner = tl.advance(w2_bptr_inner, (0, BLOCK_K)) | ||
o_bptr = tl.advance(o_bptr, (0, BLOCK_K)) | ||
|
||
|
||
def fused_ffn( | ||
x: torch.Tensor, w13: torch.Tensor, w2: torch.Tensor | ||
) -> Tuple[torch.Tensor, torch.Tensor]: | ||
# x: [B_T, D] | ||
# w13: [H_D*2, D] | ||
# w2: [H_D, D] | ||
# output: [B_T, D] | ||
B_T, D = x.shape | ||
H_D_2, D = w13.shape | ||
H_D = w2.shape[0] | ||
assert H_D_2 == 2 * H_D, f"H_D_2 must be 2 times of H_D but got {H_D_2=} and {H_D=}" | ||
|
||
def grid(META): | ||
return (triton.cdiv(B_T, META["BLOCK_M"]),) | ||
|
||
output = torch.empty_like(x) | ||
p = torch.empty((B_T, H_D), dtype=x.dtype, device=x.device) | ||
|
||
fused_ffn_fwd[grid]( | ||
x, | ||
w13, | ||
w2, | ||
output, | ||
p, | ||
B_T, | ||
x.stride(0), | ||
x.stride(1), | ||
w13.stride(0), | ||
w13.stride(1), | ||
w2.stride(0), | ||
w2.stride(1), | ||
output.stride(0), | ||
output.stride(1), | ||
p.stride(0), | ||
p.stride(1), | ||
D, | ||
H_D, | ||
) | ||
|
||
return output | ||
|
||
|
||
@triton.jit | ||
# pyre-fixme[3]: Return type must be annotated. | ||
def _silu_mul_kernel( | ||
# pyre-fixme[2]: Parameter must be annotated. | ||
x1_ptr, | ||
x1_stride: tl.constexpr, | ||
# pyre-fixme[2]: Parameter must be annotated. | ||
x2_ptr, | ||
x2_stride: tl.constexpr, | ||
# pyre-fixme[2]: Parameter must be annotated. | ||
y_ptr, | ||
D: tl.constexpr, | ||
BLOCK_SIZE: tl.constexpr, | ||
): | ||
b = tl.program_id(0).to(tl.int64) | ||
|
||
x1_start = x1_ptr + b * x1_stride | ||
x2_start = x2_ptr + b * x2_stride | ||
y_start = y_ptr + b * D | ||
|
||
for offset in range(0, D, BLOCK_SIZE): | ||
cols = offset + tl.arange(0, BLOCK_SIZE) | ||
mask = cols < D | ||
x1v = tl.load(x1_start + cols, mask=mask, other=0).to(tl.float32) | ||
x2v = tl.load(x2_start + cols, mask=mask, other=0).to(tl.float32) | ||
yv = (x1v * tl.sigmoid(x1v) * x2v).to(tl.bfloat16) | ||
tl.store(y_start + cols, yv, mask=mask) | ||
|
||
|
||
sigmoid = torch.nn.Sigmoid() | ||
|
||
|
||
def silu_mul(x1: torch.Tensor, x2: torch.Tensor) -> torch.Tensor: | ||
assert x1.shape == x2.shape | ||
(B_T, D) = x1.shape | ||
out = torch.empty_like(x1) | ||
assert x1.stride(1) == x2.stride(1) == 1 | ||
assert out.is_contiguous() | ||
grid = (B_T,) | ||
_silu_mul_kernel[grid](x1, x1.stride(0), x2, x2.stride(0), out, D, BLOCK_SIZE=1024) | ||
return out | ||
|
||
|
||
def _ffn(x, w13, w2): | ||
p = x @ w13.T | ||
H_D_2, D = w13.shape | ||
H_D = H_D_2 // 2 | ||
p1 = p[:, :H_D] # B_T, H_D | ||
p2 = p[:, H_D:] # B_T, H_D | ||
p_out = silu_mul(p1, p2) # B_T, H_D | ||
out = p_out @ w2 | ||
return out | ||
|
||
|
||
def nunerics_check(shape): | ||
B_T, H_D, D = shape | ||
x = torch.randn((B_T, D), dtype=torch.bfloat16, device="cuda") | ||
w13 = torch.randn((H_D * 2, D), dtype=torch.bfloat16, device="cuda") | ||
w2 = torch.randn((H_D, D), dtype=torch.bfloat16, device="cuda") | ||
triton_out, triton_p = fused_ffn(x, w13, w2) | ||
eager_out, eager_p, ref_p = _ffn(x, w13, w2) | ||
|
||
print("P numeric check: ", torch.allclose(triton_p, eager_p, atol=1e-2, rtol=0)) | ||
print("P numeric check: ", torch.allclose(eager_p, ref_p, atol=1e-2, rtol=0)) | ||
# print(triton_p[-1]) | ||
# print(eager_p[-1]) | ||
# print(ref_p[-1]) | ||
|
||
|
||
def do_benchmark(): | ||
|
||
D = 2048 | ||
H_D = 8192 | ||
|
||
configs = [] | ||
configs.append( | ||
triton.testing.Benchmark( | ||
x_names=[ | ||
"B_T", | ||
"H_D", | ||
"D", | ||
], # Argument names to use as an x-axis for the plot | ||
x_vals=[ | ||
(i, H_D, D) for H_D, D in [(128, 256), (1024, 512), (8192, 2048)] for i in [1024, 2048, 4096, 8192, 16384] | ||
], # Different possible values for `x_name` | ||
line_arg="provider", # Argument name whose value corresponds to a different line in the plot | ||
# Possible values for `line_arg` | ||
# Don't compare to cublas for fp8 cases as torch.matmul doesn't support fp8 at the moment. | ||
line_vals=["eager", "fused"], | ||
line_names=["Eager", "Fused"], | ||
styles=[("green", "-"), ("blue", "-")], | ||
ylabel="Latency(ms)", # Label name for the y-axis | ||
plot_name="fused_ffn-benchmark", | ||
args={}, | ||
) | ||
) | ||
|
||
@triton.testing.perf_report(configs) | ||
def benchmark(B_T, H_D, D, provider): | ||
# breakpoint() | ||
x = torch.randn((B_T, D), dtype=torch.bfloat16, device="cuda") | ||
w13 = torch.randn((H_D * 2, D), dtype=torch.bfloat16, device="cuda") | ||
w2 = torch.randn((H_D, D), dtype=torch.bfloat16, device="cuda") | ||
quantiles = [0.5, 0.2, 0.8] | ||
if provider == "eager": | ||
return triton.testing.do_bench( | ||
lambda: _ffn(x, w13, w2), quantiles=quantiles | ||
) | ||
if provider == "fused": | ||
return triton.testing.do_bench( | ||
lambda: fused_ffn(x, w13, w2), quantiles=quantiles | ||
) | ||
|
||
benchmark.run(show_plots=True, print_data=True) | ||
|
||
|
||
if __name__ == "__main__": | ||
# B_T, H_D, D | ||
# nunerics_check((16, 128, 128)) | ||
# nunerics_check((256, 8192, 2048)) | ||
do_benchmark() |
Empty file.