Skip to content

Commit aa2d23d

Browse files
authoredJan 30, 2025··
Add JSD Loss for Distillation (#425)
## Summary > [!CAUTION] > This PR depends on #417. Do not merge until #417 (later #432) is merged. This is a pure torch compiled, chunked fused linear JSD Loss, aiming for knowledge distillation. #### Jensen-Shannon Divergence Loss This PR implements Jensen-Shannon Divergence (JSD) loss as the soft learning objective in a distillation setting (teacher & student). This component can be replaced with other losses (e.g., KL divergence) as `distillation_loss_fn`. JSD is defined as the average of the KL divergences between each distribution and the mean distribution: ```math \text{JSD}(P || Q) = \frac{1}{2} \text{KL}(P || M) + \frac{1}{2} \text{KL}(Q || M), \quad \text{where } M = \frac{1}{2}(P + Q) ``` Here, `P`and `Q` are the two probability distributions, and `M` is their average. ## Testing Done Below figures are benchmark results with different `chunk_size`, which also significantly affects performance. #### Hint: User can tune their `chunk_size` as suggested by the liger [paper](https://arxiv.org/pdf/2306.13649) for the moment: ```math 2^{\lceil \log_2 \lceil \frac{BT}{V/H} \rceil \rceil} ``` #### Memory 1. `chunk_size` = 1 ![distill_jsd_loss_memory_chunk_size_1](https://github.com/user-attachments/assets/e00b2044-e075-4e34-b302-3808f7216837) 2. `chunk_size` = 1024 ![distill_jsd_loss_memory_chunk_size_1024](https://github.com/user-attachments/assets/abe9fe17-726c-4fd0-899f-5d0e563ceb05) #### Speed (Elapsed Time) 1. `chunk_size` = 1 ![distill_jsd_loss_speed_chunk_size_1](https://github.com/user-attachments/assets/e2da495e-ff20-4e63-b7df-d6e1837774c8) 2. `chunk_size` = 1024 ![distill_jsd_loss_speed_chunk_size_1024](https://github.com/user-attachments/assets/c2767754-a984-4f11-b5a1-cb21e8117ef6) - Hardware Type: NVIDIA H100 80GB HBM3 (SXM5) - [X] run `make test` to ensure correctness - [X] run `make checkstyle` to ensure code style - [ ] run `make test-convergence` to ensure convergence --------- Signed-off-by: Austin Liu <[email protected]>
1 parent b80bf95 commit aa2d23d

File tree

8 files changed

+778
-6
lines changed

8 files changed

+778
-6
lines changed
 

‎benchmark/data/all_benchmark_data.csv

+24
Original file line numberDiff line numberDiff line change
@@ -745,3 +745,27 @@ kto_loss,huggingface,full,memory,MB,B,Batch Size (B),4,5544.25390625,5544.253906
745745
kto_loss,huggingface,full,memory,MB,B,Batch Size (B),8,9057.287109375,9057.287109375,9057.287109375,"{""T"": 512, ""H"": 1024, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": true, ""beta"": 0.1, ""ignore_index"": 42}",NVIDIA A100-SXM4-80GB,2024-12-23 23:34:59,0.5.2
746746
kto_loss,huggingface,full,memory,MB,B,Batch Size (B),16,16087.353515625,16087.353515625,16087.353515625,"{""T"": 512, ""H"": 1024, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": true, ""beta"": 0.1, ""ignore_index"": 42}",NVIDIA A100-SXM4-80GB,2024-12-23 23:34:59,0.5.2
747747
kto_loss,huggingface,full,memory,MB,B,Batch Size (B),32,30147.486328125,30147.486328125,30147.486328125,"{""T"": 512, ""H"": 1024, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": true, ""beta"": 0.1, ""ignore_index"": 42}",NVIDIA A100-SXM4-80GB,2024-12-23 23:34:59,0.5.2
748+
distill_jsd_loss,liger,forward,speed,ms,BT,B x T,1024,7.735536098480225,7.729177474975586,7.798131465911865,"{""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": false, ""weight_hard_loss"": 0.5, ""weight_soft_loss"": 0.5, ""ignore_index"": -100}",NVIDIA H100 80GB HBM3,2024-12-03 07:58:46,0.4.2
749+
distill_jsd_loss,liger,forward,speed,ms,BT,B x T,2048,15.20411205291748,15.165056228637695,15.226079940795898,"{""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": false, ""weight_hard_loss"": 0.5, ""weight_soft_loss"": 0.5, ""ignore_index"": -100}",NVIDIA H100 80GB HBM3,2024-12-03 07:58:46,0.4.2
750+
distill_jsd_loss,liger,forward,speed,ms,BT,B x T,4096,30.159456253051758,30.126911163330078,30.165311813354492,"{""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": false, ""weight_hard_loss"": 0.5, ""weight_soft_loss"": 0.5, ""ignore_index"": -100}",NVIDIA H100 80GB HBM3,2024-12-03 07:58:46,0.4.2
751+
distill_jsd_loss,liger,forward,speed,ms,BT,B x T,8192,60.24163055419922,60.24163055419922,60.24163055419922,"{""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": false, ""weight_hard_loss"": 0.5, ""weight_soft_loss"": 0.5, ""ignore_index"": -100}",NVIDIA H100 80GB HBM3,2024-12-03 07:58:46,0.4.2
752+
distill_jsd_loss,torch,forward,speed,ms,BT,B x T,1024,10.906111717224121,10.903244972229004,10.91296672821045,"{""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": false, ""weight_hard_loss"": 0.5, ""weight_soft_loss"": 0.5, ""ignore_index"": -100}",NVIDIA H100 80GB HBM3,2024-12-03 07:59:18,0.4.2
753+
distill_jsd_loss,torch,forward,speed,ms,BT,B x T,2048,21.480207443237305,21.465139389038086,21.489286422729492,"{""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": false, ""weight_hard_loss"": 0.5, ""weight_soft_loss"": 0.5, ""ignore_index"": -100}",NVIDIA H100 80GB HBM3,2024-12-03 07:59:18,0.4.2
754+
distill_jsd_loss,torch,forward,speed,ms,BT,B x T,4096,42.96339416503906,42.96237564086914,42.96440887451172,"{""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": false, ""weight_hard_loss"": 0.5, ""weight_soft_loss"": 0.5, ""ignore_index"": -100}",NVIDIA H100 80GB HBM3,2024-12-03 07:59:18,0.4.2
755+
distill_jsd_loss,torch,forward,speed,ms,BT,B x T,8192,85.3946533203125,85.3946533203125,85.3946533203125,"{""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": false, ""weight_hard_loss"": 0.5, ""weight_soft_loss"": 0.5, ""ignore_index"": -100}",NVIDIA H100 80GB HBM3,2024-12-03 07:59:18,0.4.2
756+
distill_jsd_loss,liger,full,speed,ms,BT,B x T,1024,8.312895774841309,8.310400009155273,8.326751708984375,"{""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": false, ""weight_hard_loss"": 0.5, ""weight_soft_loss"": 0.5, ""ignore_index"": -100}",NVIDIA H100 80GB HBM3,2024-12-03 07:59:51,0.4.2
757+
distill_jsd_loss,liger,full,speed,ms,BT,B x T,2048,15.770208358764648,15.767775535583496,15.774784088134766,"{""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": false, ""weight_hard_loss"": 0.5, ""weight_soft_loss"": 0.5, ""ignore_index"": -100}",NVIDIA H100 80GB HBM3,2024-12-03 07:59:51,0.4.2
758+
distill_jsd_loss,liger,full,speed,ms,BT,B x T,4096,30.922752380371094,30.920312881469727,30.927898406982422,"{""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": false, ""weight_hard_loss"": 0.5, ""weight_soft_loss"": 0.5, ""ignore_index"": -100}",NVIDIA H100 80GB HBM3,2024-12-03 07:59:51,0.4.2
759+
distill_jsd_loss,liger,full,speed,ms,BT,B x T,8192,60.70627212524414,60.70627212524414,60.70627212524414,"{""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": false, ""weight_hard_loss"": 0.5, ""weight_soft_loss"": 0.5, ""ignore_index"": -100}",NVIDIA H100 80GB HBM3,2024-12-03 07:59:51,0.4.2
760+
distill_jsd_loss,torch,full,speed,ms,BT,B x T,1024,28.72480010986328,28.718809127807617,28.728179931640625,"{""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": false, ""weight_hard_loss"": 0.5, ""weight_soft_loss"": 0.5, ""ignore_index"": -100}",NVIDIA H100 80GB HBM3,2024-12-03 08:00:25,0.4.2
761+
distill_jsd_loss,torch,full,speed,ms,BT,B x T,2048,54.281761169433594,54.281761169433594,54.281761169433594,"{""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": false, ""weight_hard_loss"": 0.5, ""weight_soft_loss"": 0.5, ""ignore_index"": -100}",NVIDIA H100 80GB HBM3,2024-12-03 08:00:25,0.4.2
762+
distill_jsd_loss,torch,full,speed,ms,BT,B x T,4096,107.08905792236328,107.08905792236328,107.08905792236328,"{""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": false, ""weight_hard_loss"": 0.5, ""weight_soft_loss"": 0.5, ""ignore_index"": -100}",NVIDIA H100 80GB HBM3,2024-12-03 08:00:25,0.4.2
763+
distill_jsd_loss,torch,full,speed,ms,BT,B x T,8192,213.1598663330078,213.1598663330078,213.1598663330078,"{""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": false, ""weight_hard_loss"": 0.5, ""weight_soft_loss"": 0.5, ""ignore_index"": -100}",NVIDIA H100 80GB HBM3,2024-12-03 08:00:25,0.4.2
764+
distill_jsd_loss,liger,full,memory,MB,BT,B x T,1024,10913.541015625,10913.541015625,10913.541015625,"{""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": false, ""weight_hard_loss"": 0.5, ""weight_soft_loss"": 0.5, ""ignore_index"": -100}",NVIDIA H100 80GB HBM3,2024-12-03 08:00:58,0.4.2
765+
distill_jsd_loss,liger,full,memory,MB,BT,B x T,2048,10941.548828125,10941.548828125,10941.548828125,"{""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": false, ""weight_hard_loss"": 0.5, ""weight_soft_loss"": 0.5, ""ignore_index"": -100}",NVIDIA H100 80GB HBM3,2024-12-03 08:00:58,0.4.2
766+
distill_jsd_loss,liger,full,memory,MB,BT,B x T,4096,10997.564453125,10997.564453125,10997.564453125,"{""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": false, ""weight_hard_loss"": 0.5, ""weight_soft_loss"": 0.5, ""ignore_index"": -100}",NVIDIA H100 80GB HBM3,2024-12-03 08:00:58,0.4.2
767+
distill_jsd_loss,liger,full,memory,MB,BT,B x T,8192,11109.595703125,11109.595703125,11109.595703125,"{""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": false, ""weight_hard_loss"": 0.5, ""weight_soft_loss"": 0.5, ""ignore_index"": -100}",NVIDIA H100 80GB HBM3,2024-12-03 08:00:58,0.4.2
768+
distill_jsd_loss,torch,full,memory,MB,BT,B x T,1024,16174.0390625,16174.0390625,16174.0390625,"{""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": false, ""weight_hard_loss"": 0.5, ""weight_soft_loss"": 0.5, ""ignore_index"": -100}",NVIDIA H100 80GB HBM3,2024-12-03 08:01:32,0.4.2
769+
distill_jsd_loss,torch,full,memory,MB,BT,B x T,2048,23713.05078125,23713.05078125,23713.05078125,"{""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": false, ""weight_hard_loss"": 0.5, ""weight_soft_loss"": 0.5, ""ignore_index"": -100}",NVIDIA H100 80GB HBM3,2024-12-03 08:01:32,0.4.2
770+
distill_jsd_loss,torch,full,memory,MB,BT,B x T,4096,38791.07421875,38791.07421875,38791.07421875,"{""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": false, ""weight_hard_loss"": 0.5, ""weight_soft_loss"": 0.5, ""ignore_index"": -100}",NVIDIA H100 80GB HBM3,2024-12-03 08:01:32,0.4.2
771+
distill_jsd_loss,torch,full,memory,MB,BT,B x T,8192,68947.1015625,68947.1015625,68947.1015625,"{""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": false, ""weight_hard_loss"": 0.5, ""weight_soft_loss"": 0.5, ""ignore_index"": -100}",NVIDIA H100 80GB HBM3,2024-12-03 08:01:32,0.4.2
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,261 @@
1+
import os
2+
import sys
3+
4+
import torch
5+
import triton
6+
7+
from utils import QUANTILES
8+
from utils import SingleBenchmarkRunInput
9+
from utils import SingleBenchmarkRunOutput
10+
from utils import _test_memory
11+
from utils import parse_benchmark_script_args
12+
from utils import run_benchmarks
13+
14+
from liger_kernel.chunked_loss.jsd_loss import LigerFusedLinearJSDFunction
15+
from liger_kernel.utils import infer_device
16+
17+
device = infer_device()
18+
19+
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "../..")))
20+
21+
22+
class TorchJSDLoss(torch.nn.Module):
23+
def __init__(
24+
self,
25+
H: int,
26+
V: int,
27+
dtype: torch.dtype,
28+
weight_hard_loss: float = 0.5,
29+
weight_soft_loss: float = 0.5,
30+
ignore_index: int = -100,
31+
temperature: float = 1.0,
32+
bias: bool = False,
33+
):
34+
from test.chunked_loss.test_jsd_loss import HFJSDLoss
35+
36+
super().__init__()
37+
self.student_lin = torch.nn.Linear(in_features=H // 2, out_features=V, bias=bias, dtype=dtype)
38+
self.teacher_lin = torch.nn.Linear(in_features=H, out_features=V, bias=bias, dtype=dtype)
39+
self.jsd_loss = HFJSDLoss(
40+
ignore_index=ignore_index,
41+
weight_hard_loss=weight_hard_loss,
42+
weight_soft_loss=weight_soft_loss,
43+
temperature=temperature,
44+
).get_batch_loss_metrics
45+
46+
def forward(self, student, teacher, target):
47+
return self.jsd_loss(
48+
student,
49+
self.student_lin.weight,
50+
teacher,
51+
self.teacher_lin.weight,
52+
target,
53+
)
54+
55+
56+
class LigerJSDLoss(torch.nn.Module):
57+
def __init__(
58+
self,
59+
H: int,
60+
V: int,
61+
dtype: torch.dtype,
62+
weight_hard_loss: float = 0.5,
63+
weight_soft_loss: float = 0.5,
64+
ignore_index: int = -100,
65+
temperature: float = 1.0,
66+
bias: bool = False,
67+
):
68+
super().__init__()
69+
self.student_lin = torch.nn.Linear(in_features=H // 2, out_features=V, bias=bias, dtype=dtype)
70+
self.teacher_lin = torch.nn.Linear(in_features=H, out_features=V, bias=bias, dtype=dtype)
71+
self.weight_hard_loss = weight_hard_loss
72+
self.weight_soft_loss = weight_soft_loss
73+
self.ignore_index = ignore_index
74+
self.temperature = temperature
75+
self.jsd_loss = LigerFusedLinearJSDFunction.apply
76+
77+
def forward(self, student, teacher, target):
78+
return self.jsd_loss(
79+
student,
80+
self.student_lin.weight,
81+
teacher,
82+
self.teacher_lin.weight,
83+
target,
84+
self.weight_hard_loss,
85+
self.weight_soft_loss,
86+
)
87+
88+
89+
def bench_memory_jsd_loss(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOutput:
90+
BT = input.x
91+
H = input.extra_benchmark_config["H"]
92+
V = input.extra_benchmark_config["V"]
93+
dtype = input.extra_benchmark_config["dtype"]
94+
bias = input.extra_benchmark_config["bias"]
95+
weight_hard_loss = input.extra_benchmark_config["weight_hard_loss"]
96+
weight_soft_loss = input.extra_benchmark_config["weight_soft_loss"]
97+
ignore_index = input.extra_benchmark_config["ignore_index"]
98+
provider = input.kernel_provider
99+
100+
torch_jsd_loss = TorchJSDLoss(
101+
H=H,
102+
V=V,
103+
dtype=dtype,
104+
ignore_index=ignore_index,
105+
bias=bias,
106+
weight_hard_loss=weight_hard_loss,
107+
weight_soft_loss=weight_soft_loss,
108+
).to(device)
109+
liger_jsd_loss = LigerJSDLoss(
110+
H=H,
111+
V=V,
112+
dtype=dtype,
113+
ignore_index=ignore_index,
114+
bias=bias,
115+
weight_hard_loss=weight_hard_loss,
116+
weight_soft_loss=weight_soft_loss,
117+
).to(device)
118+
119+
_tensor = torch.rand(BT, H // 2, device=device, dtype=dtype)
120+
student_input1 = _tensor.detach().clone().requires_grad_(True)
121+
student_input2 = _tensor.detach().clone().requires_grad_(True)
122+
123+
teacher_input = torch.rand(BT, H, device=device, dtype=dtype)
124+
125+
target = torch.randint(0, V, (BT,), device=device, dtype=torch.long)
126+
127+
def fwd():
128+
if provider == "liger":
129+
return liger_jsd_loss(student_input1, teacher_input, target)
130+
elif provider == "torch":
131+
return torch_jsd_loss(student_input2, teacher_input, target)
132+
133+
def full():
134+
y = fwd()
135+
y.backward()
136+
137+
mem_50, mem_20, mem_80 = _test_memory(full, _iter=10, quantiles=QUANTILES)
138+
return SingleBenchmarkRunOutput(
139+
y_20=mem_20,
140+
y_50=mem_50,
141+
y_80=mem_80,
142+
)
143+
144+
145+
def bench_speed_jsd_loss(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOutput:
146+
BT = input.x
147+
H = input.extra_benchmark_config["H"]
148+
V = input.extra_benchmark_config["V"]
149+
dtype = input.extra_benchmark_config["dtype"]
150+
bias = input.extra_benchmark_config["bias"]
151+
weight_hard_loss = input.extra_benchmark_config["weight_hard_loss"]
152+
weight_soft_loss = input.extra_benchmark_config["weight_soft_loss"]
153+
ignore_index = input.extra_benchmark_config["ignore_index"]
154+
provider = input.kernel_provider
155+
mode = input.kernel_operation_mode
156+
157+
torch_jsd_loss = TorchJSDLoss(
158+
H=H,
159+
V=V,
160+
dtype=dtype,
161+
ignore_index=ignore_index,
162+
bias=bias,
163+
weight_hard_loss=weight_hard_loss,
164+
weight_soft_loss=weight_soft_loss,
165+
).to(device)
166+
liger_jsd_loss = LigerJSDLoss(
167+
H=H,
168+
V=V,
169+
dtype=dtype,
170+
ignore_index=ignore_index,
171+
bias=bias,
172+
weight_hard_loss=weight_hard_loss,
173+
weight_soft_loss=weight_soft_loss,
174+
).to(device)
175+
176+
_tensor = torch.rand(BT, H // 2, device=device, dtype=dtype)
177+
student_input1 = _tensor.detach().clone().requires_grad_(True)
178+
student_input2 = _tensor.detach().clone().requires_grad_(True)
179+
180+
teacher_input = torch.rand(BT, H, device=device, dtype=dtype)
181+
182+
target = torch.randint(0, V, (BT,), device=device, dtype=torch.long)
183+
184+
def fwd():
185+
if provider == "liger":
186+
return liger_jsd_loss(student_input1, teacher_input, target)
187+
elif provider == "torch":
188+
return torch_jsd_loss(student_input2, teacher_input, target)
189+
190+
if mode == "forward":
191+
ms_50, ms_20, ms_80 = triton.testing.do_bench(
192+
fwd,
193+
rep=100,
194+
quantiles=QUANTILES,
195+
)
196+
elif mode == "backward":
197+
y = fwd()
198+
ms_50, ms_20, ms_80 = triton.testing.do_bench(
199+
lambda: y.backward(retain_graph=True),
200+
grad_to_none=[student_input1, student_input2],
201+
rep=100,
202+
quantiles=QUANTILES,
203+
)
204+
elif mode == "full":
205+
206+
def full():
207+
y = fwd()
208+
y.backward()
209+
210+
ms_50, ms_20, ms_80 = triton.testing.do_bench(
211+
full,
212+
rep=100,
213+
quantiles=QUANTILES,
214+
)
215+
216+
return SingleBenchmarkRunOutput(
217+
y_20=ms_20,
218+
y_50=ms_50,
219+
y_80=ms_80,
220+
)
221+
222+
223+
if __name__ == "__main__":
224+
args = parse_benchmark_script_args()
225+
226+
common_configs = {
227+
"kernel_name": "distill_jsd_loss",
228+
"x_name": "BT",
229+
"x_label": "B x T",
230+
"x_values": [2**i for i in range(10, 14)],
231+
"kernel_providers": ["liger", "torch"],
232+
"extra_benchmark_configs": [
233+
{
234+
"H": 4096,
235+
"V": 128256,
236+
"mode": "forward",
237+
"dtype": torch.bfloat16,
238+
"bias": False,
239+
"weight_hard_loss": 0.5,
240+
"weight_soft_loss": 0.5,
241+
"ignore_index": -100,
242+
}
243+
],
244+
"overwrite": args.overwrite,
245+
}
246+
247+
run_benchmarks(
248+
bench_test_fn=bench_speed_jsd_loss,
249+
kernel_operation_modes=["forward", "full"],
250+
metric_name="speed",
251+
metric_unit="ms",
252+
**common_configs,
253+
)
254+
255+
run_benchmarks(
256+
bench_test_fn=bench_memory_jsd_loss,
257+
kernel_operation_modes=["full"],
258+
metric_name="memory",
259+
metric_unit="MB",
260+
**common_configs,
261+
)
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from liger_kernel.chunked_loss.cpo_loss import LigerFusedLinearCPOLoss # noqa: F401
22
from liger_kernel.chunked_loss.dpo_loss import LigerFusedLinearDPOLoss # noqa: F401
3+
from liger_kernel.chunked_loss.jsd_loss import LigerFusedLinearJSDLoss # noqa: F401
34
from liger_kernel.chunked_loss.kto_loss import LigerFusedLinearKTOLoss # noqa: F401
45
from liger_kernel.chunked_loss.orpo_loss import LigerFusedLinearORPOLoss # noqa: F401
56
from liger_kernel.chunked_loss.simpo_loss import LigerFusedLinearSimPOLoss # noqa: F401
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,13 @@
11
from liger_kernel.chunked_loss.cpo_loss import LigerFusedLinearCPOFunction
22
from liger_kernel.chunked_loss.dpo_loss import LigerFusedLinearDPOFunction
3+
from liger_kernel.chunked_loss.jsd_loss import LigerFusedLinearJSDFunction
34
from liger_kernel.chunked_loss.kto_loss import LigerFusedLinearKTOFunction
45
from liger_kernel.chunked_loss.orpo_loss import LigerFusedLinearORPOFunction
56
from liger_kernel.chunked_loss.simpo_loss import LigerFusedLinearSimPOFunction
67

78
liger_fused_linear_orpo = LigerFusedLinearORPOFunction.apply
89
liger_fused_linear_dpo = LigerFusedLinearDPOFunction.apply
10+
liger_fused_linear_jsd = LigerFusedLinearJSDFunction.apply
911
liger_fused_linear_cpo = LigerFusedLinearCPOFunction.apply
1012
liger_fused_linear_simpo = LigerFusedLinearSimPOFunction.apply
1113
liger_fused_linear_kto = LigerFusedLinearKTOFunction.apply

‎src/liger_kernel/chunked_loss/fused_linear_distillation.py

+14-5
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,9 @@ def distillation_loss_fn(
1717
Args:
1818
student_logits (torch.Tensor): Raw (temperature-scaled) logits of student tokens. Shape: (batch_size * seq_len, vocab_size).
1919
teacher_logits (torch.Tensor): Raw (temperature-scaled) logits of teacher tokens. Shape: (batch_size * seq_len, vocab_size).
20+
Returns:
21+
torch.Tensor: Sum of distillation losses for the chunk. The class will handle
22+
converting this to mean loss by dividing by the full batch size * sequence length in _compute_loss.
2023
"""
2124
raise NotImplementedError("Distillation loss function must be implemented.")
2225

@@ -71,10 +74,11 @@ def _compute_loss(
7174
weight_hard_loss=0.5,
7275
weight_soft_loss=0.5,
7376
compute_ce_loss=True,
77+
temperature=1,
7478
**loss_kwargs,
7579
):
7680
"""
77-
Compute the total loss for a chunk of input and target, while using an knowleedge distillation loss function.
81+
Compute the total loss for a chunk of input and target, while using an knowledge distillation loss function.
7882
Args:
7983
distillation_loss_fn (callable): Loss function to compute the loss on a chunk of input/target.
8084
student_input_chunk (torch.Tensor): Chunk of input tensor. Shape: (chunk_size, student_hidden_size).
@@ -84,11 +88,12 @@ def _compute_loss(
8488
target_chunk (torch.Tensor): Chunk of target tensor. Shape: (chunk_size,).
8589
student_bias (torch.Tensor, optional): Bias tensor. Shape: (vocab_size,).
8690
teacher_bias (torch.Tensor, optional): Bias tensor. Shape: (vocab_size,).
87-
full_target (torch.Tensor): Full target tensor. Shape: (chunk_size,).
91+
full_target (torch.Tensor): Full target tensor. Shape: (batch_size * sequence_length,).
8892
ignore_index (int): Index to ignore for loss computation.
8993
weight_hard_loss (float): Weight for hard loss.
9094
weight_soft_loss (float): Weight for soft loss.
9195
compute_ce_loss (bool): Whether to compute CE loss.
96+
temperature (float): Temperature to control the input probability distribution. Default: `1.0` (i.e. no scale)
9297
loss_kwargs (dict): Additional arguments for the loss function.
9398
"""
9499
(
@@ -107,6 +112,9 @@ def _compute_loss(
107112
compute_ce_loss=compute_ce_loss,
108113
)
109114

115+
student_logits_chunk /= temperature
116+
teacher_logits_chunk /= temperature
117+
110118
hard_loss /= full_target.shape[0]
111119

112120
soft_loss = distillation_loss_fn(student_logits_chunk, teacher_logits_chunk)
@@ -130,6 +138,7 @@ def forward(
130138
ignore_index=-100,
131139
weight_hard_loss=0.5,
132140
weight_soft_loss=0.5,
141+
beta=0.5,
133142
compute_ce_loss=True,
134143
temperature=1.0,
135144
compiled=True,
@@ -152,6 +161,7 @@ def forward(
152161
ignore_index (int): Index to ignore for loss computation.
153162
weight_hard_loss (float): Weight for hard/task loss.
154163
weight_soft_loss (float): Weight for soft/distillation loss.
164+
beta (float): Interpolation coefficient between 0 and 1 (default: 0.5).
155165
compute_ce_loss (bool): Whether to compute CE loss.
156166
temperature (float): Temperature to control the input probability distribution. Default: `1.0` (i.e. no scale)
157167
compiled (bool): Whether to use torch compile for chunk accumulation.
@@ -170,7 +180,9 @@ def forward(
170180
ignore_index=ignore_index,
171181
weight_hard_loss=weight_hard_loss,
172182
weight_soft_loss=weight_soft_loss,
183+
beta=beta,
173184
compute_ce_loss=compute_ce_loss,
185+
temperature=temperature,
174186
**loss_kwargs,
175187
)
176188

@@ -225,9 +237,6 @@ def accumulate_chunk(student_input_chunk, teacher_input_chunk, target_chunk):
225237
if compiled:
226238
accumulate_chunk = torch.compile(accumulate_chunk)
227239

228-
student_input /= temperature
229-
teacher_input /= temperature
230-
231240
num_chunks = max(1, student_input.shape[0] // CHUNK_SIZE)
232241
_student_input_chunks = torch.chunk(student_input, chunks=num_chunks, dim=0)
233242
_teacher_input_chunks = torch.chunk(teacher_input, chunks=num_chunks, dim=0)
+154
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,154 @@
1+
import torch
2+
import torch.nn.functional as F
3+
4+
from liger_kernel.chunked_loss.fused_linear_distillation import LigerFusedLinearDistillationBase
5+
6+
7+
class LigerFusedLinearJSDFunction(LigerFusedLinearDistillationBase):
8+
@staticmethod
9+
def distillation_loss_fn(student_logits, teacher_logits, beta=0.5):
10+
"""
11+
Compute JSD loss (Jensen-Shannon Divergence Loss).
12+
Args:
13+
student_logits (torch.Tensor): Logits of student tokens. Shape: (batch_size * seq_len,).
14+
teacher_logits (torch.Tensor): Logits of teacher tokens. Shape: (batch_size * seq_len,).
15+
beta (float): Coefficient beta of generalized JSD in the interval [0, 1]. Default: `0.5`.
16+
Returns:
17+
torch.Tensor: Jensen-Shannon Divergence loss
18+
"""
19+
student_log_probs = F.log_softmax(student_logits, dim=-1)
20+
teacher_log_probs = F.log_softmax(teacher_logits, dim=-1)
21+
22+
# Compute probabilities (only required for mean calculation)
23+
mean_probs = beta * student_log_probs.exp() + (1 - beta) * teacher_log_probs.exp()
24+
log_mean_probs = mean_probs.log()
25+
26+
student_kl = F.kl_div(log_mean_probs, student_log_probs, reduction="sum", log_target=True)
27+
teacher_kl = F.kl_div(log_mean_probs, teacher_log_probs, reduction="sum", log_target=True)
28+
29+
# JSD is the weighted average of the KL divergences
30+
jsd_loss = beta * teacher_kl + (1 - beta) * student_kl
31+
return jsd_loss
32+
33+
@staticmethod
34+
def forward(
35+
ctx,
36+
student_input: torch.Tensor,
37+
student_weight: torch.Tensor,
38+
teacher_input: torch.Tensor,
39+
teacher_weight: torch.Tensor,
40+
true_labels: torch.LongTensor,
41+
weight_hard_loss: float = 0.5,
42+
weight_soft_loss: float = 0.5,
43+
beta: float = 0.5,
44+
ignore_index: int = -100,
45+
temperature: float = 1.0,
46+
compiled: bool = True,
47+
):
48+
"""
49+
Fused linear layer with JSD distillation loss.
50+
Args:
51+
student_input (torch.Tensor): Student input tensor. Shape: (batch_size * seq_len, hidden_size_student)
52+
student_weight (torch.Tensor): Student weight tensor. Shape: (vocab_size, hidden_size_student)
53+
teacher_input (torch.Tensor): Teacher input tensor. Shape: (batch_size * seq_len, hidden_size_teacher)
54+
teacher_weight (torch.Tensor): Teacher weight tensor. Shape: (vocab_size, hidden_size_teacher)
55+
true_labels (torch.LongTensor): Target tensor. Shape: (batch_size * seq_len,)
56+
weight_hard_loss (float): Weight for hard loss.
57+
weight_soft_loss (float): Weight for soft loss.
58+
beta (float): Coefficient beta of generalized JSD in the interval [0, 1]. Default: `0.5`.
59+
ignore_index (int): Index to ignore in loss computation
60+
temperature (float): Temperature for softening/sharpening distributions
61+
compiled (bool): Whether to use torch compile
62+
Returns:
63+
torch.Tensor: Computed loss
64+
"""
65+
return LigerFusedLinearDistillationBase.forward(
66+
ctx=ctx,
67+
student_input=student_input,
68+
student_weight=student_weight,
69+
teacher_input=teacher_input,
70+
teacher_weight=teacher_weight,
71+
target=true_labels,
72+
loss_fn=LigerFusedLinearJSDFunction.distillation_loss_fn,
73+
chunk_size=1,
74+
weight_hard_loss=weight_hard_loss,
75+
weight_soft_loss=weight_soft_loss,
76+
beta=beta,
77+
ignore_index=ignore_index,
78+
temperature=temperature,
79+
compiled=compiled,
80+
)
81+
82+
@staticmethod
83+
def backward(ctx, grad_output):
84+
grads = LigerFusedLinearDistillationBase.backward(ctx, grad_output)[:4]
85+
86+
return (*grads, None, None, None, None, None, None, None)
87+
88+
89+
class LigerFusedLinearJSDLoss(torch.nn.Module):
90+
"""
91+
Fused linear layer with JSD distillation loss.
92+
"""
93+
94+
def __init__(
95+
self,
96+
weight_hard_loss: float = 0.5,
97+
weight_soft_loss: float = 0.5,
98+
beta: float = 0.5,
99+
ignore_index: int = -100,
100+
temperature: float = 1.0,
101+
compiled: bool = True,
102+
):
103+
"""
104+
Args:
105+
weight_hard_loss (float): Weight for hard loss.
106+
weight_soft_loss (float): Weight for soft loss.
107+
ignore_index (int): Index to ignore in the loss
108+
temperature (float): Temperature for softening distributions
109+
compiled (bool): Whether to use torch compile
110+
beta (float): Coefficient beta of generalized JSD in the interval [0, 1]. Default: `0.5`.
111+
"""
112+
super().__init__()
113+
assert temperature != 0, "Temperature cannot be 0."
114+
self.weight_hard_loss = weight_hard_loss
115+
self.weight_soft_loss = weight_soft_loss
116+
self.ignore_index = ignore_index
117+
self.temperature = temperature
118+
self.compiled = compiled
119+
self.beta = beta
120+
121+
def forward(
122+
self,
123+
student_input: torch.Tensor,
124+
student_weight: torch.Tensor,
125+
teacher_input: torch.Tensor,
126+
teacher_weight: torch.Tensor,
127+
true_labels: torch.LongTensor,
128+
) -> torch.Tensor:
129+
"""
130+
Compute the JSD distillation loss.
131+
132+
Args:
133+
student_input (torch.Tensor): Student input tensor
134+
student_weight (torch.Tensor): Student weight tensor
135+
teacher_input (torch.Tensor): Teacher input tensor
136+
teacher_weight (torch.Tensor): Teacher weight tensor
137+
true_labels (torch.LongTensor): Target labels tensor
138+
139+
Returns:
140+
torch.Tensor: Computed loss
141+
"""
142+
return LigerFusedLinearJSDFunction.apply(
143+
student_input,
144+
student_weight,
145+
teacher_input,
146+
teacher_weight,
147+
true_labels,
148+
self.weight_hard_loss,
149+
self.weight_soft_loss,
150+
self.beta,
151+
self.ignore_index,
152+
self.temperature,
153+
self.compiled,
154+
)

‎test/chunked_loss/test_jsd_loss.py

+318
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,318 @@
1+
import pytest
2+
import torch
3+
import torch.nn.functional as F
4+
5+
from liger_kernel.chunked_loss import LigerFusedLinearJSDLoss
6+
from liger_kernel.chunked_loss.functional import liger_fused_linear_jsd
7+
from liger_kernel.chunked_loss.jsd_loss import LigerFusedLinearJSDFunction
8+
from liger_kernel.utils import infer_device
9+
from test.utils import HFDistillationLoss
10+
from test.utils import assert_verbose_allclose
11+
from test.utils import set_seed
12+
13+
device = infer_device()
14+
15+
# set random seed globally
16+
set_seed()
17+
18+
19+
class HFJSDLoss(HFDistillationLoss):
20+
"""
21+
Naive implementation of a distillation loss using Jensen-Shannon Divergence (JSD).
22+
"""
23+
24+
def __init__(
25+
self,
26+
temperature: float = 1.0,
27+
ignore_index: int = -100,
28+
weight_hard_loss: float = 0.5,
29+
weight_soft_loss: float = 0.5,
30+
beta: float = 0.5,
31+
):
32+
super().__init__(
33+
ignore_index=ignore_index,
34+
weight_hard_loss=weight_hard_loss,
35+
weight_soft_loss=weight_soft_loss,
36+
temperature=temperature,
37+
)
38+
self.beta = (beta,)
39+
40+
def distillation_loss(self, student_logits, teacher_logits, beta=0.5):
41+
"""
42+
Compute JSD loss (Jensen-Shannon Divergence Loss).
43+
Args:
44+
student_logits (torch.Tensor): Logits of student tokens. Shape: (batch_size * seq_len,).
45+
teacher_logits (torch.Tensor): Logits of teacher tokens. Shape: (batch_size * seq_len,).
46+
beta (float): Coefficient beta of generalized JSD in the interval [0, 1]. Default: `0.5`.
47+
Returns:
48+
torch.Tensor: Jensen-Shannon Divergence loss
49+
"""
50+
student_log_probs = F.log_softmax(student_logits, dim=-1)
51+
teacher_log_probs = F.log_softmax(teacher_logits, dim=-1)
52+
53+
# Compute probabilities (only required for mean calculation)
54+
mean_probs = beta * student_log_probs.exp() + (1 - beta) * teacher_log_probs.exp()
55+
log_mean_probs = mean_probs.log()
56+
57+
student_kl = F.kl_div(log_mean_probs, student_log_probs, reduction="batchmean", log_target=True)
58+
teacher_kl = F.kl_div(log_mean_probs, teacher_log_probs, reduction="batchmean", log_target=True)
59+
60+
# JSD is the weighted average of the KL divergences
61+
jsd_loss = beta * teacher_kl + (1 - beta) * student_kl
62+
return jsd_loss
63+
64+
65+
class TorchLMHeadJSD(torch.nn.Module):
66+
"""Ground truth implementation of the linear fused with torch based jsd loss.
67+
:param H: hidden size
68+
:param V: vocab size
69+
:param temperature: softmax temperature
70+
:param weight_hard_loss: weight_hard_loss
71+
:param weight_soft_loss: weight_soft_loss
72+
"""
73+
74+
def __init__(
75+
self,
76+
H: int,
77+
V: int,
78+
dtype: torch.dtype,
79+
device: torch.device,
80+
weight_hard_loss: float = 0.5,
81+
weight_soft_loss: float = 0.5,
82+
beta: float = 0.5,
83+
ignore_index: int = -100,
84+
temperature: float = 1.0,
85+
):
86+
super().__init__()
87+
# smaller student model weights
88+
self.student_lin = torch.nn.Linear(in_features=H // 2, out_features=V, bias=False, dtype=dtype, device=device)
89+
self.teacher_lin = torch.nn.Linear(in_features=H, out_features=V, bias=False, dtype=dtype, device=device)
90+
self.jsd = HFJSDLoss(
91+
ignore_index=ignore_index,
92+
weight_hard_loss=weight_hard_loss,
93+
weight_soft_loss=weight_soft_loss,
94+
temperature=temperature,
95+
beta=beta,
96+
).get_batch_loss_metrics
97+
98+
def forward(self, student_input, teacher_input, target):
99+
jsd_loss = self.jsd(
100+
student_input,
101+
self.student_lin.weight,
102+
teacher_input,
103+
self.teacher_lin.weight,
104+
target,
105+
)
106+
return jsd_loss
107+
108+
109+
class LigerLMHeadJSD(torch.nn.Module):
110+
def __init__(
111+
self,
112+
H: int,
113+
V: int,
114+
dtype: torch.dtype,
115+
device: torch.device,
116+
weight_hard_loss: float = 0.5,
117+
weight_soft_loss: float = 0.5,
118+
beta: float = 0.5,
119+
ignore_index: int = -100,
120+
temperature: float = 1.0,
121+
):
122+
super().__init__()
123+
# smaller student model weights
124+
self.student_lin = torch.nn.Linear(in_features=H // 2, out_features=V, bias=False, dtype=dtype, device=device)
125+
self.teacher_lin = torch.nn.Linear(in_features=H, out_features=V, bias=False, dtype=dtype, device=device)
126+
self.chunked_jsd = LigerFusedLinearJSDLoss(
127+
weight_hard_loss=weight_hard_loss,
128+
weight_soft_loss=weight_soft_loss,
129+
ignore_index=ignore_index,
130+
temperature=temperature,
131+
)
132+
133+
def forward(self, student_input, teacher_input, target):
134+
return self.chunked_jsd(
135+
student_input,
136+
self.student_lin.weight,
137+
teacher_input,
138+
self.teacher_lin.weight,
139+
target,
140+
)
141+
142+
143+
#############################################################################
144+
# Test the correctness of the fused linear JSD
145+
#############################################################################
146+
147+
148+
@pytest.mark.parametrize(
149+
"B, T, H, V",
150+
[
151+
(8, 128, 1024, 4096),
152+
(3, 47, 31, 123), # random shape
153+
],
154+
)
155+
@pytest.mark.parametrize(
156+
"scalar, dtype, atol, rtol",
157+
[
158+
(1.0, torch.bfloat16, 5e-2, 5e-1),
159+
(1.0, torch.float32, 1e-5, 5e-4),
160+
],
161+
)
162+
@pytest.mark.parametrize(
163+
"temperature, weight_hard_loss, weight_soft_loss, beta",
164+
[
165+
(1.0, 0.5, 0.5, 0.5),
166+
(2.0, 0.5, 0.5, 0.5),
167+
(0.5, 0.5, 0.5, 0.5),
168+
(1.0, 0.0, 1.0, 0.5),
169+
(1.0, 1.0, 0.0, 0.5),
170+
(1.0, 0.5, 0.5, 0.3),
171+
(2.0, 0.5, 0.5, 0.7),
172+
],
173+
)
174+
def test_correctness(
175+
B,
176+
T,
177+
H,
178+
V,
179+
scalar,
180+
dtype,
181+
atol,
182+
rtol,
183+
temperature,
184+
weight_hard_loss,
185+
weight_soft_loss,
186+
beta,
187+
):
188+
torch_lm_head_jsd = TorchLMHeadJSD(
189+
H=H,
190+
V=V,
191+
dtype=dtype,
192+
device=device,
193+
temperature=temperature,
194+
weight_hard_loss=weight_hard_loss,
195+
weight_soft_loss=weight_soft_loss,
196+
beta=beta,
197+
)
198+
liger_lm_head_jsd = LigerLMHeadJSD(
199+
H=H,
200+
V=V,
201+
dtype=dtype,
202+
device=device,
203+
temperature=temperature,
204+
weight_hard_loss=weight_hard_loss,
205+
weight_soft_loss=weight_soft_loss,
206+
beta=beta,
207+
)
208+
209+
torch_lm_head_jsd.student_lin.weight.data = liger_lm_head_jsd.student_lin.weight.data = torch.rand(
210+
V, H // 2, device=device, dtype=dtype
211+
)
212+
torch_lm_head_jsd.teacher_lin.weight.data = liger_lm_head_jsd.teacher_lin.weight.data = torch.rand(
213+
V, H, device=device, dtype=dtype
214+
)
215+
216+
_tensor = torch.rand(B * T, H // 2, device=device, dtype=dtype) * scalar
217+
student_input1 = _tensor.detach().clone().requires_grad_(True)
218+
student_input2 = _tensor.detach().clone().requires_grad_(True)
219+
220+
teacher_input = torch.rand(B * T, H, device=device, dtype=dtype) * scalar
221+
222+
target = torch.randint(0, V, (B * T,), device=device, dtype=torch.long)
223+
224+
loss1 = torch_lm_head_jsd(student_input1, teacher_input, target)
225+
loss2 = liger_lm_head_jsd(student_input2, teacher_input, target)
226+
assert_verbose_allclose(loss1, loss2, atol=atol, rtol=rtol)
227+
228+
loss1.backward()
229+
loss2.backward()
230+
231+
assert_verbose_allclose(student_input1.grad, student_input2.grad, atol=atol, rtol=rtol)
232+
233+
assert_verbose_allclose(
234+
torch_lm_head_jsd.student_lin.weight.grad,
235+
liger_lm_head_jsd.student_lin.weight.grad,
236+
atol=atol,
237+
rtol=rtol,
238+
)
239+
240+
241+
@pytest.mark.parametrize(
242+
"B, T, H, V",
243+
[
244+
(2, 2, 8, 8),
245+
(9, 7, 41, 41),
246+
],
247+
)
248+
@pytest.mark.parametrize(
249+
"scalar, dtype, atol, rtol",
250+
[
251+
(1.0, torch.bfloat16, 5e-2, 5e-2),
252+
(1.0, torch.float32, 1e-4, 5e-3),
253+
],
254+
)
255+
@pytest.mark.parametrize(
256+
"temperature, weight_hard_loss, weight_soft_loss, beta, ignore_index",
257+
[(1.0, 0.5, 0.5, 0.5, -100), (2.0, 0.1, 0.9, 0.5, 42)],
258+
)
259+
def test_correctness_functional(
260+
B,
261+
T,
262+
H,
263+
V,
264+
scalar,
265+
dtype,
266+
weight_hard_loss,
267+
weight_soft_loss,
268+
beta,
269+
ignore_index,
270+
temperature,
271+
atol,
272+
rtol,
273+
):
274+
_weight = torch.rand(V, H // 2, device=device, dtype=dtype)
275+
student_weight1 = _weight.detach().clone().requires_grad_(True)
276+
student_weight2 = _weight.detach().clone().requires_grad_(True)
277+
teacher_weight = torch.rand(V, H, device=device, dtype=dtype)
278+
279+
_tensor = torch.rand(B * T, H // 2, device=device, dtype=dtype) * scalar
280+
student_input1 = _tensor.detach().clone().requires_grad_(True)
281+
student_input2 = _tensor.detach().clone().requires_grad_(True)
282+
teacher_input = torch.rand(B * T, H, device=device, dtype=dtype) * scalar
283+
284+
label = torch.randint(0, V, (B * T,), device=device, dtype=torch.long)
285+
286+
output1 = liger_fused_linear_jsd(
287+
student_input1,
288+
student_weight1,
289+
teacher_input,
290+
teacher_weight,
291+
label,
292+
weight_hard_loss,
293+
weight_soft_loss,
294+
beta,
295+
ignore_index,
296+
temperature,
297+
)
298+
output2 = LigerFusedLinearJSDFunction.apply(
299+
student_input2,
300+
student_weight2,
301+
teacher_input,
302+
teacher_weight,
303+
label,
304+
weight_hard_loss,
305+
weight_soft_loss,
306+
beta,
307+
ignore_index,
308+
temperature,
309+
)
310+
311+
assert_verbose_allclose(output1, output2, atol=atol, rtol=rtol)
312+
313+
output1.backward()
314+
output2.backward()
315+
316+
assert_verbose_allclose(student_input1.grad, student_input2.grad, atol=atol, rtol=rtol)
317+
318+
assert_verbose_allclose(student_weight1.grad, student_weight2.grad, atol=atol, rtol=rtol)

‎test/utils.py

+4-1
Original file line numberDiff line numberDiff line change
@@ -689,7 +689,10 @@ def get_batch_loss_metrics(
689689
hard_loss,
690690
) = forward_output
691691

692+
student_logits /= self.temperature
693+
teacher_logits /= self.temperature
694+
692695
soft_loss = self.distillation_loss(student_logits, teacher_logits)
693696
# full loss
694-
loss = self.weight_hard_loss * hard_loss + self.weight_soft_loss * soft_loss.mean()
697+
loss = self.weight_hard_loss * hard_loss + self.weight_soft_loss * soft_loss
695698
return loss

0 commit comments

Comments
 (0)
Please sign in to comment.