Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Verl] Add entropy loss to cross_entropy_loss and fused_linear_cross_entropy_loss #551

Open
wants to merge 34 commits into
base: main
Choose a base branch
from

Conversation

hongpeng-guo
Copy link
Collaborator

@hongpeng-guo hongpeng-guo commented Jan 30, 2025

Summary

In RLHF workflows, such as verl, the actor forward function usually generates both losses of cross_entropy_loss (-log_probs) and entropy_loss, the later was used to encourage the policy to be not over-deterministic.

There is a real needs for a kernel that will generates both the two losses, without materializing the huge logits tensor. Liger-kernel's fused_linear_cross_entropy_loss already works well to generate the cross_entropy_loss, but only calculating the second part of the loss, i.e., the entropy loss.

This PR adds the entropy loss option to the existing FLCE loss, and work as one important step to support verl.

  1. Adding the entropy calculation in the second pass of online softmax in cross_entropy.py::liger_cross_entropy_kernel, both the loss and its gradient subject to input are calculated and stored;
  2. Propagate the changes to relevant modules in fused_linear_cross_entropy.py,
  3. Propagate relavent changes to other functional modules in PyTorch interface.

Testing Done

Made existing unit tests working; Adding new unittest WIP.

  • Hardware Type:
  • run make test to ensure correctness
  • run make checkstyle to ensure code style
  • run make test-convergence to ensure convergence

Signed-off-by: Hongpeng Guo <[email protected]>
Signed-off-by: Hongpeng Guo <[email protected]>
Signed-off-by: Hongpeng Guo <[email protected]>
Signed-off-by: Hongpeng Guo <[email protected]>
Signed-off-by: Hongpeng Guo <[email protected]>
@hongpeng-guo hongpeng-guo marked this pull request as draft January 30, 2025 04:38
@hongpeng-guo hongpeng-guo changed the title [Feature] Add entropy loss to cross_entropy_loss and fused_linear_cross_entropy_loss [WIP][Feature][Verl] Add entropy loss to cross_entropy_loss and fused_linear_cross_entropy_loss Jan 30, 2025
@hongpeng-guo hongpeng-guo changed the title [WIP][Feature][Verl] Add entropy loss to cross_entropy_loss and fused_linear_cross_entropy_loss [WIP][Verl] Add entropy loss to cross_entropy_loss and fused_linear_cross_entropy_loss Jan 30, 2025
Signed-off-by: Hongpeng Guo <[email protected]>
Signed-off-by: Hongpeng Guo <[email protected]>
Signed-off-by: Hongpeng Guo <[email protected]>
Signed-off-by: Hongpeng Guo <[email protected]>
@hongpeng-guo hongpeng-guo requested a review from ByronHsu January 30, 2025 09:56
@Tcc0403
Copy link
Collaborator

Tcc0403 commented Jan 30, 2025

Please add a unit test with return_entropy_loss. You can write a new pytorch implementation like CrossEntropyWithZLoss, or return_entropy_loss functionality on top of it.

Signed-off-by: Hongpeng Guo <[email protected]>
Signed-off-by: Hongpeng Guo <[email protected]>
@hongpeng-guo hongpeng-guo changed the base branch from main to hpguo/ruff_style_check February 3, 2025 04:41
@hongpeng-guo hongpeng-guo marked this pull request as ready for review February 3, 2025 05:17
@hongpeng-guo hongpeng-guo changed the title [WIP][Verl] Add entropy loss to cross_entropy_loss and fused_linear_cross_entropy_loss [Verl] Add entropy loss to cross_entropy_loss and fused_linear_cross_entropy_loss Feb 3, 2025
@hongpeng-guo
Copy link
Collaborator Author

Update: Met some numerical unstable issue, inverstigating

Signed-off-by: Hongpeng Guo <[email protected]>
Signed-off-by: Hongpeng Guo <[email protected]>
Copy link
Collaborator

@qingquansong qingquansong left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the efforts! Let's try to test thoroughly on both accuracy (numerical stability) and speed (including old ones) before checking in. Considering we may fused more and more losses such as the existing Z loss and the added entropy loss, api outputs kind of diverged and also make the loss quite heavy with multiple branches coupling together (like label smoothing, target weights, etc) We probably need to refactor a bit to make it cleaner to dev later. cc @ByronHsu @shivam15s @Tcc0403 @shimizust

src/liger_kernel/ops/cross_entropy.py Outdated Show resolved Hide resolved
@@ -140,6 +149,26 @@ def liger_cross_entropy_kernel(
# = max(X) + log (sum(e ^ (X_i - max(X)))) = m + log d
lse = m + tl.log(d)

# 3.5 Calculate the entropy loss
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can probably put an equation in the PR description and also a simple one in a comment here to demonstrate how the entropy_loss is calculated (especially on the reuse of m and d computed in the first pass online softmax

src/liger_kernel/ops/cross_entropy.py Outdated Show resolved Hide resolved
src/liger_kernel/ops/cross_entropy.py Outdated Show resolved Hide resolved
@@ -248,11 +299,16 @@ def liger_cross_entropy_kernel(
loss = loss / n_non_ignore
# TODO: Implement weighted z_loss. Currently, z_loss is not scaled by weight.
z_loss = z_loss / n_non_ignore
# TODO: Implement weighted entropy loss. Currently, entropy loss is not scaled by weight.
entropy_loss = entropy_loss / n_non_ignore
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems you had done the implementation of weight provided case already above? dX_entropy_block = dX_entropy_block / sum_non_ignore_weight Did I misunderstand anything? If this is not the right equation for the weighted case, please use dX_entropy_block = dX_entropy_block / n_non_ignore above and also list a comment above as an TODO item.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for catching this. I think this is a bug in my program. I just fixed it. But it seems the numerical problem is still there. Maybe we need to take a deeper look.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

BTW, it seems the CI stops running for this PR for some reason.

hongpeng-guo and others added 3 commits February 3, 2025 09:48
Signed-off-by: Hongpeng Guo <[email protected]>
Signed-off-by: Hongpeng Guo <[email protected]>
@hongpeng-guo hongpeng-guo changed the base branch from hpguo/ruff_style_check to main February 3, 2025 09:54
@hongpeng-guo hongpeng-guo changed the base branch from main to hpguo/ruff_style_check February 3, 2025 09:55
@hongpeng-guo hongpeng-guo changed the base branch from hpguo/ruff_style_check to main February 5, 2025 22:15
Signed-off-by: Hongpeng Guo <[email protected]>
Signed-off-by: Hongpeng Guo <[email protected]>
Signed-off-by: Hongpeng Guo <[email protected]>
Signed-off-by: Hongpeng Guo <[email protected]>
Signed-off-by: Hongpeng Guo <[email protected]>
Signed-off-by: Hongpeng Guo <[email protected]>
Signed-off-by: Hongpeng Guo <[email protected]>
Signed-off-by: Hongpeng Guo <[email protected]>
Signed-off-by: Hongpeng Guo <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants