-
Notifications
You must be signed in to change notification settings - Fork 7
/
convert_t5_checkpoint.py
93 lines (66 loc) · 2.62 KB
/
convert_t5_checkpoint.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
import os
import argparse
import torch
if __name__ == '__main__':
try:
os.mkdir('t5-11b-new')
except:
pass
with open('t5-11b/pytorch_model.bin.index.json') as f:
index = json.load(f)
## emb
item = {}
item['shared.weight'] = torch.load(
't5-11b/' + index['weight_map']['shared.weight'],
map_location=torch.device('cpu'),
)['shared.weight']
torch.save(item, 't5-11b-new/pytorch_embs.pt')
## out
item = {}
item['final_layer_norm.weight'] = torch.load(
't5-11b/' + index['weight_map']['encoder.final_layer_norm.weight'],
map_location=torch.device('cpu'),
)['encoder.final_layer_norm.weight']
torch.save(item, 't5-11b-new/pytorch_enc_head.pt')
## out
item = {}
item['lm_head.weight'] = torch.load(
't5-11b/' + index['weight_map']['lm_head.weight'],
map_location=torch.device('cpu'),
)['lm_head.weight']
item['final_layer_norm.weight'] = torch.load(
't5-11b/' + index['weight_map']['decoder.final_layer_norm.weight'],
map_location=torch.device('cpu'),
)['decoder.final_layer_norm.weight']
torch.save(item, 't5-11b-new/pytorch_dec_head.pt')
## layers
for i in range(0, 24):
layer_prefix = f'encoder.block.{i}.'
item = {}
layer_maps = {k:v for k,v in index['weight_map'].items() if k.startswith(layer_prefix)}
layer_maps['layer.0.SelfAttention.relative_attention_bias.weight'] = index[
'weight_map']['encoder.block.0.layer.0.SelfAttention.relative_attention_bias.weight']
for k, v in layer_maps.items():
new_k = k.replace(layer_prefix, '')
item[new_k] = torch.load(
't5-11b/' + index['weight_map'][k],
map_location=torch.device('cpu'),
)[k]
torch.save(item, f't5-11b-new/pytorch_enc_{i}.pt')
del item
del item
del caches
for i in range(0, 24):
layer_prefix = f'decoder.block.{i}.'
item = {}
layer_maps = {k:v for k,v in index['weight_map'].items() if k.startswith(layer_prefix)}
layer_maps['layer.0.SelfAttention.relative_attention_bias.weight'] = index[
'weight_map']['decoder.block.0.layer.0.SelfAttention.relative_attention_bias.weight']
for k, v in layer_maps.items():
new_k = k.replace(layer_prefix, '')
item[new_k] = torch.load(
't5-11b/' + index['weight_map'][k],
map_location=torch.device('cpu'),
)[k]
torch.save(item, f't5-11b-new/pytorch_dec_{i}.pt')
del item