37
37
text_logs_folder = 'ft-logs'
38
38
os .makedirs (text_logs_folder , exist_ok = True )
39
39
40
+ # Model Saving Options; the default is 'legacy behavior' (only save full model, save as GmP)
41
+ save_full = True # Save full model object
42
+ save_dict = False # Save state_dict
43
+ save_jit = False # Save as JIT-traced model
44
+ save_as_gmp = True # True for saving in GmP format with .theta, .r; False for converting back to .weight (original OpenAI/CLIP)
45
+
46
+
47
+
48
+ def convert_back_to_original (state_dict ):
49
+ new_state_dict = {}
50
+ for key , value in state_dict .items ():
51
+ if key .endswith (".theta" ):
52
+ base_key = key .replace (".theta" , "" )
53
+ r_key = base_key + ".r"
54
+ new_weight = state_dict [r_key ] * F .normalize (value , p = 2 , dim = 1 )
55
+ new_state_dict [base_key + ".weight" ] = new_weight
56
+ elif key .endswith (".r" ) or key .endswith (".theta" ):
57
+ continue # Skip the .r and .theta keys
58
+ else :
59
+ new_state_dict [key ] = value
60
+ return new_state_dict
61
+
62
+ class GmPconverter :
63
+ @staticmethod
64
+
65
+ def convert_model (modelft ):
66
+ modelft = model
67
+ # Extract parameters from the fine-tuned model
68
+ config = {
69
+ 'embed_dim' : modelft .text_projection .shape [1 ],
70
+ 'image_resolution' : modelft .visual .input_resolution ,
71
+ 'vision_layers' : modelft .visual .transformer .layers ,
72
+ 'vision_width' : modelft .visual .conv1 .out_channels ,
73
+ 'vision_patch_size' : modelft .visual .conv1 .kernel_size [0 ],
74
+ 'context_length' : modelft .context_length ,
75
+ 'vocab_size' : modelft .vocab_size ,
76
+ 'transformer_width' : modelft .transformer .width ,
77
+ 'transformer_heads' : modelft .transformer .resblocks [0 ].attn .num_heads ,
78
+ 'transformer_layers' : modelft .transformer .layers
79
+ }
80
+
81
+ # Convert state_dict to original CLIP format
82
+ fine_tuned_state_dict = modelft .state_dict ()
83
+ original_state_dict = convert_back_to_original (fine_tuned_state_dict )
84
+ from clip .model import CLIP
85
+ # Instantiate the original model
86
+ original_model = CLIP (** config )
87
+ original_model .load_state_dict (original_state_dict )
88
+
89
+ return original_model
90
+
91
+ def ModelSaver (model , epoch , save_as_gmp = False ):
92
+ model_to_save = model
93
+ if not save_as_gmp :
94
+ model_to_save = GmPconverter .convert_model (model )
95
+
96
+ model_to_save .to (device )
97
+ # File suffix based on save format
98
+ suffix = 'as-gmp' if save_as_gmp else 'as-weight'
99
+
100
+ # Save full model object if enabled
101
+ if save_full :
102
+ torch .save (model_to_save , f'{ ft_checkpoints_folder } /clip_ft_{ epoch + 1 } _full_{ suffix } .pt' )
103
+
104
+ # Save state_dict if enabled
105
+ if save_dict :
106
+ torch .save (model_to_save .state_dict (), f'{ ft_checkpoints_folder } /clip_ft_{ epoch + 1 } _dict_{ suffix } .pt' )
107
+
108
+ # Save as JIT-traced model if enabled
109
+ if save_jit :
110
+ sample_data = next (iter (val_dataloader ))
111
+
112
+ images , texts = sample_data # Unpack directly if sample_data is a tuple (images, texts)
113
+ images , texts = images [:2 ], texts [:2 ]
114
+ images , texts = images .to (device ), texts .to (device )
115
+
116
+
117
+ model_to_save .eval () # Set to evaluation mode for tracing
118
+ script_model = torch .jit .trace (model_to_save , (images , texts ))
119
+ script_model .save (f'{ ft_checkpoints_folder } /clip_ft_{ epoch + 1 } _jit_{ suffix } .pt' )
120
+
121
+ del model_to_save
122
+
123
+
40
124
def adjust_unfreeze_rate (epoch , adjust_after = 12 , increase_rate = 2 ):
41
125
if epoch < adjust_after :
42
126
return 1 # Initial slower unfreeze rate
@@ -409,12 +493,12 @@ def trainloop():
409
493
f .write ("============================================================\n " )
410
494
411
495
if (epoch + 1 ) % 2 == 0 or epoch == EPOCHS - 1 :
412
- model_path = f"{ ft_checkpoints_folder } /clip_ft_{ epoch + 1 } .pt"
413
496
remove_hooks (hooks )# Remove hooks
414
- torch .save (model , model_path )
415
- print (Fore .GREEN + f"Model saved: { model_path } " + Style .RESET_ALL )
497
+ print (Fore .CYAN + "Saving checkpoints..." + Style .RESET_ALL )
498
+ ModelSaver (model , epoch , save_as_gmp = save_as_gmp ) # NEW SAVER
499
+ print (Fore .GREEN + f"Model saved to { ft_checkpoints_folder } " + Style .RESET_ALL )
416
500
hooks = register_hooks (model , modified_neurons_layers , scale_factors )# Re-attach hooks
417
501
418
502
remove_hooks (hooks )# After training
419
503
420
- trainloop ()
504
+ trainloop ()
0 commit comments