Skip to content

Commit 68cb155

Browse files
authored
Load & Test models w/ new saver flavors
A small example script for loading various flavors of checkpoints; gets cosine similarity for a cat.jpg image you'll need to provide.
1 parent 5237a20 commit 68cb155

File tree

1 file changed

+115
-0
lines changed

1 file changed

+115
-0
lines changed

test-models-new-saver.py

+115
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,115 @@
1+
"""
2+
Example script on how to use different types of saved checkpoints.
3+
To use, make sure you specify ft_checkpoints_folder AND have a photo of a cat.jpg available!
4+
"""
5+
6+
save_is_gmp = False # if saved with 'save_as_gmp = True', set 'True'. If 'False', set 'False'.
7+
8+
if save_is_gmp:
9+
import gmpclip as clip
10+
suffix = 'gmp'
11+
if not save_is_gmp:
12+
import clip
13+
suffix = 'weight'
14+
15+
import torch
16+
from PIL import Image
17+
import os
18+
import torch.nn.functional as F
19+
import warnings
20+
warnings.filterwarnings("ignore", category=FutureWarning) # disbale torch nag about pickle spam
21+
22+
device = "cuda:0" if torch.cuda.is_available() else "cpu"
23+
clipmodel='ViT-L/14'
24+
25+
ft_checkpoints_folder = 'ft-checkpoints'
26+
epoch = 1 # epoch of saved file; e.g. if filename is clip_ft_1_*.pt => enter 1
27+
28+
# Define paths
29+
full_model_path = os.path.join(ft_checkpoints_folder, f'clip_ft_{epoch}_full_as-{suffix}.pt')
30+
state_dict_path = os.path.join(ft_checkpoints_folder, f'clip_ft_{epoch}_dict_as-{suffix}.pt')
31+
jit_model_path = os.path.join(ft_checkpoints_folder, f'clip_ft_{epoch}_jit_as-{suffix}.pt')
32+
33+
# Make sure you have a cat image available!
34+
image_path = 'cat.jpg'
35+
36+
# Load and preprocess the image
37+
def preprocess_image(image_path):
38+
image = Image.open(image_path).convert("RGB")
39+
preprocess = clip.load(clipmodel)[1]
40+
return preprocess(image).unsqueeze(0)
41+
42+
image_input = preprocess_image(image_path).to(device)
43+
44+
# Define a text prompt
45+
text_inputs = clip.tokenize(["a photo of a cat"]).to(device)
46+
47+
# Function to calculate and print cosine similarity
48+
def print_cosine_similarity(image_features, text_features, model_name):
49+
cosine_sim = F.cosine_similarity(image_features, text_features)
50+
print(f"{model_name} Cosine Similarity:", cosine_sim.item())
51+
52+
# 0. Load Original CLIP Model
53+
original_clip = clip.load(clipmodel)[0].to(device).float()
54+
original_clip.eval()
55+
with torch.no_grad():
56+
image_features = original_clip.encode_image(image_input)
57+
text_features = original_clip.encode_text(text_inputs)
58+
logits_per_image, logits_per_text = original_clip(image_input, text_inputs)
59+
print("Original CLIP Results:")
60+
print("Logits per Image:", logits_per_image)
61+
print("Logits per Text:", logits_per_text)
62+
print_cosine_similarity(image_features, text_features, "Original CLIP")
63+
64+
65+
# 1. Load the Full Model Object
66+
print("\nLoading Full Model Object...")
67+
full_model = torch.load(full_model_path).to(device)
68+
full_model.eval().float()
69+
with torch.no_grad():
70+
image_features = full_model.encode_image(image_input)
71+
text_features = full_model.encode_text(text_inputs)
72+
logits_per_image, logits_per_text = full_model(image_input, text_inputs)
73+
print("Full Model Object Results:")
74+
print("Logits per Image:", logits_per_image)
75+
print("Logits per Text:", logits_per_text)
76+
print_cosine_similarity(image_features, text_features, "Full Model Object")
77+
78+
79+
# 2. Load the Model from State Dictionary
80+
print("\nLoading Model from State Dictionary...")
81+
state_dict_model = clip.load(clipmodel)[0] # Create an empty model instance of the correct architecture
82+
state_dict = torch.load(state_dict_path, map_location=device)
83+
state_dict_model.load_state_dict(state_dict)
84+
state_dict_model = state_dict_model.to(device).float()
85+
86+
state_dict_model.eval()
87+
with torch.no_grad():
88+
image_features = state_dict_model.encode_image(image_input)
89+
text_features = state_dict_model.encode_text(text_inputs)
90+
logits_per_image, logits_per_text = state_dict_model(image_input, text_inputs)
91+
print("State Dictionary Model Results:")
92+
print("Logits per Image:", logits_per_image)
93+
print("Logits per Text:", logits_per_text)
94+
print_cosine_similarity(image_features, text_features, "State Dictionary Model")
95+
96+
# 3. Load the JiT-Traced Model
97+
print("\nLoading JIT-Traced Model...")
98+
jit_model = torch.jit.load(jit_model_path).to(device).float()
99+
jit_model.eval()
100+
with torch.no_grad():
101+
# Directly pass both inputs through the model, as tracing only captures the forward pass
102+
logits_per_image, logits_per_text = jit_model(image_input, text_inputs)
103+
print("JIT Model Results:")
104+
print("Logits per Image:", logits_per_image)
105+
print("Logits per Text:", logits_per_text)
106+
# Create a new CLIP model instance to hold the structure
107+
jit_injected_model = clip.load(clipmodel, jit=True)[0].to(device).float()
108+
jit_injected_model.eval()
109+
110+
# Inject the openai/clip JIT model's forward function for use with the torch.jit loaded fine-tune
111+
# RecursiveScriptModule does not have encode_image, encode_text, and so on.
112+
jit_injected_model.forward = lambda image_input, text_inputs: jit_model(image_input, text_inputs)
113+
print_cosine_similarity(image_features, text_features, "Injected JIT Model")
114+
115+
print("\nDone. Enjoy scratching your head about the diff in floating-point numerical precision!")

0 commit comments

Comments
 (0)