Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
28 commits
Select commit Hold shift + click to select a range
672c8c8
feat(FLCE): add helion version of fused linear cross entropy
Tcc0403 Nov 3, 2025
fc9a406
compute dx
Tcc0403 Nov 3, 2025
5c81648
add grad_x, grad_w computation
Tcc0403 Nov 3, 2025
8e9c13e
clean up
Tcc0403 Nov 3, 2025
368433b
format
Tcc0403 Nov 3, 2025
164e63a
Fix incorrect grad_w computation with reduction="mean"
Tcc0403 Nov 4, 2025
b50870c
Add unit test
Tcc0403 Nov 4, 2025
f890711
Improve n_non_ignore read efficiency and ERROR comments
Tcc0403 Nov 4, 2025
04f07fa
Set higher tolerance
Tcc0403 Nov 4, 2025
8781e8e
Add benchmark
Tcc0403 Nov 4, 2025
98de5b9
Unfuse forward/backward and use lock
Tcc0403 Nov 6, 2025
30b12eb
testing misc
Tcc0403 Nov 6, 2025
f402ff3
Add cut_cross_entropy comparison
Tcc0403 Nov 6, 2025
c156022
Add LigerFusedLinearCrossEntropy for comparison
Tcc0403 Nov 6, 2025
719344d
Add IMA error comment to liger flce
Tcc0403 Nov 6, 2025
81d0e98
Fix incorrect liger flce args positions
Tcc0403 Nov 6, 2025
fd05527
Remove lock functions wrappers
Tcc0403 Nov 6, 2025
34decd4
Update best configs for h100 with BT=2048, H=4096, V=32000
Tcc0403 Nov 6, 2025
6905390
Clean up handwriting test and let run_example() handle correctness test
Tcc0403 Nov 6, 2025
b508243
Add chunk version of flce backward
Tcc0403 Nov 8, 2025
d901dac
Add autotune misc
Tcc0403 Nov 8, 2025
5ecb417
Fix ignore_index
Tcc0403 Nov 8, 2025
74553dc
Fix backward ctx.savetensors
Tcc0403 Nov 8, 2025
d2b2372
clean up
Tcc0403 Nov 8, 2025
07deb98
Fix autotune fuction
Tcc0403 Nov 8, 2025
06bde20
Add h100 autotune configs
Tcc0403 Nov 8, 2025
4eb0a69
Fix reduction!="mean" and add benchmark script
Tcc0403 Nov 11, 2025
cafa211
Merge branch 'main' into tcc/helion-flce
lancerts Nov 23, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 24 additions & 1 deletion benchmark/scripts/benchmark_fused_linear_cross_entropy.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from utils import parse_benchmark_script_args
from utils import run_benchmarks

from liger_kernel.ops.helion.fused_linear_cross_entropy import LigerFusedLinearCrossEntropyHelion
from liger_kernel.transformers.fused_linear_cross_entropy import LigerFusedLinearCrossEntropyLoss
from liger_kernel.utils import infer_device

Expand Down Expand Up @@ -45,6 +46,20 @@ def forward(self, x, y):
return self.ce_loss(self.lin.weight, x, y)


class LigerLMHeadCEHelion(torch.nn.Module):
def __init__(
self, H: int, V: int, dtype: torch.dtype, ignore_index: int = -100, bwd_impl="chunk", grad_in_forward=False
):
super().__init__()
self.lin = torch.nn.Linear(in_features=H, out_features=V, bias=False, dtype=dtype)
self.ce_loss = LigerFusedLinearCrossEntropyHelion(
ignore_index=ignore_index, reduction="mean", bwd_impl=bwd_impl, grad_in_forward=grad_in_forward
)

def forward(self, x, y):
return self.ce_loss(x, self.lin.weight, y)


#############################################################################
# Test the memory consumption of the linear fused cross entropy loss
#############################################################################
Expand All @@ -64,6 +79,10 @@ def bench_memory_fused_linear_cross_entropy(
lm_head_ce = LigerLMHeadCE(H=H, V=V, dtype=dtype).to(device)
elif provider == "liger-fp32-accum":
lm_head_ce = LigerLMHeadCE(H=H, V=V, dtype=dtype, accum_dtype=torch.float32).to(device)
elif provider == "liger-helion":
lm_head_ce = LigerLMHeadCEHelion(H=H, V=V, dtype=dtype, bwd_impl="chunk", grad_in_forward=False).to(device)
elif provider == "liger-helion-grad-in-fwd":
lm_head_ce = LigerLMHeadCEHelion(H=H, V=V, dtype=dtype, bwd_impl="chunk", grad_in_forward=True).to(device)
else:
lm_head_ce = TorchLMHeadCE(H=H, V=V, dtype=dtype).to(device)

Expand Down Expand Up @@ -106,6 +125,10 @@ def bench_speed_fused_linear_cross_entropy(
lm_head_ce = LigerLMHeadCE(H=H, V=V, dtype=dtype).to(device)
elif provider == "liger-fp32-accum":
lm_head_ce = LigerLMHeadCE(H=H, V=V, dtype=dtype, accum_dtype=torch.float32).to(device)
elif provider == "liger-helion":
lm_head_ce = LigerLMHeadCEHelion(H=H, V=V, dtype=dtype, bwd_impl="chunk", grad_in_forward=False).to(device)
elif provider == "liger-helion-grad-in-fwd":
lm_head_ce = LigerLMHeadCEHelion(H=H, V=V, dtype=dtype, bwd_impl="chunk", grad_in_forward=True).to(device)
else:
lm_head_ce = TorchLMHeadCE(H=H, V=V, dtype=dtype).to(device)

Expand Down Expand Up @@ -163,7 +186,7 @@ def full():
"x_name": "BT",
"x_label": "B x T",
"x_values": [2**i for i in range(12, 16)],
"kernel_providers": ["liger", "liger-fp32-accum", "huggingface"],
"kernel_providers": ["liger", "liger-fp32-accum", "huggingface", "liger-helion", "liger-helion-grad-in-fwd"],
"extra_benchmark_configs": [{"H": 4096, "V": 128256, "mode": "forward", "dtype": torch.bfloat16}],
"overwrite": args.overwrite,
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
{
"block_sizes": [
64,
64,
256
],
"range_unroll_factors": [
0,
1,
1
],
"range_num_stages": [
0,
3,
4
],
"range_multi_buffers": [
null,
false,
null
],
"range_flattens": [
null,
true,
true
],
"load_eviction_policies": [
"last",
"last",
"",
""
],
"num_warps": 4,
"num_stages": 8,
"indexing": [
"tensor_descriptor",
"pointer",
"tensor_descriptor",
"tensor_descriptor",
"pointer",
"pointer"
],
"pid_type": "flat",
"range_warp_specializes": []
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
{
"block_sizes": [
64,
32,
256
],
"range_unroll_factors": [
0,
1,
1
],
"range_num_stages": [
0,
3,
4
],
"range_multi_buffers": [
null,
true,
null
],
"range_flattens": [
null,
null,
true
],
"load_eviction_policies": [
"last",
"last",
"",
""
],
"num_warps": 4,
"num_stages": 6,
"indexing": [
"tensor_descriptor",
"tensor_descriptor",
"pointer",
"tensor_descriptor",
"pointer",
"tensor_descriptor"
],
"pid_type": "flat",
"range_warp_specializes": []
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
{
"block_sizes": [
64,
32,
256
],
"loop_orders": [
[
0,
1
]
],
"l2_groupings": [
32
],
"range_unroll_factors": [
0,
1
],
"range_num_stages": [
4,
2
],
"range_multi_buffers": [
true,
null
],
"range_flattens": [
true,
true
],
"load_eviction_policies": [
"last",
"last",
"first",
"first"
],
"num_warps": 8,
"num_stages": 1,
"indexing": [
"tensor_descriptor",
"tensor_descriptor",
"tensor_descriptor",
"tensor_descriptor",
"pointer",
"tensor_descriptor"
],
"pid_type": "persistent_interleaved",
"range_warp_specializes": []
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
{
"block_sizes": [
64,
32,
512
],
"range_unroll_factors": [
3,
0,
0
],
"range_num_stages": [
4,
0,
3
],
"range_multi_buffers": [
true,
true,
false
],
"range_flattens": [
true,
true,
false
],
"load_eviction_policies": [
"last",
"first",
"",
"last",
"",
"first"
],
"num_warps": 8,
"num_stages": 7,
"indexing": [
"pointer",
"tensor_descriptor",
"pointer",
"pointer",
"pointer",
"pointer",
"pointer",
"pointer"
],
"pid_type": "persistent_blocked",
"range_warp_specializes": []
}
Loading
Loading