Skip to content

Commit

Permalink
nemo automodel sft squad data prep fix (#11994)
Browse files Browse the repository at this point in the history
* use pad_token_id= 0 if padding loss_mask; update squad dataset

Signed-off-by: Alexandros Koumparoulis <[email protected]>

* update LLM peft/sft scripts + FSDP2 strategy

Signed-off-by: Alexandros Koumparoulis <[email protected]>

* fix

Signed-off-by: Alexandros Koumparoulis <[email protected]>

* fix

Signed-off-by: Alexandros Koumparoulis <[email protected]>

* fix

Signed-off-by: Alexandros Koumparoulis <[email protected]>

* fix

Signed-off-by: Alexandros Koumparoulis <[email protected]>

* Apply isort and black reformatting

Signed-off-by: akoumpa <[email protected]>

* Update hf.py

Signed-off-by: Alexandros Koumparoulis <[email protected]>

* Update hf.py

Signed-off-by: Alexandros Koumparoulis <[email protected]>

* Update hf.py

Signed-off-by: Alexandros Koumparoulis <[email protected]>

* Update hf.py

Signed-off-by: Alexandros Koumparoulis <[email protected]>

---------

Signed-off-by: Alexandros Koumparoulis <[email protected]>
Signed-off-by: akoumpa <[email protected]>
Signed-off-by: Alexandros Koumparoulis <[email protected]>
Co-authored-by: akoumpa <[email protected]>
  • Loading branch information
akoumpa and akoumpa authored Jan 30, 2025
1 parent 7ce7f9e commit ea5ed67
Show file tree
Hide file tree
Showing 3 changed files with 71 additions and 106 deletions.
86 changes: 34 additions & 52 deletions examples/llm/peft/hf.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import tempfile

import fiddle as fdl
from lightning.pytorch.loggers import WandbLogger

Expand All @@ -22,35 +24,24 @@


def make_squad_hf_dataset(tokenizer):
EOS_TOKEN = tokenizer.eos_token # Must add EOS_TOKEN

def formatting_prompts_func(examples):
alpaca_prompt = """Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.
### Instruction:
{}
### Input:
{}
### Response:
{}"""
instruction = examples["context"]
input = examples["question"]
output = examples["answers"]['text']
if isinstance(output, list):
output = output[0]
text = alpaca_prompt.format(instruction, input, output) + EOS_TOKEN
ans = tokenizer(text)
# 'input_ids' is a list, we want to remove EOS_TOKEN from input_ids and the first token from
# labels to align the two:
ans['labels'] = list(ans['input_ids'][1:])
ans['input_ids'] = ans['input_ids'][:-1]
ans['attention_mask'] = ans['attention_mask'][:-1]
return ans

tokenizer = getattr(tokenizer, 'tokenizer', tokenizer)
datamodule = llm.HFDatasetDataModule("rajpurkar/squad", split="train[:100]", pad_token_id=tokenizer.eos_token_id)
def formatting_prompts_func(example):
formatted_text = [
f"Context: {example['context']} Question: {example['question']} Answer:",
f" {example['answers']['text'][0].strip()}",
]
context_ids, answer_ids = list(map(tokenizer.text_to_ids, formatted_text))
if len(context_ids) > 0 and context_ids[0] != tokenizer.bos_id:
context_ids.insert(0, tokenizer.bos_id)
if len(answer_ids) > 0 and answer_ids[-1] != tokenizer.eos_id:
answer_ids.append(tokenizer.eos_id)

return dict(
labels=(context_ids + answer_ids)[1:],
input_ids=(context_ids + answer_ids)[:-1],
loss_mask=[0] * (len(context_ids) - 1) + [1] * len(answer_ids),
)

datamodule = llm.HFDatasetDataModule("rajpurkar/squad", split="train[:100]", pad_token_id=tokenizer.eos_id)
datamodule.map(
formatting_prompts_func,
batched=False,
Expand All @@ -61,17 +52,19 @@ def formatting_prompts_func(examples):


def main():
"""Example script to run PEFT with a HF transformers-instantiated model on squad."""
import argparse

parser = argparse.ArgumentParser()
parser.add_argument('--model', default='meta-llama/Llama-3.2-1B')
parser.add_argument('--strategy', type=str, default='auto', choices=['auto', 'ddp', 'fsdp'])
parser.add_argument('--devices', default=1)
parser.add_argument('--accelerator', default='gpu', choices=['gpu'])
parser.add_argument('--model', type=str, default='meta-llama/Llama-3.2-1B')
parser.add_argument('--strategy', type=str, default='auto', choices=['auto', 'ddp', 'fsdp2'])
parser.add_argument('--devices', type=int, default=1)
parser.add_argument('--accelerator', type=str, default='gpu', choices=['gpu'])
parser.add_argument('--grad-clip', type=float, default=1.0)
parser.add_argument('--max-steps', type=int, default=100)
parser.add_argument('--wandb-project', type=str, default=None)
parser.add_argument('--use-torch-jit', action='store_true')
parser.add_argument('--ckpt-folder', type=str, default=None)
parser.add_argument('--ckpt-folder', type=str, default=tempfile.TemporaryDirectory().name)
args = parser.parse_args()

wandb = None
Expand All @@ -81,29 +74,18 @@ def main():
project=args.wandb_project,
name=f'{model}_dev{args.devices}_strat_{args.strategy}',
)
grad_clip = 0.5
if args.strategy == 'fsdp':
# See:
# https://github.com/Lightning-AI/pytorch-lightning/blob/8ad3e29816a63d8ce5c00ac104b14729a4176f4f/src/lightning/pytorch/plugins/precision/fsdp.py#L81
grad_clip = None
use_dist_samp = False

import tempfile

if args.ckpt_folder is None:
args.ckpt_folder = tempfile.TemporaryDirectory().name
print("Temp directory created for base model: ", args.ckpt_folder)

tokenizer = llm.HFAutoModelForCausalLM.configure_tokenizer(args.model)

callbacks = []
if args.use_torch_jit:
jit_config = JitConfig(use_torch=True, torch_kwargs={'dynamic': True}, use_thunder=False)
callbacks = [JitTransform(jit_config)]

if args.strategy == 'fsdp2':
args.strategy = nl.FSDP2Strategy(data_parallel_size=args.devices, tensor_parallel_size=1)

llm.api.finetune(
model=llm.HFAutoModelForCausalLM(args.model),
data=make_squad_hf_dataset(tokenizer.tokenizer),
model=llm.HFAutoModelForCausalLM(model_name=args.model),
data=make_squad_hf_dataset(llm.HFAutoModelForCausalLM.configure_tokenizer(args.model)),
trainer=nl.Trainer(
devices=args.devices,
max_steps=args.max_steps,
Expand All @@ -113,8 +95,8 @@ def main():
limit_val_batches=0.0,
num_sanity_val_steps=0,
accumulate_grad_batches=10,
gradient_clip_val=grad_clip,
use_distributed_sampler=use_dist_samp,
gradient_clip_val=args.grad_clip,
use_distributed_sampler=False,
logger=wandb,
callbacks=callbacks,
precision="bf16",
Expand Down
46 changes: 17 additions & 29 deletions examples/llm/sft/hf.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
#!/usr/bin/python3
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
Expand All @@ -12,14 +13,15 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from functools import partial

import fiddle as fdl
import lightning.pytorch as pl
from lightning.pytorch.loggers import WandbLogger
from torch.utils.data import DataLoader

from nemo import lightning as nl
from nemo.collections import llm
from nemo.lightning.pytorch.accelerate.transformer_engine import is_te_accelerated
from nemo.lightning.pytorch.callbacks import JitConfig, JitTransform


Expand Down Expand Up @@ -55,7 +57,6 @@ def squad(tokenizer) -> pl.LightningDataModule:
num_workers=0,
dataset_kwargs={
"sanity_check_dist_workers": False,
"pad_to_max_length": True,
"get_attention_mask_from_fusion": True,
},
)
Expand All @@ -66,13 +67,14 @@ def main():
import argparse

parser = argparse.ArgumentParser()
parser.add_argument('--model', default='meta-llama/Llama-3.2-1B')
parser.add_argument('--strategy', type=str, default='auto', choices=['auto', 'ddp', 'fsdp'])
parser.add_argument('--devices', default=1)
parser.add_argument('--accelerator', default='gpu', choices=['gpu'])
parser.add_argument('--model-accelerator', default=None, choices=['te'])
parser.add_argument('--model', type=str, default='meta-llama/Llama-3.2-1B')
parser.add_argument('--strategy', type=str, default='auto', choices=['auto', 'ddp', 'fsdp2'])
parser.add_argument('--devices', type=int, default=1)
parser.add_argument('--accelerator', type=str, default='gpu', choices=['gpu'])
parser.add_argument('--grad-clip', type=float, default=1.0)
parser.add_argument('--model-accelerator', type=str, default=None, choices=['te'])
parser.add_argument('--max-steps', type=int, default=100)
parser.add_argument("--fp8-autocast", default=False, action='store_true')
parser.add_argument("--fp8-autocast", action='store_true')
parser.add_argument('--wandb-project', type=str, default=None)
parser.add_argument('--model-save-path', type=str, default=None)
parser.add_argument('--use-torch-jit', action='store_true')
Expand All @@ -85,32 +87,24 @@ def main():
project=args.wandb_project,
name=f'{model}_dev{args.devices}_strat_{args.strategy}',
)
grad_clip = 0.5
if args.strategy == 'fsdp':
# See: https://github.com/Lightning-AI/pytorch-lightning/blob/8ad3e29816a63d8ce5c00ac104b14729a4176f4f/src/lightning/pytorch/plugins/precision/fsdp.py#L81
grad_clip = None
use_dist_samp = False

model_accelerator = None
if args.model_accelerator == "te":
from functools import partial
from nemo.lightning.pytorch.accelerate.transformer_engine import te_accelerate

model_accelerator = partial(te_accelerate, fp8_autocast=args.fp8_autocast)

from nemo.lightning.pytorch.accelerate.transformer_engine import te_accelerate

model = llm.HFAutoModelForCausalLM(model_name=args.model, model_accelerator=model_accelerator)
tokenizer = model.tokenizer

callbacks = []
if args.use_torch_jit:
jit_config = JitConfig(use_torch=True, torch_kwargs={'dynamic': False}, use_thunder=False)
callbacks = [JitTransform(jit_config)]

if args.strategy == 'fsdp2':
args.strategy = nl.FSDP2Strategy(data_parallel_size=args.devices, tensor_parallel_size=1)

llm.api.finetune(
model=model,
data=squad(tokenizer),
model=llm.HFAutoModelForCausalLM(model_name=args.model, model_accelerator=model_accelerator),
data=squad(llm.HFAutoModelForCausalLM.configure_tokenizer(args.model)),
trainer=nl.Trainer(
devices=args.devices,
max_steps=args.max_steps,
Expand All @@ -120,8 +114,8 @@ def main():
limit_val_batches=0.0,
num_sanity_val_steps=0,
accumulate_grad_batches=10,
gradient_clip_val=grad_clip,
use_distributed_sampler=use_dist_samp,
gradient_clip_val=args.grad_clip,
use_distributed_sampler=False,
logger=wandb,
callbacks=callbacks,
precision="bf16",
Expand All @@ -130,12 +124,6 @@ def main():
log=None,
)

if args.model_accelerator:
if args.model_accelerator == "te":
te_acc = is_te_accelerated(model.model)
assert te_acc, "Transformer Engine acceleration was unsuccessful"
print("TE Accelerated: ", te_acc)

if args.model_save_path is not None:
model.save_pretrained(args.model_save_path)

Expand Down
45 changes: 20 additions & 25 deletions nemo/collections/llm/gpt/data/hf_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,7 +198,7 @@ def pad_within_micro(batch, pad_token_id):
torch.LongTensor(
pad_within_micro(
extract_key_from_dicts(batch, key),
pad_token_id,
pad_token_id if key != 'loss_mask' else 0,
)
)
)
Expand Down Expand Up @@ -268,33 +268,28 @@ def map(self, function=None, split_names=None, **kwargs):
class SquadHFDataModule(HFDatasetDataModule):
def __init__(self, tokenizer, **kwargs):
super().__init__(**kwargs)
self.tokenizer = getattr(tokenizer, 'tokenizer', tokenizer)

def formatting_prompts_func(self, examples):
EOS_TOKEN = self.tokenizer.eos_token # Must add EOS_TOKEN
alpaca_prompt = """Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.
### Instruction:
{}
### Input:
{}
### Response:
{}"""
instruction = examples["context"]
input = examples["question"]
output = examples["answers"]['text']
if isinstance(output, list):
output = output[0]
text = alpaca_prompt.format(instruction, input, output) + EOS_TOKEN
ans = self.tokenizer(text)
ans['labels'] = ans['input_ids']
return ans
self.tokenizer = tokenizer

def formatting_prompts_func(self, example):
formatted_text = [
f"Context: {example['context']} Question: {example['question']} Answer:",
f" {example['answers']['text'][0].strip()}",
]
context_ids, answer_ids = list(map(self.tokenizer.text_to_ids, formatted_text))
if len(context_ids) > 0 and context_ids[0] != tokenizer.bos_id:
context_ids.insert(0, tokenizer.bos_id)
if len(answer_ids) > 0 and answer_ids[-1] != tokenizer.eos_id:
answer_ids.append(tokenizer.eos_id)

return dict(
labels=(context_ids + answer_ids)[1:],
input_ids=(context_ids + answer_ids)[:-1],
loss_mask=[0] * (len(context_ids) - 1) + [1] * len(answer_ids),
)

def setup(self, stage):
super().setup(stage)
self.tokenizer = getattr(self.tokenizer, 'tokenizer', self.tokenizer)

self.map(
self.formatting_prompts_func,
batched=False,
Expand Down

0 comments on commit ea5ed67

Please sign in to comment.