|
5 | 5 | # This source code is licensed under the BSD-style license found in the
|
6 | 6 | # LICENSE file in the root directory of this source tree.
|
7 | 7 |
|
| 8 | +from enum import Enum |
| 9 | + |
| 10 | + |
| 11 | +class Model(str, Enum): |
| 12 | + Mul = "mul" |
| 13 | + Linear = "linear" |
| 14 | + Add = "add" |
| 15 | + AddMul = "add_mul" |
| 16 | + Softmax = "softmax" |
| 17 | + Dl3 = "dl3" |
| 18 | + Edsr = "edsr" |
| 19 | + EmformerTranscribe = "emformer_transcribe" |
| 20 | + EmformerPredict = "emformer_predict" |
| 21 | + EmformerJoin = "emformer_join" |
| 22 | + Llama2 = "llama2" |
| 23 | + Llama = "llama" |
| 24 | + Llama32VisionEncoder = "llama3_2_vision_encoder" |
| 25 | + Lstm = "lstm" |
| 26 | + MobileBert = "mobilebert" |
| 27 | + Mv2 = "mv2" |
| 28 | + Mv2Untrained = "mv2_untrained" |
| 29 | + Mv3 = "mv3" |
| 30 | + Vit = "vit" |
| 31 | + W2l = "w2l" |
| 32 | + Ic3 = "ic3" |
| 33 | + Ic4 = "ic4" |
| 34 | + ResNet18 = "resnet18" |
| 35 | + ResNet50 = "resnet50" |
| 36 | + Llava = "llava" |
| 37 | + EfficientSam = "efficient_sam" |
| 38 | + Qwen25 = "qwen2_5" |
| 39 | + Phi4Mini = "phi-4-mini" |
| 40 | + |
| 41 | + def __str__(self) -> str: |
| 42 | + return self.value |
| 43 | + |
| 44 | + |
| 45 | +class Backend(str, Enum): |
| 46 | + XnnpackQuantizationDelegation = "xnnpack-quantization-delegation" |
| 47 | + |
| 48 | + def __str__(self) -> str: |
| 49 | + return self.value |
| 50 | + |
| 51 | + |
8 | 52 | MODEL_NAME_TO_MODEL = {
|
9 |
| - "mul": ("toy_model", "MulModule"), |
10 |
| - "linear": ("toy_model", "LinearModule"), |
11 |
| - "add": ("toy_model", "AddModule"), |
12 |
| - "add_mul": ("toy_model", "AddMulModule"), |
13 |
| - "softmax": ("toy_model", "SoftmaxModule"), |
14 |
| - "dl3": ("deeplab_v3", "DeepLabV3ResNet50Model"), |
15 |
| - "edsr": ("edsr", "EdsrModel"), |
16 |
| - "emformer_transcribe": ("emformer_rnnt", "EmformerRnntTranscriberModel"), |
17 |
| - "emformer_predict": ("emformer_rnnt", "EmformerRnntPredictorModel"), |
18 |
| - "emformer_join": ("emformer_rnnt", "EmformerRnntJoinerModel"), |
19 |
| - "llama2": ("llama", "Llama2Model"), |
20 |
| - "llama": ("llama", "Llama2Model"), |
21 |
| - "llama3_2_vision_encoder": ("llama3_2_vision", "FlamingoVisionEncoderModel"), |
| 53 | + str(Model.Mul): ("toy_model", "MulModule"), |
| 54 | + str(Model.Linear): ("toy_model", "LinearModule"), |
| 55 | + str(Model.Add): ("toy_model", "AddModule"), |
| 56 | + str(Model.AddMul): ("toy_model", "AddMulModule"), |
| 57 | + str(Model.Softmax): ("toy_model", "SoftmaxModule"), |
| 58 | + str(Model.Dl3): ("deeplab_v3", "DeepLabV3ResNet50Model"), |
| 59 | + str(Model.Edsr): ("edsr", "EdsrModel"), |
| 60 | + str(Model.EmformerTranscribe): ("emformer_rnnt", "EmformerRnntTranscriberModel"), |
| 61 | + str(Model.EmformerPredict): ("emformer_rnnt", "EmformerRnntPredictorModel"), |
| 62 | + str(Model.EmformerJoin): ("emformer_rnnt", "EmformerRnntJoinerModel"), |
| 63 | + str(Model.Llama2): ("llama", "Llama2Model"), |
| 64 | + str(Model.Llama): ("llama", "Llama2Model"), |
| 65 | + str(Model.Llama32VisionEncoder): ("llama3_2_vision", "FlamingoVisionEncoderModel"), |
22 | 66 | # TODO: This take too long to export on both Linux and MacOS (> 6 hours)
|
23 | 67 | # "llama3_2_text_decoder": ("llama3_2_vision", "Llama3_2Decoder"),
|
24 |
| - "lstm": ("lstm", "LSTMModel"), |
25 |
| - "mobilebert": ("mobilebert", "MobileBertModelExample"), |
26 |
| - "mv2": ("mobilenet_v2", "MV2Model"), |
27 |
| - "mv2_untrained": ("mobilenet_v2", "MV2UntrainedModel"), |
28 |
| - "mv3": ("mobilenet_v3", "MV3Model"), |
29 |
| - "vit": ("torchvision_vit", "TorchVisionViTModel"), |
30 |
| - "w2l": ("wav2letter", "Wav2LetterModel"), |
31 |
| - "ic3": ("inception_v3", "InceptionV3Model"), |
32 |
| - "ic4": ("inception_v4", "InceptionV4Model"), |
33 |
| - "resnet18": ("resnet", "ResNet18Model"), |
34 |
| - "resnet50": ("resnet", "ResNet50Model"), |
35 |
| - "llava": ("llava", "LlavaModel"), |
36 |
| - "efficient_sam": ("efficient_sam", "EfficientSAM"), |
37 |
| - "qwen2_5": ("qwen2_5", "Qwen2_5Model"), |
38 |
| - "phi-4-mini": ("phi-4-mini", "Phi4MiniModel"), |
| 68 | + str(Model.Lstm): ("lstm", "LSTMModel"), |
| 69 | + str(Model.MobileBert): ("mobilebert", "MobileBertModelExample"), |
| 70 | + str(Model.Mv2): ("mobilenet_v2", "MV2Model"), |
| 71 | + str(Model.Mv2Untrained): ("mobilenet_v2", "MV2UntrainedModel"), |
| 72 | + str(Model.Mv3): ("mobilenet_v3", "MV3Model"), |
| 73 | + str(Model.Vit): ("torchvision_vit", "TorchVisionViTModel"), |
| 74 | + str(Model.W2l): ("wav2letter", "Wav2LetterModel"), |
| 75 | + str(Model.Ic3): ("inception_v3", "InceptionV3Model"), |
| 76 | + str(Model.Ic4): ("inception_v4", "InceptionV4Model"), |
| 77 | + str(Model.ResNet18): ("resnet", "ResNet18Model"), |
| 78 | + str(Model.ResNet50): ("resnet", "ResNet50Model"), |
| 79 | + str(Model.Llava): ("llava", "LlavaModel"), |
| 80 | + str(Model.EfficientSam): ("efficient_sam", "EfficientSAM"), |
| 81 | + str(Model.Qwen25): ("qwen2_5", "Qwen2_5Model"), |
| 82 | + str(Model.Phi4Mini): ("phi-4-mini", "Phi4MiniModel"), |
39 | 83 | }
|
40 | 84 |
|
41 | 85 | __all__ = [
|
|
0 commit comments