-
Notifications
You must be signed in to change notification settings - Fork 34
/
Copy pathllama2.py
380 lines (298 loc) · 14.5 KB
/
llama2.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
# based on model.py from https://github.com/karpathy/llama2.c by Andrej Karpathy, MIT licenced
# modifications by okuvshynov include:
# - no weight tying
# - using blackbox offloadable modules
# - simplify init/generation as we only use it for fine-tuning experiments
# - manual backprop
# - support for ffn_dim_multiplier which llama2-70b uses
# - LoRA
import logging
import math
from typing import Optional, Tuple
import torch
import torch.nn.functional as F
from torch import nn
from blackbox import BlackboxDisk
from utils import save_rng_state, restore_rng_state, device_map, cleanup_cache
from model_config import ModelArgs
import logging
class RMSNorm(torch.nn.Module):
def __init__(self, dim: int, eps: float):
super().__init__()
self.eps = eps
self.weight = nn.Parameter(torch.ones(dim))
def _norm(self, x):
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
def forward(self, x):
output = self._norm(x.float()).type_as(x)
return output * self.weight
def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0):
freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
t = torch.arange(end, device=freqs.device) # type: ignore
freqs = torch.outer(t, freqs).float() # type: ignore
freqs_cos = torch.cos(freqs) # real part
freqs_sin = torch.sin(freqs) # imaginary part
return freqs_cos, freqs_sin
def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor):
ndim = x.ndim
assert 0 <= 1 < ndim
assert freqs_cis.shape == (x.shape[1], x.shape[-1])
shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)]
return freqs_cis.view(shape)
def apply_rotary_emb(
xq: torch.Tensor,
xk: torch.Tensor,
freqs_cos: torch.Tensor,
freqs_sin: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]:
# reshape xq and xk to match the complex representation
xq_r, xq_i = xq.float().reshape(xq.shape[:-1] + (-1, 2)).unbind(-1)
xk_r, xk_i = xk.float().reshape(xk.shape[:-1] + (-1, 2)).unbind(-1)
# reshape freqs_cos and freqs_sin for broadcasting
freqs_cos = reshape_for_broadcast(freqs_cos, xq_r)
freqs_sin = reshape_for_broadcast(freqs_sin, xq_r)
# apply rotation using real numbers
xq_out_r = xq_r * freqs_cos - xq_i * freqs_sin
xq_out_i = xq_r * freqs_sin + xq_i * freqs_cos
xk_out_r = xk_r * freqs_cos - xk_i * freqs_sin
xk_out_i = xk_r * freqs_sin + xk_i * freqs_cos
# flatten last two dimensions
xq_out = torch.stack([xq_out_r, xq_out_i], dim=-1).flatten(3)
xk_out = torch.stack([xk_out_r, xk_out_i], dim=-1).flatten(3)
return xq_out.type_as(xq), xk_out.type_as(xk)
def repeat_kv(x: torch.Tensor, n_rep: int) -> torch.Tensor:
"""torch.repeat_interleave(x, dim=2, repeats=n_rep)"""
bs, slen, n_kv_heads, head_dim = x.shape
if n_rep == 1:
return x
return (
x[:, :, :, None, :]
.expand(bs, slen, n_kv_heads, n_rep, head_dim)
.reshape(bs, slen, n_kv_heads * n_rep, head_dim)
)
class Attention(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
self.n_kv_heads = args.n_heads if args.n_kv_heads is None else args.n_kv_heads
self.n_heads = args.n_heads
self.n_rep = self.n_heads // self.n_kv_heads
self.head_dim = args.dim // args.n_heads
# here's where we inject LoRA
self.wq = nn.Linear(args.dim, args.n_heads * self.head_dim, bias=False)
self.wk = nn.Linear(args.dim, self.n_kv_heads * self.head_dim, bias=False)
# here's where we inject LoRA
self.wv = nn.Linear(args.dim, self.n_kv_heads * self.head_dim, bias=False)
self.wo = nn.Linear(args.n_heads * self.head_dim, args.dim, bias=False)
# TODO: probably don't need dropout here as we don't plan to do full finetune
# or maybe we do.
self.attn_dropout = nn.Dropout(args.dropout)
self.resid_dropout = nn.Dropout(args.dropout)
self.dropout = args.dropout
# use flash attention or a manual implementation?
self.flash = hasattr(torch.nn.functional, 'scaled_dot_product_attention')
if not self.flash:
logging.warn("using slow attention. Flash Attention requires PyTorch >= 2.0")
mask = torch.full((1, 1, args.max_seq_len, args.max_seq_len), float("-inf"))
mask = torch.triu(mask, diagonal=1)
self.register_buffer("mask", mask)
self.attention_norm = RMSNorm(args.dim, eps=args.norm_eps)
def forward(
self,
x: torch.Tensor,
freqs_cos: torch.Tensor,
freqs_sin: torch.Tensor,
q_lora: nn.Module,
v_lora: nn.Module
):
bsz, seqlen, _ = x.shape
x_base = x
x = self.attention_norm(x)
# QKV
xq, xk, xv = self.wq(x) + q_lora(x), self.wk(x), self.wv(x) + v_lora(x)
xq = xq.view(bsz, seqlen, self.n_heads, self.head_dim)
xk = xk.view(bsz, seqlen, self.n_kv_heads, self.head_dim)
xv = xv.view(bsz, seqlen, self.n_kv_heads, self.head_dim)
# RoPE relative positional embeddings
xq, xk = apply_rotary_emb(xq, xk, freqs_cos, freqs_sin)
# grouped multiquery attention: expand out keys and values
xk = repeat_kv(xk, self.n_rep) # (bs, seqlen, n_heads, head_dim)
xv = repeat_kv(xv, self.n_rep) # (bs, seqlen, n_heads, head_dim)
# make heads into a batch dimension
xq = xq.transpose(1, 2) # (bs, n_heads, seqlen, head_dim)
xk = xk.transpose(1, 2)
xv = xv.transpose(1, 2)
# flash implementation
if self.flash:
output = torch.nn.functional.scaled_dot_product_attention(xq, xk, xv, attn_mask=None, dropout_p=self.dropout if self.training else 0.0, is_causal=True)
else:
# manual implementation
scores = torch.matmul(xq, xk.transpose(2, 3)) / math.sqrt(self.head_dim)
assert hasattr(self, 'mask')
scores = scores + self.mask[:, :, :seqlen, :seqlen] # (bs, n_heads, seqlen, cache_len + seqlen)
scores = F.softmax(scores.float(), dim=-1).type_as(xq)
scores = self.attn_dropout(scores)
output = torch.matmul(scores, xv) # (bs, n_heads, seqlen, head_dim)
# restore time as batch dimension and concat heads
output = output.transpose(1, 2).contiguous().view(bsz, seqlen, -1)
# final projection into the residual stream
output = self.wo(output)
output = self.resid_dropout(output)
return x_base + output
class FeedForward(nn.Module):
def __init__(self, dim: int, hidden_dim: int, multiple_of: int, dropout: float, ffn_dim_multiplier: Optional[float], args: ModelArgs):
super().__init__()
hidden_dim = int(2 * hidden_dim / 3)
if ffn_dim_multiplier is not None:
hidden_dim = int(ffn_dim_multiplier * hidden_dim)
hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of)
self.w1 = nn.Linear(dim, hidden_dim, bias=False)
self.w2 = nn.Linear(hidden_dim, dim, bias=False)
self.w3 = nn.Linear(dim, hidden_dim, bias=False)
self.dropout = nn.Dropout(dropout)
self.ffn_norm = RMSNorm(args.dim, eps=args.norm_eps)
def forward(self, x):
x_base = x
x = self.ffn_norm(x)
return x_base + self.dropout(self.w2(F.silu(self.w1(x)) * self.w3(x)))
class TransformerBlock(nn.Module):
def __init__(self, layer_id: int, args: ModelArgs):
super().__init__()
self.n_heads = args.n_heads
self.dim = args.dim
self.head_dim = args.dim // args.n_heads
self.attention = BlackboxDisk(Attention(args), args)
self.feed_forward = BlackboxDisk(FeedForward(
dim=args.dim,
hidden_dim=4 * args.dim,
multiple_of=args.multiple_of,
dropout=args.dropout,
ffn_dim_multiplier=args.ffn_dim_multiplier,
args=args
), args)
self.layer_id = layer_id
def forward(self, x, freqs_cos, freqs_sin, lora_q, lora_v):
h = self.attention(x, freqs_cos, freqs_sin, lora_q, lora_v)
out = self.feed_forward(h)
return out
class LoRA(nn.Module):
def __init__(self, original_layer, rank, alpha, dropout):
super().__init__()
n, m = original_layer.weight.shape
self.A = nn.Linear(m, rank, bias=False)
self.B = nn.Linear(rank, n, bias=False)
nn.init.zeros_(self.B.weight)
self.dropout = nn.Dropout(dropout)
self.scale = alpha / rank
# return matrix to add to original weight
def expanded(self):
res = self.B.weight.mm(self.A.weight) * self.scale
return res
def forward(self, x):
return self.dropout(self.B(self.A(x))) * self.scale
class Transformer(nn.Module):
def __init__(self, params: ModelArgs):
super().__init__()
self.params = params
self.vocab_size = params.vocab_size
self.n_layers = params.n_layers
self.tok_embeddings = BlackboxDisk(nn.Embedding(params.vocab_size, params.dim), params)
self.dropout = nn.Dropout(params.dropout)
self.layers = torch.nn.ModuleList()
# we create LoRA adapters separately. As we don't want to load/save them continously
self.lora_layers = []
for layer_id in range(params.n_layers):
block = TransformerBlock(layer_id, params)
# TODO: remove this one
attn = block.attention.loaded_inner()
q_lora = LoRA(attn.wq, rank=params.lora_rank, alpha=params.lora_alpha, dropout=params.lora_dropout).to(params.compute_dtype)
v_lora = LoRA(attn.wv, rank=params.lora_rank, alpha=params.lora_alpha, dropout=params.lora_dropout).to(params.compute_dtype)
self.lora_layers.append({ 'q_lora': q_lora, 'v_lora': v_lora})
self.add_module(f'q_lora_{layer_id}', q_lora)
self.add_module(f'v_lora_{layer_id}', v_lora)
self.layers.append(block)
logging.debug(f'created transformer block {layer_id}')
self.norm = RMSNorm(params.dim, eps=params.norm_eps)
self.norm.requires_grad = False
self.output = BlackboxDisk(nn.Linear(params.dim, params.vocab_size, bias=False), params)
# some useful precompute for the RoPE relative positional embeddings
freqs_cos, freqs_sin = precompute_freqs_cis(self.params.dim // self.params.n_heads, self.params.max_seq_len, theta=params.rope_theta)
self.register_buffer("freqs_cos", freqs_cos, persistent=False)
self.register_buffer("freqs_sin", freqs_sin, persistent=False)
def forward(self, tokens: torch.Tensor) -> torch.Tensor:
_bsz, seqlen = tokens.shape
# dummy input to force gradient propagation to blackbox modules
h = self.tok_embeddings(tokens)
h = self.dropout(h)
freqs_cos = self.freqs_cos[:seqlen]
freqs_sin = self.freqs_sin[:seqlen]
for layer, lora in zip(self.layers, self.lora_layers):
h = layer(h, freqs_cos, freqs_sin, lora['q_lora'], lora['v_lora'])
h = self.norm(h)
return self.output(h[:, [-1], :])
def backprop_w_lora(self, blackbox_module, output_grad, *args):
device = output_grad.device
module = blackbox_module.load(device)
# we use LoRA and only updated attached low-rank modules
# no part of original model is getting any updates, so no need for gradient
for param in module.parameters():
param.requires_grad = False
input = blackbox_module.load_input(device)
input.requires_grad = True
output = module(input, *args)
output.backward(output_grad)
return input.grad if input.requires_grad else None
# this is a manual implementation on forward/backward passes
def manual_loop(self, tokens, targets):
logging.log(level=logging.DEBUG, msg=f'starting manual loop')
device = device_map(tokens.device)
embd_out = self.tok_embeddings(tokens)
embd_out = embd_out.detach()
embd_out.requires_grad = True
logging.log(level=logging.DEBUG, msg=f'done embedding')
_, seqlen = tokens.shape
freqs_cos = self.freqs_cos[:seqlen]
freqs_sin = self.freqs_sin[:seqlen]
current = self.dropout(embd_out)
del embd_out
rng_before = []
for i, (layer, lora) in enumerate(zip(self.layers, self.lora_layers)):
rng_before.append(save_rng_state(device))
current = layer(current, freqs_cos, freqs_sin, lora['q_lora'], lora['v_lora'])
logging.log(level=logging.DEBUG, msg=f'forward: transformer block {i} done')
current = current.detach()
current.requires_grad = True
norm_out = self.norm(current)
norm_out = norm_out.detach()
norm_out.requires_grad = True
# TODO: micro-optimization: as output is last layer, we can skip loading and running it second time
logging.log(level=logging.DEBUG, msg=f'output layer')
logits = self.output(norm_out)
del norm_out
logging.log(level=logging.DEBUG, msg=f'output layer done')
if (self.params.compute_dtype != torch.float32):
logits = logits.to(torch.float32)
logits = logits.detach()
logits.requires_grad = True
loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1), ignore_index=-1)
logging.log(level=logging.DEBUG, msg=f'forward: computed loss')
loss.backward()
norm_out_grad = self.backprop_w_lora(self.output, logits.grad.to(self.params.compute_dtype))
del logits
logging.log(level=logging.DEBUG, msg=f'combined: output layer done')
norm_out2 = self.norm(current)
norm_out2.backward(norm_out_grad)
del norm_out_grad
del norm_out2
last_grad = current.grad
del current
for i, (layer, rng_state, lora) in enumerate(zip(reversed(self.layers), reversed(rng_before), reversed(self.lora_layers))):
cleanup_cache(device)
restore_rng_state(rng_state, device=device)
# first, do feed_forward
last_grad = self.backprop_w_lora(layer.feed_forward, last_grad)
# now, do attention
cleanup_cache(device)
last_grad = self.backprop_w_lora(layer.attention, last_grad, freqs_cos, freqs_sin, lora['q_lora'], lora['v_lora'])
logging.log(level=logging.DEBUG, msg=f'combined: transformer block {i} done')
# no need to backpropagate through embeddings no LoRA layers there.
return loss.item()