-
Notifications
You must be signed in to change notification settings - Fork 1.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Reintroduce truncation_mode
in DPOTrainer
#2551
base: main
Are you sure you want to change the base?
Reintroduce truncation_mode
in DPOTrainer
#2551
Conversation
Hi, I have read the issue thread here and this PR, I agree that we can use def tokenize_row(self, feature):
"""Tokenize a single row from the DPO specific dataset.
NOTE that this does not convert to tensors yet; rather just handles the truncation
in case the prompt + chosen or prompt + rejected responses is/are too long. It first
truncates the prompt; if we're still too long, we truncate the chosen/rejected.
"""
batch = {}
prompt = feature["prompt"]
chosen = feature["text_chosen"]
rejected = feature["text_rejected"]
#the current version says we typically add special tokens for encoder-decoder models but what if it's not enc-dec
if not self.is_encoder_decoder:
#check the input type first
if not isinstance(prompt, str):
raise ValueError(f"prompt should be in str form but got {type(prompt)}")
prompt_tokens = self.tokenizer(prompt, add_special_tokens=False)
prompt_tokens = {f"prompt_{k}": v for k, v in prompt_tokens.items()}
if not isinstance(chosen, str):
raise ValueError(f"chosen should be in str form but got {type(chosen)}")
chosen_tokens = self.build_tokenized_answer(prompt, chosen)
if not isinstance(rejected, str):
raise ValueError(f"rejected should be in str form but got {type(rejected)}")
rejected_tokens = self.build_tokenized_answer(prompt, rejected)
# add BOS token to head of prompt
prompt_tokens["prompt_input_ids"] = [self.tokenizer.bos_token_id] + prompt_tokens["prompt_input_ids"]
chosen_tokens["prompt_input_ids"] = [self.tokenizer.bos_token_id] + chosen_tokens["prompt_input_ids"]
rejected_tokens["prompt_input_ids"] = [self.tokenizer.bos_token_id] + rejected_tokens["prompt_input_ids"]
# add EOS token to end of answer
chosen_tokens["input_ids"].append(self.tokenizer.eos_token_id)
rejected_tokens["input_ids"].append(self.tokenizer.eos_token_id)
longer_response = max(len(chosen_tokens["input_ids"]), len(rejected_tokens["input_ids"]))
# if prompt+completion is too long, truncate the prompt
if len(prompt_tokens["prompt_input_ids"]) + longer_response > self.args.max_length:
if self.truncation_mode == "keep_start":
prompt_tokens["prompt_input_ids"] = prompt_tokens["prompt_input_ids"][: self.max_prompt_length]
elif self.truncation_mode == "keep_end":
prompt_tokens["prompt_input_ids"] = prompt_tokens["prompt_input_ids"][-self.max_prompt_length :]
else:
raise ValueError(f"Unknown truncation mode: {self.truncation_mode}")
# if that's still too long, truncate the response
for answer_tokens in [chosen_tokens, rejected_tokens]:
if len(answer_tokens["prompt_input_ids"]) + longer_response > self.args.max_length:
answer_tokens["input_ids"] = answer_tokens["input_ids"][: self.args.max_length - self.max_prompt_length]
return batch Lemme know what you think, please. @qgallouedec @anakin87 |
Where does this version
This is anyway handled here: trl/trl/trainer/dpo_trainer.py Lines 1155 to 1158 in edabe0a
|
This version is what I came up with based on my research. And yes, it's getting handled where you mentioned but they are in two different functions;
|
@qgallouedec feel free to review the proposed fix |
that makes sense. Can you open another pull request for this? Wait for this one to be merged though |
sorry @anakin87 I forgot to press the submit review button a couple of days ago. Also, @shirinyamani came with an idea that could make more sense: truncate the [prompt+completion] (either left or right) instead of just the prompt. Something like # Truncate
if self.args.max_length is not None:
if self.args.truncation_mode == "keep_start":
input_ids = input_ids[:, : self.args.max_length]
attention_mask = attention_mask[:, : self.args.max_length]
loss_mask = loss_mask[:, : self.args.max_length]
elif self.args.truncation_mode == "keep_start":
input_ids = input_ids[:, -self.args.max_length:]
attention_mask = attention_mask[:, -self.args.max_length:]
loss_mask = loss_mask[:, -self.args.max_length:]
else:
raise ValueError (Currently: # Truncate right
if self.args.max_length is not None:
input_ids = input_ids[:, : self.args.max_length]
attention_mask = attention_mask[:, : self.args.max_length]
loss_mask = loss_mask[:, : self.args.max_length] ) What do you think? |
What does this PR do?
In #2538, we found that the
truncation_mode
attribute ofDPOConfig
was essentially ignored, and the prompt was always truncated by keeping the end.In this PR, I am taking this argument into consideration in
tokenize_row
andprocess_row
, so that users can specify whether they want to keep the end or the start.Fixes #2538
Technical details:
tokenize_row
static method. The alternative option would have been to not use a static method and accessself.truncation_mode
, but this would have been a breaking change.Before submitting
Pull Request section?
to it if that's the case.
documentation guidelines.
Who can review?
@qgallouedec