Skip to content

Commit 365adad

Browse files
authored
Support LoRA ChatGLM with Alpaca Dataset (#11580)
* Support LoRA ChatGLM with Alpaca Dataset * refine * fix * add 2-card alpaca
1 parent 99c2274 commit 365adad

6 files changed

+162
-20
lines changed

python/llm/example/GPU/LLM-Finetuning/LoRA/chatglm_finetune/README.md

+22-4
Original file line numberDiff line numberDiff line change
@@ -31,20 +31,30 @@ source /opt/intel/oneapi/setvars.sh
3131

3232
### 3. LoRA Fine-Tune on ChatGLM3-6B
3333

34-
First, download the dataset: we use `AdvertiseGen` to finetune ChatGLM3-6B in the following, and please now get it from [Google Drive](https://drive.google.com/file/d/13_vf0xRTQsyneRKdD1bZIr93vBGOczrk/view?usp=sharing) or [Tsinghua Cloud](https://cloud.tsinghua.edu.cn/f/b3f119a008264b1cabd1/?dl=1), and unzip it in the current directory. Then, process the dataset with the below script:
34+
First, as for the dataset, you have two options:
35+
36+
1. `AdvertiseGen`: please now get it from [Google Drive](https://drive.google.com/file/d/13_vf0xRTQsyneRKdD1bZIr93vBGOczrk/view?usp=sharing) or [Tsinghua Cloud](https://cloud.tsinghua.edu.cn/f/b3f119a008264b1cabd1/?dl=1), and unzip it in the current directory. Then, process the dataset with the below script:
3537

3638
```bash
3739
python process_advertise_gen_dataset.py
3840
```
3941

4042
Then, './AdvertiseGen' will be converted to './AdvertiseGen_fix'. Now, we have prepared the dataset, and are going to start LoRA fine-tuning on ChatGLM3-6B.
4143

44+
2. `Alapca`: We also support [yahma/alpaca-cleaned](https://huggingface.co/datasets/yahma/alpaca-cleaned) that contains generated instructions and demonstrations. It does not require preprocessing, and please directy run the following script.
45+
4246
#### 3.1. Fine-Tune with a Single Arc Card
4347

44-
Start the fine-tuning by:
48+
1. For `AdvertiseGen`, start the fine-tuning by:
49+
50+
```bash
51+
bash lora_finetuning_chatglm3_6b_on_advertise_gen_with_1_arc_card.sh
52+
```
53+
54+
2. For `Alpaca`, start the fine-tuning by:
4555

4656
```bash
47-
bash lora_finetuning_on_chatglm3_6b_with_1_arc_card.sh
57+
bash lora_finetuning_chatglm3_6b_on_alpaca_with_1_arc_card.sh
4858
```
4959

5060
Then, you will get output are as below:
@@ -145,6 +155,14 @@ Training completed. Do not forget to share your model on huggingface.co/models =
145155

146156
Start the data-parallel fine-tuning on 2 Intel Arc XPU cards by:
147157

158+
1. `AdvertiseGen` dataset:
159+
160+
```bash
161+
bash lora_finetuning_chatglm3_6b_on_advertise_gen_with_2_arc_cards.sh
162+
```
163+
164+
2. `Alpaca` dataset:
165+
148166
```bash
149-
bash lora_finetuning_on_chatglm3_6b_with_2_arc_cards.sh
167+
bash lora_finetuning_chatglm3_6b_on_alpaca_with_2_arc_cards.sh
150168
```

python/llm/example/GPU/LLM-Finetuning/LoRA/chatglm_finetune/lora_finetune_chatglm.py

+88-16
Original file line numberDiff line numberDiff line change
@@ -65,8 +65,15 @@
6565
)
6666
from transformers import DataCollatorForSeq2Seq as _DataCollatorForSeq2Seq
6767

68+
from transformers import Trainer
6869
from transformers import Seq2SeqTrainer as _Seq2SeqTrainer
6970

71+
current_dir = os.path.dirname(os.path.realpath(__file__))
72+
common_util_path = os.path.join(current_dir, '..', '..')
73+
import sys
74+
sys.path.append(common_util_path)
75+
from common.utils import get_train_val_data, Prompter
76+
7077
ModelType = Union[PreTrainedModel, PeftModelForCausalLM]
7178
TokenizerType = Union[PreTrainedTokenizer, PreTrainedTokenizerFast]
7279
app = typer.Typer(pretty_exceptions_show_locals=False)
@@ -247,7 +254,7 @@ def _load_datasets(
247254
return dataset_dct
248255

249256

250-
class DataManager(object):
257+
class AdvertiseGenDataManager(object):
251258
def __init__(self, data_dir: str, data_config: DataConfig):
252259
self._num_proc = data_config.num_proc
253260

@@ -283,6 +290,52 @@ def get_dataset(
283290
num_proc=self._num_proc,
284291
)
285292

293+
class AlpacaDataConfig(object):
294+
def __init__(self, tokenizer, prompter, train_on_inputs,
295+
add_eos_token, cutoff_len, val_set_size, seed):
296+
self.tokenizer = tokenizer
297+
self.prompter = prompter
298+
self.train_on_inputs = train_on_inputs
299+
self.add_eos_token = add_eos_token
300+
self.cutoff_len = cutoff_len
301+
self.val_set_size = val_set_size
302+
self.seed = seed
303+
304+
305+
class AlpacaDataManager(object):
306+
def __init__(self, data_dir: str, data_config: AlpacaDataConfig):
307+
if data_dir.endswith(".json") or data_dir.endswith(".jsonl"):
308+
data = load_dataset("json", data_files=data_dir)
309+
else:
310+
data = load_dataset(data_dir)
311+
self.train_data, self.val_data = get_train_val_data(
312+
data,
313+
data_config.tokenizer,
314+
data_config.prompter,
315+
data_config.train_on_inputs,
316+
data_config.add_eos_token,
317+
data_config.cutoff_len,
318+
data_config.val_set_size,
319+
seed=data_config.seed)
320+
self.train_data = self.train_data.remove_columns(
321+
['output', 'input', 'instruction', 'attention_mask', 'position_ids'])
322+
self.val_data = self.val_data.remove_columns(
323+
['output', 'input', 'instruction', 'attention_mask', 'position_ids'])
324+
325+
def get_dataset(
326+
self,
327+
split: NamedSplit,
328+
process_fn: Callable[[dict[str, Any]], dict[str, Any]],
329+
batched: bool = True,
330+
remove_orig_columns: bool = True,
331+
) -> Optional[Dataset]:
332+
if split == Split.TRAIN:
333+
return self.train_data
334+
elif split == Split.VALIDATION:
335+
return self.val_data
336+
else:
337+
return None
338+
286339

287340
def print_model_size(model: PreTrainedModel):
288341
print("--> Model")
@@ -484,7 +537,17 @@ def main(
484537
):
485538
ft_config = FinetuningConfig.from_file(config_file)
486539
tokenizer, model = load_tokenizer_and_model(model_dir, peft_config=ft_config.peft_config)
487-
data_manager = DataManager(data_dir, ft_config.data_config)
540+
if tokenizer.pad_token is None:
541+
tokenizer.pad_token = tokenizer.eos_token
542+
if 'AdvertiseGen' in data_dir:
543+
data_manager = AdvertiseGenDataManager(data_dir, ft_config.data_config)
544+
elif 'alpaca' in data_dir:
545+
data_config = AlpacaDataConfig(tokenizer=tokenizer, prompter=Prompter("alpaca"),
546+
train_on_inputs=True, add_eos_token=False,
547+
cutoff_len=256, val_set_size=2000, seed=42)
548+
data_manager = AlpacaDataManager(data_dir, data_config)
549+
else:
550+
raise NotImplementedError("Wrong dataset, currently only support AdvertiseGen and Alpaca")
488551

489552
train_dataset = data_manager.get_dataset(
490553
Split.TRAIN,
@@ -530,38 +593,47 @@ def main(
530593
# turn model to fp32
531594
_prepare_model_for_training(model, ft_config.training_args.use_cpu)
532595

533-
ft_config.training_args.generation_config.pad_token_id = (
534-
tokenizer.pad_token_id
535-
)
536-
ft_config.training_args.generation_config.eos_token_id = [
537-
tokenizer.eos_token_id,
538-
tokenizer.get_command('<|user|>'),
539-
tokenizer.get_command('<|observation|>'),
540-
]
596+
if 'AdvertiseGen' in data_dir:
597+
ft_config.training_args.generation_config.pad_token_id = (
598+
tokenizer.pad_token_id
599+
)
600+
ft_config.training_args.generation_config.eos_token_id = [
601+
tokenizer.eos_token_id,
602+
tokenizer.get_command('<|user|>'),
603+
tokenizer.get_command('<|observation|>'),
604+
]
541605
model.gradient_checkpointing_enable()
542606
model.enable_input_require_grads()
543607

544-
use_tokenizer = True
545-
if ft_config.peft_config is not None:
546-
use_tokenizer = False if ft_config.peft_config.peft_type == "LORA" else True
608+
if 'AdvertiseGen' in data_dir:
609+
use_tokenizer = True
610+
if ft_config.peft_config is not None:
611+
use_tokenizer = False if ft_config.peft_config.peft_type == "LORA" else True
612+
else:
613+
use_tokenizer = False
547614

548615
# Add below L544-L546 to enable finetuning on 2 Intel Arc XPU cards on top of oneccl and deepspeed
549616
if deepspeed_config_file != '':
550617
ft_config.training_args.ddp_backend = "ccl"
551618
ft_config.training_args.deepspeed = deepspeed_config_file
552619

553-
trainer = Seq2SeqTrainer(
620+
from transformers import Trainer, TrainingArguments, DataCollatorForSeq2Seq
621+
622+
BASE_TRAINER = Trainer if 'alpaca' in data_dir else Seq2SeqTrainer
623+
624+
trainer = BASE_TRAINER(
554625
model=model,
555626
args=ft_config.training_args,
556627
data_collator=DataCollatorForSeq2Seq(
557628
tokenizer=tokenizer,
558-
padding='longest',
559629
return_tensors='pt',
630+
padding=True if 'alpaca' in data_dir else 'longest',
631+
pad_to_multiple_of=8 if 'alpaca' in data_dir else None,
560632
),
561633
train_dataset=train_dataset,
562634
eval_dataset=val_dataset.select(list(range(50))),
563635
tokenizer=tokenizer if use_tokenizer else None, # LORA does not need tokenizer
564-
compute_metrics=functools.partial(compute_metrics, tokenizer=tokenizer),
636+
compute_metrics=functools.partial(compute_metrics, tokenizer=tokenizer) if 'AdvertiseGen' in data_dir else None,
565637
)
566638

567639
if auto_resume_from_checkpoint.upper() == "" or auto_resume_from_checkpoint is None:
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
#
2+
# Copyright 2016 The BigDL Authors.
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
#
16+
17+
export BIGDL_CHECK_DUPLICATE_IMPORT=0
18+
19+
# You can also set the remote model repository to a local model path
20+
python lora_finetune_chatglm.py \
21+
yahma/alpaca-cleaned \
22+
THUDM/chatglm3-6b \
23+
./lora_config.yaml
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
#
2+
# Copyright 2016 The BigDL Authors.
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
#
16+
17+
export MASTER_ADDR=127.0.0.1
18+
export OMP_NUM_THREADS=6
19+
export FI_PROVIDER=tcp
20+
export CCL_ATL_TRANSPORT=ofi
21+
export BIGDL_CHECK_DUPLICATE_IMPORT=0
22+
23+
# You can also set the remote model repository to a local model path
24+
mpirun -n 2 \
25+
python lora_finetune_chatglm.py \
26+
yahma/alpaca-cleaned \
27+
THUDM/chatglm3-6b \
28+
./lora_config.yaml \
29+
./deepspeed_config.json

0 commit comments

Comments
 (0)