Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Option to clip logprobs rlhf.get_batch_log_probs #2470

Open
krammnic opened this issue Mar 9, 2025 · 4 comments
Open

Option to clip logprobs rlhf.get_batch_log_probs #2470

krammnic opened this issue Mar 9, 2025 · 4 comments

Comments

@krammnic
Copy link
Contributor

krammnic commented Mar 9, 2025

RLHF procedures with modern DPO functionals may lead to the degenerate solution, e.g., the EOS token dropping during the generation. For instance, let's consider the output of the Qwen2.5 model after the SimPO procedure with torchtune.


Einstein's theory of relativity describes how gravity arises from the curvature of spacetime caused by mass and energy. Energy energy energy energy energy energy energy energy energy energy energy... (repeated 50 times)

This is a pretty common problem, and the root cause is logarithm behavior near 0 when calculating log-probs, which is used to calculate rewards (and the difference between them is optimized in DPO). (https://arxiv.org/abs/2405.14734)

Image

In the case of DPO, we sum the log-probs; in the case of the other methods, we usually find the average. In both cases, we are not protected from such outliers. Let's rethink this in easier terms: If in some rejected values there are some tokens that make the understanding process simpler (which is the chosen and which is the rejected), the model will learn to underestimate these logprobs to $-\infty$ (this means that probability got to zero). But, in cleverer sequences, it might not be optimized at all (which leads to $P(EOS) \rightarrow 0$). (Empirical observation for DPO and SimPO, ORPO-like, etc. functionals)

Sometimes, a smaller lr or bigger $\beta$ (in the case of DPO) may solve this problem, but it cannot be cured in all cases. The solution is simple: we need to add an option clip_log_probs: True in our DPO configs. If it is, True then logprobs will be clipped and vice versa.

@krammnic
Copy link
Contributor Author

krammnic commented Mar 9, 2025

cc: @SalmanMohammadi @ebsmothers

@SalmanMohammadi
Copy link
Collaborator

Hi @krammnic. Thanks for raising this interesting issue.

Could you point to any empirical evidence for this issue for DPO/PPO?

@krammnic
Copy link
Contributor Author

krammnic commented Mar 9, 2025

@SalmanMohammadi Sure, I can collect some just by logging the logprobs. SimPO guys introduced SFT loss component to attempt to fix several issues including this one, but it is not a full solution: https://github.com/princeton-nlp/SimPO

But the reasons are definitely pretty intuitive (that's why we sometimes try $\beta \uparrow$ or use KL divergence).

@krammnic
Copy link
Contributor Author

krammnic commented Mar 9, 2025

Some correction, we want to do 3 things: (controlled by user)

  1. Winsorize extremals for the chosen
  2. Winsorize extremals for the rejected
  3. Clip minimum logprob

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants