Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Verl] Add entropy loss to cross_entropy_loss and fused_linear_cross_entropy_loss #551

Open
wants to merge 34 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
34 commits
Select commit Hold shift + click to select a range
05f0edb
run make checkstyle
hongpeng-guo Jan 30, 2025
6a26dbb
wip initial try test existing unitest
hongpeng-guo Jan 30, 2025
7dad560
ruff style check
hongpeng-guo Jan 30, 2025
1b13b2f
fix for cross_entropy
hongpeng-guo Jan 30, 2025
8a43d1e
fix checkstyle
hongpeng-guo Jan 30, 2025
82d9b55
wip fix flce
hongpeng-guo Jan 30, 2025
984e85f
fix bugs
hongpeng-guo Jan 30, 2025
eb90401
fix bugs
hongpeng-guo Jan 30, 2025
7684eed
fix
hongpeng-guo Jan 30, 2025
bed2d45
fix a unit test
hongpeng-guo Jan 30, 2025
a967e65
fix ce kernel, add unit test make it work
hongpeng-guo Feb 3, 2025
068b9be
fix style
hongpeng-guo Feb 3, 2025
32ac203
add unit test to flce
hongpeng-guo Feb 3, 2025
201f47e
revert the chanegs on unit tests
hongpeng-guo Feb 3, 2025
38c5d44
improve ce unit test
hongpeng-guo Feb 3, 2025
96c3192
improve ce unit test
hongpeng-guo Feb 3, 2025
af84880
handle comments partial
hongpeng-guo Feb 3, 2025
4307e37
Update src/liger_kernel/ops/cross_entropy.py
hongpeng-guo Feb 3, 2025
8d65866
fix typo
hongpeng-guo Feb 3, 2025
d43d0ee
Merge branch 'main' into hpguo/ruff_style_check
hongpeng-guo Feb 3, 2025
5f6253b
Merge branch 'hpguo/ruff_style_check' into hpguo/lce_add_entropy_loss
hongpeng-guo Feb 3, 2025
4c97042
fix bug in softcap and ce weight confusion
hongpeng-guo Feb 6, 2025
74d0f0e
fix bug in softcap and ce weight confusion
hongpeng-guo Feb 6, 2025
8005999
bisec unittes to test on ci
hongpeng-guo Feb 6, 2025
e341aea
refactor code
hongpeng-guo Feb 6, 2025
ced5709
revert changes to unit tests
hongpeng-guo Feb 6, 2025
c1d36e6
change a new way calculating entropy
hongpeng-guo Feb 6, 2025
b1053a3
make deriv stable
hongpeng-guo Feb 6, 2025
7af2fe3
bisect unitets
hongpeng-guo Feb 6, 2025
6162e88
fix wip
hongpeng-guo Feb 6, 2025
02fd778
try to make it numerical stable
hongpeng-guo Feb 6, 2025
7f53b59
wip another
hongpeng-guo Feb 6, 2025
62d2ca3
revert a unittest
hongpeng-guo Feb 6, 2025
0d6487c
update unittest
hongpeng-guo Feb 6, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion examples/lightning/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,7 +158,7 @@ def formatting_func(self, example):
for i in range(len(example["question"])):
choices = ""
for j in range(len(example["choices"][i])):
choices += f"{j+1}. {example['choices'][i][j]}; "
choices += f"{j + 1}. {example['choices'][i][j]}; "
s = "Below is a question and multiple choice answers, choices separated by a semicolon. Please select the best answer for the question. "
s += f"{QUESTION}{example['question'][i]} "
s += f"{CHOICES}{choices} "
Expand Down
6 changes: 3 additions & 3 deletions examples/medusa/callback.py
Original file line number Diff line number Diff line change
Expand Up @@ -352,9 +352,9 @@ def _get_effective_num_gpus():
else:
return world_size

assert (
world_size != 0
), "WORLD_SIZE should be set to a positive integer. For single GPU training, please explicitly set WORLD_SIZE=1."
assert world_size != 0, (
"WORLD_SIZE should be set to a positive integer. For single GPU training, please explicitly set WORLD_SIZE=1."
)

# TODO: add deepspeed support
return world_size
Expand Down
132 changes: 114 additions & 18 deletions src/liger_kernel/ops/cross_entropy.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,12 +24,14 @@
@triton.jit
def liger_cross_entropy_kernel(
X_ptr,
dX_entropy_ptr,
X_stride,
Y_ptr,
Y_stride,
weight_ptr,
loss_ptr,
z_loss_ptr,
entropy_loss_ptr,
loss_stride,
n_cols,
n_non_ignore,
Expand All @@ -41,6 +43,7 @@ def liger_cross_entropy_kernel(
reduction: tl.constexpr, # set it as constexpr since reduction is always known at compile time
softcap,
RETURN_Z_LOSS: tl.constexpr,
RETURN_ENTROPY_LOSS: tl.constexpr,
BLOCK_SIZE: tl.constexpr,
HAS_WEIGHT: tl.constexpr,
HAS_SOFTCAPPING: tl.constexpr,
Expand All @@ -51,6 +54,7 @@ def liger_cross_entropy_kernel(

Parameters:
X_ptr: Pointer to input tensor.
dX_entropy_ptr: Pointer to tensor to store the gradient of the input w.r.t the entropy loss
X_stride (int): The stride of the input tensor.
Y_ptr: Pointer to target tensor.
Y_stride (int): The stride of the target tensor.
Expand All @@ -68,6 +72,7 @@ def liger_cross_entropy_kernel(
reduction (str): The string for the reduction to apply
softcap (float): The upper threshold for scaling logits to the range (-softcap, +softcap).
RETURN_Z_LOSS (int): The boolean value to decide whether storing z loss to z_loss_ptr or not. It must be 0 or 1.
RETURN_ENTROPY_LOSS (int): The boolean value to decide whether storing entropy loss to entropy_loss_ptr or not. It must be 0 or 1.
BLOCK_SIZE (int): The block size for Triton operations.
HAS_WEIGHT (bool): The boolean value to determine whether assigning weight to each of the classes.
HAS_SOFTCAPPING (bool): The boolean value to determine whether applying soft-capping or not.
Expand All @@ -94,6 +99,9 @@ def liger_cross_entropy_kernel(
loss_ptr += program_id * loss_stride
if RETURN_Z_LOSS:
z_loss_ptr += program_id * loss_stride
if RETURN_ENTROPY_LOSS:
entropy_loss_ptr += program_id * loss_stride
dX_entropy_ptr += program_id * X_stride

if HAS_WEIGHT:
weight_y = tl.load(weight_ptr + y).cast(tl.float32)
Expand All @@ -104,6 +112,7 @@ def liger_cross_entropy_kernel(
# 3. [Online softmax] first pass: find max + sum
m = float("-inf") # m is the max value. use the notation from the paper
d = 0.0 # d is the sum. use the notation from the paper
entropy_loss = 0.0 # entropy loss
ori_X_y = tl.load(X_ptr + y).cast(tl.float32) # we need to store the original value of X_y for the loss calculation
if HAS_SOFTCAPPING:
ori_X_y = softcap * tanh(ori_X_y / softcap)
Expand Down Expand Up @@ -140,6 +149,27 @@ def liger_cross_entropy_kernel(
# = max(X) + log (sum(e ^ (X_i - max(X)))) = m + log d
lse = m + tl.log(d)

# 3.5 Calculate the entropy loss
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can probably put an equation in the PR description and also a simple one in a comment here to demonstrate how the entropy_loss is calculated (especially on the reuse of m and d computed in the first pass online softmax

sum_p_x = 0.0 # sum of softmax(x_i) * x_i
if RETURN_ENTROPY_LOSS:
for i in range(0, n_cols, BLOCK_SIZE):
X_offsets = i + tl.arange(0, BLOCK_SIZE)
X_block = tl.load(
X_ptr + X_offsets,
mask=X_offsets < n_cols,
other=float("-inf"),
# Ensure float32 precision for softmax calculation
).cast(tl.float32)
if HAS_SOFTCAPPING:
intermediate = tanh(X_block / softcap)
X_block = softcap * intermediate

softmax_X = tl.exp(X_block - m) / d
# Cumulate the sum of softmax(x_i) * x_i
sum_p_x += tl.sum(tl.where(X_offsets < n_cols, softmax_X * X_block, 0.0))

entropy_loss = lse - sum_p_x

# 4. [Online Softmax] Second pass: compute gradients
# For 'mean' reduction, gradients are normalized by number of non-ignored elements (N)
# dx_y = (softmax(x_y) - 1) / N
Expand Down Expand Up @@ -167,9 +197,26 @@ def liger_cross_entropy_kernel(
intermediate = tanh(X_block / softcap)
X_block = softcap * intermediate

# Calculate the softmax of the input
softmax_X = tl.exp(X_block - m) / d

# load the derivatives of the entropy loss
if RETURN_ENTROPY_LOSS:
dX_entropy_block = tl.load(
dX_entropy_ptr + X_offsets,
mask=X_offsets < n_cols,
other=0.0,
).cast(tl.float32)

# derivatives of the entropy loss term
dX_entropy_block = softmax_X * sum_p_x - softmax_X * X_block
# Note that the weight is only applied to ce loss, not for entropy loss.
if reduction == "mean":
dX_entropy_block = dX_entropy_block / n_non_ignore

if not HAS_WEIGHT:
# softmax(x_i)
X_block = tl.exp(X_block - m) / d
X_block = softmax_X
# derivative of z-loss: 2 * lse_square_scale * lse * softmax(x_i)
X_block += 2 * lse_square_scale * lse * X_block
# smoothing term
Expand All @@ -181,7 +228,6 @@ def liger_cross_entropy_kernel(
X_block = X_block / n_non_ignore
else:
weight_block = tl.load(weight_ptr + X_offsets, mask=X_offsets < n_cols)
softmax_X = tl.exp(X_block - m) / d
# derivative of original_loss
dloss_ori = (1 - label_smoothing) * softmax_X
# specially handle dx_y
Expand All @@ -204,8 +250,12 @@ def liger_cross_entropy_kernel(
# d(softcap * tanh(x / softcap)) = (1 - tanh^2(x / softcap))
if HAS_SOFTCAPPING:
X_block = X_block * (1 - intermediate * intermediate)
if RETURN_ENTROPY_LOSS:
dX_entropy_block = dX_entropy_block * (1 - intermediate * intermediate)

tl.store(X_ptr + X_offsets, X_block, mask=X_offsets < n_cols)
if RETURN_ENTROPY_LOSS:
tl.store(dX_entropy_ptr + X_offsets, dX_entropy_block, mask=X_offsets < n_cols)

# We need tl.debug_barrier() to ensure the new result of X_ptr is written as mentioned in
# https://github.com/triton-lang/triton/blob/ba42a5c68fd0505f8c42f4202d53be0f8d9a5fe0/python/triton/ops/cross_entropy.py#L34
Expand Down Expand Up @@ -248,11 +298,16 @@ def liger_cross_entropy_kernel(
loss = loss / n_non_ignore
# TODO: Implement weighted z_loss. Currently, z_loss is not scaled by weight.
z_loss = z_loss / n_non_ignore
# Note that the weight is only applied to ce loss, not for entropy loss.
entropy_loss = entropy_loss / n_non_ignore
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems you had done the implementation of weight provided case already above? dX_entropy_block = dX_entropy_block / sum_non_ignore_weight Did I misunderstand anything? If this is not the right equation for the weighted case, please use dX_entropy_block = dX_entropy_block / n_non_ignore above and also list a comment above as an TODO item.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for catching this. I think this is a bug in my program. I just fixed it. But it seems the numerical problem is still there. Maybe we need to take a deeper look.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

BTW, it seems the CI stops running for this PR for some reason.


loss += z_loss

tl.store(loss_ptr, loss)
if RETURN_Z_LOSS:
tl.store(z_loss_ptr, z_loss)
if RETURN_ENTROPY_LOSS:
tl.store(entropy_loss_ptr, entropy_loss)


# The hard limit of TRITON_MAX_TENSOR_NUMEL is 1048576 https://github.com/triton-lang/triton/blob/ba42a5c68fd0505f8c42f4202d53be0f8d9a5fe0/python/triton/language/core.py#L19
Expand All @@ -271,27 +326,32 @@ def cross_entropy_forward(
reduction,
softcap,
return_z_loss,
return_entropy_loss,
):
assert isinstance(return_z_loss, bool), f"return_z_loss must be True or False. Got: {return_z_loss}"
assert isinstance(return_entropy_loss, bool), (
f"return_entropy_loss must be True or False. Got: {return_entropy_loss}"
)

BT, V = _input.shape
n_rows = BT

BLOCK_SIZE = min(MAX_FUSED_SIZE, triton.next_power_of_2(V))

# unreduced loss
dX_entropy_2d = torch.zeros(n_rows, V, dtype=_input.dtype, device=_input.device) if return_entropy_loss else None
loss_1d = torch.zeros(n_rows, dtype=_input.dtype, device=_input.device)
z_loss_1d = torch.zeros(n_rows, dtype=_input.dtype, device=_input.device) if return_z_loss else None

entropy_loss_1d = torch.zeros(n_rows, dtype=_input.dtype, device=_input.device) if return_entropy_loss else None
target_mask = target != ignore_index
n_non_ignore = target_mask.sum().item()
sum_non_ignore_weight = n_non_ignore
weight_sum = 0.0
if weight is not None:
assert weight.shape[0] == V, f"If given, weight has to be a Tensor of size V. Got: {weight.shape}"
assert torch.is_floating_point(
weight
), f"If given, weight has to be a Tensor of floating point dtype. Got: {weight.dtype}"
assert torch.is_floating_point(weight), (
f"If given, weight has to be a Tensor of floating point dtype. Got: {weight.dtype}"
)
sum_non_ignore_weight = torch.gather(weight, dim=0, index=target.masked_select(target_mask)).sum().item()
weight_sum = weight.sum().item()
# ensure weight is contiguous
Expand All @@ -307,12 +367,14 @@ def cross_entropy_forward(
# Here we use a trick to store X_ptr gradient in X_ptr so we can save memory
liger_cross_entropy_kernel[(n_rows,)](
X_ptr=_input,
dX_entropy_ptr=dX_entropy_2d,
X_stride=_input.stride(-2),
Y_ptr=target,
Y_stride=target.stride(-1), # always 1
weight_ptr=weight, # dummy if None
loss_ptr=loss_1d,
z_loss_ptr=z_loss_1d,
entropy_loss_ptr=entropy_loss_1d,
loss_stride=loss_1d.stride(-1), # always 1
n_cols=V,
n_non_ignore=n_non_ignore,
Expand All @@ -324,6 +386,7 @@ def cross_entropy_forward(
reduction=reduction,
softcap=softcap,
RETURN_Z_LOSS=return_z_loss,
RETURN_ENTROPY_LOSS=return_entropy_loss,
BLOCK_SIZE=BLOCK_SIZE,
HAS_WEIGHT=True if weight is not None else False,
HAS_SOFTCAPPING=True if softcap is not None else False,
Expand All @@ -335,25 +398,27 @@ def cross_entropy_forward(
if reduction == "none":
loss = loss_1d
z_loss = z_loss_1d if return_z_loss else None
entropy_loss = entropy_loss_1d if return_entropy_loss else None
else:
loss = torch.sum(loss_1d)
z_loss = torch.sum(z_loss_1d) if return_z_loss else None
entropy_loss = torch.sum(entropy_loss_1d) if return_entropy_loss else None

return loss, z_loss, _input
return loss, z_loss, entropy_loss, _input, dX_entropy_2d


def cross_entropy_backward(_input, grad_output):
def cross_entropy_backward(_input, dX_entropy_2d, grad_output, grad_output_entropy):
BT, V = _input.shape
n_rows = BT
BLOCK_SIZE = min(MAX_FUSED_SIZE, triton.next_power_of_2(V))

# If cross entropy is the last layer, grad_output is 1.0. Skip the mul to save time
if torch.equal(grad_output, torch.tensor(1.0, device=grad_output.device)):
pass

# We use a Triton kernel instead of a PyTorch operation because modifying inputs in-place
# for gradient storage and backward multiple times causes anomalies with PyTorch but not with Triton.
else:
BT, V = _input.shape
n_rows = BT
BLOCK_SIZE = min(MAX_FUSED_SIZE, triton.next_power_of_2(V))

element_mul_kernel[(n_rows,)](
_input,
_input.stride(-2),
Expand All @@ -362,6 +427,18 @@ def cross_entropy_backward(_input, grad_output):
BLOCK_SIZE=BLOCK_SIZE,
num_warps=32 if not is_hip() else 16,
)
# calculate the gradient of the input w.r.t the entropy loss
if dX_entropy_2d is not None:
if not torch.equal(grad_output_entropy, torch.tensor(1.0, device=grad_output_entropy.device)):
element_mul_kernel[(n_rows,)](
dX_entropy_2d,
dX_entropy_2d.stride(-2),
grad_output_entropy,
V,
BLOCK_SIZE=BLOCK_SIZE,
num_warps=32 if not is_hip() else 16,
)
_input = dX_entropy_2d + _input

return _input

Expand All @@ -384,6 +461,7 @@ def forward(
reduction: str = "mean",
softcap: Optional[float] = None,
return_z_loss: bool = False,
return_entropy_loss: bool = False,
):
"""
The forward pass of the Liger Cross Entropy loss.
Expand All @@ -403,7 +481,7 @@ def forward(
Returns:
tuple: A tuple with the compouted losses with respect to loss and z loss. The elements are tensors or None.
"""
loss, z_loss, _input = cross_entropy_forward(
loss, z_loss, entropy_loss, _input, dX_entropy_2d = cross_entropy_forward(
_input,
target,
weight,
Expand All @@ -413,32 +491,49 @@ def forward(
reduction,
softcap,
return_z_loss,
return_entropy_loss,
)
# TODO: investigation
# If we don't detach the _input tensor, the memory will double
# Not sure why but seems that there will be a time both grad and value exist but in different location
ctx.save_for_backward(_input.detach())
if return_entropy_loss:
ctx.save_for_backward(_input.detach(), dX_entropy_2d.detach())
else:
ctx.save_for_backward(_input.detach())
ctx.return_z_loss = return_z_loss
ctx.return_entropy_loss = return_entropy_loss

return loss, z_loss
return loss, z_loss, entropy_loss

@staticmethod
def backward(ctx, grad_output, grad_ouput2):
def backward(ctx, grad_output, grad_ouput2, grad_ouput3):
"""
The backward pass of the Liger Cross Entropy loss.

Parameters:
ctx : The context object with saved tensors.
grad_output (tensor): The tensor containing the gradient of the loss with respect to the output.
grad_output2 (tenosr): No use.
grad_output3 (tenosr): The tensor containing the gradient of the loss with respect to the entropy loss.
Returns:
tuple: A tuple with the gradients with respect to the inputs. The elements are tensors or None.
"""
if ctx.return_z_loss:
del grad_ouput2 # z_loss is only for logging

(_input,) = ctx.saved_tensors
_input = cross_entropy_backward(_input, grad_output)
if ctx.return_entropy_loss:
(_input, dX_entropy_2d) = ctx.saved_tensors
else:
(_input,), dX_entropy_2d = ctx.saved_tensors, None

print(grad_output, grad_ouput3)

_input = cross_entropy_backward(_input, dX_entropy_2d, grad_output, grad_ouput3)

# delete the tensors that are not used in remaining steps
del grad_ouput3
del dX_entropy_2d

return (
_input,
None,
Expand All @@ -449,4 +544,5 @@ def backward(ctx, grad_output, grad_ouput2):
None,
None,
None,
None,
)
Loading
Loading