-
Notifications
You must be signed in to change notification settings - Fork 438
Add TiledMLP Implementation #935
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
|
@Tcc0403 |
|
Thank you for the contribution! I left some comments, mostly about testings. Feel free to ping me again once they are resolved! |
|
Can you add a comparison of LigerMLP, LigerTiledMLP, normal MLP, deepspeed's TiledMLP? so we can have more baseline numbers as reference. I'll run run the benchmark on h100 as well. |
|
@Tcc0403 |
Test Environment
GEGLU - Full Pass (Forward + Backward)Speed (ms)
Memory (MB)
SwiGLU - Full Pass (Forward + Backward)Speed (ms)
Memory (MB)
|
|
@Tcc0403 These are the results I got on an RTX 4090. When you get a chance, could you try running it on an H100 as well? |
|
I've added h100 benchmark data. Is bf16 still not able to pass the tests with (atol=1, rtol=1) on your 4090? If that's the case, I don't mind skipping bf16 as it is now. Overall lgtm! cc @sfc-gh-sbekman in case I am missing something. |
|
Yes, even with (atol=1, rtol=1) on a 4090, bf16 still can’t pass. |
|
Very cool, thank you, @upskyy - no fused kernel, though? Wrt benchmarks - really this feature is most useful at very large seqlen - like 100K+ where you can save memory on an order of magnitude ;) You can easily do a much longer seqlen using liger + deepspeed offload to cpu even on a 80GB card. The current implementation in Deepspeed works with single gpu and Deepspeed multi-gpu. So additionally support for DDP and FSDP needs to be added for it to work for more use cases. @winglian @ Axolotl implemented a manual gradient accumulator https://github.com/axolotl-ai-cloud/axolotl/blob/dd78f2e0cc5cc6458daaad02cc29b649ff1046f5/src/axolotl/monkeypatch/tiled_mlp/base.py#L99 which supposedly works with DDP and FSDP - I just haven't tried it. The problem is that as soon as you run it tells deepspeed not reduce the param across ranks until it's ready. |
|
Thanks for the feedback! For kernel fusion, I've opened #937 listing some potential approaches for MLP we can try and see how they perform in future PRs. |
|
super! because otherwise it's mainly copying the original implementation in pure python with small improvements, but what makes me excited is Liger-Kernel writing an efficient Kernel for this feature, like it did for the many other features! |
|
That's super cool, @upskyy - I wasn't aware of |
src/liger_kernel/ops/tiled_mlp.py
Outdated
| # Check if mlp_module actually has no_sync() method (it's a DDP/FSDP wrapper) | ||
| if hasattr(mlp_module, "no_sync"): | ||
| sync_context = mlp_module.no_sync() | ||
| # If no no_sync() method, we can't control gradient synchronization |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
if no no_sync is not clear to me, could you please write it out in more details of what you suggest here is happening? and what does "we can't control" implies?
Should there be some exception following this comment then or will it fail in some weird way?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@stas00 I’m really glad to hear that incorporating it into the original code was helpful.
Both DistributedDataParallel and FullyShardedDataParallel provide the no_sync function, which suppresses gradient synchronization during intermediate backward passes and performs synchronization only on the final shard to improve efficiency.
Here is the code for the no_sync function.
Since modules other than DistributedDataParallel and FullyShardedDataParallel do not have the no_sync method, I added the comment "if no no_sync" to indicate this.
And I added the comment to indicate that, in rare edge cases where the module does not have the no_sync function, gradient synchronization will still occur — resulting in less efficient execution, but ultimately the same outcome.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
your explanation in the comment is very clear - can we use that in the comment in the PR? or if you want I can make a suggestion.
So basically some modules might still fail if they can't support such grad sync, exactly like the original implementation will fail. So your comment just explains why it might happen. I was just thinking if we could catch that, but I can see that we have no idea how other modules are implemented and they might still work w/o no_sync.
I checked in Deepspeed no_sync is specifically erring out on purpose because it has a custom gradient sharding implementation. I was hoping to use no_sync for deepspeed as well, but it won't work so your code is perfect.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Your update looks good. thank you.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I revised the comments and added explanations to the commit.
For reference, it seems that the no_sync function is also consistently used when training a model with DDP and gradient accumulation.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
it seems that the no_sync function is also consistently used when training a model with DDP and gradient accumulation.
Aha, that's a good way of finding out how some custom module/framework deals with GAS!
Please feel free to let me know if there’s anything else I need to take care of for this PR.
I'm not part of this project, so @Tcc0403 will follow up if there are any other things to do I'm sure. As far as my opinion you did an amazing job, @upskyy!
In this edge case, gradient synchronization will occur on every shard (inefficient), but the final result remains correct.
|
Thank you! The TiledMLP structure looks good to me. We should be able to add more kernel implementations easily in the future. I think it can also be a great chance to discuss how we structure distributed/multi-gpu tests. I suggest adding a new directory |
|
@Tcc0403 Let me know if there’s anything else I should take care of. Thanks in advance! |
Tcc0403
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The distributed test shows that tiled mlp impl for ddp and fsdp is buggy.
|
Some investigation of the current status of ddp and fsdp integeration. The current implementations for ddp and fsdp with I didn't test that making TiledMLP a submodule and apply ddp/fsdp on root module, such as DecoderBlock. Maybe it won't trigger grad sync or reshard when looping submodule fwd/bwd passes. DDPThe gradient sync only performs once in the first backward pass. In the for loop over sequence shards, only the backward in the first iteration can trigger gradient synchronization. num_shards=1 num_shards=4 FSDPFor FSDP, the current implementation for I tried different sharding strategies ( (Setting Cannot perform gate_proj(x)BTW, I found the doc says fsdp's no_sync should only be used on the root FSDP instance. Is there any side effects if using it in the child module as this PR does?
|
update: it will raise Full error logRuntimeError: Expected to mark a variable ready only once. This error is caused by one of the following reasons: 1) Use of a module parameter outside the Or you have to wrap the module with For FSDP, it will result in the same error as previous comment. Doc also states that
Besides axolotl's grad accum hook approach, fsdp2's explict control of resharding/unsharding submodules is also a potential solution. |
|
You’re right. I guess using |
|
Due to fundamental limitations of PyTorch FSDP, it is incompatible with
RuntimeError: FSDP does not support custom autograd Functions in the forward pass in use_orig_params=True mode
RuntimeError: element 0 of tensors does not require grad and does not have a grad_fn |
|
I think ddp/fsdp support can be another PR along with tests, as this PR becomes a little bloated. Currently there are too many walkarounds, hacky checks just to make ddp/fsdp work. TiledMLP is quite simple and shouldn't be that complicated. Let's stash these changes and get BTW, I tried axolotl's gradient accumulator, it works with ddp. For fsdp, it still enocunter param reshard/unshard issues after the fisrt iteration. |
|
Thanks for sharing the information. I agree with you. You can go ahead and delete the test/distributed folder or make the changes yourself. |
|
Yes, don't need to consider ddp/fsdp in this pr. I'll open an issue to track these support. |
|
I’ve made the changes and completed the commit. |
Tcc0403
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you!








Summary
TiledMLP
TiledMLP is a technique that splits the computation of a Multi-Layer Perceptron (MLP) into multiple small chunks (tiles) and processes them sequentially. This approach:
Supported Models
LigerTiledSwiGLUMLP(with SwiGLU activation):• ✅ Llama (1, 2, 3, 3.1, 3.2, 3.3, 4)
• ✅ Mistral
• ✅ Mixtral (MoE)
• ✅ Qwen2, Qwen3
• ✅ Phi3
• ✅ OLMo2
• ✅ GLM4
LigerTiledGEGLUMLP(with GEGLU activation):• ✅ Gemma (1, 2, 3)
• ✅ Other models using GELU activation
Usage Examples
How to Test
How to Run Benchmarks
Results of the Benchmarks
Key Takeaways
Trade-off in Speed
Detailed Analysis
GEGLU Results
SwiGLU Results
More dramatic memory savings in inference-only scenarios
Use Tiled MLP when:
Sequence length ≥ 8K
Sequence length ≥ 16K
Inference only (Forward Pass)
Related Issues
#893
Testing Done
make testto ensure correctnessmake checkstyleto ensure code stylemake test-convergenceto ensure convergence