Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

new dataset commonsense_qa #13

Open
wants to merge 1 commit into
base: base
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 6 additions & 4 deletions configs/experiment/trl_train/sft_pause_tiny_llama_pause.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,14 @@ rl_algorithm:
device_map: cuda:0

tokenizer:
pretrained_model_name_or_path: /dlabscratch1/wendler/code/nanotron/hf_models/pause-30000
pretrained_model_name_or_path: /dlabscratch1/wendler/code/nanotron/hf_models/tinyllamapause

run_name: "sft_pause_gsm8k_tiny_llama"
run_name: "sft_pause_commonsense_qa_tiny_llama"

data:
debug_n: null
path: "/dlabscratch1/baldwin/pause2/PauseToken/data/gsm8k_json/gsm8k_variable_random_pauses" # TODO: change path of DS if you want to use another one
path: "/dlabscratch1/wendler/code/PauseToken/data/gsm8k_pause_injected"
#path: "/dlabscratch1/baldwin/pause2/PauseToken/data/gsm8k_json/gsm8k_variable_random_pauses" # TODO: change path of DS if you want to use another one

trainer:
data_collator:
Expand All @@ -41,14 +42,15 @@ trainer:
eval_steps: 300
load_best_model_at_end: true
save_total_limit: 10
num_train_epochs: 1.0
num_train_epochs: 5.0
per_device_train_batch_size: 8
per_device_eval_batch_size: 8
save_steps: 300
report_to: "wandb"
learning_rate: 5e-05
adam_beta2: 0.999
warmup_ratio: 0.05
weight_decay: 0.01

save_before_train: true
test: true
Expand Down
63 changes: 63 additions & 0 deletions configs/experiment/trl_train/sft_pause_tiny_llama_pause_qa.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
# @package _global_
defaults:
- override /rl_algorithm/policy/model/language_model: pause_tiny_llama
- override /rl_algorithm/policy/model/[email protected]_config: null
- override /metrics: commonsense_qa

rl_algorithm:
policy:
model:
language_model:
language_model:
device_map: cuda:0

tokenizer:
pretrained_model_name_or_path: /dlabscratch1/wendler/code/nanotron/hf_models/tinyllamapause

run_name: "sft_pause50_commonsense_qa_tiny_llama"

data:
debug_n: null
path: "/dlabscratch1/wendler/code/PauseToken/data/commonsense_qa_injected_50"
#path: "/dlabscratch1/baldwin/pause2/PauseToken/data/gsm8k_json/gsm8k_variable_random_pauses" # TODO: change path of DS if you want to use another one

trainer:
data_collator:
_target_: trl.DataCollatorForCompletionOnlyLM
response_template:
_target_: src.utils.hydra_custom_resolvers.get_module_attr
module_and_attr: src.utils.constants.ANSWER_TEMPLATE

max_seq_length: 600
formatting_func:
_target_: functools.partial
_args_:
- ${get_method:src.utils.trainer_utils.sft_formating_function}
eos_token: ${get_obj_attr:${rl_algorithm.policy.model.tokenizer},[eos_token]}

args:
do_eval: true
evaluation_strategy: "steps"
save_strategy: "steps"
eval_steps: 300
load_best_model_at_end: true
save_total_limit: 10
num_train_epochs: 5.0
per_device_train_batch_size: 8
per_device_eval_batch_size: 8
save_steps: 300
report_to: "wandb"
learning_rate: 5e-05
adam_beta2: 0.999
warmup_ratio: 0.05
weight_decay: 0.01

save_before_train: true
test: true
test_batch_size: 8
test_formatting_func:
_target_: functools.partial
_args_:
- ${get_method:src.utils.trainer_utils.inference_formatting_function}
eos_token: ${get_obj_attr:${rl_algorithm.policy.model.tokenizer},[eos_token]}

5 changes: 3 additions & 2 deletions configs/experiment/trl_train/sft_tiny_llama.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ run_name: "sft_gsm8k_tiny_llama"

data:
debug_n: null
path: "/dlabscratch1/baldwin/pause2/PauseToken/data/gsm8k_json/gsm8k"
path: "/dlabscratch1/wendler/code/PauseToken/data/gsm8k/"

trainer:
data_collator:
Expand All @@ -32,14 +32,15 @@ trainer:
eval_steps: 300
load_best_model_at_end: true
save_total_limit: 10
num_train_epochs: 1.0
num_train_epochs: 5.0
per_device_train_batch_size: 8
per_device_eval_batch_size: 8
save_steps: 300
report_to: "wandb"
learning_rate: 5e-05
adam_beta2: 0.999
warmup_ratio: 0.05
weight_decay: 0.01

save_before_train: false
merge_peft_after_train: false
Expand Down
53 changes: 53 additions & 0 deletions configs/experiment/trl_train/sft_tiny_llama_qa.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
# @package _global_
defaults:
- override /rl_algorithm/policy/model/language_model: tiny_llama
- override /rl_algorithm/policy/model/[email protected]_config: null
- override /metrics: commonsense_qa

run_name: "sft_commonsense_qa_tiny_llama"

data:
debug_n: null
path: "/dlabscratch1/wendler/code/PauseToken/data/commonsense_qa/"

trainer:
data_collator:
_target_: trl.DataCollatorForCompletionOnlyLM
response_template:
_target_: src.utils.hydra_custom_resolvers.get_module_attr
module_and_attr: src.utils.constants.ANSWER_TEMPLATE

max_seq_length: 600

formatting_func:
_target_: functools.partial
_args_:
- ${get_method:src.utils.trainer_utils.sft_formating_function}
eos_token: ${get_obj_attr:${rl_algorithm.policy.model.tokenizer},[eos_token]}

args:
do_eval: true
evaluation_strategy: "steps"
save_strategy: "steps"
eval_steps: 300
load_best_model_at_end: true
save_total_limit: 10
num_train_epochs: 5.0
per_device_train_batch_size: 8
per_device_eval_batch_size: 8
save_steps: 300
report_to: "wandb"
learning_rate: 5e-05
adam_beta2: 0.999
warmup_ratio: 0.05
weight_decay: 0.01

save_before_train: false
merge_peft_after_train: false
test: true
test_batch_size: 8
test_formatting_func:
_target_: functools.partial
_args_:
- ${get_method:src.utils.trainer_utils.inference_formatting_function}
eos_token: ${get_obj_attr:${rl_algorithm.policy.model.tokenizer},[eos_token]}
11 changes: 11 additions & 0 deletions configs/metrics/commonsense_qa.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
defaults:
- default

test:
accuracy:
_target_: src.metrics.commonsense_qa_metrics.is_correct

# At the moment, val is only supported for src/train.py (not src/trl_train.py)
val:
accuracy:
_target_: src.metrics.commonsense_qa_metrics.is_correct
14 changes: 11 additions & 3 deletions scripts/data_generation/gsm8k_pause_injector.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,8 @@ def inject_pauses(
pause_augm_col_name = "pause_augmented_answer",
tokenizer = None,
variable_number_of_pauses = False,
n_leading_pauses = 0
n_leading_pauses = 0,
n_pauses_end_of_question = 0
):
""" function used in map to inject pauses in a sample

Expand All @@ -144,7 +145,7 @@ def inject_pauses(
:param pause_token: The pause token to be injected
:type pause_token: str
"""

sample["question"] = add_pause(sample["question"], len(sample["question"]), n_pauses_end_of_question, pause_token)
input_string = sample["answer"]
sample[pause_augm_col_name] = inject_pause_to_str(input_string, n_pauses_per_patterns, pause_token,n_random_pauses, tokenizer,variable_number_of_pauses, n_leading_pauses)
return sample
Expand Down Expand Up @@ -192,6 +193,13 @@ def parse_args():
type=int,
help="The number of pauses to be injected at the beginning of the response."
)

parser.add_argument(
"--n_pauses_end_of_question",
default=0,
type=int,
help="The number of pauses to be injected at the end of the question."
)

parser.add_argument(
"--variable_number_of_pauses",
Expand Down Expand Up @@ -281,7 +289,7 @@ def parse_args():
for it in range(args.n_generated_samples_per_datapoint):
augmented_ds.append(
dataset.map(
lambda sample: inject_pauses(sample,args.n_pauses_per_patterns,args.n_random_pauses ,args.pause_token, args.pause_augm_col_name, tokenizer,args.variable_number_of_pauses, args.n_leading_pauses),load_from_cache_file=False
lambda sample: inject_pauses(sample,args.n_pauses_per_patterns,args.n_random_pauses ,args.pause_token, args.pause_augm_col_name, tokenizer,args.variable_number_of_pauses, args.n_leading_pauses, args.n_pauses_end_of_question),load_from_cache_file=False
)
)
if args.n_generated_samples_per_datapoint == 1:
Expand Down
27 changes: 27 additions & 0 deletions scripts/download/commonsense_qa.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
import pandas as pd
import os

splits = {'train': 'data/train-00000-of-00001.parquet', 'validation': 'data/validation-00000-of-00001.parquet', 'test': 'data/test-00000-of-00001.parquet'}
df_train = pd.read_parquet("hf://datasets/tau/commonsense_qa/" + splits["train"])
df_valid = pd.read_parquet("hf://datasets/tau/commonsense_qa/" + splits["validation"])
df_test = pd.read_parquet("hf://datasets/tau/commonsense_qa/" + splits["test"])

def convert_df_to_qstr_astr(d):
#print(d)
qstr = "Answer the following question\n"
qstr += d['question']
qstr += "\nChoices:\n"
for l, t in zip(d['choices']['label'], d['choices']['text']):
qstr += f"{l}: {t}\n"
astr = d['answerKey']
return pd.Series({"question" : qstr, "answer": astr})

rewritten_train = df_train.apply(convert_df_to_qstr_astr, axis=1)
rewritten_valid = df_valid.apply(convert_df_to_qstr_astr, axis=1)
rewritten_test = df_test.apply(convert_df_to_qstr_astr, axis=1)

# export as jsonl
os.makedirs("./data/commonsense_qa", exist_ok=True)
rewritten_train.to_json("./data/commonsense_qa/train.json", orient='records', lines=True)
rewritten_valid.to_json("./data/commonsense_qa/test.json", orient='records', lines=True)
#rewritten_test.to_json("../data/commonsense_qa/test.json", orient='records', lines=True)
21 changes: 21 additions & 0 deletions src/metrics/commonsense_qa_metrics.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
from src.utils.trainer_utils import extract_answer, INVALID_ANS

def is_correct(model_completion: str, gt_example: str) -> bool:
""" Check if the model completion is correct given the ground truth example. Completions must be in the GSM8K dataset format

:param model_completion: Model completion
:type model_completion: str

"""
#print(gt_example)
#print(gt_example.split("Answer: "))
#print("------------------------------")
#print(model_completion)
gt_answer = gt_example[0]
if gt_answer not in ["A", "B", "C", "D", "E"]:
f"Ground truth answer is invalid and doesn't follow the GSM8K formate, your ground truth answer is {gt_example}"
try:
print("eval:", model_completion.split("Answer:")[1][0], gt_answer, model_completion.split("Answer:")[1][0] == gt_answer)
return model_completion.split("Answer:")[1][0] == gt_answer
except:
return False