|
65 | 65 | )
|
66 | 66 | from transformers import DataCollatorForSeq2Seq as _DataCollatorForSeq2Seq
|
67 | 67 |
|
| 68 | +from transformers import Trainer |
68 | 69 | from transformers import Seq2SeqTrainer as _Seq2SeqTrainer
|
69 | 70 |
|
| 71 | +current_dir = os.path.dirname(os.path.realpath(__file__)) |
| 72 | +common_util_path = os.path.join(current_dir, '..', '..') |
| 73 | +import sys |
| 74 | +sys.path.append(common_util_path) |
| 75 | +from common.utils import get_train_val_data, Prompter |
| 76 | + |
70 | 77 | ModelType = Union[PreTrainedModel, PeftModelForCausalLM]
|
71 | 78 | TokenizerType = Union[PreTrainedTokenizer, PreTrainedTokenizerFast]
|
72 | 79 | app = typer.Typer(pretty_exceptions_show_locals=False)
|
@@ -247,7 +254,7 @@ def _load_datasets(
|
247 | 254 | return dataset_dct
|
248 | 255 |
|
249 | 256 |
|
250 |
| -class DataManager(object): |
| 257 | +class AdvertiseGenDataManager(object): |
251 | 258 | def __init__(self, data_dir: str, data_config: DataConfig):
|
252 | 259 | self._num_proc = data_config.num_proc
|
253 | 260 |
|
@@ -283,6 +290,52 @@ def get_dataset(
|
283 | 290 | num_proc=self._num_proc,
|
284 | 291 | )
|
285 | 292 |
|
| 293 | +class AlpacaDataConfig(object): |
| 294 | + def __init__(self, tokenizer, prompter, train_on_inputs, |
| 295 | + add_eos_token, cutoff_len, val_set_size, seed): |
| 296 | + self.tokenizer = tokenizer |
| 297 | + self.prompter = prompter |
| 298 | + self.train_on_inputs = train_on_inputs |
| 299 | + self.add_eos_token = add_eos_token |
| 300 | + self.cutoff_len = cutoff_len |
| 301 | + self.val_set_size = val_set_size |
| 302 | + self.seed = seed |
| 303 | + |
| 304 | + |
| 305 | +class AlpacaDataManager(object): |
| 306 | + def __init__(self, data_dir: str, data_config: AlpacaDataConfig): |
| 307 | + if data_dir.endswith(".json") or data_dir.endswith(".jsonl"): |
| 308 | + data = load_dataset("json", data_files=data_dir) |
| 309 | + else: |
| 310 | + data = load_dataset(data_dir) |
| 311 | + self.train_data, self.val_data = get_train_val_data( |
| 312 | + data, |
| 313 | + data_config.tokenizer, |
| 314 | + data_config.prompter, |
| 315 | + data_config.train_on_inputs, |
| 316 | + data_config.add_eos_token, |
| 317 | + data_config.cutoff_len, |
| 318 | + data_config.val_set_size, |
| 319 | + seed=data_config.seed) |
| 320 | + self.train_data = self.train_data.remove_columns( |
| 321 | + ['output', 'input', 'instruction', 'attention_mask', 'position_ids']) |
| 322 | + self.val_data = self.val_data.remove_columns( |
| 323 | + ['output', 'input', 'instruction', 'attention_mask', 'position_ids']) |
| 324 | + |
| 325 | + def get_dataset( |
| 326 | + self, |
| 327 | + split: NamedSplit, |
| 328 | + process_fn: Callable[[dict[str, Any]], dict[str, Any]], |
| 329 | + batched: bool = True, |
| 330 | + remove_orig_columns: bool = True, |
| 331 | + ) -> Optional[Dataset]: |
| 332 | + if split == Split.TRAIN: |
| 333 | + return self.train_data |
| 334 | + elif split == Split.VALIDATION: |
| 335 | + return self.val_data |
| 336 | + else: |
| 337 | + return None |
| 338 | + |
286 | 339 |
|
287 | 340 | def print_model_size(model: PreTrainedModel):
|
288 | 341 | print("--> Model")
|
@@ -484,7 +537,17 @@ def main(
|
484 | 537 | ):
|
485 | 538 | ft_config = FinetuningConfig.from_file(config_file)
|
486 | 539 | tokenizer, model = load_tokenizer_and_model(model_dir, peft_config=ft_config.peft_config)
|
487 |
| - data_manager = DataManager(data_dir, ft_config.data_config) |
| 540 | + if tokenizer.pad_token is None: |
| 541 | + tokenizer.pad_token = tokenizer.eos_token |
| 542 | + if 'AdvertiseGen' in data_dir: |
| 543 | + data_manager = AdvertiseGenDataManager(data_dir, ft_config.data_config) |
| 544 | + elif 'alpaca' in data_dir: |
| 545 | + data_config = AlpacaDataConfig(tokenizer=tokenizer, prompter=Prompter("alpaca"), |
| 546 | + train_on_inputs=True, add_eos_token=False, |
| 547 | + cutoff_len=256, val_set_size=2000, seed=42) |
| 548 | + data_manager = AlpacaDataManager(data_dir, data_config) |
| 549 | + else: |
| 550 | + raise NotImplementedError("Wrong dataset, currently only support AdvertiseGen and Alpaca") |
488 | 551 |
|
489 | 552 | train_dataset = data_manager.get_dataset(
|
490 | 553 | Split.TRAIN,
|
@@ -530,38 +593,47 @@ def main(
|
530 | 593 | # turn model to fp32
|
531 | 594 | _prepare_model_for_training(model, ft_config.training_args.use_cpu)
|
532 | 595 |
|
533 |
| - ft_config.training_args.generation_config.pad_token_id = ( |
534 |
| - tokenizer.pad_token_id |
535 |
| - ) |
536 |
| - ft_config.training_args.generation_config.eos_token_id = [ |
537 |
| - tokenizer.eos_token_id, |
538 |
| - tokenizer.get_command('<|user|>'), |
539 |
| - tokenizer.get_command('<|observation|>'), |
540 |
| - ] |
| 596 | + if 'AdvertiseGen' in data_dir: |
| 597 | + ft_config.training_args.generation_config.pad_token_id = ( |
| 598 | + tokenizer.pad_token_id |
| 599 | + ) |
| 600 | + ft_config.training_args.generation_config.eos_token_id = [ |
| 601 | + tokenizer.eos_token_id, |
| 602 | + tokenizer.get_command('<|user|>'), |
| 603 | + tokenizer.get_command('<|observation|>'), |
| 604 | + ] |
541 | 605 | model.gradient_checkpointing_enable()
|
542 | 606 | model.enable_input_require_grads()
|
543 | 607 |
|
544 |
| - use_tokenizer = True |
545 |
| - if ft_config.peft_config is not None: |
546 |
| - use_tokenizer = False if ft_config.peft_config.peft_type == "LORA" else True |
| 608 | + if 'AdvertiseGen' in data_dir: |
| 609 | + use_tokenizer = True |
| 610 | + if ft_config.peft_config is not None: |
| 611 | + use_tokenizer = False if ft_config.peft_config.peft_type == "LORA" else True |
| 612 | + else: |
| 613 | + use_tokenizer = False |
547 | 614 |
|
548 | 615 | # Add below L544-L546 to enable finetuning on 2 Intel Arc XPU cards on top of oneccl and deepspeed
|
549 | 616 | if deepspeed_config_file != '':
|
550 | 617 | ft_config.training_args.ddp_backend = "ccl"
|
551 | 618 | ft_config.training_args.deepspeed = deepspeed_config_file
|
552 | 619 |
|
553 |
| - trainer = Seq2SeqTrainer( |
| 620 | + from transformers import Trainer, TrainingArguments, DataCollatorForSeq2Seq |
| 621 | + |
| 622 | + BASE_TRAINER = Trainer if 'alpaca' in data_dir else Seq2SeqTrainer |
| 623 | + |
| 624 | + trainer = BASE_TRAINER( |
554 | 625 | model=model,
|
555 | 626 | args=ft_config.training_args,
|
556 | 627 | data_collator=DataCollatorForSeq2Seq(
|
557 | 628 | tokenizer=tokenizer,
|
558 |
| - padding='longest', |
559 | 629 | return_tensors='pt',
|
| 630 | + padding=True if 'alpaca' in data_dir else 'longest', |
| 631 | + pad_to_multiple_of=8 if 'alpaca' in data_dir else None, |
560 | 632 | ),
|
561 | 633 | train_dataset=train_dataset,
|
562 | 634 | eval_dataset=val_dataset.select(list(range(50))),
|
563 | 635 | tokenizer=tokenizer if use_tokenizer else None, # LORA does not need tokenizer
|
564 |
| - compute_metrics=functools.partial(compute_metrics, tokenizer=tokenizer), |
| 636 | + compute_metrics=functools.partial(compute_metrics, tokenizer=tokenizer) if 'AdvertiseGen' in data_dir else None, |
565 | 637 | )
|
566 | 638 |
|
567 | 639 | if auto_resume_from_checkpoint.upper() == "" or auto_resume_from_checkpoint is None:
|
|
0 commit comments