-
Notifications
You must be signed in to change notification settings - Fork 438
feat(FLCE): add helion version of fused linear cross entropy #928
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
|
The new test file avoid importing from test.utils because torchvision is not compatible with pytorch>=2.9.0 and triton>=3.5.0 (required by helion), which leads to ImportError. Some test cases failed due to
|
|
Autotuned with the following shapes on H100 SXM5: batch_size = 2
seq_len = 2048
hidden_size = 4096
vocab_size = 32000
dtype = torch.float32
reduction = "mean"Here's the result: It took more than 2 hours to do a full autotune. Current implementation is extremely slow. |
|
Benchmarked with 8x4096 input length, hidden_size=4096, vocab_size=32000
|
Signed-off-by: Tcc0403 <[email protected]>
Signed-off-by: Tcc0403 <[email protected]>
Signed-off-by: Tcc0403 <[email protected]>
Signed-off-by: Tcc0403 <[email protected]>
Signed-off-by: Tcc0403 <[email protected]>
Signed-off-by: Tcc0403 <[email protected]>
Signed-off-by: Tcc0403 <[email protected]>
Signed-off-by: Tcc0403 <[email protected]>
Signed-off-by: Tcc0403 <[email protected]>
Signed-off-by: Tcc0403 <[email protected]>
Signed-off-by: Tcc0403 <[email protected]>
1f25784 to
30b12eb
Compare
Signed-off-by: Tcc0403 <[email protected]>
Signed-off-by: Tcc0403 <[email protected]>
Signed-off-by: Tcc0403 <[email protected]>
Signed-off-by: Tcc0403 <[email protected]>
Signed-off-by: Tcc0403 <[email protected]>
Llama3(H=4096, V=32000)BT=4096 DetailsBT=32768 DetailsQwen3(H=4096, V=151936)BT=32768 DetailsGemma3(H=2304, V=262208)BT=8192 Details |
Signed-off-by: Tcc0403 <[email protected]>
|
This helion implementation never materializes any logits on device memroy. The forward pass works well, but backward pass suffers from dw and dx matmuls. Both matmuls are quite inefficient due to the inner loops are not the reduction dimension, which means we have to perform atomic_add or lock_add each iterations. I'll write a version with partially materialized logits (similar to current liger impl) which can perform more efficient matmuls for dw and dx. |
Signed-off-by: Tcc0403 <[email protected]>
Signed-off-by: Tcc0403 <[email protected]>
Signed-off-by: Tcc0403 <[email protected]>
Signed-off-by: Tcc0403 <[email protected]>
Signed-off-by: Tcc0403 <[email protected]>
Signed-off-by: Tcc0403 <[email protected]>
Signed-off-by: Tcc0403 <[email protected]>
Signed-off-by: Tcc0403 <[email protected]>
Llama(H=4096, V=32000, dtype=fp32)BT=32768 chunk backward (recompute chunk logits then accumulate dx and dw) works fine. Details |
|
Fusing logits computation and softmax doesn't seem like a good idea, less parallelism in small shapes. I will make another version closer to current liger impl. |




Summary
TODO (might be follow-up PRs):
Testing Done
make testto ensure correctnessmake checkstyleto ensure code stylemake test-convergenceto ensure convergence