-
Notifications
You must be signed in to change notification settings - Fork 557
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Option to clip logprobs rlhf.get_batch_log_probs
#2470
Comments
Hi @krammnic. Thanks for raising this interesting issue. Could you point to any empirical evidence for this issue for DPO/PPO? |
@SalmanMohammadi Sure, I can collect some just by logging the logprobs. SimPO guys introduced SFT loss component to attempt to fix several issues including this one, but it is not a full solution: https://github.com/princeton-nlp/SimPO But the reasons are definitely pretty intuitive (that's why we sometimes try |
Some correction, we want to do 3 things: (controlled by user)
|
RLHF procedures with modern DPO functionals may lead to the degenerate solution, e.g., the EOS token dropping during the generation. For instance, let's consider the output of the Qwen2.5 model after the SimPO procedure with torchtune.
This is a pretty common problem, and the root cause is logarithm behavior near 0 when calculating log-probs, which is used to calculate rewards (and the difference between them is optimized in DPO). (https://arxiv.org/abs/2405.14734)
In the case of DPO, we sum the log-probs; in the case of the other methods, we usually find the average. In both cases, we are not protected from such outliers. Let's rethink this in easier terms: If in some rejected values there are some tokens that make the understanding process simpler (which is the chosen and which is the rejected), the model will learn to underestimate these logprobs to$-\infty$ (this means that probability got to zero). But, in cleverer sequences, it might not be optimized at all (which leads to $P(EOS) \rightarrow 0$ ). (Empirical observation for DPO and SimPO, ORPO-like, etc. functionals)
Sometimes, a smaller lr or bigger$\beta$ (in the case of DPO) may solve this problem, but it cannot be cured in all cases. The solution is simple: we need to add an option clip_log_probs: True in our DPO configs. If it is, True then logprobs will be clipped and vice versa.
The text was updated successfully, but these errors were encountered: