Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Reintroduce truncation_mode in DPOTrainer #2551

Open
wants to merge 2 commits into
base: main
Choose a base branch
from

Conversation

anakin87
Copy link
Contributor

@anakin87 anakin87 commented Jan 8, 2025

What does this PR do?

In #2538, we found that the truncation_mode attribute of DPOConfig was essentially ignored, and the prompt was always truncated by keeping the end.

In this PR, I am taking this argument into consideration in tokenize_row and process_row, so that users can specify whether they want to keep the end or the start.

Fixes #2538

Technical details:

  • I decided to introduce a new parameter in the tokenize_row static method. The alternative option would have been to not use a static method and access self.truncation_mode, but this would have been a breaking change.
  • Added 2 new tests

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you read the contributor guideline,
    Pull Request section?
  • Was this discussed/approved via a GitHub issue? Please add a link
    to it if that's the case.
  • Did you make sure to update the documentation with your changes? Here are the
    documentation guidelines.
  • Did you write any new necessary tests?

Who can review?

@qgallouedec

@shirinyamani
Copy link

shirinyamani commented Jan 9, 2025

Hi, I have read the issue thread here and this PR, I agree that we can use truncation_mode in the tokenize_row function and I reviewed your addition. I wanted to also share my thoughts on it. so here the addition is if after truncating the prompt its still too long we can further truncate the response, and what if its not encode-decoder architecture?

def tokenize_row(self, feature):
    """Tokenize a single row from the DPO specific dataset.
    NOTE that this does not convert to tensors yet; rather just handles the truncation
    in case the prompt + chosen or prompt + rejected responses is/are too long. It first
    truncates the prompt; if we're still too long, we truncate the chosen/rejected.
    """
    batch = {}
    prompt = feature["prompt"]
    chosen = feature["text_chosen"]  
    rejected = feature["text_rejected"]  

 #the current version says we typically add special tokens for encoder-decoder models but what if it's not enc-dec
    if not self.is_encoder_decoder:
#check the input type first 
        if not isinstance(prompt, str):
            raise ValueError(f"prompt should be in str form but got {type(prompt)}")
        prompt_tokens = self.tokenizer(prompt, add_special_tokens=False)
        prompt_tokens = {f"prompt_{k}": v for k, v in prompt_tokens.items()}

        if not isinstance(chosen, str):
            raise ValueError(f"chosen should be in str form but got {type(chosen)}")
        chosen_tokens = self.build_tokenized_answer(prompt, chosen)

        if not isinstance(rejected, str):
            raise ValueError(f"rejected should be in str form but got {type(rejected)}")
        rejected_tokens = self.build_tokenized_answer(prompt, rejected)

        # add BOS token to head of prompt
        prompt_tokens["prompt_input_ids"] = [self.tokenizer.bos_token_id] + prompt_tokens["prompt_input_ids"]
        chosen_tokens["prompt_input_ids"] = [self.tokenizer.bos_token_id] + chosen_tokens["prompt_input_ids"]
        rejected_tokens["prompt_input_ids"] = [self.tokenizer.bos_token_id] + rejected_tokens["prompt_input_ids"]

        # add EOS token to end of answer
        chosen_tokens["input_ids"].append(self.tokenizer.eos_token_id)
        rejected_tokens["input_ids"].append(self.tokenizer.eos_token_id)

        longer_response = max(len(chosen_tokens["input_ids"]), len(rejected_tokens["input_ids"]))

        # if prompt+completion is too long, truncate the prompt
        if len(prompt_tokens["prompt_input_ids"]) + longer_response > self.args.max_length:
            if self.truncation_mode == "keep_start":
                prompt_tokens["prompt_input_ids"] = prompt_tokens["prompt_input_ids"][: self.max_prompt_length]
            elif self.truncation_mode == "keep_end":
                prompt_tokens["prompt_input_ids"] = prompt_tokens["prompt_input_ids"][-self.max_prompt_length :]
            else:
                raise ValueError(f"Unknown truncation mode: {self.truncation_mode}")

        # if that's still too long, truncate the response
        for answer_tokens in [chosen_tokens, rejected_tokens]:
            if len(answer_tokens["prompt_input_ids"]) + longer_response > self.args.max_length:
                answer_tokens["input_ids"] = answer_tokens["input_ids"][: self.args.max_length - self.max_prompt_length]

    return batch

Lemme know what you think, please. @qgallouedec @anakin87

@qgallouedec
Copy link
Member

Where does this version tokenize_row comes from @shirinyamani? It seems quite different from its current version in main.

if after truncating the prompt its still too long

This is anyway handled here:

trl/trl/trainer/dpo_trainer.py

Lines 1155 to 1158 in edabe0a

if self.args.max_length is not None:
input_ids = input_ids[:, : self.args.max_length]
attention_mask = attention_mask[:, : self.args.max_length]
loss_mask = loss_mask[:, : self.args.max_length]

@shirinyamani
Copy link

This version is what I came up with based on my research. And yes, it's getting handled where you mentioned but they are in two different functions; concatenated_forward and tokenize_row. I wanted to have all the relevant stuff to truncation/prompt/response all in one function which would be tokenize_row for simplicity and clarity purposes, but I agree it's gonna be a different version than what we have in the repo! @qgallouedec

Where does this version tokenize_row comes from @shirinyamani? It seems quite different from its current version in main.

if after truncating the prompt its still too long

This is anyway handled here:

trl/trl/trainer/dpo_trainer.py

Lines 1155 to 1158 in edabe0a

if self.args.max_length is not None:
input_ids = input_ids[:, : self.args.max_length]
attention_mask = attention_mask[:, : self.args.max_length]
loss_mask = loss_mask[:, : self.args.max_length]

@anakin87
Copy link
Contributor Author

@qgallouedec feel free to review the proposed fix

@qgallouedec
Copy link
Member

all in one function which would be tokenize_row for simplicity and clarity purposes

that makes sense. Can you open another pull request for this? Wait for this one to be merged though

@qgallouedec
Copy link
Member

sorry @anakin87 I forgot to press the submit review button a couple of days ago.

Also, @shirinyamani came with an idea that could make more sense: truncate the [prompt+completion] (either left or right) instead of just the prompt. Something like

# Truncate
if self.args.max_length is not None:
    if self.args.truncation_mode == "keep_start":
        input_ids = input_ids[:, : self.args.max_length]
        attention_mask = attention_mask[:, : self.args.max_length]
        loss_mask = loss_mask[:, : self.args.max_length]
    elif self.args.truncation_mode == "keep_start":
        input_ids = input_ids[:, -self.args.max_length:]
        attention_mask = attention_mask[:, -self.args.max_length:]
        loss_mask = loss_mask[:, -self.args.max_length:]
    else:
        raise ValueError

(Currently:

# Truncate right
if self.args.max_length is not None:
    input_ids = input_ids[:, : self.args.max_length]
    attention_mask = attention_mask[:, : self.args.max_length]
    loss_mask = loss_mask[:, : self.args.max_length]

)

What do you think?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Is truncation_mode used in DPOTrainer?
3 participants