@@ -210,6 +210,7 @@ def __init__(self, config: ModelArgs, layer_id: int, rope: Rope):
210
210
self .inv_scale = 1.0 / (float (self .head_dim ) ** 0.5 )
211
211
self .attention_qkv_bias = config .attention_qkv_bias
212
212
self .use_qk_norm = config .use_qk_norm
213
+ self .use_conv2d = False
213
214
214
215
assert not self .use_qk_norm , "QK norm not supported in static attention yet"
215
216
self .wqs = nn .ModuleList (
@@ -255,9 +256,25 @@ def forward(
255
256
in_cache_state = kwargs .get ("in_cache_state" )
256
257
out_cache_state = kwargs .get ("out_cache_state" )
257
258
259
+ bsz , seq_len , dim = x .shape
260
+ if self .use_conv2d :
261
+ x = x .reshape (bsz , seq_len , 1 , dim ).transpose (1 , 3 )
262
+
258
263
new_qs = [self .wqs [i ](x ) for i in range (self .n_heads )]
259
264
new_ks = [self .wks [i ](x ) for i in range (self .n_kv_heads )]
260
265
new_vs = [self .wvs [i ](x ) for i in range (self .n_kv_heads )]
266
+
267
+ if self .use_conv2d :
268
+
269
+ def from_conv2ds (ts ):
270
+ return [
271
+ t .reshape (bsz , self .head_dim , seq_len ).transpose (1 , 2 ) for t in ts
272
+ ]
273
+
274
+ new_qs = from_conv2ds (new_qs )
275
+ new_ks = from_conv2ds (new_ks )
276
+ new_vs = from_conv2ds (new_vs )
277
+
261
278
new_qs = [self .rope (q , freqs_cos , freqs_sin ) for q in new_qs ]
262
279
new_ks = [self .rope (k , freqs_cos , freqs_sin ) for k in new_ks ]
263
280
all_ks = []
@@ -282,7 +299,14 @@ def forward(
282
299
heads .append (attn @ all_vs [kv_idx ])
283
300
284
301
y = torch .cat (heads , dim = - 1 )
285
- y = self .wo (y )
302
+ if self .use_conv2d :
303
+ y = (
304
+ self .wo (y .reshape (bsz , seq_len , 1 , - 1 ).transpose (1 , 3 ))
305
+ .transpose (1 , 3 )
306
+ .reshape (bsz , seq_len , - 1 )
307
+ )
308
+ else :
309
+ y = self .wo (y )
286
310
return y , {"out_cache_state" : out_cache_state }
287
311
288
312
def load_weights_from_attention_mha (self , other : AttentionMHA ):
@@ -300,3 +324,44 @@ def load_weights_from_attention_mha(self, other: AttentionMHA):
300
324
)
301
325
302
326
self .wo .weight .data .copy_ (other .wo .weight )
327
+
328
+ def linear_to_conv2d (self ):
329
+ def transfer_weight (linear , conv2d ):
330
+ conv2d .weight .data .copy_ (linear .weight [:, :, None , None ])
331
+ return conv2d
332
+
333
+ self .wqs = nn .ModuleList (
334
+ [
335
+ transfer_weight (
336
+ linear ,
337
+ nn .Conv2d (self .dim , self .head_dim , 1 , bias = self .attention_qkv_bias ),
338
+ )
339
+ for linear in self .wqs
340
+ ]
341
+ )
342
+ self .wks = nn .ModuleList (
343
+ [
344
+ transfer_weight (
345
+ linear ,
346
+ nn .Conv2d (self .dim , self .head_dim , 1 , bias = self .attention_qkv_bias ),
347
+ )
348
+ for linear in self .wks
349
+ ]
350
+ )
351
+ self .wvs = nn .ModuleList (
352
+ [
353
+ transfer_weight (
354
+ linear ,
355
+ nn .Conv2d (self .dim , self .head_dim , 1 , bias = self .attention_qkv_bias ),
356
+ )
357
+ for linear in self .wvs
358
+ ]
359
+ )
360
+ self .wo = transfer_weight (
361
+ self .wo ,
362
+ nn .Conv2d (
363
+ self .n_heads * self .head_dim , self .dim , 1 , bias = self .attention_qkv_bias
364
+ ),
365
+ )
366
+
367
+ self .use_conv2d = True
0 commit comments