-
Notifications
You must be signed in to change notification settings - Fork 7
/
convert_gptneox_checkpoint.py
71 lines (51 loc) · 2.02 KB
/
convert_gptneox_checkpoint.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
import os
import argparse
import torch
import json
if __name__ == '__main__':
input_path = '/root/fm/models/GPT-NeoXT-20B-chat-v0.9.2'
output_path = '/root/fm/models/GPT-NeoXT-20B-chat-v0.9.2-shard'
try:
os.mkdir(output_path)
except:
pass
os.system(f'cp {input_path}/*.json {output_path}/')
with open(f'{input_path}/pytorch_model.bin.index.json') as f:
index = json.load(f)
## emb
item = {}
item['embed_in.weight'] = torch.load(
f'{input_path}/' + index['weight_map']['gpt_neox.embed_in.weight'],
map_location=torch.device('cpu'),
)['gpt_neox.embed_in.weight']
torch.save(item, f'{output_path}/pytorch_embs.pt')
## out
item = {}
item['embed_out.weight'] = torch.load(
f'{input_path}/' + index['weight_map']['embed_out.weight'],
map_location=torch.device('cpu'),
)['embed_out.weight']
item['final_layer_norm.weight'] = torch.load(
f'{input_path}/' + index['weight_map']['gpt_neox.final_layer_norm.weight'],
map_location=torch.device('cpu'),
)['gpt_neox.final_layer_norm.weight']
item['final_layer_norm.bias'] = torch.load(
f'{input_path}/' + index['weight_map']['gpt_neox.final_layer_norm.bias'],
map_location=torch.device('cpu'),
)['gpt_neox.final_layer_norm.bias']
torch.save(item, f'{output_path}/pytorch_lm_head.pt')
## layers
for i in range(0, 44):
layer_prefix = f'gpt_neox.layers.{i}.'
item = {}
layer_maps = {k:v for k,v in index['weight_map'].items() if k.startswith(layer_prefix)}
caches = {}
for k, v in layer_maps.items():
new_k = k.replace(layer_prefix, '')
to_read = f'{input_path}/' + index['weight_map'][k]
if to_read not in caches:
caches[to_read] = torch.load(to_read,map_location=torch.device('cpu'))
item[new_k] = caches[to_read][k]
torch.save(item, f'{output_path}/pytorch_{i}.pt')
del item
del caches