Skip to content

Commit 2d73b33

Browse files
authored
Linear to Conv2d transform for static attention
Differential Revision: D70726317 Pull Request resolved: #9025
1 parent 75898bf commit 2d73b33

File tree

2 files changed

+97
-26
lines changed

2 files changed

+97
-26
lines changed

examples/models/llama/static_attention.py

+66-1
Original file line numberDiff line numberDiff line change
@@ -210,6 +210,7 @@ def __init__(self, config: ModelArgs, layer_id: int, rope: Rope):
210210
self.inv_scale = 1.0 / (float(self.head_dim) ** 0.5)
211211
self.attention_qkv_bias = config.attention_qkv_bias
212212
self.use_qk_norm = config.use_qk_norm
213+
self.use_conv2d = False
213214

214215
assert not self.use_qk_norm, "QK norm not supported in static attention yet"
215216
self.wqs = nn.ModuleList(
@@ -255,9 +256,25 @@ def forward(
255256
in_cache_state = kwargs.get("in_cache_state")
256257
out_cache_state = kwargs.get("out_cache_state")
257258

259+
bsz, seq_len, dim = x.shape
260+
if self.use_conv2d:
261+
x = x.reshape(bsz, seq_len, 1, dim).transpose(1, 3)
262+
258263
new_qs = [self.wqs[i](x) for i in range(self.n_heads)]
259264
new_ks = [self.wks[i](x) for i in range(self.n_kv_heads)]
260265
new_vs = [self.wvs[i](x) for i in range(self.n_kv_heads)]
266+
267+
if self.use_conv2d:
268+
269+
def from_conv2ds(ts):
270+
return [
271+
t.reshape(bsz, self.head_dim, seq_len).transpose(1, 2) for t in ts
272+
]
273+
274+
new_qs = from_conv2ds(new_qs)
275+
new_ks = from_conv2ds(new_ks)
276+
new_vs = from_conv2ds(new_vs)
277+
261278
new_qs = [self.rope(q, freqs_cos, freqs_sin) for q in new_qs]
262279
new_ks = [self.rope(k, freqs_cos, freqs_sin) for k in new_ks]
263280
all_ks = []
@@ -282,7 +299,14 @@ def forward(
282299
heads.append(attn @ all_vs[kv_idx])
283300

284301
y = torch.cat(heads, dim=-1)
285-
y = self.wo(y)
302+
if self.use_conv2d:
303+
y = (
304+
self.wo(y.reshape(bsz, seq_len, 1, -1).transpose(1, 3))
305+
.transpose(1, 3)
306+
.reshape(bsz, seq_len, -1)
307+
)
308+
else:
309+
y = self.wo(y)
286310
return y, {"out_cache_state": out_cache_state}
287311

288312
def load_weights_from_attention_mha(self, other: AttentionMHA):
@@ -300,3 +324,44 @@ def load_weights_from_attention_mha(self, other: AttentionMHA):
300324
)
301325

302326
self.wo.weight.data.copy_(other.wo.weight)
327+
328+
def linear_to_conv2d(self):
329+
def transfer_weight(linear, conv2d):
330+
conv2d.weight.data.copy_(linear.weight[:, :, None, None])
331+
return conv2d
332+
333+
self.wqs = nn.ModuleList(
334+
[
335+
transfer_weight(
336+
linear,
337+
nn.Conv2d(self.dim, self.head_dim, 1, bias=self.attention_qkv_bias),
338+
)
339+
for linear in self.wqs
340+
]
341+
)
342+
self.wks = nn.ModuleList(
343+
[
344+
transfer_weight(
345+
linear,
346+
nn.Conv2d(self.dim, self.head_dim, 1, bias=self.attention_qkv_bias),
347+
)
348+
for linear in self.wks
349+
]
350+
)
351+
self.wvs = nn.ModuleList(
352+
[
353+
transfer_weight(
354+
linear,
355+
nn.Conv2d(self.dim, self.head_dim, 1, bias=self.attention_qkv_bias),
356+
)
357+
for linear in self.wvs
358+
]
359+
)
360+
self.wo = transfer_weight(
361+
self.wo,
362+
nn.Conv2d(
363+
self.n_heads * self.head_dim, self.dim, 1, bias=self.attention_qkv_bias
364+
),
365+
)
366+
367+
self.use_conv2d = True

examples/models/llama/tests/test_static_attention.py

+31-25
Original file line numberDiff line numberDiff line change
@@ -17,32 +17,38 @@ def setUp(self):
1717
torch.manual_seed(42)
1818

1919
def test_without_cache(self):
20-
config = ModelArgs(
21-
dim=64,
22-
n_heads=4,
23-
n_kv_heads=2,
24-
max_seq_len=8,
25-
)
26-
layer_id = 0
27-
rope = Rope(config)
28-
attn_mha = AttentionMHA(config, layer_id, rope).eval()
29-
static_attn = StaticAttention(config, layer_id, rope).eval()
30-
static_attn.load_weights_from_attention_mha(attn_mha)
20+
def test(use_conv2d):
21+
config = ModelArgs(
22+
dim=64,
23+
n_heads=4,
24+
n_kv_heads=2,
25+
max_seq_len=8,
26+
)
27+
layer_id = 0
28+
rope = Rope(config)
29+
attn_mha = AttentionMHA(config, layer_id, rope).eval()
30+
static_attn = StaticAttention(config, layer_id, rope).eval()
31+
static_attn.load_weights_from_attention_mha(attn_mha)
32+
if use_conv2d:
33+
static_attn.linear_to_conv2d()
34+
35+
x = torch.rand(1, config.max_seq_len, config.dim)
36+
freqs_cos, freqs_sin = rope.get_freqs(None, config.max_seq_len)
37+
expected, _ = attn_mha(x, freqs_cos, freqs_sin)
38+
mask = torch.triu(
39+
torch.full((1, config.max_seq_len, config.max_seq_len), float("-inf")),
40+
diagonal=1,
41+
)
42+
y, _ = static_attn(
43+
x,
44+
freqs_cos,
45+
freqs_sin,
46+
mask=mask,
47+
)
48+
self.assertTrue(torch.isclose(y, expected, rtol=1e-3).all())
3149

32-
x = torch.rand(1, config.max_seq_len, config.dim)
33-
freqs_cos, freqs_sin = rope.get_freqs(None, config.max_seq_len)
34-
expected, _ = attn_mha(x, freqs_cos, freqs_sin)
35-
mask = torch.triu(
36-
torch.full((1, config.max_seq_len, config.max_seq_len), float("-inf")),
37-
diagonal=1,
38-
)
39-
y, _ = static_attn(
40-
x,
41-
freqs_cos,
42-
freqs_sin,
43-
mask=mask,
44-
)
45-
self.assertTrue(torch.isclose(y, expected, rtol=1e-3).all())
50+
test(True)
51+
test(False)
4652

4753
def test_hf_rope_without_cache(self):
4854
config = ModelArgs(

0 commit comments

Comments
 (0)