Skip to content

Commit 59d6193

Browse files
authored
Add new saver with torch.jit.trace
1 parent 3e3059d commit 59d6193

1 file changed

+88
-4
lines changed

exp-acts-ft-finetune-OpenAI-CLIP-ViT-L-14-GmP-manipulate-neurons.py

+88-4
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,90 @@
3737
text_logs_folder = 'ft-logs'
3838
os.makedirs(text_logs_folder, exist_ok=True)
3939

40+
# Model Saving Options; the default is 'legacy behavior' (only save full model, save as GmP)
41+
save_full = True # Save full model object
42+
save_dict = False # Save state_dict
43+
save_jit = False # Save as JIT-traced model
44+
save_as_gmp = True # True for saving in GmP format with .theta, .r; False for converting back to .weight (original OpenAI/CLIP)
45+
46+
47+
48+
def convert_back_to_original(state_dict):
49+
new_state_dict = {}
50+
for key, value in state_dict.items():
51+
if key.endswith(".theta"):
52+
base_key = key.replace(".theta", "")
53+
r_key = base_key + ".r"
54+
new_weight = state_dict[r_key] * F.normalize(value, p=2, dim=1)
55+
new_state_dict[base_key + ".weight"] = new_weight
56+
elif key.endswith(".r") or key.endswith(".theta"):
57+
continue # Skip the .r and .theta keys
58+
else:
59+
new_state_dict[key] = value
60+
return new_state_dict
61+
62+
class GmPconverter:
63+
@staticmethod
64+
65+
def convert_model(modelft):
66+
modelft = model
67+
# Extract parameters from the fine-tuned model
68+
config = {
69+
'embed_dim': modelft.text_projection.shape[1],
70+
'image_resolution': modelft.visual.input_resolution,
71+
'vision_layers': modelft.visual.transformer.layers,
72+
'vision_width': modelft.visual.conv1.out_channels,
73+
'vision_patch_size': modelft.visual.conv1.kernel_size[0],
74+
'context_length': modelft.context_length,
75+
'vocab_size': modelft.vocab_size,
76+
'transformer_width': modelft.transformer.width,
77+
'transformer_heads': modelft.transformer.resblocks[0].attn.num_heads,
78+
'transformer_layers': modelft.transformer.layers
79+
}
80+
81+
# Convert state_dict to original CLIP format
82+
fine_tuned_state_dict = modelft.state_dict()
83+
original_state_dict = convert_back_to_original(fine_tuned_state_dict)
84+
from clip.model import CLIP
85+
# Instantiate the original model
86+
original_model = CLIP(**config)
87+
original_model.load_state_dict(original_state_dict)
88+
89+
return original_model
90+
91+
def ModelSaver(model, epoch, save_as_gmp=False):
92+
model_to_save = model
93+
if not save_as_gmp:
94+
model_to_save = GmPconverter.convert_model(model)
95+
96+
model_to_save.to(device)
97+
# File suffix based on save format
98+
suffix = 'as-gmp' if save_as_gmp else 'as-weight'
99+
100+
# Save full model object if enabled
101+
if save_full:
102+
torch.save(model_to_save, f'{ft_checkpoints_folder}/clip_ft_{epoch+1}_full_{suffix}.pt')
103+
104+
# Save state_dict if enabled
105+
if save_dict:
106+
torch.save(model_to_save.state_dict(), f'{ft_checkpoints_folder}/clip_ft_{epoch+1}_dict_{suffix}.pt')
107+
108+
# Save as JIT-traced model if enabled
109+
if save_jit:
110+
sample_data = next(iter(val_dataloader))
111+
112+
images, texts = sample_data # Unpack directly if sample_data is a tuple (images, texts)
113+
images, texts = images[:2], texts[:2]
114+
images, texts = images.to(device), texts.to(device)
115+
116+
117+
model_to_save.eval() # Set to evaluation mode for tracing
118+
script_model = torch.jit.trace(model_to_save, (images, texts))
119+
script_model.save(f'{ft_checkpoints_folder}/clip_ft_{epoch+1}_jit_{suffix}.pt')
120+
121+
del model_to_save
122+
123+
40124
def adjust_unfreeze_rate(epoch, adjust_after=12, increase_rate=2):
41125
if epoch < adjust_after:
42126
return 1 # Initial slower unfreeze rate
@@ -409,12 +493,12 @@ def trainloop():
409493
f.write("============================================================\n")
410494

411495
if (epoch + 1) % 2 == 0 or epoch == EPOCHS - 1:
412-
model_path = f"{ft_checkpoints_folder}/clip_ft_{epoch+1}.pt"
413496
remove_hooks(hooks)# Remove hooks
414-
torch.save(model, model_path)
415-
print(Fore.GREEN + f"Model saved: {model_path}" + Style.RESET_ALL)
497+
print(Fore.CYAN + "Saving checkpoints..." + Style.RESET_ALL)
498+
ModelSaver(model, epoch, save_as_gmp=save_as_gmp) # NEW SAVER
499+
print(Fore.GREEN + f"Model saved to {ft_checkpoints_folder}" + Style.RESET_ALL)
416500
hooks = register_hooks(model, modified_neurons_layers, scale_factors)# Re-attach hooks
417501

418502
remove_hooks(hooks)# After training
419503

420-
trainloop()
504+
trainloop()

0 commit comments

Comments
 (0)