-
Notifications
You must be signed in to change notification settings - Fork 881
/
convert.py
398 lines (332 loc) · 13.8 KB
/
convert.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
# Copyright © 2023-2024 Apple Inc.
import argparse
import copy
import hashlib
import json
import os
import urllib
import warnings
from dataclasses import asdict
from pathlib import Path
from typing import List
import mlx.core as mx
import mlx.nn as nn
import numpy as np
import torch
from mlx.utils import tree_flatten, tree_map, tree_unflatten
from mlx_whisper import torch_whisper
from mlx_whisper.whisper import ModelDimensions, Whisper
from tqdm import tqdm
_VALID_DTYPES = {"float16", "float32"}
_MODELS = {
"tiny.en": "https://openaipublic.azureedge.net/main/whisper/models/d3dd57d32accea0b295c96e26691aa14d8822fac7d9d27d5dc00b4ca2826dd03/tiny.en.pt",
"tiny": "https://openaipublic.azureedge.net/main/whisper/models/65147644a518d12f04e32d6f3b26facc3f8dd46e5390956a9424a650c0ce22b9/tiny.pt",
"base.en": "https://openaipublic.azureedge.net/main/whisper/models/25a8566e1d0c1e2231d1c762132cd20e0f96a85d16145c3a00adf5d1ac670ead/base.en.pt",
"base": "https://openaipublic.azureedge.net/main/whisper/models/ed3a0b6b1c0edf879ad9b11b1af5a0e6ab5db9205f891f668f8b0e6c6326e34e/base.pt",
"small.en": "https://openaipublic.azureedge.net/main/whisper/models/f953ad0fd29cacd07d5a9eda5624af0f6bcf2258be67c92b79389873d91e0872/small.en.pt",
"small": "https://openaipublic.azureedge.net/main/whisper/models/9ecf779972d90ba49c06d968637d720dd632c55bbf19d441fb42bf17a411e794/small.pt",
"medium.en": "https://openaipublic.azureedge.net/main/whisper/models/d7440d1dc186f76616474e0ff0b3b6b879abc9d1a4926b7adfa41db2d497ab4f/medium.en.pt",
"medium": "https://openaipublic.azureedge.net/main/whisper/models/345ae4da62f9b3d59415adc60127b97c714f32e89e936602e85993674d08dcb1/medium.pt",
"large-v1": "https://openaipublic.azureedge.net/main/whisper/models/e4b87e7e0bf463eb8e6956e646f1e277e901512310def2c24bf0e11bd3c28e9a/large-v1.pt",
"large-v2": "https://openaipublic.azureedge.net/main/whisper/models/81f7c96c852ee8fc832187b0132e569d6c3065a3252ed18e56effd0b6a73e524/large-v2.pt",
"large-v3": "https://openaipublic.azureedge.net/main/whisper/models/e5b1a55b89c1367dacf97e3e19bfd829a01529dbfdeefa8caeb59b3f1b81dadb/large-v3.pt",
"large": "https://openaipublic.azureedge.net/main/whisper/models/e5b1a55b89c1367dacf97e3e19bfd829a01529dbfdeefa8caeb59b3f1b81dadb/large-v3.pt",
"large-v3-turbo": "https://openaipublic.azureedge.net/main/whisper/models/aff26ae408abcba5fbf8813c21e62b0941638c5f6eebfb145be0c9839262a19a/large-v3-turbo.pt",
"turbo": "https://openaipublic.azureedge.net/main/whisper/models/aff26ae408abcba5fbf8813c21e62b0941638c5f6eebfb145be0c9839262a19a/large-v3-turbo.pt",
}
# base85-encoded (n_layers, n_heads) boolean arrays indicating the cross-attention heads that are
# highly correlated to the word-level timing, i.e. the alignment between audio and text tokens.
_ALIGNMENT_HEADS = {
"tiny.en": b"ABzY8J1N>@0{>%R00Bk>$p{7v037`oCl~+#00",
"tiny": b"ABzY8bu8Lr0{>%RKn9Fp%m@SkK7Kt=7ytkO",
"base.en": b"ABzY8;40c<0{>%RzzG;p*o+Vo09|#PsxSZm00",
"base": b"ABzY8KQ!870{>%RzyTQH3`Q^yNP!>##QT-<FaQ7m",
"small.en": b"ABzY8>?_)10{>%RpeA61k&I|OI3I$65C{;;pbCHh0B{qLQ;+}v00",
"small": b"ABzY8DmU6=0{>%Rpa?J`kvJ6qF(V^F86#Xh7JUGMK}P<N0000",
"medium.en": b"ABzY8usPae0{>%R7<zz_OvQ{)4kMa0BMw6u5rT}kRKX;$NfYBv00*Hl@qhsU00",
"medium": b"ABzY8B0Jh+0{>%R7}kK1fFL7w6%<-Pf*t^=N)Qr&0RR9",
"large-v1": b"ABzY8r9j$a0{>%R7#4sLmoOs{s)o3~84-RPdcFk!JR<kSfC2yj",
"large-v2": b"ABzY8zd+h!0{>%R7=D0pU<_bnWW*tkYAhobTNnu$jnkEkXqp)j;w1Tzk)UH3X%SZd&fFZ2fC2yj",
"large-v3": b"ABzY8gWO1E0{>%R7(9S+Kn!D~%ngiGaR?*L!iJG9p-nab0JQ=-{D1-g00",
"large": b"ABzY8gWO1E0{>%R7(9S+Kn!D~%ngiGaR?*L!iJG9p-nab0JQ=-{D1-g00",
"large-v3-turbo": b"ABzY8j^C+e0{>%RARaKHP%t(lGR*)0g!tONPyhe`",
"turbo": b"ABzY8j^C+e0{>%RARaKHP%t(lGR*)0g!tONPyhe`",
}
def _download(url: str, root: str) -> str:
os.makedirs(root, exist_ok=True)
expected_sha256 = url.split("/")[-2]
download_target = os.path.join(root, os.path.basename(url))
if os.path.exists(download_target) and not os.path.isfile(download_target):
raise RuntimeError(f"{download_target} exists and is not a regular file")
if os.path.isfile(download_target):
with open(download_target, "rb") as f:
model_bytes = f.read()
if hashlib.sha256(model_bytes).hexdigest() == expected_sha256:
return download_target
else:
warnings.warn(
f"{download_target} exists, but the SHA256 checksum does not match; re-downloading the file"
)
with urllib.request.urlopen(url) as source, open(download_target, "wb") as output:
with tqdm(
total=int(source.info().get("Content-Length")),
ncols=80,
unit="iB",
unit_scale=True,
unit_divisor=1024,
) as loop:
while True:
buffer = source.read(8192)
if not buffer:
break
output.write(buffer)
loop.update(len(buffer))
with open(download_target, "rb") as fid:
model_bytes = fid.read()
if hashlib.sha256(model_bytes).hexdigest() != expected_sha256:
raise RuntimeError(
"Model has been downloaded but the SHA256 checksum does not not match. Please retry loading the model."
)
return download_target
def available_models() -> List[str]:
"""Returns the names of available models"""
return list(_MODELS.keys())
def hf_to_pt(weights, config):
config = {
"n_mels": config["num_mel_bins"],
"n_audio_ctx": config["max_source_positions"],
"n_audio_state": config["d_model"],
"n_audio_head": config["encoder_attention_heads"],
"n_audio_layer": config["encoder_layers"],
"n_vocab": config["vocab_size"],
"n_text_ctx": config["max_target_positions"],
"n_text_state": config["d_model"],
"n_text_head": config["decoder_attention_heads"],
"n_text_layer": config["decoder_layers"],
}
def remap(k):
k = k.replace("model.", "")
k = k.replace(".layers", ".blocks")
k = k.replace(".self_attn", ".attn")
k = k.replace(".attn_layer_norm", ".attn_ln")
k = k.replace(".encoder_attn.", ".cross_attn.")
k = k.replace(".encoder_attn_layer_norm", ".cross_attn_ln")
k = k.replace(".final_layer_norm", ".mlp_ln")
k = k.replace(".q_proj", ".query")
k = k.replace(".k_proj", ".key")
k = k.replace(".v_proj", ".value")
k = k.replace(".out_proj", ".out")
k = k.replace(".fc1", ".mlp1")
k = k.replace(".fc2", ".mlp2")
k = k.replace("embed_positions.weight", "positional_embedding")
k = k.replace("decoder.embed_tokens", "decoder.token_embedding")
k = k.replace("encoder.layer_norm", "encoder.ln_post")
k = k.replace("decoder.layer_norm", "decoder.ln")
return k
# token embeddings are shared with output projection
weights.pop("proj_out.weight", None)
weights = {remap(k): v for k, v in weights.items()}
return weights, config
def load_torch_weights_and_config(
name_or_path: str,
download_root: str = None,
):
if download_root is None:
download_root = os.path.join(os.path.expanduser("~"), ".cache/whisper")
# todo: accept alignment_heads of local Pytorch checkpoint
alignment_heads = None
if name_or_path in _MODELS:
alignment_heads = _ALIGNMENT_HEADS[name_or_path]
name_or_path = _download(_MODELS[name_or_path], download_root)
elif not Path(name_or_path).exists():
# Try downloading from HF
from huggingface_hub import snapshot_download
name_or_path = snapshot_download(
repo_id=name_or_path,
allow_patterns=[
"*.json",
"pytorch_model.bin",
"model.safetensors",
"*.txt",
],
)
else:
raise RuntimeError(
f"Model {name_or_path} is not found in {available_models()},"
"on Hugging Face or as a local path."
)
if name_or_path.endswith(".pt"):
checkpoint = torch.load(name_or_path, map_location="cpu", weights_only=False)
weights, config = checkpoint["model_state_dict"], checkpoint["dims"]
else:
name_or_path = Path(name_or_path)
pt_path = name_or_path / "pytorch_model.bin"
if pt_path.is_file():
weights = torch.load(pt_path, map_location="cpu")
else:
weights = mx.load(str(name_or_path / "model.safetensors"))
with open(name_or_path / "config.json", "r") as fp:
config = json.load(fp)
weights, config = hf_to_pt(weights, config)
return weights, config, alignment_heads
def load_torch_model(
name_or_path: str,
download_root: str = None,
) -> torch_whisper.Whisper:
"""
Load a Whisper ASR model
Parameters
----------
name_or_path : str
one of the official model names listed by `whisper.available_models()` or
a local Pytorch checkpoint which is in the original OpenAI format
download_root: str
path to download the model files; by default, it uses "~/.cache/whisper"
Returns
-------
model : Whisper
The Whisper ASR model instance
"""
if download_root is None:
download_root = os.path.join(os.path.expanduser("~"), ".cache/whisper")
weights, config, alignment_heads = load_torch_weights_and_config(
name_or_path, download_root
)
dims = torch_whisper.ModelDimensions(**config)
model = torch_whisper.Whisper(dims)
model.load_state_dict(weights)
if alignment_heads is not None:
model.set_alignment_heads(alignment_heads)
return model
def convert(name_or_path: str, dtype: mx.Dtype = mx.float16):
def remap(key, value):
key = key.replace("mlp.0", "mlp1")
key = key.replace("mlp.2", "mlp2")
if "conv" in key and value.ndim == 3:
value = value.swapaxes(1, 2)
if isinstance(value, torch.Tensor):
value = mx.array(value.detach())
return key, value.astype(dtype)
weights, config, alignment_heads = load_torch_weights_and_config(name_or_path)
weights.pop("encoder.positional_embedding", None)
weights = dict(remap(k, v) for k, v in weights.items())
model_dims = ModelDimensions(**config)
model = Whisper(model_dims, dtype)
model.load_weights(list(weights.items()), strict=False)
if alignment_heads is not None:
model.set_alignment_heads(alignment_heads)
return model
def upload_to_hub(path: str, name: str, torch_name_or_path: str):
import os
from huggingface_hub import HfApi, ModelCard, logging
repo_id = f"mlx-community/{name}"
text = f"""
---
library_name: mlx
---
# {name}
This model was converted to MLX format from [`{torch_name_or_path}`]().
## Use with mlx
```bash
pip install mlx-whisper
```
```python
import mlx_whisper
result = mlx_whisper.transcribe(
"FILE_NAME",
path_or_hf_repo={repo_id},
)
```
"""
card = ModelCard(text)
card.save(os.path.join(path, "README.md"))
logging.set_verbosity_info()
api = HfApi()
api.create_repo(repo_id=repo_id, exist_ok=True)
api.upload_folder(
folder_path=path,
repo_id=repo_id,
repo_type="model",
)
def quantize(weights, config, args):
quantized_config = copy.deepcopy(config)
# Load the model:
model = Whisper(ModelDimensions(**config))
weights = tree_map(mx.array, weights)
model.update(tree_unflatten(list(weights.items())))
# Quantize the model:
nn.quantize(model, args.q_group_size, args.q_bits)
# Update the config:
quantized_config["quantization"] = {
"group_size": args.q_group_size,
"bits": args.q_bits,
}
quantized_weights = dict(tree_flatten(model.parameters()))
return quantized_weights, quantized_config
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Convert Whisper weights to MLX.")
parser.add_argument(
"--torch-name-or-path",
type=str,
default="tiny",
help="The name or path to the PyTorch model.",
)
parser.add_argument(
"--mlx-path",
type=str,
default="mlx_models",
help="The path to save the MLX model.",
)
parser.add_argument(
"--dtype",
type=str,
default="float16",
help="The dtype to save the MLX model.",
)
parser.add_argument(
"-q",
"--quantize",
help="Generate a quantized model.",
action="store_true",
)
parser.add_argument(
"--q-group-size",
help="Group size for quantization.",
type=int,
default=64,
)
parser.add_argument(
"--q-bits",
help="Bits per weight for quantization.",
type=int,
default=4,
)
parser.add_argument(
"--upload-name",
help="The name of model to upload to Hugging Face MLX Community",
type=str,
default=None,
)
args = parser.parse_args()
assert (
args.dtype in _VALID_DTYPES
), f"dtype {args.dtype} not found in {_VALID_DTYPES}"
dtype = getattr(mx, args.dtype)
print("[INFO] Loading")
model = convert(args.torch_name_or_path, dtype)
config = asdict(model.dims)
weights = dict(tree_flatten(model.parameters()))
if args.quantize:
print("[INFO] Quantizing")
weights, config = quantize(weights, config, args)
mlx_path = Path(args.mlx_path)
mlx_path.mkdir(parents=True, exist_ok=True)
# Save weights
print("[INFO] Saving")
mx.save_safetensors(str(mlx_path / "weights.safetensors"), weights)
# Save config.json with model_type
with open(str(mlx_path / "config.json"), "w") as f:
config["model_type"] = "whisper"
json.dump(config, f, indent=4)
if args.upload_name is not None:
upload_to_hub(mlx_path, args.upload_name, args.torch_name_or_path)