Skip to content

Conversation

@upskyy
Copy link
Contributor

@upskyy upskyy commented Nov 8, 2025

Summary

TiledMLP

TiledMLP is a technique that splits the computation of a Multi-Layer Perceptron (MLP) into multiple small chunks (tiles) and processes them sequentially. This approach:

  • Significantly reduces memory usage when handling very long sequences
  • Recomputes the forward pass during the backward pass, similar to activation checkpointing
  • Sacrifices some speed in exchange for improved memory efficiency

Supported Models

LigerTiledSwiGLUMLP (with SwiGLU activation):
• ✅ Llama (1, 2, 3, 3.1, 3.2, 3.3, 4)
• ✅ Mistral
• ✅ Mixtral (MoE)
• ✅ Qwen2, Qwen3
• ✅ Phi3
• ✅ OLMo2
• ✅ GLM4

LigerTiledGEGLUMLP (with GEGLU activation):
• ✅ Gemma (1, 2, 3)
• ✅ Other models using GELU activation

Usage Examples

from liger_kernel.transformers.tiled_mlp import LigerTiledSwiGLUMLP

config = LlamaConfig(hidden_size=2048, intermediate_size=4096, hidden_act='silu')
mlp = LigerTiledSwiGLUMLP(config=config, num_shards=4)
from liger_kernel.transformers.tiled_mlp import LigerTiledGEGLUMLP

config = GemmaConfig(hidden_size=2048, intermediate_size=4096, hidden_act='gelu_pytorch_tanh')
mlp = LigerTiledGEGLUMLP(config=config)  # Automatically calculates number of shards
from liger_kernel.transformers.tiled_mlp import LigerTiledSwiGLUMLP
from transformers.models.llama.configuration_llama import LlamaConfig
import torch

config = LlamaConfig(
    hidden_size=4096,  # Must match the last dimension of input
    intermediate_size=16384, 
    hidden_act="silu",
)

tiled_mlp = LigerTiledSwiGLUMLP(config=config).cuda()

x = torch.randn(1, 10000, 4096, device='cuda', requires_grad=True)

output = tiled_mlp(x)
loss = output.sum()
loss.backward()

How to Test

python -m pytest test/transformers/test_tiled_mlp.py -v
python -m pytest test/transformers/test_tiled_mlp.py::test_automatic_shard_calculation -v
python -m pytest test/transformers/test_tiled_mlp.py::test_tiled_mlp_with_2d_input -v
python -m pytest test/transformers/test_tiled_mlp.py::test_memory_efficiency -v

How to Run Benchmarks

python benchmark/scripts/benchmark_tiled_mlp.py

Results of the Benchmarks

Key Takeaways

  • Memory Savings Achieved
  • Up to 41% memory reduction at 16K sequence length
  • Greater memory efficiency as sequence length increases

Trade-off in Speed

  • Around 1.4–1.5× slower — expected due to forward recomputation

Detailed Analysis

  1. Full Pass (Forward + Backward)

GEGLU Results

Sequence Length Regular (ms) Tiled (ms) Speed ↑ Regular (MB) Tiled (MB) Memory ↓
1024 2.26 3.35 +48% 232.25 186.25 -20%
2048 4.54 6.03 +33% 336.25 244.25 -27%
4096 8.80 11.49 +31% 544.25 360.25 -34%
8192 17.03 23.84 +40% 960.25 592.25 -38%
16384 33.60 47.47 +41% 1792.25 1056.25 -41%

SwiGLU Results

Sequence Length Regular (ms) Tiled (ms) Speed ↑ Regular (MB) Tiled (MB) Memory ↓
1024 2.18 3.35 +54% 232.25 186.25 -20%
2048 4.42 6.05 +37% 336.25 244.25 -27%
4096 8.81 11.51 +31% 544.25 360.25 -34%
8192 17.01 23.95 +41% 960.25 592.25 -38%
16384 33.75 47.54 +41% 1792.25 1056.25 -41%
  1. Forward Pass Only

More dramatic memory savings in inference-only scenarios

Sequence Length Regular (MB) Tiled (MB) Memory ↓
1024 128.25 92.25 -28%
2048 192.25 120.25 -37%
4096 320.25 176.25 -45%
8192 576.25 288.25 -50%
16384 1088.25 512.25 -53%
  1. Backward Pass Only
Sequence Length Regular (ms) Tiled (ms) Increase
1024 1.50 2.61 +74%
2048 3.03 4.64 +53%
4096 6.05 8.68 +43%
8192 11.54 17.85 +55%
16384 22.72 35.47 +56%

Use Tiled MLP when:

Sequence length ≥ 8K

  • Memory savings: ≥38%
  • Speed penalty: ~40%

Sequence length ≥ 16K

  • Memory: 41% reduction
  • Speed: 1.4× slower (reasonable for memory gain)

Inference only (Forward Pass)

  • Memory savings: up to 53%
  • Speed overhead: just 10–12%

Related Issues

#893

Testing Done

  • Hardware Type:
  • run make test to ensure correctness
  • run make checkstyle to ensure code style
  • run make test-convergence to ensure convergence

@upskyy
Copy link
Contributor Author

upskyy commented Nov 10, 2025

@Tcc0403
Hello, I worked on the TiledMLP and opened a PR for it.
Would you be able to take a look at it whenever you get a chance?

@Tcc0403
Copy link
Collaborator

Tcc0403 commented Nov 10, 2025

Thank you for the contribution! I left some comments, mostly about testings. Feel free to ping me again once they are resolved!

@Tcc0403
Copy link
Collaborator

Tcc0403 commented Nov 10, 2025

Can you add a comparison of LigerMLP, LigerTiledMLP, normal MLP, deepspeed's TiledMLP? so we can have more baseline numbers as reference. I'll run run the benchmark on h100 as well.

@upskyy
Copy link
Contributor Author

upskyy commented Nov 11, 2025

@Tcc0403
Thanks for the feedback! I’ll make the changes and ping you once it’s done!

@upskyy
Copy link
Contributor Author

upskyy commented Nov 11, 2025

Test Environment

  • GPU: NVIDIA GeForce RTX 4090
  • liger_kernel==0.6.3

GEGLU - Full Pass (Forward + Backward)

Speed (ms)

Sequence Length Normal MLP LigerMLP LigerTiledMLP DeepSpeed TiledMLP
1024 2.34 2.17 3.35 3.42
2048 4.76 4.34 6.02 6.16
4096 9.42 8.65 11.50 11.93
8192 17.63 16.91 23.69 24.73
16384 35.07 33.63 47.48 49.46

Memory (MB)

Sequence Length Normal MLP LigerMLP LigerTiledMLP DeepSpeed TiledMLP
1024 264 232 186 190
2048 400 336 244 252
4096 688 544 360 376
8192 1264 960 592 640
16384 2416 1792 1056 1168

SwiGLU - Full Pass (Forward + Backward)

Speed (ms)

Sequence Length Normal MLP LigerMLP LigerTiledMLP DeepSpeed TiledMLP
1024 2.25 2.17 3.35 3.43
2048 4.59 4.37 6.02 6.16
4096 9.23 8.94 11.61 11.93
8192 17.87 17.08 23.86 24.82
16384 35.34 33.75 47.72 49.63

Memory (MB)

Sequence Length Normal MLP LigerMLP LigerTiledMLP DeepSpeed TiledMLP
1024 264 232 186 190
2048 400 336 244 252
4096 688 544 360 376
8192 1264 960 592 640
16384 2416 1792 1056 1168

@upskyy
Copy link
Contributor Author

upskyy commented Nov 11, 2025

@Tcc0403
I’ve updated all the comments, and also added benchmark code comparing LigerMLP, LigerTiledMLP, normal MLP and deepspeed's TiledMLP.

These are the results I got on an RTX 4090. When you get a chance, could you try running it on an H100 as well?
Thanks in advance!

@Tcc0403
Copy link
Collaborator

Tcc0403 commented Nov 11, 2025

H100 benchmark:

GeGLU

runtime

full image
forward image
backward image

memory

full image

SwiGLU

runtime

full image
forward image
backward image

memory

full image

@Tcc0403
Copy link
Collaborator

Tcc0403 commented Nov 11, 2025

I've added h100 benchmark data. Is bf16 still not able to pass the tests with (atol=1, rtol=1) on your 4090? If that's the case, I don't mind skipping bf16 as it is now.

Overall lgtm! cc @sfc-gh-sbekman in case I am missing something.

@upskyy
Copy link
Contributor Author

upskyy commented Nov 11, 2025

Yes, even with (atol=1, rtol=1) on a 4090, bf16 still can’t pass.

@stas00
Copy link

stas00 commented Nov 11, 2025

Very cool, thank you, @upskyy - no fused kernel, though?

Wrt benchmarks - really this feature is most useful at very large seqlen - like 100K+ where you can save memory on an order of magnitude ;) You can easily do a much longer seqlen using liger + deepspeed offload to cpu even on a 80GB card.

The current implementation in Deepspeed works with single gpu and Deepspeed multi-gpu. So additionally support for DDP and FSDP needs to be added for it to work for more use cases. @winglian @ Axolotl implemented a manual gradient accumulator https://github.com/axolotl-ai-cloud/axolotl/blob/dd78f2e0cc5cc6458daaad02cc29b649ff1046f5/src/axolotl/monkeypatch/tiled_mlp/base.py#L99 which supposedly works with DDP and FSDP - I just haven't tried it.

The problem is that as soon as you run backward on a shard of hidden_states DDP and FSDP may try to reduce it across ranks, and when the 2nd slice comes in, they will break not expecting the same param being reduced a 2nd time. So that's why there is a special instrumentation here:

https://github.com/linkedin/Liger-Kernel/pull/935/files#diff-9a701bafb3f4b37cf4796ed821a678d64026e497e4e0698f9277a05e17333972R86-R93

it tells deepspeed not reduce the param across ranks until it's ready.

@Tcc0403
Copy link
Collaborator

Tcc0403 commented Nov 11, 2025

Thanks for the feedback! For kernel fusion, I've opened #937 listing some potential approaches for MLP we can try and see how they perform in future PRs.

@stas00
Copy link

stas00 commented Nov 11, 2025

super! because otherwise it's mainly copying the original implementation in pure python with small improvements, but what makes me excited is Liger-Kernel writing an efficient Kernel for this feature, like it did for the many other features!

@upskyy
Copy link
Contributor Author

upskyy commented Nov 13, 2025

@stas00 @Tcc0403
I implemented it using PyTorch’s no_sync module instead of manually implementing a gradient accumulator like Axolotl.
I also created a corresponding pytest and successfully completed the tests.

@stas00
Copy link

stas00 commented Nov 14, 2025

That's super cool, @upskyy - I wasn't aware of no_sync when I was searching to how to overcome this problem with DDP - learned a new practical thing. I love it and hope you don't mind me copying your version back to the original once the dust settles here.

# Check if mlp_module actually has no_sync() method (it's a DDP/FSDP wrapper)
if hasattr(mlp_module, "no_sync"):
sync_context = mlp_module.no_sync()
# If no no_sync() method, we can't control gradient synchronization
Copy link

@stas00 stas00 Nov 14, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if no no_sync is not clear to me, could you please write it out in more details of what you suggest here is happening? and what does "we can't control" implies?

Should there be some exception following this comment then or will it fail in some weird way?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@stas00 I’m really glad to hear that incorporating it into the original code was helpful.

Both DistributedDataParallel and FullyShardedDataParallel provide the no_sync function, which suppresses gradient synchronization during intermediate backward passes and performs synchronization only on the final shard to improve efficiency.
Here is the code for the no_sync function.
Since modules other than DistributedDataParallel and FullyShardedDataParallel do not have the no_sync method, I added the comment "if no no_sync" to indicate this.

And I added the comment to indicate that, in rare edge cases where the module does not have the no_sync function, gradient synchronization will still occur — resulting in less efficient execution, but ultimately the same outcome.

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

your explanation in the comment is very clear - can we use that in the comment in the PR? or if you want I can make a suggestion.

So basically some modules might still fail if they can't support such grad sync, exactly like the original implementation will fail. So your comment just explains why it might happen. I was just thinking if we could catch that, but I can see that we have no idea how other modules are implemented and they might still work w/o no_sync.

I checked in Deepspeed no_sync is specifically erring out on purpose because it has a custom gradient sharding implementation. I was hoping to use no_sync for deepspeed as well, but it won't work so your code is perfect.

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Your update looks good. thank you.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I revised the comments and added explanations to the commit.
For reference, it seems that the no_sync function is also consistently used when training a model with DDP and gradient accumulation.

pytorch/pytorch#143721

Copy link
Contributor Author

@upskyy upskyy Nov 14, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@Tcc0403 @stas00
Please feel free to let me know if there’s anything else I need to take care of for this PR.
Thank you in advance!

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it seems that the no_sync function is also consistently used when training a model with DDP and gradient accumulation.

Aha, that's a good way of finding out how some custom module/framework deals with GAS!

Please feel free to let me know if there’s anything else I need to take care of for this PR.

I'm not part of this project, so @Tcc0403 will follow up if there are any other things to do I'm sure. As far as my opinion you did an amazing job, @upskyy!

In this edge case, gradient synchronization will occur on every shard (inefficient),
but the final result remains correct.
@Tcc0403
Copy link
Collaborator

Tcc0403 commented Nov 14, 2025

Thank you! The TiledMLP structure looks good to me. We should be able to add more kernel implementations easily in the future.

I think it can also be a great chance to discuss how we structure distributed/multi-gpu tests. I suggest adding a new directory test/distributed, and put fsdp/ddp/any parallelism related unit tests and convergence tests there, so we can issue multiple gpus in another modal ci request. cc @shimizust @vaibhavjindal

@upskyy
Copy link
Contributor Author

upskyy commented Nov 15, 2025

@Tcc0403
I’ve moved the distributed/multi-GPU tests into the test/distributed directory. Since the focus of this PR is the TiledMLP implementation, I think it would be reasonable to handle any additional work in a separate PR.

Let me know if there’s anything else I should take care of. Thanks in advance!

Copy link
Collaborator

@Tcc0403 Tcc0403 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The distributed test shows that tiled mlp impl for ddp and fsdp is buggy.

@Tcc0403
Copy link
Collaborator

Tcc0403 commented Nov 17, 2025

Some investigation of the current status of ddp and fsdp integeration.

The current implementations for ddp and fsdp with num_shards > 1 are both incorrect. However, if no_sync() can work as intended, it should be able to solve the issues.

I didn't test that making TiledMLP a submodule and apply ddp/fsdp on root module, such as DecoderBlock. Maybe it won't trigger grad sync or reshard when looping submodule fwd/bwd passes.

DDP

The gradient sync only performs once in the first backward pass. In the for loop over sequence shards, only the backward in the first iteration can trigger gradient synchronization.

num_shards=1

iter 0 starts
forward pass succeed
comm hook called (grad sync happened)
backward pass succeed
iter 0 ends

num_shards=4

iter 0 starts
forward pass succeed
comm hook called (grad sync happened)
backward pass succeed
iter 0 ends

# no grad sync after
iter 1 starts
forward pass succeed
backward pass succeed
iter 1 ends

iter 2 starts
forward pass succeed
backward pass succeed
iter 2 ends

iter 3 starts
forward pass succeed
backward pass succeed
iter 3 ends

FSDP

For FSDP, the current implementation for num_shards > 1 is incorrect. In the for loop over sequence shards, the backward pass in the first iteration will trigger reshrading parameters and fail to compute forward+backward pass after.

I tried different sharding strategies (FULL_SHARD and SHARD_GRAD_OP), both fail to compute forward pass in iteration >1.

(Setting sharding_strategy=NO_SHARD does work, so the problem really is resharding after the first backward pass)

Cannot perform gate_proj(x)
E       -- Process 0 terminated with the following error:
E       Traceback (most recent call last):
E         File "/Liger-Kernel/.venv/lib/python3.11/site-packages/torch/multiprocessing/spawn.py", line 95, in _wrap
E           fn(i, *args)
E         File "/Liger-Kernel/test/distributed/test_tiled_mlp_distributed.py", line 185, in run_fsdp_test
E           output.backward(grad_output)
E         File "/Liger-Kernel/.venv/lib/python3.11/site-packages/torch/_tensor.py", line 625, in backward
E           torch.autograd.backward(
E         File "/Liger-Kernel/.venv/lib/python3.11/site-packages/torch/autograd/__init__.py", line 354, in backward
E           _engine_run_backward(
E         File "/Liger-Kernel/.venv/lib/python3.11/site-packages/torch/autograd/graph.py", line 841, in _engine_run_backward
E           return Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
E                  ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
E         File "/Liger-Kernel/.venv/lib/python3.11/site-packages/torch/autograd/function.py", line 315, in apply
E           return user_fn(self, *args)
E                  ^^^^^^^^^^^^^^^^^^^^
E         File "/Liger-Kernel/src/liger_kernel/ops/utils.py", line 40, in wrapper
E           return fn(ctx, *args, **kwargs)
E                  ^^^^^^^^^^^^^^^^^^^^^^^^
E         File "/Liger-Kernel/src/liger_kernel/ops/tiled_mlp.py", line 161, in backward
E           output = fn(mlp_module, x_shard)
E                    ^^^^^^^^^^^^^^^^^^^^^^^
E         File "/Liger-Kernel/src/liger_kernel/transformers/tiled_mlp.py", line 110, in _mlp_forward
E           gate = module.gate_proj(x)
E                  ^^^^^^^^^^^^^^^^^^^
E         File "/Liger-Kernel/.venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1775, in _wrapped_call_impl
E           return self._call_impl(*args, **kwargs)
E                  ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
E         File "/Liger-Kernel/.venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1786, in _call_impl
E           return forward_call(*args, **kwargs)
E                  ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
E         File "/Liger-Kernel/.venv/lib/python3.11/site-packages/torch/nn/modules/linear.py", line 134, in forward
E           return F.linear(input, self.weight, self.bias)
E                  ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
E       RuntimeError: setStorage: sizes [64, 256], strides [1, 64], storage offset 0, and itemsize 4 requiring a storage size of 65536 are out of bounds for storage of size 0
start of iter 0
forward pass succeed
backward pass succeed
end of iter 0
start of iter 1

BTW, I found the doc says fsdp's no_sync should only be used on the root FSDP instance. Is there any side effects if using it in the child module as this PR does?

Within this context, gradients will be accumulated in module variables, which will later be synchronized in the first forward-backward pass after exiting the context. This should only be used on the root FSDP instance and will recursively apply to all children FSDP instances.

@Tcc0403
Copy link
Collaborator

Tcc0403 commented Nov 17, 2025

I didn't test that making TiledMLP a submodule and apply ddp/fsdp on root module, such as DecoderBlock. Maybe it won't trigger grad sync or reshard when looping submodule fwd/bwd passes.

update: it will raise RuntimeError: ... 2) Reused parameters in multiple reentrant backward passes.

Full error log

RuntimeError: Expected to mark a variable ready only once. This error is caused by one of the following reasons: 1) Use of a module parameter outside the forward function. Please make sure model parameters are not shared across multiple concurrent forward-backward passes. or try to use _set_static_graph() as a workaround if this module graph does not change during training loop.2) Reused parameters in multiple reentrant backward passes. For example, if you use multiple checkpoint functions to wrap the same part of your model, it would result in the same set of parameters been used by different reentrant backward passes multiple times, and hence marking a variable ready multiple times. DDP does not support such use cases in default. You can try to use _set_static_graph() as a workaround if your module graph does not change over iterations.
Parameter at index 6 with name mlp.mlp.down_proj.weight has been marked as ready twice. This means that multiple autograd engine hooks have fired for this particular parameter during this iteration.

Or you have to wrap the module with DDP(..., statc_graph=True). In this case, grad sync is only called once after whole ddp module backward computation. No need to worry about multiple grad syncs in TiledMLP module backward pass.

For FSDP, it will result in the same error as previous comment. Doc also states that

FSDP does not work with double backwards due to the way it registers backward hooks.

Besides axolotl's grad accum hook approach, fsdp2's explict control of resharding/unsharding submodules is also a potential solution.

@upskyy
Copy link
Contributor Author

upskyy commented Nov 18, 2025

You’re right. I guess using no_sync() with DDP should resolve the issue. But FSDP seems to have stricter conditions—especially because it’s hard to access the external wrapper from within the nn.Module, making it difficult to determine whether it’s DDP or FSDP. That’s probably why they recommend using no_sync only on the root FSDP instance.

@upskyy
Copy link
Contributor Author

upskyy commented Nov 18, 2025

Due to fundamental limitations of PyTorch FSDP, it is incompatible with LigerTiledMLP:

  1. When use_orig_params=True:
    FSDP does not explicitly support custom autograd functions.
RuntimeError: FSDP does not support custom autograd Functions in the forward pass in use_orig_params=True mode
  1. When use_orig_params=False:
    Parameters are managed as flattened tensors, which results in loss of grad_fn.
RuntimeError: element 0 of tensors does not require grad and does not have a grad_fn

@Tcc0403
Copy link
Collaborator

Tcc0403 commented Nov 18, 2025

I think ddp/fsdp support can be another PR along with tests, as this PR becomes a little bloated. Currently there are too many walkarounds, hacky checks just to make ddp/fsdp work. TiledMLP is quite simple and shouldn't be that complicated. Let's stash these changes and get TiledMLP merged first, and add those support in future PR step by step.

BTW, I tried axolotl's gradient accumulator, it works with ddp. For fsdp, it still enocunter param reshard/unshard issues after the fisrt iteration.

@upskyy
Copy link
Contributor Author

upskyy commented Nov 18, 2025

Thanks for sharing the information. I agree with you. You can go ahead and delete the test/distributed folder or make the changes yourself.
Alternatively, should I assume it’s running on a single GPU without considering DDP/FSDP, revise the code accordingly, and commit it again?

@Tcc0403
Copy link
Collaborator

Tcc0403 commented Nov 18, 2025

Yes, don't need to consider ddp/fsdp in this pr. I'll open an issue to track these support.

@upskyy
Copy link
Contributor Author

upskyy commented Nov 18, 2025

I’ve made the changes and completed the commit.

Copy link
Collaborator

@Tcc0403 Tcc0403 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you!

@Tcc0403 Tcc0403 mentioned this pull request Nov 18, 2025
3 tasks
@Tcc0403 Tcc0403 merged commit 4c32ab6 into linkedin:main Nov 18, 2025
3 of 7 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants