Skip to content

Commit 67a1e05

Browse files
authored
Remove zero3 context manager from LoRA (#11346)
1 parent f6cd628 commit 67a1e05

File tree

1 file changed

+19
-42
lines changed

1 file changed

+19
-42
lines changed

python/llm/example/GPU/LLM-Finetuning/LoRA/alpaca_lora_finetuning.py

+19-42
Original file line numberDiff line numberDiff line change
@@ -158,51 +158,28 @@ def train(
158158
# Check if parameter passed or if set within environ
159159
use_wandb = wandb_check(wandb_project, wandb_watch, wandb_log_model)
160160

161-
if deepspeed_zero3:
162-
deepspeed = deepspeed if deepspeed is not None else "./deepspeed_zero3_config.json"
163-
164161
if saved_low_bit_model is not None:
165162
# Load the low bit optimized model if provide the saved path
166-
if deepspeed_zero3:
167-
import deepspeed as ds
168-
with ds.zero.Init(config_dict_or_path=deepspeed):
169-
model = AutoModelForCausalLM.load_low_bit(
170-
saved_low_bit_model,
171-
optimize_model=False,
172-
torch_dtype=torch.bfloat16,
173-
modules_to_not_convert=["lm_head"],
174-
trust_remote_code=True,
175-
)
176-
else:
177-
model = AutoModelForCausalLM.load_low_bit(
178-
saved_low_bit_model,
179-
optimize_model=False,
180-
torch_dtype=torch.bfloat16,
181-
modules_to_not_convert=["lm_head"],
182-
trust_remote_code=True,
183-
)
163+
model = AutoModelForCausalLM.load_low_bit(
164+
saved_low_bit_model,
165+
optimize_model=False,
166+
torch_dtype=torch.bfloat16,
167+
modules_to_not_convert=["lm_head"],
168+
trust_remote_code=True,
169+
)
170+
else:
171+
model = AutoModelForCausalLM.from_pretrained(
172+
base_model,
173+
load_in_low_bit="bf16",
174+
optimize_model=False,
175+
torch_dtype=torch.bfloat16,
176+
modules_to_not_convert=["lm_head"],
177+
trust_remote_code=True,
178+
)
179+
180+
if deepspeed_zero3:
181+
deepspeed = deepspeed if deepspeed is not None else "./deepspeed_zero3_config.json"
184182
else:
185-
if deepspeed_zero3:
186-
import deepspeed as ds
187-
with ds.zero.Init(config_dict_or_path=deepspeed):
188-
model = AutoModelForCausalLM.from_pretrained(
189-
base_model,
190-
load_in_low_bit="bf16",
191-
optimize_model=False,
192-
torch_dtype=torch.bfloat16,
193-
modules_to_not_convert=["lm_head"],
194-
trust_remote_code=True,
195-
)
196-
else:
197-
model = AutoModelForCausalLM.from_pretrained(
198-
base_model,
199-
load_in_low_bit="bf16",
200-
optimize_model=False,
201-
torch_dtype=torch.bfloat16,
202-
modules_to_not_convert=["lm_head"],
203-
trust_remote_code=True,
204-
)
205-
if not deepspeed_zero3:
206183
print(f"Model loaded on rank {os.environ.get('LOCAL_RANK')}")
207184
model = model.to(f'xpu:{os.environ.get("LOCAL_RANK", 0)}')
208185
print(f"Model moved to rank {os.environ.get('LOCAL_RANK')}")

0 commit comments

Comments
 (0)