-
Notifications
You must be signed in to change notification settings - Fork 259
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Verl] Add entropy loss to cross_entropy_loss
and fused_linear_cross_entropy_loss
#551
base: main
Are you sure you want to change the base?
Conversation
Signed-off-by: Hongpeng Guo <[email protected]>
Signed-off-by: Hongpeng Guo <[email protected]>
Signed-off-by: Hongpeng Guo <[email protected]>
Signed-off-by: Hongpeng Guo <[email protected]>
Signed-off-by: Hongpeng Guo <[email protected]>
Signed-off-by: Hongpeng Guo <[email protected]>
cross_entropy_loss
and fused_linear_cross_entropy_loss
cross_entropy_loss
and fused_linear_cross_entropy_loss
cross_entropy_loss
and fused_linear_cross_entropy_loss
cross_entropy_loss
and fused_linear_cross_entropy_loss
Signed-off-by: Hongpeng Guo <[email protected]>
Signed-off-by: Hongpeng Guo <[email protected]>
Signed-off-by: Hongpeng Guo <[email protected]>
Signed-off-by: Hongpeng Guo <[email protected]>
Please add a unit test with |
Signed-off-by: Hongpeng Guo <[email protected]>
Signed-off-by: Hongpeng Guo <[email protected]>
Signed-off-by: Hongpeng Guo <[email protected]>
Signed-off-by: Hongpeng Guo <[email protected]>
cross_entropy_loss
and fused_linear_cross_entropy_loss
cross_entropy_loss
and fused_linear_cross_entropy_loss
Update: Met some numerical unstable issue, inverstigating |
Signed-off-by: Hongpeng Guo <[email protected]>
Signed-off-by: Hongpeng Guo <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the efforts! Let's try to test thoroughly on both accuracy (numerical stability) and speed (including old ones) before checking in. Considering we may fused more and more losses such as the existing Z loss and the added entropy loss, api outputs kind of diverged and also make the loss quite heavy with multiple branches coupling together (like label smoothing, target weights, etc) We probably need to refactor a bit to make it cleaner to dev later. cc @ByronHsu @shivam15s @Tcc0403 @shimizust
@@ -140,6 +149,26 @@ def liger_cross_entropy_kernel( | |||
# = max(X) + log (sum(e ^ (X_i - max(X)))) = m + log d | |||
lse = m + tl.log(d) | |||
|
|||
# 3.5 Calculate the entropy loss |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can probably put an equation in the PR description and also a simple one in a comment here to demonstrate how the entropy_loss
is calculated (especially on the reuse of m
and d
computed in the first pass online softmax
@@ -248,11 +299,16 @@ def liger_cross_entropy_kernel( | |||
loss = loss / n_non_ignore | |||
# TODO: Implement weighted z_loss. Currently, z_loss is not scaled by weight. | |||
z_loss = z_loss / n_non_ignore | |||
# TODO: Implement weighted entropy loss. Currently, entropy loss is not scaled by weight. | |||
entropy_loss = entropy_loss / n_non_ignore |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It seems you had done the implementation of weight provided case already above? dX_entropy_block = dX_entropy_block / sum_non_ignore_weight
Did I misunderstand anything? If this is not the right equation for the weighted case, please use dX_entropy_block = dX_entropy_block / n_non_ignore
above and also list a comment above as an TODO item.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for catching this. I think this is a bug in my program. I just fixed it. But it seems the numerical problem is still there. Maybe we need to take a deeper look.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
BTW, it seems the CI stops running for this PR for some reason.
Signed-off-by: Hongpeng Guo <[email protected]>
Co-authored-by: Qingquan Song <[email protected]>
Signed-off-by: Hongpeng Guo <[email protected]>
Signed-off-by: Hongpeng Guo <[email protected]>
Signed-off-by: Hongpeng Guo <[email protected]>
Signed-off-by: Hongpeng Guo <[email protected]>
Signed-off-by: Hongpeng Guo <[email protected]>
Signed-off-by: Hongpeng Guo <[email protected]>
Signed-off-by: Hongpeng Guo <[email protected]>
Signed-off-by: Hongpeng Guo <[email protected]>
Signed-off-by: Hongpeng Guo <[email protected]>
Signed-off-by: Hongpeng Guo <[email protected]>
Signed-off-by: Hongpeng Guo <[email protected]>
Signed-off-by: Hongpeng Guo <[email protected]>
Signed-off-by: Hongpeng Guo <[email protected]>
Signed-off-by: Hongpeng Guo <[email protected]>
Summary
In RLHF workflows, such as verl, the actor forward function usually generates both losses of
cross_entropy_loss (-log_probs)
andentropy_loss
, the later was used to encourage the policy to be not over-deterministic.There is a real needs for a kernel that will generates both the two losses, without materializing the huge
logits
tensor. Liger-kernel'sfused_linear_cross_entropy_loss
already works well to generate thecross_entropy_loss
, but only calculating the second part of the loss, i.e., the entropy loss.This PR adds the
entropy
loss option to the existing FLCE loss, and work as one important step to support verl.cross_entropy.py::liger_cross_entropy_kernel
, both the loss and its gradient subject to input are calculated and stored;fused_linear_cross_entropy.py
,Testing Done
Made existing unit tests working; Adding new unittest WIP.
make test
to ensure correctnessmake checkstyle
to ensure code stylemake test-convergence
to ensure convergence