-
Notifications
You must be signed in to change notification settings - Fork 1.1k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Is it possible to fine-tune the model using a fine-tuning method such as LORA? #496
Comments
FHT-hub
changed the title
Is it possible to fine-tune the model using a fine-tuning method such as lora?
Is it possible to fine-tune the model using a fine-tuning method such as LORA?
Aug 2, 2024
Thanks for your answer, I see that the code you provided doesn't mention the speech-to-speech translation part, but I assume the method should work for speech-to-speech translation tasks as well?
…------------------ 原始邮件 ------------------
发件人: "facebookresearch/seamless_communication" ***@***.***>;
发送时间: 2024年8月22日(星期四) 凌晨2:04
***@***.***>;
***@***.******@***.***>;
主题: Re: [facebookresearch/seamless_communication] Is it possible to fine-tune the model using a fine-tuning method such as LORA? (Issue #496)
This can be done with huggingface api. here is a simple training script. I added deepspeed so i can offload optimizer to cpu and increase training batch size.
from transformers import (SeamlessM4TProcessor, SeamlessM4TForSpeechToText, SeamlessM4TForTextToText, SeamlessM4TTokenizer, SeamlessM4Tv2ForSpeechToText, SeamlessM4Tv2ForTextToText) from seamless_communication.cli.m4t.finetune import dist_utils from seamless_communication.cli.m4t.finetune import trainer from torch import nn import torch, audiofile, argparse, os import pandas as pd import audiofile import os from sacrebleu.metrics import BLEU, CHRF, TER from dataclasses import dataclass from accelerate import Accelerator from enum import Enum import logging, json import torch.distributed as dist from pathlib import Path from torch.utils.data import Dataset from transformers import Trainer, TrainingArguments from typing import List, Optional, Tuple, Union from tqdm import tqdm from collections import namedtuple import wandb WANDB_PRJ_NAME = "master_thesis" os.environ["WANDB_PROJECT"] = WANDB_PRJ_NAME # name your W&B project os.environ["WANDB_MODE"] = "offline" accelerator = Accelerator(log_with="wandb") class PandasDataset(Dataset): def __init__(self, dataframe): self.dataframe = dataframe def __len__(self): return self.dataframe.shape[0] def __getitem__(self, idx): return self.dataframe.iloc[idx,:] def get_config() -> argparse.ArgumentParser: parser = argparse.ArgumentParser(description="Example finetuning script for M4T models") parser.add_argument( "--config", type=Path, required=True, help=("config path"),) with open(parser.parse_args().config, 'r') as config_file: return json.load(config_file) def load_hf_model(model_name, mode): if model_name == 'seamlessM4T_v2_large': if mode.lower() == 's2tt': seamless_cls = SeamlessM4Tv2ForSpeechToText else: seamless_cls = SeamlessM4Tv2ForTextToText else: if mode.lower() == 's2tt': seamless_cls = SeamlessM4TForSpeechToText else: seamless_cls = SeamlessM4TForTextToText if model_name == 'seamlessM4T_medium': processor = SeamlessM4TProcessor.from_pretrained("facebook/hf-seamless-m4t-medium") tokenizer = SeamlessM4TTokenizer.from_pretrained("facebook/hf-seamless-m4t-medium") model = seamless_cls.from_pretrained("facebook/hf-seamless-m4t-medium") elif model_name == 'seamlessM4T_large': processor = SeamlessM4TProcessor.from_pretrained("facebook/hf-seamless-m4t-large") tokenizer = SeamlessM4TTokenizer.from_pretrained("facebook/hf-seamless-m4t-large") model = seamless_cls.from_pretrained("facebook/hf-seamless-m4t-large") elif model_name == 'seamlessM4T_v2_large': processor = SeamlessM4TProcessor.from_pretrained("facebook/seamless-m4t-v2-large") tokenizer = SeamlessM4TTokenizer.from_pretrained("facebook/seamless-m4t-v2-large") model = seamless_cls.from_pretrained("facebook/seamless-m4t-v2-large") else: raise Exception(f'Invalid model name({model_name}).') return model, tokenizer, processor def main(): config = get_config() torch.manual_seed(config["seed"]) model, tokenizer, processor = load_hf_model(config["model_name"], config["mode"]) model.train() train_dataset = PandasDataset(pd.read_csv(config["train_dataset"])) eval_dataset = PandasDataset(pd.read_csv(config["eval_dataset"]).iloc[:100,:]) # test_dataset = PandasDataset(pd.read_csv(config["test_dataset)) # log_stuff(accelerator, config) training_args = TrainingArguments( output_dir=config["save_model_dir"], eval_strategy='steps', eval_steps=config["eval_steps"], per_device_train_batch_size=config["train_batch_size"], per_device_eval_batch_size=config["eval_batch_size"], data_seed=config["seed"], eval_on_start=True, adam_epsilon=1e-08, adam_beta1=0.9, adam_beta2=0.98, learning_rate=config["learning_rate"], warmup_steps=config["warmup_steps"], max_steps=config["max_steps"], num_train_epochs=config["max_epochs"], bf16=config["bf16"], fp16=config["fp16"], report_to='none') data_collator = SeamlessDataCollator(tokenizer, processor, config, accelerator) trainer = accelerator.prepare(Trainer(model=model, args=training_args, train_dataset=train_dataset, eval_dataset=eval_dataset, data_collator=data_collator)) trainer.train() @accelerator.on_main_process def log_stuff(accelerator, config): accelerator.init_trackers(project_name=WANDB_PRJ_NAME, config=config) with open(config["dataset_info"], 'r') as config_file: metadata=json.load(config_file) wandb_tracker = accelerator.get_tracker("wandb", unwrap=True) wandb_tracker.log_artifact(wandb.Artifact( "ds2t_only", type="dataset", description="Speech to text translation dataset", metadata=metadata)) class SeamlessDataCollator(): def __init__(self, tokenizer, processor, config, accelerator): self.tokenizer = tokenizer self.processor = processor self.src_lang = config["src_lang"] self.tgt_lang = config["tgt_lang"] self.sample_rate = config["sample_rate"] if config["mode"].lower() == 's2tt': self.collator_fn = self._prepare_batch_s2t else: self.collator_fn = self._prepare_batch_t2t self.accelerator = accelerator def __call__(self, batch): batch_dict = {} for elm in batch: for k,v in elm.items(): if k in batch_dict: batch_dict[k].append(v) else: batch_dict[k] = [v] return self.collator_fn(batch_dict) def _prepare_batch_s2t(self, batch): audio_array_list = [] for audio_path in batch['src_audio_path']: audio_array, sr = audiofile.read(audio_path) assert sr == self.sample_rate audio_array_list.append(audio_array) audio_inputs = self.processor(audios=audio_array_list, sampling_rate=self.sample_rate, return_tensors="pt") #src_text which is unused. src_text = ['hello world' for _ in batch['tgt_text']] text_inputs = self.tokenizer(text=src_text, text_target=batch['tgt_text'], src_lang=self.src_lang, tgt_lang=self.tgt_lang, return_tensors="pt") return {'input_features' : audio_inputs.input_features, 'attention_mask' : audio_inputs.attention_mask, 'labels' : text_inputs.labels, 'tgt_lang' : self.tgt_lang} def _prepare_batch_t2t(self, batch): text_inputs = self.tokenizer(text=batch['src_text'], text_target=batch['tgt_text'], src_lang=self.src_lang, tgt_lang=self.tgt_lang, return_tensors="pt") return {'input_ids' : text_inputs.input_ids, 'attention_mask' : text_inputs.attention_mask, 'labels' : text_inputs.labels, 'tgt_lang' : self.tgt_lang} main()
training_config.json
{ "train_dataset": "/path/to/train.csv", "eval_dataset":"/path/to/dev.csv", "test_dataset":"/path/to/test.csv", "dataset_info":"/path/to/dataset_info.json", "model_name":"seamlessM4T_v2_large", "bf16": true, "fp16": false, "save_model_dir":"/path/to/save_dir/", "seed":42, "train_batch_size":8, "eval_batch_size":8, "test_batch_size":32, "patience":10000, "max_epochs":1, "learning_rate":1e-6, "max_steps": 10000000, "warmup_steps": 100, "eval_steps": 1000, "log_steps":1000, "max_src_tokens": 100000000, "mode" : "s2t", "freeze_layers": [], "device": "cuda", "sample_rate": 16000, "src_lang": "eng", "tgt_lang": "pes" }
deepspeed_config.json
{ "zero_optimization": { "stage": 3, "offload_optimizer": { "device": "cpu", "pin_memory": true }, "overlap_comm": true, "contiguous_gradients": true, "sub_group_size": 0, "reduce_bucket_size": "auto", "stage3_prefetch_bucket_size": 0, "stage3_param_persistence_threshold": "auto", "stage3_max_live_parameters": 0, "stage3_max_reuse_distance": 0, "stage3_gather_16bit_weights_on_model_save": true }, "bf16": { "enabled": "auto" }, "fp16": { "enabled": "auto", "auto_cast": false, "loss_scale": 0, "initial_scale_power": 32, "loss_scale_window": 1000, "hysteresis": 2, "min_loss_scale": 1 }, "optimizer": { "type": "AdamW", "params": { "lr": "auto", "betas": "auto", "eps": "auto", "weight_decay": "auto" } }, "gradient_accumulation_steps": "auto", "train_batch_size": "auto", "train_micro_batch_size_per_gpu": "auto", "wall_clock_breakdown": false }
accelerate_config.yaml
compute_environment: LOCAL_MACHINE debug: false deepspeed_config: deepspeed_config_file: deepspeed_config.json distributed_type: DEEPSPEED downcast_bf16: 'no' enable_cpu_affinity: false machine_rank: 0 main_training_function: main num_machines: 1 num_processes: 2 rdzv_backend: static same_network: true tpu_env: [] tpu_use_cluster: false tpu_use_sudo: false use_cpu: false
Then you can simply add lora to the model with following lines :
from peft import LoraConfig, get_peft_model lora_config = LoraConfig(...) peft_model = get_peft_model(model, lora_config)
finally you can run the script with this :
CUDA_VISIBLE_DEVICES=0,1 accelerate launch --config_file /path/to/accelerate_config.yaml \ /path/to/hf_seamless_trainer.py --config /path/to/training_config.json
—
Reply to this email directly, view it on GitHub, or unsubscribe.
You are receiving this because you authored the thread.Message ID: ***@***.***>
|
here is a trainer file (replace it with src/seamless_communication/cli/m4t/finetune/trainer.py) with added lora. pass a lora_config dictionary to UnitYFinetune. lora_config = {''r"=32,alpha=64,dropout=0.2,keys=[".*_proj"]}
|
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Is it possible to fine-tune the model using a fine-tuning method such as LORA?
The text was updated successfully, but these errors were encountered: