forked from aws-neuron/nki-samples
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtest_flash_attn_fwd.py
137 lines (118 loc) · 6.86 KB
/
test_flash_attn_fwd.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
"""
Copyright (c) 2023, Amazon.com. All Rights Reserved
"""
import pytest
from neuronxcc.nki.kernels.attention import flash_fwd, FlashConfig
from neuronxcc.nki import benchmark, baremetal
import neuronxcc.nki.language as nl
import numpy as np
numeric_func = baremetal(flash_fwd)
bench_func = benchmark(warmup=5, iters=10)(flash_fwd)
def softmax(x: np.ndarray, dim: int, zero_max_mode=False,
mixed_precision=False, return_max_reduce=False):
max_value = np.amax(x, axis=dim, keepdims=True)
max_value = np.maximum(0, max_value) if zero_max_mode else max_value
exp = np.exp(x - max_value)
if mixed_precision:
reduce = np.add.reduce(exp.astype(np.float32), axis=dim, keepdims=True).astype(x.dtype)
else:
reduce = np.add.reduce(exp, axis=dim, keepdims=True)
if return_max_reduce:
return exp / reduce, -max_value, np.reciprocal(reduce)
return exp / reduce
def cpu_attention_forward(q, k, v, use_causal_mask=True, mixed_precision=True):
def mixed_precision_matmul(a, b):
input_dtype = a.dtype
a, b = a.astype(np.float32), b.astype(np.float32)
c = np.matmul(a, b)
return c.astype(input_dtype)
_, _, d, _ = q.shape
# Compute golden output
softmax_scale = 1.0 / (d ** 0.5)
q_scaled = q * softmax_scale
nheads = q.shape[1]
kv_heads = k.shape[1]
if nheads > kv_heads:
k = np.repeat(k, nheads//kv_heads, axis=1)
v = np.repeat(v, nheads//kv_heads, axis=1)
raw_score = mixed_precision_matmul(q_scaled.transpose(0, 1, 3, 2), k)
if use_causal_mask:
# raw_score has K seq in the most inner dim
# we want to mask all elements where Q idx is smaller than K idx with -inf
# this maps to the upper triangle of the final two axes
for i in range(raw_score.shape[0]):
for j in range(raw_score.shape[1]):
# -inf triggers invalid input error in softmax implementation, use a small negative instead
# k=1 to exclude the diagonal, because each token can still attend to itself
raw_score[i, j][np.triu_indices_from(raw_score[i, j], k=1)] = -9984.0
norm_score, cached_negative_max, cached_sum_reciprocal = \
softmax(raw_score, dim=-1, mixed_precision=mixed_precision, return_max_reduce=True)
# Transpose the result so it has the same layout as ours
out_golden = mixed_precision_matmul(norm_score, v.transpose(0, 1, 3, 2)).transpose(0, 1, 3, 2)
return out_golden, cached_negative_max, cached_sum_reciprocal
class TestAttention:
@pytest.mark.parametrize("bs, nheads, seqlen, d, dtype, use_causal_mask,\
mixed_precision, training, tile_size, kv_heads, should_transpose_v, latency", [
[1, 6, 32*1024, 96, nl.bfloat16, True, True, True, 2048, 3, False, 87000000000],
[1, 1, 32*1024, 96, nl.bfloat16, True, True, False, 2048, None, False, 15100000000],
])
def test_flash_attn_fwd_perf(self, bs, nheads, seqlen, d, dtype, use_causal_mask,
mixed_precision, training, tile_size, kv_heads, should_transpose_v,latency):
q = (np.random.random_sample([bs, nheads, d, seqlen]) - 0.5) * 2
k = (np.random.random_sample([bs, nheads, d, seqlen]) - 0.5) * 2
if should_transpose_v:
v = (np.random.random_sample([bs, nheads, d, seqlen]) - 0.5) * 2
else:
v = (np.random.random_sample([bs, nheads, seqlen, d]) - 0.5) * 2
o_proj = np.zeros(shape=[bs, nheads, seqlen, d], dtype=dtype)
out_lse = np.zeros(shape=[bs, nheads, int(nl.tile_size.pmax), seqlen // nl.tile_size.pmax],
dtype=nl.float32 if mixed_precision else dtype) if training else None
seed = None
q = nl.static_cast(q, dtype)
k = nl.static_cast(k, dtype)
v = nl.static_cast(v, dtype)
o_proj = nl.static_cast(o_proj, dtype)
config = FlashConfig(**{'seq_tile_size':tile_size, 'training':training, 'should_transpose_v':should_transpose_v})
heads = nheads if kv_heads is None else kv_heads
bench_func[bs, heads](q, k, v, seed, o_proj, out_lse,
use_causal_mask=use_causal_mask, mixed_precision=mixed_precision, config=config)
latency_res = bench_func.benchmark_result.nc_latency
p99 = latency_res.get_latency_percentile(99)
assert p99 <= latency
@pytest.mark.parametrize("bs, nheads, seqlen, d, dtype, use_causal_mask,\
training, tile_size, kv_heads, should_transpose_v", [
[1, 6, 4096, 128, np.float32, True, True, 2048, 3, False],
[1, 1, 4096, 128, np.float32, True, False, 2048, None, False],
])
def test_flash_attn_fwd_numerical(self, bs, nheads, seqlen, d, dtype, use_causal_mask,
training, tile_size, kv_heads, should_transpose_v):
q = (np.random.random_sample([bs, nheads, d, seqlen]) - 0.5) * 2
k = (np.random.random_sample([bs, kv_heads or nheads, d, seqlen]) - 0.5) * 2
if should_transpose_v:
v = (np.random.random_sample([bs, nheads, d, seqlen]) - 0.5) * 2
cpu_permute = (0, 1, 2, 3)
else:
v = (np.random.random_sample([bs, kv_heads or nheads, seqlen, d]) - 0.5) * 2
cpu_permute = (0, 1, 3, 2)
o_proj = np.zeros(shape=[bs, nheads, seqlen, d], dtype=dtype)
q = nl.static_cast(q, dtype)
k = nl.static_cast(k, dtype)
v = nl.static_cast(v, dtype)
seed = None
out_lse = np.zeros(shape=[bs, nheads, int(nl.tile_size.pmax), seqlen // nl.tile_size.pmax],
dtype=np.float32) if training else None
o_proj_golden, cached_negative_max, cached_sum_reciprocal = \
cpu_attention_forward(q, k, v.transpose(cpu_permute), use_causal_mask=use_causal_mask,mixed_precision=True)
o_proj_golden = o_proj_golden.transpose(0,1,3,2) # (b,h, d, seq)
cached_negative_max = cached_negative_max.reshape(bs, nheads, seqlen // nl.tile_size.pmax,
nl.tile_size.pmax).transpose(0, 1, 3, 2)
cached_sum_reciprocal = cached_sum_reciprocal.reshape(bs, nheads, seqlen // nl.tile_size.pmax,
nl.tile_size.pmax).transpose(0, 1, 3, 2)
lse_golden = -1.0 * (cached_negative_max + np.log(cached_sum_reciprocal)) if training else None
config = FlashConfig(**{'seq_tile_size':tile_size, 'training':training, 'should_transpose_v':should_transpose_v})
heads = nheads if kv_heads is None else kv_heads
numeric_func[bs, heads](q, k, v, seed, o_proj, out_lse, seed,
use_causal_mask=use_causal_mask, mixed_precision=True, config=config)
assert np.allclose(o_proj, o_proj_golden, atol=1e-2)
if training:
assert np.allclose(out_lse, lse_golden, atol=1e-2)