1
+ import hashlib
2
+ import os
3
+ import urllib
4
+ import warnings
5
+ from typing import Any , Union , List
6
+ from pkg_resources import packaging
7
+
8
+ import torch
9
+ from PIL import Image
10
+ from torchvision .transforms import Compose , Resize , CenterCrop , ToTensor , Normalize
11
+ from tqdm import tqdm
12
+
13
+ from .model import build_model
14
+ from .simple_tokenizer import SimpleTokenizer as _Tokenizer
15
+
16
+ try :
17
+ from torchvision .transforms import InterpolationMode
18
+ BICUBIC = InterpolationMode .BICUBIC
19
+ except ImportError :
20
+ BICUBIC = Image .BICUBIC
21
+
22
+
23
+ if packaging .version .parse (torch .__version__ ) < packaging .version .parse ("1.7.1" ):
24
+ warnings .warn ("PyTorch version 1.7.1 or higher is recommended" )
25
+
26
+
27
+ __all__ = ["available_models" , "load" , "tokenize" ]
28
+ _tokenizer = _Tokenizer ()
29
+
30
+ _MODELS = {
31
+ "RN50" : "https://openaipublic.azureedge.net/clip/models/afeb0e10f9e5a86da6080e35cf09123aca3b358a0c3e3b6c78a7b63bc04b6762/RN50.pt" ,
32
+ "RN101" : "https://openaipublic.azureedge.net/clip/models/8fa8567bab74a42d41c5915025a8e4538c3bdbe8804a470a72f30b0d94fab599/RN101.pt" ,
33
+ "RN50x4" : "https://openaipublic.azureedge.net/clip/models/7e526bd135e493cef0776de27d5f42653e6b4c8bf9e0f653bb11773263205fdd/RN50x4.pt" ,
34
+ "RN50x16" : "https://openaipublic.azureedge.net/clip/models/52378b407f34354e150460fe41077663dd5b39c54cd0bfd2b27167a4a06ec9aa/RN50x16.pt" ,
35
+ "RN50x64" : "https://openaipublic.azureedge.net/clip/models/be1cfb55d75a9666199fb2206c106743da0f6468c9d327f3e0d0a543a9919d9c/RN50x64.pt" ,
36
+ "ViT-B/32" : "https://openaipublic.azureedge.net/clip/models/40d365715913c9da98579312b702a82c18be219cc2a73407c4526f58eba950af/ViT-B-32.pt" ,
37
+ "ViT-B/16" : "https://openaipublic.azureedge.net/clip/models/5806e77cd80f8b59890b7e101eabd078d9fb84e6937f9e85e4ecb61988df416f/ViT-B-16.pt" ,
38
+ "ViT-L/14" : "https://openaipublic.azureedge.net/clip/models/b8cca3fd41ae0c99ba7e8951adf17d267cdb84cd88be6f7c2e0eca1737a03836/ViT-L-14.pt" ,
39
+ "ViT-L/14@336px" : "https://openaipublic.azureedge.net/clip/models/3035c92b350959924f9f00213499208652fc7ea050643e8b385c2dac08641f02/ViT-L-14-336px.pt" ,
40
+ }
41
+
42
+ def _download (url : str , root : str ):
43
+ os .makedirs (root , exist_ok = True )
44
+ filename = os .path .basename (url )
45
+
46
+ # Commenting out the expected_sha256 as we're bypassing checksum verification
47
+ # expected_sha256 = url.split("/")[-2]
48
+ download_target = os .path .join (root , filename )
49
+
50
+ if os .path .exists (download_target ) and not os .path .isfile (download_target ):
51
+ raise RuntimeError (f"{ download_target } exists and is not a regular file" )
52
+
53
+ if os .path .isfile (download_target ):
54
+ # Bypassing the SHA256 checksum verification
55
+ return download_target
56
+
57
+
58
+
59
+ with urllib .request .urlopen (url ) as source , open (download_target , "wb" ) as output :
60
+ with tqdm (total = int (source .info ().get ("Content-Length" )), ncols = 80 , unit = 'iB' , unit_scale = True , unit_divisor = 1024 ) as loop :
61
+ while True :
62
+ buffer = source .read (8192 )
63
+ if not buffer :
64
+ break
65
+
66
+ output .write (buffer )
67
+ loop .update (len (buffer ))
68
+
69
+
70
+
71
+
72
+ # Bypassing the SHA256 checksum verification on download completion
73
+ return download_target
74
+
75
+
76
+ def _convert_image_to_rgb (image ):
77
+ return image .convert ("RGB" )
78
+
79
+
80
+ def _transform (n_px ):
81
+ return Compose ([
82
+ Resize (n_px , interpolation = BICUBIC ),
83
+ CenterCrop (n_px ),
84
+ _convert_image_to_rgb ,
85
+ ToTensor (),
86
+ Normalize ((0.48145466 , 0.4578275 , 0.40821073 ), (0.26862954 , 0.26130258 , 0.27577711 )),
87
+ ])
88
+
89
+
90
+ def available_models () -> List [str ]:
91
+ """Returns the names of available CLIP models"""
92
+ return list (_MODELS .keys ())
93
+
94
+
95
+ def load (name : str , device : Union [str , torch .device ] = "cuda" if torch .cuda .is_available () else "cpu" , jit : bool = False , download_root : str = None ):
96
+ """Load a CLIP model
97
+
98
+ Parameters
99
+ ----------
100
+ name : str
101
+ A model name listed by `clip.available_models()`, or the path to a model checkpoint containing the state_dict
102
+
103
+ device : Union[str, torch.device]
104
+ The device to put the loaded model
105
+
106
+ jit : bool
107
+ Whether to load the optimized JIT model or more hackable non-JIT model (default).
108
+
109
+ download_root: str
110
+ path to download the model files; by default, it uses "~/.cache/clip"
111
+
112
+ Returns
113
+ -------
114
+ model : torch.nn.Module
115
+ The CLIP model
116
+
117
+ preprocess : Callable[[PIL.Image], torch.Tensor]
118
+ A torchvision transform that converts a PIL image into a tensor that the returned model can take as its input
119
+ """
120
+ if name in _MODELS :
121
+ model_path = _download (_MODELS [name ], download_root or os .path .expanduser ("~/.cache/clip" ))
122
+ elif os .path .isfile (name ):
123
+ model_path = name
124
+ else :
125
+ raise RuntimeError (f"Model { name } not found; available models = { available_models ()} " )
126
+
127
+ with open (model_path , 'rb' ) as opened_file :
128
+ try :
129
+ # loading JIT archive
130
+ #model = torch.jit.load(opened_file, map_location=device if jit else "cpu").eval()
131
+ model = torch .load (opened_file , map_location = device if jit else "cpu" ).eval ()
132
+ state_dict = None
133
+ except RuntimeError :
134
+ # loading saved state dict
135
+ if jit :
136
+ warnings .warn (f"File { model_path } is not a JIT archive. Loading as a state dict instead" )
137
+ jit = False
138
+ state_dict = torch .load (opened_file , map_location = "cpu" )
139
+
140
+ if not jit :
141
+ model = build_model (state_dict or model .state_dict ()).to (device )
142
+ if str (device ) == "cpu" :
143
+ model .float ()
144
+ return model , _transform (model .visual .input_resolution )
145
+
146
+ # patch the device names
147
+ device_holder = torch .jit .trace (lambda : torch .ones ([]).to (torch .device (device )), example_inputs = [])
148
+ device_node = [n for n in device_holder .graph .findAllNodes ("prim::Constant" ) if "Device" in repr (n )][- 1 ]
149
+
150
+ def _node_get (node : torch ._C .Node , key : str ):
151
+ """Gets attributes of a node which is polymorphic over return type.
152
+
153
+ From https://github.com/pytorch/pytorch/pull/82628
154
+ """
155
+ sel = node .kindOf (key )
156
+ return getattr (node , sel )(key )
157
+
158
+ def patch_device (module ):
159
+ try :
160
+ graphs = [module .graph ] if hasattr (module , "graph" ) else []
161
+ except RuntimeError :
162
+ graphs = []
163
+
164
+ if hasattr (module , "forward1" ):
165
+ graphs .append (module .forward1 .graph )
166
+
167
+ for graph in graphs :
168
+ for node in graph .findAllNodes ("prim::Constant" ):
169
+ if "value" in node .attributeNames () and str (_node_get (node , "value" )).startswith ("cuda" ):
170
+ node .copyAttributes (device_node )
171
+
172
+ model .apply (patch_device )
173
+ patch_device (model .encode_image )
174
+ patch_device (model .encode_text )
175
+
176
+ # patch dtype to float32 on CPU
177
+ if str (device ) == "cpu" :
178
+ float_holder = torch .jit .trace (lambda : torch .ones ([]).float (), example_inputs = [])
179
+ float_input = list (float_holder .graph .findNode ("aten::to" ).inputs ())[1 ]
180
+ float_node = float_input .node ()
181
+
182
+ def patch_float (module ):
183
+ try :
184
+ graphs = [module .graph ] if hasattr (module , "graph" ) else []
185
+ except RuntimeError :
186
+ graphs = []
187
+
188
+ if hasattr (module , "forward1" ):
189
+ graphs .append (module .forward1 .graph )
190
+
191
+ for graph in graphs :
192
+ for node in graph .findAllNodes ("aten::to" ):
193
+ inputs = list (node .inputs ())
194
+ for i in [1 , 2 ]: # dtype can be the second or third argument to aten::to()
195
+ if _node_get (inputs [i ].node (), "value" ) == 5 :
196
+ inputs [i ].node ().copyAttributes (float_node )
197
+
198
+ model .apply (patch_float )
199
+ patch_float (model .encode_image )
200
+ patch_float (model .encode_text )
201
+
202
+ model .float ()
203
+
204
+ return model , _transform (model .input_resolution .item ())
205
+
206
+
207
+ def tokenize (texts : Union [str , List [str ]], context_length : int = 77 , truncate : bool = False ) -> Union [torch .IntTensor , torch .LongTensor ]:
208
+ """
209
+ Returns the tokenized representation of given input string(s)
210
+
211
+ Parameters
212
+ ----------
213
+ texts : Union[str, List[str]]
214
+ An input string or a list of input strings to tokenize
215
+
216
+ context_length : int
217
+ The context length to use; all CLIP models use 77 as the context length
218
+
219
+ truncate: bool
220
+ Whether to truncate the text in case its encoding is longer than the context length
221
+
222
+ Returns
223
+ -------
224
+ A two-dimensional tensor containing the resulting tokens, shape = [number of input strings, context_length].
225
+ We return LongTensor when torch version is <1.8.0, since older index_select requires indices to be long.
226
+ """
227
+ if isinstance (texts , str ):
228
+ texts = [texts ]
229
+
230
+ sot_token = _tokenizer .encoder ["<|startoftext|>" ]
231
+ eot_token = _tokenizer .encoder ["<|endoftext|>" ]
232
+ all_tokens = [[sot_token ] + _tokenizer .encode (text ) + [eot_token ] for text in texts ]
233
+ if packaging .version .parse (torch .__version__ ) < packaging .version .parse ("1.8.0" ):
234
+ result = torch .zeros (len (all_tokens ), context_length , dtype = torch .long )
235
+ else :
236
+ result = torch .zeros (len (all_tokens ), context_length , dtype = torch .int )
237
+
238
+ for i , tokens in enumerate (all_tokens ):
239
+ if len (tokens ) > context_length :
240
+ if truncate :
241
+ tokens = tokens [:context_length ]
242
+ tokens [- 1 ] = eot_token
243
+ else :
244
+ raise RuntimeError (f"Input { texts [i ]} is too long for context length { context_length } " )
245
+ result [i , :len (tokens )] = torch .tensor (tokens )
246
+
247
+ return result
0 commit comments