1
+ """
2
+ Example script on how to use different types of saved checkpoints.
3
+ To use, make sure you specify ft_checkpoints_folder AND have a photo of a cat.jpg available!
4
+ """
5
+
6
+ save_is_gmp = False # if saved with 'save_as_gmp = True', set 'True'. If 'False', set 'False'.
7
+
8
+ if save_is_gmp :
9
+ import gmpclip as clip
10
+ suffix = 'gmp'
11
+ if not save_is_gmp :
12
+ import clip
13
+ suffix = 'weight'
14
+
15
+ import torch
16
+ from PIL import Image
17
+ import os
18
+ import torch .nn .functional as F
19
+ import warnings
20
+ warnings .filterwarnings ("ignore" , category = FutureWarning ) # disbale torch nag about pickle spam
21
+
22
+ device = "cuda:0" if torch .cuda .is_available () else "cpu"
23
+ clipmodel = 'ViT-L/14'
24
+
25
+ ft_checkpoints_folder = 'ft-checkpoints'
26
+ epoch = 1 # epoch of saved file; e.g. if filename is clip_ft_1_*.pt => enter 1
27
+
28
+ # Define paths
29
+ full_model_path = os .path .join (ft_checkpoints_folder , f'clip_ft_{ epoch } _full_as-{ suffix } .pt' )
30
+ state_dict_path = os .path .join (ft_checkpoints_folder , f'clip_ft_{ epoch } _dict_as-{ suffix } .pt' )
31
+ jit_model_path = os .path .join (ft_checkpoints_folder , f'clip_ft_{ epoch } _jit_as-{ suffix } .pt' )
32
+
33
+ # Make sure you have a cat image available!
34
+ image_path = 'cat.jpg'
35
+
36
+ # Load and preprocess the image
37
+ def preprocess_image (image_path ):
38
+ image = Image .open (image_path ).convert ("RGB" )
39
+ preprocess = clip .load (clipmodel )[1 ]
40
+ return preprocess (image ).unsqueeze (0 )
41
+
42
+ image_input = preprocess_image (image_path ).to (device )
43
+
44
+ # Define a text prompt
45
+ text_inputs = clip .tokenize (["a photo of a cat" ]).to (device )
46
+
47
+ # Function to calculate and print cosine similarity
48
+ def print_cosine_similarity (image_features , text_features , model_name ):
49
+ cosine_sim = F .cosine_similarity (image_features , text_features )
50
+ print (f"{ model_name } Cosine Similarity:" , cosine_sim .item ())
51
+
52
+ # 0. Load Original CLIP Model
53
+ original_clip = clip .load (clipmodel )[0 ].to (device ).float ()
54
+ original_clip .eval ()
55
+ with torch .no_grad ():
56
+ image_features = original_clip .encode_image (image_input )
57
+ text_features = original_clip .encode_text (text_inputs )
58
+ logits_per_image , logits_per_text = original_clip (image_input , text_inputs )
59
+ print ("Original CLIP Results:" )
60
+ print ("Logits per Image:" , logits_per_image )
61
+ print ("Logits per Text:" , logits_per_text )
62
+ print_cosine_similarity (image_features , text_features , "Original CLIP" )
63
+
64
+
65
+ # 1. Load the Full Model Object
66
+ print ("\n Loading Full Model Object..." )
67
+ full_model = torch .load (full_model_path ).to (device )
68
+ full_model .eval ().float ()
69
+ with torch .no_grad ():
70
+ image_features = full_model .encode_image (image_input )
71
+ text_features = full_model .encode_text (text_inputs )
72
+ logits_per_image , logits_per_text = full_model (image_input , text_inputs )
73
+ print ("Full Model Object Results:" )
74
+ print ("Logits per Image:" , logits_per_image )
75
+ print ("Logits per Text:" , logits_per_text )
76
+ print_cosine_similarity (image_features , text_features , "Full Model Object" )
77
+
78
+
79
+ # 2. Load the Model from State Dictionary
80
+ print ("\n Loading Model from State Dictionary..." )
81
+ state_dict_model = clip .load (clipmodel )[0 ] # Create an empty model instance of the correct architecture
82
+ state_dict = torch .load (state_dict_path , map_location = device )
83
+ state_dict_model .load_state_dict (state_dict )
84
+ state_dict_model = state_dict_model .to (device ).float ()
85
+
86
+ state_dict_model .eval ()
87
+ with torch .no_grad ():
88
+ image_features = state_dict_model .encode_image (image_input )
89
+ text_features = state_dict_model .encode_text (text_inputs )
90
+ logits_per_image , logits_per_text = state_dict_model (image_input , text_inputs )
91
+ print ("State Dictionary Model Results:" )
92
+ print ("Logits per Image:" , logits_per_image )
93
+ print ("Logits per Text:" , logits_per_text )
94
+ print_cosine_similarity (image_features , text_features , "State Dictionary Model" )
95
+
96
+ # 3. Load the JiT-Traced Model
97
+ print ("\n Loading JIT-Traced Model..." )
98
+ jit_model = torch .jit .load (jit_model_path ).to (device ).float ()
99
+ jit_model .eval ()
100
+ with torch .no_grad ():
101
+ # Directly pass both inputs through the model, as tracing only captures the forward pass
102
+ logits_per_image , logits_per_text = jit_model (image_input , text_inputs )
103
+ print ("JIT Model Results:" )
104
+ print ("Logits per Image:" , logits_per_image )
105
+ print ("Logits per Text:" , logits_per_text )
106
+ # Create a new CLIP model instance to hold the structure
107
+ jit_injected_model = clip .load (clipmodel , jit = True )[0 ].to (device ).float ()
108
+ jit_injected_model .eval ()
109
+
110
+ # Inject the openai/clip JIT model's forward function for use with the torch.jit loaded fine-tune
111
+ # RecursiveScriptModule does not have encode_image, encode_text, and so on.
112
+ jit_injected_model .forward = lambda image_input , text_inputs : jit_model (image_input , text_inputs )
113
+ print_cosine_similarity (image_features , text_features , "Injected JIT Model" )
114
+
115
+ print ("\n Done. Enjoy scratching your head about the diff in floating-point numerical precision!" )
0 commit comments