-
Notifications
You must be signed in to change notification settings - Fork 259
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Verl] Add entropy loss to cross_entropy_loss
and fused_linear_cross_entropy_loss
#551
base: main
Are you sure you want to change the base?
Changes from all commits
05f0edb
6a26dbb
7dad560
1b13b2f
8a43d1e
82d9b55
984e85f
eb90401
7684eed
bed2d45
a967e65
068b9be
32ac203
201f47e
38c5d44
96c3192
af84880
4307e37
8d65866
d43d0ee
5f6253b
4c97042
74d0f0e
8005999
e341aea
ced5709
c1d36e6
b1053a3
7af2fe3
6162e88
02fd778
7f53b59
62d2ca3
0d6487c
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -24,12 +24,14 @@ | |
@triton.jit | ||
def liger_cross_entropy_kernel( | ||
X_ptr, | ||
dX_entropy_ptr, | ||
X_stride, | ||
Y_ptr, | ||
Y_stride, | ||
weight_ptr, | ||
loss_ptr, | ||
z_loss_ptr, | ||
entropy_loss_ptr, | ||
loss_stride, | ||
n_cols, | ||
n_non_ignore, | ||
|
@@ -41,6 +43,7 @@ def liger_cross_entropy_kernel( | |
reduction: tl.constexpr, # set it as constexpr since reduction is always known at compile time | ||
softcap, | ||
RETURN_Z_LOSS: tl.constexpr, | ||
RETURN_ENTROPY_LOSS: tl.constexpr, | ||
BLOCK_SIZE: tl.constexpr, | ||
HAS_WEIGHT: tl.constexpr, | ||
HAS_SOFTCAPPING: tl.constexpr, | ||
|
@@ -51,6 +54,7 @@ def liger_cross_entropy_kernel( | |
|
||
Parameters: | ||
X_ptr: Pointer to input tensor. | ||
dX_entropy_ptr: Pointer to tensor to store the gradient of the input w.r.t the entropy loss | ||
X_stride (int): The stride of the input tensor. | ||
Y_ptr: Pointer to target tensor. | ||
Y_stride (int): The stride of the target tensor. | ||
|
@@ -68,6 +72,7 @@ def liger_cross_entropy_kernel( | |
reduction (str): The string for the reduction to apply | ||
softcap (float): The upper threshold for scaling logits to the range (-softcap, +softcap). | ||
RETURN_Z_LOSS (int): The boolean value to decide whether storing z loss to z_loss_ptr or not. It must be 0 or 1. | ||
RETURN_ENTROPY_LOSS (int): The boolean value to decide whether storing entropy loss to entropy_loss_ptr or not. It must be 0 or 1. | ||
BLOCK_SIZE (int): The block size for Triton operations. | ||
HAS_WEIGHT (bool): The boolean value to determine whether assigning weight to each of the classes. | ||
HAS_SOFTCAPPING (bool): The boolean value to determine whether applying soft-capping or not. | ||
|
@@ -94,6 +99,9 @@ def liger_cross_entropy_kernel( | |
loss_ptr += program_id * loss_stride | ||
if RETURN_Z_LOSS: | ||
z_loss_ptr += program_id * loss_stride | ||
if RETURN_ENTROPY_LOSS: | ||
entropy_loss_ptr += program_id * loss_stride | ||
dX_entropy_ptr += program_id * X_stride | ||
|
||
if HAS_WEIGHT: | ||
weight_y = tl.load(weight_ptr + y).cast(tl.float32) | ||
|
@@ -104,6 +112,7 @@ def liger_cross_entropy_kernel( | |
# 3. [Online softmax] first pass: find max + sum | ||
m = float("-inf") # m is the max value. use the notation from the paper | ||
d = 0.0 # d is the sum. use the notation from the paper | ||
entropy_loss = 0.0 # entropy loss | ||
ori_X_y = tl.load(X_ptr + y).cast(tl.float32) # we need to store the original value of X_y for the loss calculation | ||
if HAS_SOFTCAPPING: | ||
ori_X_y = softcap * tanh(ori_X_y / softcap) | ||
|
@@ -140,6 +149,27 @@ def liger_cross_entropy_kernel( | |
# = max(X) + log (sum(e ^ (X_i - max(X)))) = m + log d | ||
lse = m + tl.log(d) | ||
|
||
# 3.5 Calculate the entropy loss | ||
sum_p_x = 0.0 # sum of softmax(x_i) * x_i | ||
if RETURN_ENTROPY_LOSS: | ||
for i in range(0, n_cols, BLOCK_SIZE): | ||
X_offsets = i + tl.arange(0, BLOCK_SIZE) | ||
X_block = tl.load( | ||
X_ptr + X_offsets, | ||
mask=X_offsets < n_cols, | ||
other=float("-inf"), | ||
# Ensure float32 precision for softmax calculation | ||
).cast(tl.float32) | ||
if HAS_SOFTCAPPING: | ||
intermediate = tanh(X_block / softcap) | ||
X_block = softcap * intermediate | ||
|
||
softmax_X = tl.exp(X_block - m) / d | ||
# Cumulate the sum of softmax(x_i) * x_i | ||
sum_p_x += tl.sum(tl.where(X_offsets < n_cols, softmax_X * X_block, 0.0)) | ||
|
||
entropy_loss = lse - sum_p_x | ||
|
||
# 4. [Online Softmax] Second pass: compute gradients | ||
# For 'mean' reduction, gradients are normalized by number of non-ignored elements (N) | ||
# dx_y = (softmax(x_y) - 1) / N | ||
|
@@ -167,9 +197,26 @@ def liger_cross_entropy_kernel( | |
intermediate = tanh(X_block / softcap) | ||
X_block = softcap * intermediate | ||
|
||
# Calculate the softmax of the input | ||
softmax_X = tl.exp(X_block - m) / d | ||
|
||
# load the derivatives of the entropy loss | ||
if RETURN_ENTROPY_LOSS: | ||
dX_entropy_block = tl.load( | ||
dX_entropy_ptr + X_offsets, | ||
mask=X_offsets < n_cols, | ||
other=0.0, | ||
).cast(tl.float32) | ||
|
||
# derivatives of the entropy loss term | ||
dX_entropy_block = softmax_X * sum_p_x - softmax_X * X_block | ||
# Note that the weight is only applied to ce loss, not for entropy loss. | ||
if reduction == "mean": | ||
dX_entropy_block = dX_entropy_block / n_non_ignore | ||
|
||
if not HAS_WEIGHT: | ||
# softmax(x_i) | ||
X_block = tl.exp(X_block - m) / d | ||
X_block = softmax_X | ||
# derivative of z-loss: 2 * lse_square_scale * lse * softmax(x_i) | ||
X_block += 2 * lse_square_scale * lse * X_block | ||
# smoothing term | ||
|
@@ -181,7 +228,6 @@ def liger_cross_entropy_kernel( | |
X_block = X_block / n_non_ignore | ||
else: | ||
weight_block = tl.load(weight_ptr + X_offsets, mask=X_offsets < n_cols) | ||
softmax_X = tl.exp(X_block - m) / d | ||
# derivative of original_loss | ||
dloss_ori = (1 - label_smoothing) * softmax_X | ||
# specially handle dx_y | ||
|
@@ -204,8 +250,12 @@ def liger_cross_entropy_kernel( | |
# d(softcap * tanh(x / softcap)) = (1 - tanh^2(x / softcap)) | ||
if HAS_SOFTCAPPING: | ||
X_block = X_block * (1 - intermediate * intermediate) | ||
if RETURN_ENTROPY_LOSS: | ||
dX_entropy_block = dX_entropy_block * (1 - intermediate * intermediate) | ||
|
||
tl.store(X_ptr + X_offsets, X_block, mask=X_offsets < n_cols) | ||
if RETURN_ENTROPY_LOSS: | ||
tl.store(dX_entropy_ptr + X_offsets, dX_entropy_block, mask=X_offsets < n_cols) | ||
|
||
# We need tl.debug_barrier() to ensure the new result of X_ptr is written as mentioned in | ||
# https://github.com/triton-lang/triton/blob/ba42a5c68fd0505f8c42f4202d53be0f8d9a5fe0/python/triton/ops/cross_entropy.py#L34 | ||
|
@@ -248,11 +298,16 @@ def liger_cross_entropy_kernel( | |
loss = loss / n_non_ignore | ||
# TODO: Implement weighted z_loss. Currently, z_loss is not scaled by weight. | ||
z_loss = z_loss / n_non_ignore | ||
# Note that the weight is only applied to ce loss, not for entropy loss. | ||
entropy_loss = entropy_loss / n_non_ignore | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It seems you had done the implementation of weight provided case already above? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Thanks for catching this. I think this is a bug in my program. I just fixed it. But it seems the numerical problem is still there. Maybe we need to take a deeper look. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. BTW, it seems the CI stops running for this PR for some reason. |
||
|
||
loss += z_loss | ||
|
||
tl.store(loss_ptr, loss) | ||
if RETURN_Z_LOSS: | ||
tl.store(z_loss_ptr, z_loss) | ||
if RETURN_ENTROPY_LOSS: | ||
tl.store(entropy_loss_ptr, entropy_loss) | ||
|
||
|
||
# The hard limit of TRITON_MAX_TENSOR_NUMEL is 1048576 https://github.com/triton-lang/triton/blob/ba42a5c68fd0505f8c42f4202d53be0f8d9a5fe0/python/triton/language/core.py#L19 | ||
|
@@ -271,27 +326,32 @@ def cross_entropy_forward( | |
reduction, | ||
softcap, | ||
return_z_loss, | ||
return_entropy_loss, | ||
): | ||
assert isinstance(return_z_loss, bool), f"return_z_loss must be True or False. Got: {return_z_loss}" | ||
assert isinstance(return_entropy_loss, bool), ( | ||
f"return_entropy_loss must be True or False. Got: {return_entropy_loss}" | ||
) | ||
|
||
BT, V = _input.shape | ||
n_rows = BT | ||
|
||
BLOCK_SIZE = min(MAX_FUSED_SIZE, triton.next_power_of_2(V)) | ||
|
||
# unreduced loss | ||
dX_entropy_2d = torch.zeros(n_rows, V, dtype=_input.dtype, device=_input.device) if return_entropy_loss else None | ||
loss_1d = torch.zeros(n_rows, dtype=_input.dtype, device=_input.device) | ||
z_loss_1d = torch.zeros(n_rows, dtype=_input.dtype, device=_input.device) if return_z_loss else None | ||
|
||
entropy_loss_1d = torch.zeros(n_rows, dtype=_input.dtype, device=_input.device) if return_entropy_loss else None | ||
target_mask = target != ignore_index | ||
n_non_ignore = target_mask.sum().item() | ||
sum_non_ignore_weight = n_non_ignore | ||
weight_sum = 0.0 | ||
if weight is not None: | ||
assert weight.shape[0] == V, f"If given, weight has to be a Tensor of size V. Got: {weight.shape}" | ||
assert torch.is_floating_point( | ||
weight | ||
), f"If given, weight has to be a Tensor of floating point dtype. Got: {weight.dtype}" | ||
assert torch.is_floating_point(weight), ( | ||
f"If given, weight has to be a Tensor of floating point dtype. Got: {weight.dtype}" | ||
) | ||
sum_non_ignore_weight = torch.gather(weight, dim=0, index=target.masked_select(target_mask)).sum().item() | ||
weight_sum = weight.sum().item() | ||
# ensure weight is contiguous | ||
|
@@ -307,12 +367,14 @@ def cross_entropy_forward( | |
# Here we use a trick to store X_ptr gradient in X_ptr so we can save memory | ||
liger_cross_entropy_kernel[(n_rows,)]( | ||
X_ptr=_input, | ||
dX_entropy_ptr=dX_entropy_2d, | ||
X_stride=_input.stride(-2), | ||
Y_ptr=target, | ||
Y_stride=target.stride(-1), # always 1 | ||
weight_ptr=weight, # dummy if None | ||
loss_ptr=loss_1d, | ||
z_loss_ptr=z_loss_1d, | ||
entropy_loss_ptr=entropy_loss_1d, | ||
loss_stride=loss_1d.stride(-1), # always 1 | ||
n_cols=V, | ||
n_non_ignore=n_non_ignore, | ||
|
@@ -324,6 +386,7 @@ def cross_entropy_forward( | |
reduction=reduction, | ||
softcap=softcap, | ||
RETURN_Z_LOSS=return_z_loss, | ||
RETURN_ENTROPY_LOSS=return_entropy_loss, | ||
BLOCK_SIZE=BLOCK_SIZE, | ||
HAS_WEIGHT=True if weight is not None else False, | ||
HAS_SOFTCAPPING=True if softcap is not None else False, | ||
|
@@ -335,25 +398,27 @@ def cross_entropy_forward( | |
if reduction == "none": | ||
loss = loss_1d | ||
z_loss = z_loss_1d if return_z_loss else None | ||
entropy_loss = entropy_loss_1d if return_entropy_loss else None | ||
else: | ||
loss = torch.sum(loss_1d) | ||
z_loss = torch.sum(z_loss_1d) if return_z_loss else None | ||
entropy_loss = torch.sum(entropy_loss_1d) if return_entropy_loss else None | ||
|
||
return loss, z_loss, _input | ||
return loss, z_loss, entropy_loss, _input, dX_entropy_2d | ||
|
||
|
||
def cross_entropy_backward(_input, grad_output): | ||
def cross_entropy_backward(_input, dX_entropy_2d, grad_output, grad_output_entropy): | ||
BT, V = _input.shape | ||
n_rows = BT | ||
BLOCK_SIZE = min(MAX_FUSED_SIZE, triton.next_power_of_2(V)) | ||
|
||
# If cross entropy is the last layer, grad_output is 1.0. Skip the mul to save time | ||
if torch.equal(grad_output, torch.tensor(1.0, device=grad_output.device)): | ||
pass | ||
|
||
# We use a Triton kernel instead of a PyTorch operation because modifying inputs in-place | ||
# for gradient storage and backward multiple times causes anomalies with PyTorch but not with Triton. | ||
else: | ||
BT, V = _input.shape | ||
n_rows = BT | ||
BLOCK_SIZE = min(MAX_FUSED_SIZE, triton.next_power_of_2(V)) | ||
|
||
element_mul_kernel[(n_rows,)]( | ||
_input, | ||
_input.stride(-2), | ||
|
@@ -362,6 +427,18 @@ def cross_entropy_backward(_input, grad_output): | |
BLOCK_SIZE=BLOCK_SIZE, | ||
num_warps=32 if not is_hip() else 16, | ||
) | ||
# calculate the gradient of the input w.r.t the entropy loss | ||
if dX_entropy_2d is not None: | ||
if not torch.equal(grad_output_entropy, torch.tensor(1.0, device=grad_output_entropy.device)): | ||
element_mul_kernel[(n_rows,)]( | ||
dX_entropy_2d, | ||
dX_entropy_2d.stride(-2), | ||
grad_output_entropy, | ||
V, | ||
BLOCK_SIZE=BLOCK_SIZE, | ||
num_warps=32 if not is_hip() else 16, | ||
) | ||
_input = dX_entropy_2d + _input | ||
|
||
return _input | ||
|
||
|
@@ -384,6 +461,7 @@ def forward( | |
reduction: str = "mean", | ||
softcap: Optional[float] = None, | ||
return_z_loss: bool = False, | ||
return_entropy_loss: bool = False, | ||
): | ||
""" | ||
The forward pass of the Liger Cross Entropy loss. | ||
|
@@ -403,7 +481,7 @@ def forward( | |
Returns: | ||
tuple: A tuple with the compouted losses with respect to loss and z loss. The elements are tensors or None. | ||
""" | ||
loss, z_loss, _input = cross_entropy_forward( | ||
loss, z_loss, entropy_loss, _input, dX_entropy_2d = cross_entropy_forward( | ||
_input, | ||
target, | ||
weight, | ||
|
@@ -413,32 +491,49 @@ def forward( | |
reduction, | ||
softcap, | ||
return_z_loss, | ||
return_entropy_loss, | ||
) | ||
# TODO: investigation | ||
# If we don't detach the _input tensor, the memory will double | ||
# Not sure why but seems that there will be a time both grad and value exist but in different location | ||
ctx.save_for_backward(_input.detach()) | ||
if return_entropy_loss: | ||
ctx.save_for_backward(_input.detach(), dX_entropy_2d.detach()) | ||
else: | ||
ctx.save_for_backward(_input.detach()) | ||
ctx.return_z_loss = return_z_loss | ||
ctx.return_entropy_loss = return_entropy_loss | ||
|
||
return loss, z_loss | ||
return loss, z_loss, entropy_loss | ||
|
||
@staticmethod | ||
def backward(ctx, grad_output, grad_ouput2): | ||
def backward(ctx, grad_output, grad_ouput2, grad_ouput3): | ||
""" | ||
The backward pass of the Liger Cross Entropy loss. | ||
|
||
Parameters: | ||
ctx : The context object with saved tensors. | ||
grad_output (tensor): The tensor containing the gradient of the loss with respect to the output. | ||
grad_output2 (tenosr): No use. | ||
grad_output3 (tenosr): The tensor containing the gradient of the loss with respect to the entropy loss. | ||
Returns: | ||
tuple: A tuple with the gradients with respect to the inputs. The elements are tensors or None. | ||
""" | ||
if ctx.return_z_loss: | ||
del grad_ouput2 # z_loss is only for logging | ||
|
||
(_input,) = ctx.saved_tensors | ||
_input = cross_entropy_backward(_input, grad_output) | ||
if ctx.return_entropy_loss: | ||
(_input, dX_entropy_2d) = ctx.saved_tensors | ||
else: | ||
(_input,), dX_entropy_2d = ctx.saved_tensors, None | ||
|
||
print(grad_output, grad_ouput3) | ||
|
||
_input = cross_entropy_backward(_input, dX_entropy_2d, grad_output, grad_ouput3) | ||
|
||
# delete the tensors that are not used in remaining steps | ||
del grad_ouput3 | ||
del dX_entropy_2d | ||
|
||
return ( | ||
_input, | ||
None, | ||
|
@@ -449,4 +544,5 @@ def backward(ctx, grad_output, grad_ouput2): | |
None, | ||
None, | ||
None, | ||
None, | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can probably put an equation in the PR description and also a simple one in a comment here to demonstrate how the
entropy_loss
is calculated (especially on the reuse ofm
andd
computed in the first pass online softmax