@@ -158,51 +158,28 @@ def train(
158
158
# Check if parameter passed or if set within environ
159
159
use_wandb = wandb_check (wandb_project , wandb_watch , wandb_log_model )
160
160
161
- if deepspeed_zero3 :
162
- deepspeed = deepspeed if deepspeed is not None else "./deepspeed_zero3_config.json"
163
-
164
161
if saved_low_bit_model is not None :
165
162
# Load the low bit optimized model if provide the saved path
166
- if deepspeed_zero3 :
167
- import deepspeed as ds
168
- with ds .zero .Init (config_dict_or_path = deepspeed ):
169
- model = AutoModelForCausalLM .load_low_bit (
170
- saved_low_bit_model ,
171
- optimize_model = False ,
172
- torch_dtype = torch .bfloat16 ,
173
- modules_to_not_convert = ["lm_head" ],
174
- trust_remote_code = True ,
175
- )
176
- else :
177
- model = AutoModelForCausalLM .load_low_bit (
178
- saved_low_bit_model ,
179
- optimize_model = False ,
180
- torch_dtype = torch .bfloat16 ,
181
- modules_to_not_convert = ["lm_head" ],
182
- trust_remote_code = True ,
183
- )
163
+ model = AutoModelForCausalLM .load_low_bit (
164
+ saved_low_bit_model ,
165
+ optimize_model = False ,
166
+ torch_dtype = torch .bfloat16 ,
167
+ modules_to_not_convert = ["lm_head" ],
168
+ trust_remote_code = True ,
169
+ )
170
+ else :
171
+ model = AutoModelForCausalLM .from_pretrained (
172
+ base_model ,
173
+ load_in_low_bit = "bf16" ,
174
+ optimize_model = False ,
175
+ torch_dtype = torch .bfloat16 ,
176
+ modules_to_not_convert = ["lm_head" ],
177
+ trust_remote_code = True ,
178
+ )
179
+
180
+ if deepspeed_zero3 :
181
+ deepspeed = deepspeed if deepspeed is not None else "./deepspeed_zero3_config.json"
184
182
else :
185
- if deepspeed_zero3 :
186
- import deepspeed as ds
187
- with ds .zero .Init (config_dict_or_path = deepspeed ):
188
- model = AutoModelForCausalLM .from_pretrained (
189
- base_model ,
190
- load_in_low_bit = "bf16" ,
191
- optimize_model = False ,
192
- torch_dtype = torch .bfloat16 ,
193
- modules_to_not_convert = ["lm_head" ],
194
- trust_remote_code = True ,
195
- )
196
- else :
197
- model = AutoModelForCausalLM .from_pretrained (
198
- base_model ,
199
- load_in_low_bit = "bf16" ,
200
- optimize_model = False ,
201
- torch_dtype = torch .bfloat16 ,
202
- modules_to_not_convert = ["lm_head" ],
203
- trust_remote_code = True ,
204
- )
205
- if not deepspeed_zero3 :
206
183
print (f"Model loaded on rank { os .environ .get ('LOCAL_RANK' )} " )
207
184
model = model .to (f'xpu:{ os .environ .get ("LOCAL_RANK" , 0 )} ' )
208
185
print (f"Model moved to rank { os .environ .get ('LOCAL_RANK' )} " )
0 commit comments