Skip to content

Commit 6ca3543

Browse files
authored
Add Geometric Parameterization (GmP) for CLIP ft
Extremely experimental Geometric Parameterization (GmP): .weight -> .theta, .r
1 parent d097ede commit 6ca3543

8 files changed

+1442
-0
lines changed

exp-ft-B-GmP-finetune-OpenAI-ViT-L-14.py

+418
Large diffs are not rendered by default.
+63
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,63 @@
1+
import torch
2+
import torch.nn.functional as F
3+
from gmpclip.modeloriginal import CLIP # Make sure this imports the original CLIP class definition
4+
5+
def convert_back_to_original(state_dict):
6+
new_state_dict = {}
7+
8+
for key, value in state_dict.items():
9+
if key.endswith(".theta"):
10+
base_key = key.replace(".theta", "")
11+
r_key = base_key + ".r"
12+
new_weight = state_dict[r_key] * F.normalize(value, p=2, dim=1)
13+
new_state_dict[base_key + ".weight"] = new_weight
14+
elif key.endswith(".r") or key.endswith(".theta"):
15+
continue # Skip the .r and .theta keys
16+
else:
17+
new_state_dict[key] = value
18+
19+
return new_state_dict
20+
21+
# Example usage
22+
# Load the fine-tuned model object
23+
modelft = torch.load("ft-checkpoints/clip_ft_20.pt")
24+
25+
# Extract model parameters from the fine-tuned model
26+
embed_dim = modelft.text_projection.shape[1]
27+
image_resolution = modelft.visual.input_resolution
28+
vision_layers = modelft.visual.transformer.layers
29+
vision_width = modelft.visual.conv1.out_channels
30+
vision_patch_size = modelft.visual.conv1.kernel_size[0]
31+
context_length = modelft.context_length
32+
vocab_size = modelft.vocab_size
33+
transformer_width = modelft.transformer.width
34+
transformer_heads = modelft.transformer.resblocks[0].attn.num_heads
35+
transformer_layers = modelft.transformer.layers
36+
37+
# Convert the fine-tuned model to a state_dict
38+
fine_tuned_state_dict = modelft.state_dict()
39+
40+
# Convert back to original weights
41+
original_state_dict = convert_back_to_original(fine_tuned_state_dict)
42+
43+
# Rebuild the original model using the converted state_dict
44+
original_model = CLIP(
45+
embed_dim=embed_dim,
46+
image_resolution=image_resolution,
47+
vision_layers=vision_layers,
48+
vision_width=vision_width,
49+
vision_patch_size=vision_patch_size,
50+
context_length=context_length,
51+
vocab_size=vocab_size,
52+
transformer_width=transformer_width,
53+
transformer_heads=transformer_heads,
54+
transformer_layers=transformer_layers
55+
)
56+
57+
# Load the converted state_dict into the original model
58+
original_model.load_state_dict(original_state_dict)
59+
60+
# Save the original model object
61+
torch.save(original_model, "ft-checkpoints/full_model_converted_model.pth")
62+
63+
print("Model has been successfully converted back to the original format and saved.")
+75
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,75 @@
1+
import torch
2+
import gmpclip as clip
3+
import matplotlib.pyplot as plt
4+
import seaborn as sns
5+
import concurrent.futures
6+
from itertools import chain
7+
import gc
8+
import numpy as np
9+
10+
def load_model_checkpoint(checkpoint_path):
11+
model = torch.load(checkpoint_path)
12+
return model
13+
14+
def extract_geometric_params(model):
15+
geometric_params = {'theta': [], 'r': []}
16+
for name, param in model.named_parameters():
17+
if name.endswith('.theta'):
18+
geometric_params['theta'].append(param.detach().cpu().numpy())
19+
elif name.endswith('.r'):
20+
geometric_params['r'].append(param.detach().cpu().numpy())
21+
return geometric_params
22+
23+
# Downsampling factor - without, this takes >>100 GB RAM + an hour or so. Factor 10 => 5 minutes (Ryzen 9).
24+
# Adjust factor as needed.
25+
def downsample(data, factor=10):
26+
return data[::factor]
27+
28+
def visualize_params(params, title, iteration):
29+
theta_flat = list(chain.from_iterable([item.flatten() for sublist in params['theta'] for item in sublist]))
30+
r_flat = list(chain.from_iterable([item.flatten() for sublist in params['r'] for item in sublist]))
31+
32+
# Downsample data to reduce memory usage
33+
theta_flat = downsample(theta_flat)
34+
r_flat = downsample(r_flat)
35+
36+
plt.figure(figsize=(12, 6))
37+
38+
plt.subplot(1, 2, 1)
39+
sns.histplot(theta_flat, bins=50, kde=True)
40+
plt.title('Distribution of Theta Components')
41+
42+
plt.subplot(1, 2, 2)
43+
sns.histplot(r_flat, bins=50, kde=True)
44+
plt.title('Distribution of R Components')
45+
46+
plt.suptitle(title)
47+
plt.savefig(f'geometric_params_visualization_{iteration}.png')
48+
plt.close()
49+
50+
def process_checkpoint(checkpoint, iteration):
51+
model = load_model_checkpoint(checkpoint)
52+
params = extract_geometric_params(model)
53+
visualize_params(params, f'Checkpoint: {checkpoint}', iteration)
54+
# Explicitly call garbage collection
55+
gc.collect()
56+
57+
if __name__ == '__main__':
58+
# List of checkpoints - fine-tuned model saves:
59+
checkpoints = ["ft-checkpoints/clip_ft_5.pt", "ft-checkpoints/clip_ft_10.pt",
60+
"ft-checkpoints/clip_ft_15.pt", "ft-checkpoints/clip_ft_20.pt"]
61+
62+
# Split the list of checkpoints into smaller batches if 4 at once consumes too much RAM (RAM, not VRAM!)
63+
batch_size = 4
64+
batches = [checkpoints[i:i + batch_size] for i in range(0, len(checkpoints), batch_size)]
65+
66+
for batch in batches:
67+
# Parallel processing using concurrent.futures
68+
with concurrent.futures.ProcessPoolExecutor() as executor:
69+
futures = [executor.submit(process_checkpoint, checkpoint, i+1) for i, checkpoint in enumerate(batch)]
70+
# Wait for all futures to complete
71+
concurrent.futures.wait(futures)
72+
# Explicitly call garbage collection
73+
gc.collect()
74+
75+
print("All visualizations have been generated and saved.")

gmpclip/__init__.py

+1
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
from .clip import *

gmpclip/bpe_simple_vocab_16e6.txt.gz

1.29 MB
Binary file not shown.

gmpclip/clip.py

+247
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,247 @@
1+
import hashlib
2+
import os
3+
import urllib
4+
import warnings
5+
from typing import Any, Union, List
6+
from pkg_resources import packaging
7+
8+
import torch
9+
from PIL import Image
10+
from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize
11+
from tqdm import tqdm
12+
13+
from .model import build_model
14+
from .simple_tokenizer import SimpleTokenizer as _Tokenizer
15+
16+
try:
17+
from torchvision.transforms import InterpolationMode
18+
BICUBIC = InterpolationMode.BICUBIC
19+
except ImportError:
20+
BICUBIC = Image.BICUBIC
21+
22+
23+
if packaging.version.parse(torch.__version__) < packaging.version.parse("1.7.1"):
24+
warnings.warn("PyTorch version 1.7.1 or higher is recommended")
25+
26+
27+
__all__ = ["available_models", "load", "tokenize"]
28+
_tokenizer = _Tokenizer()
29+
30+
_MODELS = {
31+
"RN50": "https://openaipublic.azureedge.net/clip/models/afeb0e10f9e5a86da6080e35cf09123aca3b358a0c3e3b6c78a7b63bc04b6762/RN50.pt",
32+
"RN101": "https://openaipublic.azureedge.net/clip/models/8fa8567bab74a42d41c5915025a8e4538c3bdbe8804a470a72f30b0d94fab599/RN101.pt",
33+
"RN50x4": "https://openaipublic.azureedge.net/clip/models/7e526bd135e493cef0776de27d5f42653e6b4c8bf9e0f653bb11773263205fdd/RN50x4.pt",
34+
"RN50x16": "https://openaipublic.azureedge.net/clip/models/52378b407f34354e150460fe41077663dd5b39c54cd0bfd2b27167a4a06ec9aa/RN50x16.pt",
35+
"RN50x64": "https://openaipublic.azureedge.net/clip/models/be1cfb55d75a9666199fb2206c106743da0f6468c9d327f3e0d0a543a9919d9c/RN50x64.pt",
36+
"ViT-B/32": "https://openaipublic.azureedge.net/clip/models/40d365715913c9da98579312b702a82c18be219cc2a73407c4526f58eba950af/ViT-B-32.pt",
37+
"ViT-B/16": "https://openaipublic.azureedge.net/clip/models/5806e77cd80f8b59890b7e101eabd078d9fb84e6937f9e85e4ecb61988df416f/ViT-B-16.pt",
38+
"ViT-L/14": "https://openaipublic.azureedge.net/clip/models/b8cca3fd41ae0c99ba7e8951adf17d267cdb84cd88be6f7c2e0eca1737a03836/ViT-L-14.pt",
39+
"ViT-L/14@336px": "https://openaipublic.azureedge.net/clip/models/3035c92b350959924f9f00213499208652fc7ea050643e8b385c2dac08641f02/ViT-L-14-336px.pt",
40+
}
41+
42+
def _download(url: str, root: str):
43+
os.makedirs(root, exist_ok=True)
44+
filename = os.path.basename(url)
45+
46+
# Commenting out the expected_sha256 as we're bypassing checksum verification
47+
# expected_sha256 = url.split("/")[-2]
48+
download_target = os.path.join(root, filename)
49+
50+
if os.path.exists(download_target) and not os.path.isfile(download_target):
51+
raise RuntimeError(f"{download_target} exists and is not a regular file")
52+
53+
if os.path.isfile(download_target):
54+
# Bypassing the SHA256 checksum verification
55+
return download_target
56+
57+
58+
59+
with urllib.request.urlopen(url) as source, open(download_target, "wb") as output:
60+
with tqdm(total=int(source.info().get("Content-Length")), ncols=80, unit='iB', unit_scale=True, unit_divisor=1024) as loop:
61+
while True:
62+
buffer = source.read(8192)
63+
if not buffer:
64+
break
65+
66+
output.write(buffer)
67+
loop.update(len(buffer))
68+
69+
70+
71+
72+
# Bypassing the SHA256 checksum verification on download completion
73+
return download_target
74+
75+
76+
def _convert_image_to_rgb(image):
77+
return image.convert("RGB")
78+
79+
80+
def _transform(n_px):
81+
return Compose([
82+
Resize(n_px, interpolation=BICUBIC),
83+
CenterCrop(n_px),
84+
_convert_image_to_rgb,
85+
ToTensor(),
86+
Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)),
87+
])
88+
89+
90+
def available_models() -> List[str]:
91+
"""Returns the names of available CLIP models"""
92+
return list(_MODELS.keys())
93+
94+
95+
def load(name: str, device: Union[str, torch.device] = "cuda" if torch.cuda.is_available() else "cpu", jit: bool = False, download_root: str = None):
96+
"""Load a CLIP model
97+
98+
Parameters
99+
----------
100+
name : str
101+
A model name listed by `clip.available_models()`, or the path to a model checkpoint containing the state_dict
102+
103+
device : Union[str, torch.device]
104+
The device to put the loaded model
105+
106+
jit : bool
107+
Whether to load the optimized JIT model or more hackable non-JIT model (default).
108+
109+
download_root: str
110+
path to download the model files; by default, it uses "~/.cache/clip"
111+
112+
Returns
113+
-------
114+
model : torch.nn.Module
115+
The CLIP model
116+
117+
preprocess : Callable[[PIL.Image], torch.Tensor]
118+
A torchvision transform that converts a PIL image into a tensor that the returned model can take as its input
119+
"""
120+
if name in _MODELS:
121+
model_path = _download(_MODELS[name], download_root or os.path.expanduser("~/.cache/clip"))
122+
elif os.path.isfile(name):
123+
model_path = name
124+
else:
125+
raise RuntimeError(f"Model {name} not found; available models = {available_models()}")
126+
127+
with open(model_path, 'rb') as opened_file:
128+
try:
129+
# loading JIT archive
130+
#model = torch.jit.load(opened_file, map_location=device if jit else "cpu").eval()
131+
model = torch.load(opened_file, map_location=device if jit else "cpu").eval()
132+
state_dict = None
133+
except RuntimeError:
134+
# loading saved state dict
135+
if jit:
136+
warnings.warn(f"File {model_path} is not a JIT archive. Loading as a state dict instead")
137+
jit = False
138+
state_dict = torch.load(opened_file, map_location="cpu")
139+
140+
if not jit:
141+
model = build_model(state_dict or model.state_dict()).to(device)
142+
if str(device) == "cpu":
143+
model.float()
144+
return model, _transform(model.visual.input_resolution)
145+
146+
# patch the device names
147+
device_holder = torch.jit.trace(lambda: torch.ones([]).to(torch.device(device)), example_inputs=[])
148+
device_node = [n for n in device_holder.graph.findAllNodes("prim::Constant") if "Device" in repr(n)][-1]
149+
150+
def _node_get(node: torch._C.Node, key: str):
151+
"""Gets attributes of a node which is polymorphic over return type.
152+
153+
From https://github.com/pytorch/pytorch/pull/82628
154+
"""
155+
sel = node.kindOf(key)
156+
return getattr(node, sel)(key)
157+
158+
def patch_device(module):
159+
try:
160+
graphs = [module.graph] if hasattr(module, "graph") else []
161+
except RuntimeError:
162+
graphs = []
163+
164+
if hasattr(module, "forward1"):
165+
graphs.append(module.forward1.graph)
166+
167+
for graph in graphs:
168+
for node in graph.findAllNodes("prim::Constant"):
169+
if "value" in node.attributeNames() and str(_node_get(node, "value")).startswith("cuda"):
170+
node.copyAttributes(device_node)
171+
172+
model.apply(patch_device)
173+
patch_device(model.encode_image)
174+
patch_device(model.encode_text)
175+
176+
# patch dtype to float32 on CPU
177+
if str(device) == "cpu":
178+
float_holder = torch.jit.trace(lambda: torch.ones([]).float(), example_inputs=[])
179+
float_input = list(float_holder.graph.findNode("aten::to").inputs())[1]
180+
float_node = float_input.node()
181+
182+
def patch_float(module):
183+
try:
184+
graphs = [module.graph] if hasattr(module, "graph") else []
185+
except RuntimeError:
186+
graphs = []
187+
188+
if hasattr(module, "forward1"):
189+
graphs.append(module.forward1.graph)
190+
191+
for graph in graphs:
192+
for node in graph.findAllNodes("aten::to"):
193+
inputs = list(node.inputs())
194+
for i in [1, 2]: # dtype can be the second or third argument to aten::to()
195+
if _node_get(inputs[i].node(), "value") == 5:
196+
inputs[i].node().copyAttributes(float_node)
197+
198+
model.apply(patch_float)
199+
patch_float(model.encode_image)
200+
patch_float(model.encode_text)
201+
202+
model.float()
203+
204+
return model, _transform(model.input_resolution.item())
205+
206+
207+
def tokenize(texts: Union[str, List[str]], context_length: int = 77, truncate: bool = False) -> Union[torch.IntTensor, torch.LongTensor]:
208+
"""
209+
Returns the tokenized representation of given input string(s)
210+
211+
Parameters
212+
----------
213+
texts : Union[str, List[str]]
214+
An input string or a list of input strings to tokenize
215+
216+
context_length : int
217+
The context length to use; all CLIP models use 77 as the context length
218+
219+
truncate: bool
220+
Whether to truncate the text in case its encoding is longer than the context length
221+
222+
Returns
223+
-------
224+
A two-dimensional tensor containing the resulting tokens, shape = [number of input strings, context_length].
225+
We return LongTensor when torch version is <1.8.0, since older index_select requires indices to be long.
226+
"""
227+
if isinstance(texts, str):
228+
texts = [texts]
229+
230+
sot_token = _tokenizer.encoder["<|startoftext|>"]
231+
eot_token = _tokenizer.encoder["<|endoftext|>"]
232+
all_tokens = [[sot_token] + _tokenizer.encode(text) + [eot_token] for text in texts]
233+
if packaging.version.parse(torch.__version__) < packaging.version.parse("1.8.0"):
234+
result = torch.zeros(len(all_tokens), context_length, dtype=torch.long)
235+
else:
236+
result = torch.zeros(len(all_tokens), context_length, dtype=torch.int)
237+
238+
for i, tokens in enumerate(all_tokens):
239+
if len(tokens) > context_length:
240+
if truncate:
241+
tokens = tokens[:context_length]
242+
tokens[-1] = eot_token
243+
else:
244+
raise RuntimeError(f"Input {texts[i]} is too long for context length {context_length}")
245+
result[i, :len(tokens)] = torch.tensor(tokens)
246+
247+
return result

0 commit comments

Comments
 (0)