Skip to content

Conversation

@Tcc0403
Copy link
Collaborator

@Tcc0403 Tcc0403 commented Nov 3, 2025

Summary

TODO (might be follow-up PRs):

  1. unit test
  2. autotune
  3. benchmark

Testing Done

  • Hardware Type:
  • run make test to ensure correctness
  • run make checkstyle to ensure code style
  • run make test-convergence to ensure convergence

@Tcc0403
Copy link
Collaborator Author

Tcc0403 commented Nov 4, 2025

The new test file avoid importing from test.utils because torchvision is not compatible with pytorch>=2.9.0 and triton>=3.5.0 (required by helion), which leads to ImportError.

Some test cases failed due to

  1. nan in gradients of lm_head.weight. (weird shape -> padding issue? maybe zero out oob grad_logit_tile)
  2. numerical error with dtype=torch.float -> need higher tolerance
python3 -m pytest test/transformers/helion/test_fused_linear_cross_entropy.py --log-cli-level="WARNING"
FAILED test/transformers/helion/test_fused_linear_cross_entropy.py::test_fused_linear_cross_entropy_correctness[dtype0-0.01-0.01-sum-2-1024-4096-32000] - AssertionError: Tensor-likes are not close!
FAILED test/transformers/helion/test_fused_linear_cross_entropy.py::test_fused_linear_cross_entropy_correctness[dtype0-0.01-0.01-sum-3-423-1000-10000] - AssertionError: lm_head.weight of liger contains nan
FAILED test/transformers/helion/test_fused_linear_cross_entropy.py::test_fused_linear_cross_entropy_correctness[dtype0-0.01-0.01-mean-3-423-1000-10000] - AssertionError: lm_head.weight of liger contains nan
FAILED test/transformers/helion/test_fused_linear_cross_entropy.py::test_fused_linear_cross_entropy_correctness[dtype1-0.001-0.01-sum-2-1024-4096-32000] - AssertionError: Tensor-likes are not close!
FAILED test/transformers/helion/test_fused_linear_cross_entropy.py::test_fused_linear_cross_entropy_correctness[dtype1-0.001-0.01-sum-3-423-1000-10000] - AssertionError: lm_head.weight of liger contains nan
FAILED test/transformers/helion/test_fused_linear_cross_entropy.py::test_fused_linear_cross_entropy_correctness[dtype1-0.001-0.01-mean-3-423-1000-10000] - AssertionError: lm_head.weight of liger contains nan
============================================================ 6 failed, 6 passed in 30.64s =============================================================

@Tcc0403
Copy link
Collaborator Author

Tcc0403 commented Nov 4, 2025

Autotuned with the following shapes on H100 SXM5:

batch_size = 2
seq_len = 2048
hidden_size = 4096
vocab_size = 32000
dtype = torch.float32
reduction = "mean"

Here's the result:

[8430s] Generation 20 complete: error=23 timeout=2 ok=96 min=56.2426 mid=76.8076 max=184.9636 best=Config(block_sizes=[32, 32, 256], indexing=['tensor_descriptor', 'tensor_descriptor', 'pointer', 'tensor_descriptor', 'pointer', 'pointer', 'tensor_descriptor', 'pointer', 'tensor_descriptor'], load_eviction_policies=['first', '', 'first', 'last', '', '', 'last', 'first'], num_stages=5, num_warps=4, pid_type='flat', range_flattens=[None, True, False], range_multi_buffers=[None, True, False], range_num_stages=[0, 0, 0], range_unroll_factors=[0, 1, 1], range_warp_specializes=[])
[8430s] Autotuning complete in 8430.1s after searching 3711 configs.
One can hardcode the best config and skip autotuning with:
    @helion.kernel(config=helion.Config(block_sizes=[32, 32, 256], indexing=['tensor_descriptor', 'tensor_descriptor', 'pointer', 'tensor_descriptor', 'pointer', 'pointer', 'tensor_descriptor', 'pointer', 'tensor_descriptor'], load_eviction_policies=['first', '', 'first', 'last', '', '', 'last', 'first'], num_stages=5, num_warps=4, pid_type='flat', range_flattens=[None, True, False], range_multi_buffers=[None, True, False], range_num_stages=[0, 0, 0], range_unroll_factors=[0, 1, 1], range_warp_specializes=[]), static_shapes=True)

It took more than 2 hours to do a full autotune. Current implementation is extremely slow.

@Tcc0403
Copy link
Collaborator Author

Tcc0403 commented Nov 4, 2025

Benchmarked with 8x4096 input length, hidden_size=4096, vocab_size=32000

=================================================================
Benchmark Results
=================================================================
Implementation       Time (ms)    Speedup        
-----------------------------------------------------------------
helion_flce_fwd      493.3454     0.06x          
torch_fwd            27.3347      1.00x (ref)    
=================================================================

=================================================================
Benchmark Results
=================================================================
Implementation       Time (ms)    Speedup        
-----------------------------------------------------------------
helion_flce_fwd_bwd  506.0976     0.18x          
torch_fwd_bwd        92.5972      1.00x (ref)    
=================================================================

There are 2 hl.atomic_add() in the most inner loop for 2 backprop matmul. Replacing them with a lock version add should improve the performance a lot. Not really, atomic add is not the main bottleneck.

@Tcc0403
Copy link
Collaborator Author

Tcc0403 commented Nov 6, 2025

Llama3(H=4096, V=32000)

BT=4096

Details
=================================================================
Benchmark Results
=================================================================
Implementation       Time (ms)    Speedup        
-----------------------------------------------------------------
helion_flce_fwd      6.3785       0.54x          
torch_fwd            3.4434       1.00x (ref)    
cce_fwd              19.3814      0.18x          
triton_flce_fwd      17.1954      0.20x          
=================================================================
=================================================================
Benchmark Results
=================================================================
Implementation       Time (ms)    Speedup        
-----------------------------------------------------------------
helion_flce_fwd_bwd  50.0497      0.22x          
torch_fwd_bwd        11.1041      1.00x (ref)    
cce_fwd_bwd          61.2685      0.18x          
triton_flce_fwd_bwd  17.8189      0.62x          
=================================================================

BT=32768

Details
=================================================================
Benchmark Results
=================================================================
Implementation       Time (ms)    Speedup        
-----------------------------------------------------------------
helion_flce_fwd      29.0221      1.02x          
torch_fwd            29.7404      1.00x (ref)    
cce_fwd              153.8144     0.19x          
triton_flce_fwd      80.8064      0.37x          
=================================================================
=================================================================
Benchmark Results
=================================================================
Implementation       Time (ms)    Speedup        
-----------------------------------------------------------------
helion_flce_fwd_bwd  374.1972     0.22x          
torch_fwd_bwd        93.4857      0.88x          
cce_fwd_bwd          477.5481     0.17x          
triton_flce_fwd_bwd  82.4108      1.00x (ref)    
=================================================================

Qwen3(H=4096, V=151936)

BT=32768

Details
=================================================================
Benchmark Results
=================================================================
Implementation       Time (ms)    Speedup        
-----------------------------------------------------------------
helion_flce_fwd      71.3643      1.02x          
torch_fwd            73.1353      1.00x (ref)    
cce_fwd              413.2089     0.18x          
triton_flce_fwd      335.1601     0.22x          
=================================================================
=================================================================
Benchmark Results
=================================================================
Implementation       Time (ms)    Speedup        
-----------------------------------------------------------------
helion_flce_fwd_bwd  882.6362     0.27x          
torch_fwd_bwd        234.5237     1.00x (ref)    
cce_fwd_bwd          775.0858     0.30x          
triton_flce_fwd_bwd  338.7038     0.69x          
=================================================================

Gemma3(H=2304, V=262208)

BT=8192

Details
=================================================================
Benchmark Results
=================================================================
Implementation       Time (ms)    Speedup        
-----------------------------------------------------------------
helion_flce_fwd      34.3520      1.10x          
torch_fwd            37.8505      1.00x (ref)    
cce_fwd              348.1931     0.11x          
triton_flce_fwd      338.0345     0.11x          
=================================================================
=================================================================
Benchmark Results
=================================================================
Implementation       Time (ms)    Speedup        
-----------------------------------------------------------------
helion_flce_fwd_bwd  454.7143     0.27x          
torch_fwd_bwd        120.9745     1.00x (ref)    
cce_fwd_bwd          510.3440     0.24x          
triton_flce_fwd_bwd  340.8792     0.35x          
=================================================================

@Tcc0403
Copy link
Collaborator Author

Tcc0403 commented Nov 6, 2025

This helion implementation never materializes any logits on device memroy. The forward pass works well, but backward pass suffers from dw and dx matmuls. Both matmuls are quite inefficient due to the inner loops are not the reduction dimension, which means we have to perform atomic_add or lock_add each iterations. I'll write a version with partially materialized logits (similar to current liger impl) which can perform more efficient matmuls for dw and dx.

Signed-off-by: Tcc0403 <[email protected]>
Signed-off-by: Tcc0403 <[email protected]>
Signed-off-by: Tcc0403 <[email protected]>
@Tcc0403
Copy link
Collaborator Author

Tcc0403 commented Nov 8, 2025

Llama(H=4096, V=32000, dtype=fp32)

BT=32768

chunk backward (recompute chunk logits then accumulate dx and dw) works fine.

Details
=================================================================
Benchmark Results
=================================================================
Implementation       Time (ms)    Speedup        
-----------------------------------------------------------------
helion_fwd           52.8671      0.89x          
torch_fwd            47.1587      1.00x (ref)    
cce_fwd              191.8264     0.25x          
triton_flce_fwd      131.3960     0.36x          
=================================================================
=================================================================
Benchmark Results
=================================================================
Implementation       Time (ms)    Speedup        
-----------------------------------------------------------------
helion_fwd_bwd_chunk 221.1143     0.64x          
torch_fwd_bwd        147.3726     0.96x          
cce_fwd_bwd          648.3762     0.22x          
triton_flce_fwd_bwd  141.1500     1.00x (ref)    
=================================================================

@Tcc0403
Copy link
Collaborator Author

Tcc0403 commented Nov 11, 2025

Some benchmark results

There seems to be a constant overhead when running
LigerLMHeadCEHelion(H=H, V=V, dtype=dtype, grad_in_forward=True). Need to fix it.

speed

full image
forward image
backward image

memory

full image

@Tcc0403
Copy link
Collaborator Author

Tcc0403 commented Nov 11, 2025

Fusing logits computation and softmax doesn't seem like a good idea, less parallelism in small shapes. I will make another version closer to current liger impl.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants