1
+ # coding=utf-8
2
+ # Copyright 2021 The HuggingFace Inc. team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import argparse
17
+
18
+ import torch
19
+ from clip import load
20
+
21
+ from transformers import CLIPConfig , CLIPModel
22
+
23
+
24
+ def copy_attn_layer (hf_attn_layer , pt_attn_layer ):
25
+ q_proj , k_proj , v_proj = pt_attn_layer .in_proj_weight .chunk (3 , dim = 0 )
26
+ q_proj_bias , k_proj_bias , v_proj_bias = pt_attn_layer .in_proj_bias .chunk (3 , dim = 0 )
27
+
28
+ out_proj_weights = pt_attn_layer .out_proj .weight
29
+ out_proj_bias = pt_attn_layer .out_proj .bias
30
+
31
+ hf_attn_layer .q_proj .weight .data = q_proj
32
+ hf_attn_layer .q_proj .bias .data = q_proj_bias
33
+
34
+ hf_attn_layer .k_proj .weight .data = k_proj
35
+ hf_attn_layer .k_proj .bias .data = k_proj_bias
36
+
37
+ hf_attn_layer .v_proj .weight .data = v_proj
38
+ hf_attn_layer .v_proj .bias .data = v_proj_bias
39
+
40
+ hf_attn_layer .out_proj .weight = out_proj_weights
41
+ hf_attn_layer .out_proj .bias = out_proj_bias
42
+
43
+
44
+ def copy_mlp (hf_mlp , pt_mlp ):
45
+ copy_linear (hf_mlp .fc1 , pt_mlp .c_fc )
46
+ copy_linear (hf_mlp .fc2 , pt_mlp .c_proj )
47
+
48
+
49
+ def copy_linear (hf_linear , pt_linear ):
50
+ hf_linear .weight = pt_linear .weight
51
+ hf_linear .bias = pt_linear .bias
52
+
53
+
54
+ def copy_layer (hf_layer , pt_layer ):
55
+ # copy layer norms
56
+ copy_linear (hf_layer .layer_norm1 , pt_layer .ln_1 )
57
+ copy_linear (hf_layer .layer_norm2 , pt_layer .ln_2 )
58
+
59
+ # copy MLP
60
+ copy_mlp (hf_layer .mlp , pt_layer .mlp )
61
+
62
+ # copy attn
63
+ copy_attn_layer (hf_layer .self_attn , pt_layer .attn )
64
+
65
+
66
+ def copy_layers (hf_layers , pt_layers ):
67
+ for hf_layer , pt_layer in zip (hf_layers , pt_layers ):
68
+ copy_layer (hf_layer , pt_layer )
69
+
70
+
71
+ def copy_encoder (hf_encoder , pt_model ):
72
+ # copy embeds
73
+ hf_encoder .embeddings .token_embedding .weight = pt_model .token_embedding .weight
74
+ hf_encoder .embeddings .position_embedding .weight .data = pt_model .positional_embedding
75
+
76
+ # copy layer norm
77
+ copy_linear (hf_encoder .final_layer_norm , pt_model .ln_final )
78
+
79
+ # copy hidden layers
80
+ copy_layers (hf_encoder .encoder .layers , pt_model .transformer .resblocks )
81
+
82
+
83
+ def copy_text_model_and_projection (hf_model , pt_model ):
84
+ # copy projection
85
+ hf_model .text_projection .weight .data = pt_model .text_projection .data .T .contiguous ()
86
+
87
+ # copy text encoder
88
+ copy_encoder (hf_model .text_model , pt_model )
89
+
90
+
91
+ def copy_vison_model_and_projection (hf_model , pt_model ):
92
+ # copy projection
93
+ hf_model .visual_projection .weight .data = pt_model .visual .proj .data .T .contiguous ()
94
+
95
+ # copy layer norms
96
+ copy_linear (hf_model .vision_model .pre_layrnorm , pt_model .visual .ln_pre )
97
+ copy_linear (hf_model .vision_model .post_layernorm , pt_model .visual .ln_post )
98
+
99
+ # copy embeds
100
+ hf_model .vision_model .embeddings .patch_embedding .weight .data = pt_model .visual .conv1 .weight .data
101
+ hf_model .vision_model .embeddings .class_embedding = pt_model .visual .class_embedding
102
+ hf_model .vision_model .embeddings .position_embedding .weight .data = pt_model .visual .positional_embedding .data
103
+
104
+ # copy encoder
105
+ copy_layers (hf_model .vision_model .encoder .layers , pt_model .visual .transformer .resblocks )
106
+
107
+
108
+ @torch .no_grad ()
109
+ def convert_clip_checkpoint (checkpoint_path , pytorch_dump_folder_path , config_path = None ):
110
+ """
111
+ Copy/paste/tweak model's weights to transformers design.
112
+ """
113
+ if config_path is not None :
114
+ config = CLIPConfig .from_pretrained (config_path )
115
+ else :
116
+ config = CLIPConfig (projection_dim = 512 , text_config = {}, vision_config = {})
117
+
118
+ hf_model = CLIPModel (config ).eval ()
119
+
120
+ pt_model , _ = load (checkpoint_path , device = "cpu" , jit = False )
121
+ pt_model = pt_model .eval ()
122
+
123
+ copy_text_model_and_projection (hf_model , pt_model )
124
+ copy_vison_model_and_projection (hf_model , pt_model )
125
+ hf_model .logit_scale = pt_model .logit_scale
126
+
127
+ # Use `eos_token` so the example is more meaningful
128
+ input_ids = torch .tensor (
129
+ [
130
+ [config .text_config .bos_token_id ]
131
+ + list (range (3 , 77 ))
132
+ + [config .text_config .eos_token_id ]
133
+ + [config .text_config .pad_token_id ]
134
+ ]
135
+ )
136
+ pixel_values = torch .randn (1 , 3 , 224 , 224 )
137
+
138
+ hf_outputs = hf_model (input_ids = input_ids , pixel_values = pixel_values , return_dict = True )
139
+ hf_logits_per_image = hf_outputs .logits_per_image
140
+ hf_logits_per_text = hf_outputs .logits_per_text
141
+ pt_logits_per_image , pt_logits_per_text = pt_model (pixel_values , input_ids )
142
+
143
+ assert torch .allclose (hf_logits_per_image , pt_logits_per_image , atol = 1e-3 )
144
+ assert torch .allclose (hf_logits_per_text , pt_logits_per_text , atol = 1e-3 )
145
+
146
+ hf_model .save_pretrained (pytorch_dump_folder_path )
147
+
148
+
149
+ if __name__ == "__main__" :
150
+ parser = argparse .ArgumentParser ()
151
+ parser .add_argument ("--pytorch_dump_folder_path" , default = None , type = str , help = "Path to the output PyTorch model." )
152
+ parser .add_argument ("--checkpoint_path" , default = None , type = str , help = "Path to fairseq checkpoint" )
153
+ parser .add_argument ("--config_path" , default = None , type = str , help = "Path to hf config.json of model to convert" )
154
+ args = parser .parse_args ()
155
+
156
+ convert_clip_checkpoint (args .checkpoint_path , args .pytorch_dump_folder_path , args .config_path )
0 commit comments