-
Notifications
You must be signed in to change notification settings - Fork 438
Add TiledMLP Implementation #935
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
+1,126
−0
Merged
Changes from 1 commit
Commits
Show all changes
11 commits
Select commit
Hold shift + click to select a range
7b604a7
Add LigerTiledGEGLUMLP, LigerTiledSwiGLUMLP, test, benchmark codes
upskyy ad4edcc
Apply make checkstyle
upskyy caa72f7
Fix pytest about TiledMLP
upskyy aa6ffad
Add a comparison of LigerMLP, LigerTiledMLP, normal MLP, deepspeed's …
upskyy dc53e90
Add support for DDP and FSDP
upskyy b7fb636
Add DDP and FSDP test codes
upskyy 88895f6
Update comments, Case: module has no no_sync() method
upskyy a273923
Add test/distributed directory
upskyy 9ea1b1d
Fix test case
upskyy eb92706
Update DDP/FSDP wrapper
upskyy 82b9bf7
Delete DDP/FSDP module in TiledMLP
upskyy File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
Large diffs are not rendered by default.
Oops, something went wrong.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,244 @@ | ||
| import torch | ||
| import triton | ||
|
|
||
| from transformers.models.llama.configuration_llama import LlamaConfig | ||
| from utils import QUANTILES | ||
| from utils import SingleBenchmarkRunInput | ||
| from utils import SingleBenchmarkRunOutput | ||
| from utils import _test_memory | ||
| from utils import parse_benchmark_script_args | ||
| from utils import run_benchmarks | ||
|
|
||
| from liger_kernel.transformers.geglu import LigerGEGLUMLP | ||
| from liger_kernel.transformers.swiglu import LigerSwiGLUMLP | ||
| from liger_kernel.transformers.tiled_mlp import LigerTiledGEGLUMLP | ||
| from liger_kernel.transformers.tiled_mlp import LigerTiledSwiGLUMLP | ||
| from liger_kernel.utils import infer_device | ||
|
|
||
| device = infer_device() | ||
|
|
||
|
|
||
| def bench_speed_tiled_mlp(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOutput: | ||
| seq_len = input.x | ||
| bsz = input.extra_benchmark_config["bsz"] | ||
| hidden_size = input.extra_benchmark_config["hidden_size"] | ||
| intermediate_size = input.extra_benchmark_config["intermediate_size"] | ||
| hidden_act = input.extra_benchmark_config["hidden_act"] | ||
| dtype = input.extra_benchmark_config["dtype"] | ||
| num_shards = input.extra_benchmark_config.get("num_shards", None) | ||
| activation_type = input.extra_benchmark_config["activation_type"] | ||
| provider = input.kernel_provider | ||
| mode = input.kernel_operation_mode | ||
|
|
||
| llama_config = LlamaConfig( | ||
| hidden_size=hidden_size, | ||
| intermediate_size=intermediate_size, | ||
| hidden_act=hidden_act, | ||
| ) | ||
|
|
||
| x_shape = (bsz, seq_len, hidden_size) | ||
|
|
||
| # initialize input | ||
| x = torch.randn(*x_shape, device=device, dtype=dtype, requires_grad=True) | ||
|
|
||
| if activation_type == "geglu": | ||
| if provider == "liger": | ||
| layer = LigerGEGLUMLP(config=llama_config).to(device).to(dtype) | ||
| elif provider == "liger_tiled": | ||
| layer = LigerTiledGEGLUMLP(config=llama_config, num_shards=num_shards).to(device).to(dtype) | ||
| else: | ||
| raise ValueError(f"Invalid provider: {provider} for GEGLU") | ||
| elif activation_type == "swiglu": | ||
| if provider == "liger": | ||
| layer = LigerSwiGLUMLP(config=llama_config).to(device).to(dtype) | ||
| elif provider == "liger_tiled": | ||
| layer = LigerTiledSwiGLUMLP(config=llama_config, num_shards=num_shards).to(device).to(dtype) | ||
| else: | ||
| raise ValueError(f"Invalid provider: {provider} for SwiGLU") | ||
| else: | ||
| raise ValueError(f"Invalid activation_type: {activation_type}") | ||
|
|
||
| def fwd(): | ||
| return layer(x) | ||
|
|
||
| if mode == "forward": | ||
| ms_50, ms_20, ms_80 = triton.testing.do_bench( | ||
| fwd, | ||
| grad_to_none=[x], | ||
| rep=10, | ||
| quantiles=QUANTILES, | ||
| ) | ||
| elif mode == "backward": | ||
| do = torch.randn_like(x) | ||
| y = fwd() | ||
| ms_50, ms_20, ms_80 = triton.testing.do_bench( | ||
| lambda: y.backward(do, retain_graph=True), | ||
| grad_to_none=[x], | ||
| rep=10, | ||
| quantiles=QUANTILES, | ||
| ) | ||
| else: | ||
|
|
||
| def full(): | ||
| y = fwd() | ||
| y.backward(torch.randn_like(y), retain_graph=True) | ||
|
|
||
| ms_50, ms_20, ms_80 = triton.testing.do_bench( | ||
| full, | ||
| grad_to_none=[x], | ||
| rep=10, | ||
| quantiles=QUANTILES, | ||
| ) | ||
|
|
||
| return SingleBenchmarkRunOutput( | ||
| y_20=ms_20, | ||
| y_50=ms_50, | ||
| y_80=ms_80, | ||
| ) | ||
|
|
||
|
|
||
| def bench_memory_tiled_mlp(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOutput: | ||
| seq_len = input.x | ||
| bsz = input.extra_benchmark_config["bsz"] | ||
| hidden_size = input.extra_benchmark_config["hidden_size"] | ||
| intermediate_size = input.extra_benchmark_config["intermediate_size"] | ||
| hidden_act = input.extra_benchmark_config["hidden_act"] | ||
| dtype = input.extra_benchmark_config["dtype"] | ||
| num_shards = input.extra_benchmark_config.get("num_shards", None) | ||
| activation_type = input.extra_benchmark_config["activation_type"] | ||
| provider = input.kernel_provider | ||
| mode = input.kernel_operation_mode | ||
|
|
||
| llama_config = LlamaConfig( | ||
| hidden_size=hidden_size, | ||
| intermediate_size=intermediate_size, | ||
| hidden_act=hidden_act, | ||
| ) | ||
|
|
||
| x_shape = (bsz, seq_len, hidden_size) | ||
| # initialize input | ||
| x = torch.randn(*x_shape, device=device, dtype=dtype, requires_grad=True) | ||
|
|
||
| if activation_type == "geglu": | ||
| if provider == "liger": | ||
| layer = LigerGEGLUMLP(config=llama_config).to(device).to(dtype) | ||
| elif provider == "liger_tiled": | ||
| layer = LigerTiledGEGLUMLP(config=llama_config, num_shards=num_shards).to(device).to(dtype) | ||
| else: | ||
| raise ValueError(f"Invalid provider: {provider} for GEGLU") | ||
| elif activation_type == "swiglu": | ||
| if provider == "liger": | ||
| layer = LigerSwiGLUMLP(config=llama_config).to(device).to(dtype) | ||
| elif provider == "liger_tiled": | ||
| layer = LigerTiledSwiGLUMLP(config=llama_config, num_shards=num_shards).to(device).to(dtype) | ||
| else: | ||
| raise ValueError(f"Invalid provider: {provider} for SwiGLU") | ||
| else: | ||
| raise ValueError(f"Invalid activation_type: {activation_type}") | ||
|
|
||
| def fwd(): | ||
| return layer(x) | ||
|
|
||
| def full(): | ||
| y = fwd() | ||
| y.backward(torch.randn_like(y), retain_graph=True) | ||
|
|
||
| if mode == "forward": | ||
| mem_50, mem_20, mem_80 = _test_memory( | ||
| fwd, | ||
| quantiles=QUANTILES, | ||
| ) | ||
| elif mode == "backward": | ||
| do = torch.randn_like(x) | ||
| y = fwd() | ||
| mem_50, mem_20, mem_80 = _test_memory( | ||
| lambda: y.backward(do, retain_graph=True), | ||
| quantiles=QUANTILES, | ||
| ) | ||
| else: | ||
| mem_50, mem_20, mem_80 = _test_memory( | ||
| full, | ||
| quantiles=QUANTILES, | ||
| ) | ||
|
|
||
| return SingleBenchmarkRunOutput( | ||
| y_20=mem_20, | ||
| y_50=mem_50, | ||
| y_80=mem_80, | ||
| ) | ||
|
|
||
|
|
||
| if __name__ == "__main__": | ||
| args = parse_benchmark_script_args() | ||
|
|
||
| # Benchmark GEGLU variants | ||
| common_configs_geglu = { | ||
| "kernel_name": "tiled_geglu", | ||
| "x_name": "T", | ||
| "x_label": "sequence length", | ||
| "x_values": [2**i for i in range(10, 15)], # 1024 to 16384 | ||
| "kernel_providers": ["liger", "liger_tiled"], | ||
| "extra_benchmark_configs": [ | ||
| { | ||
| "bsz": 2, | ||
| "hidden_size": 2048, | ||
| "intermediate_size": 4096, | ||
| "hidden_act": "gelu_pytorch_tanh", | ||
| "activation_type": "geglu", | ||
| "num_shards": 4, | ||
| "dtype": torch.bfloat16, | ||
| } | ||
| ], | ||
| "overwrite": args.overwrite, | ||
| } | ||
|
|
||
| run_benchmarks( | ||
| bench_test_fn=bench_speed_tiled_mlp, | ||
| kernel_operation_modes=["full", "forward", "backward"], | ||
| metric_name="speed", | ||
| metric_unit="ms", | ||
| **common_configs_geglu, | ||
| ) | ||
| run_benchmarks( | ||
| bench_test_fn=bench_memory_tiled_mlp, | ||
| kernel_operation_modes=["full", "forward", "backward"], | ||
| metric_name="memory", | ||
| metric_unit="MB", | ||
| **common_configs_geglu, | ||
| ) | ||
|
|
||
| # Benchmark SwiGLU variants | ||
| common_configs_swiglu = { | ||
| "kernel_name": "tiled_swiglu", | ||
| "x_name": "T", | ||
| "x_label": "sequence length", | ||
| "x_values": [2**i for i in range(10, 15)], # 1024 to 16384 | ||
| "kernel_providers": ["liger", "liger_tiled"], | ||
| "extra_benchmark_configs": [ | ||
| { | ||
| "bsz": 2, | ||
| "hidden_size": 2048, | ||
| "intermediate_size": 4096, | ||
| "hidden_act": "silu", | ||
| "activation_type": "swiglu", | ||
| "num_shards": 4, | ||
| "dtype": torch.bfloat16, | ||
| } | ||
| ], | ||
| "overwrite": args.overwrite, | ||
| } | ||
|
|
||
| run_benchmarks( | ||
| bench_test_fn=bench_speed_tiled_mlp, | ||
| kernel_operation_modes=["full", "forward", "backward"], | ||
| metric_name="speed", | ||
| metric_unit="ms", | ||
| **common_configs_swiglu, | ||
| ) | ||
| run_benchmarks( | ||
| bench_test_fn=bench_memory_tiled_mlp, | ||
| kernel_operation_modes=["full", "forward", "backward"], | ||
| metric_name="memory", | ||
| metric_unit="MB", | ||
| **common_configs_swiglu, | ||
| ) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,147 @@ | ||
| """ | ||
| Based on DeepSpeed's TiledMLP: | ||
| https://github.com/microsoft/DeepSpeed/blob/master/deepspeed/runtime/sequence_parallel/ulysses_sp.py | ||
| """ | ||
|
|
||
| import math | ||
| from typing import Callable, List, Optional | ||
|
|
||
| import torch | ||
|
|
||
| from liger_kernel.ops.utils import ensure_contiguous | ||
|
|
||
|
|
||
| class LigerTiledMLPFunction(torch.autograd.Function): | ||
| """ | ||
| Perform a tiled MLP computation to massively reduce memory usage needed to compute MLP | ||
| when using very long sequence lengths. | ||
|
|
||
| This module re-computes `forward` in the `backward`. So the `forward` occurs twice each iteration. | ||
| And if you're using activation checkpointing it then occurs thrice. | ||
|
|
||
| Args: | ||
| fn: the function to call on sharded inputs (e.g., mlp.forward) | ||
| mlp_module: the MLP nn.Module object | ||
| x: the input to MLP.forward (hidden_states) | ||
| shards: how many shards to use | ||
| compute_params: a list of weights engaged in the compute (only needed when using DeepSpeed ZeRO) | ||
|
|
||
| Returns: | ||
| the computed hidden_states | ||
| """ | ||
|
|
||
| @staticmethod | ||
| @ensure_contiguous | ||
| def forward( | ||
| ctx, | ||
| fn: Callable, | ||
| mlp_module: torch.nn.Module, | ||
| x: torch.Tensor, | ||
| shards: int, | ||
| compute_params: Optional[List[torch.nn.Parameter]] = None, | ||
| ) -> torch.Tensor: | ||
| ctx.fn = fn | ||
| ctx.mlp_module = mlp_module | ||
| ctx.shards = shards | ||
| ctx.compute_params = [p for p in compute_params if p.requires_grad] if compute_params else [] | ||
| ctx.save_for_backward(x) | ||
|
|
||
| # x.shape could be [bs, seqlen, hidden_size] or [seqlen, hidden_size] (moe experts) | ||
| x_shards = list(torch.chunk(x, chunks=shards, dim=-2)) | ||
| with torch.no_grad(): | ||
| output_shards = [fn(mlp_module, x_shard) for x_shard in x_shards] | ||
| output_unsharded = torch.cat(output_shards, dim=-2) | ||
|
|
||
| return output_unsharded | ||
|
|
||
| @staticmethod | ||
| @ensure_contiguous | ||
| def backward(ctx, *grads) -> tuple: | ||
| fn = ctx.fn | ||
| (x,) = ctx.saved_tensors | ||
| mlp_module = ctx.mlp_module | ||
| shards = ctx.shards | ||
| compute_params = ctx.compute_params | ||
|
|
||
| x_requires_grad = x.requires_grad | ||
| x = x.detach() | ||
| # detach() unsets x.requires_grad, so restore it | ||
| x.requires_grad_(x_requires_grad) | ||
|
|
||
| # x.shape could be [bs, seqlen, hidden_size] or [seqlen, hidden_size] (moe experts) | ||
| hidden_size = x.shape[-1] | ||
| x_shape_orig = x.shape | ||
|
|
||
| # flatten bs+seqlen to avoid having stride issues when narrowing into seqlen w/ bs>1 | ||
| x = x.view(-1, hidden_size) | ||
| incoming_grad = grads[0].view(-1, hidden_size) | ||
| x_grad = torch.zeros_like(x) | ||
|
|
||
| x_shards = list(torch.chunk(x, chunks=shards, dim=0)) | ||
|
|
||
| for i, x_shard in enumerate(x_shards): | ||
| # Tell deepspeed not to add a new grad to its ipg bucket until the last shard is run | ||
| # XXX: DDP, FSDP will need something similar to make it work | ||
| if compute_params: | ||
| if i + 1 < shards: | ||
| for param in compute_params: | ||
| param.ds_grad_is_ready = False | ||
| else: | ||
| # last shard, can add the grad | ||
| for param in compute_params: | ||
| param.ds_grad_is_ready = True | ||
|
|
||
| x_shard.requires_grad_(x_requires_grad) | ||
|
|
||
| # if seqlen is not exactly divisible by shards the last step will be shorter than shard_step | ||
| shard_step = x_shards[i].shape[0] | ||
| shard_offset = i * x_shards[0].shape[0] | ||
|
|
||
| x_shard.grad = x_grad.narrow(0, shard_offset, shard_step).view_as(x_shard) | ||
| incoming_grad_shard = incoming_grad.narrow(0, shard_offset, shard_step).view_as(x_shard) | ||
| with torch.enable_grad(): | ||
| output = fn(mlp_module, x_shard) | ||
| torch.autograd.backward(output, incoming_grad_shard) | ||
|
|
||
| # unflatten | ||
| x_grad = x_grad.view(x_shape_orig) | ||
|
|
||
| return (None, None, x_grad, None, None) | ||
|
|
||
|
|
||
| def apply_tiled_mlp( | ||
| fn: Callable, | ||
| mlp_module: torch.nn.Module, | ||
| x: torch.Tensor, | ||
| num_shards: Optional[int] = None, | ||
| compute_params: Optional[List[torch.nn.Parameter]] = None, | ||
| ) -> torch.Tensor: | ||
| """ | ||
| Apply tiled MLP computation for memory efficiency. | ||
|
|
||
| Args: | ||
| fn: the function to call on sharded inputs (e.g., lambda module, x: module(x)) | ||
| mlp_module: the MLP nn.Module object | ||
| x: the input tensor with shape [bs, seqlen, hidden_size] or [seqlen, hidden_size] | ||
| num_shards: number of shards to use. If None, automatically calculated as ceil(seqlen / hidden_size) | ||
| compute_params: list of parameters for DeepSpeed ZeRO optimization | ||
|
|
||
| Returns: | ||
| output tensor with the same shape as input | ||
| """ | ||
| if num_shards is None: | ||
| # x.shape could be [bs, seqlen, hidden_size] or [seqlen, hidden_size] | ||
| hidden_size = x.shape[-1] | ||
| seqlen = x.shape[-2] | ||
| num_shards = math.ceil(seqlen / hidden_size) | ||
|
|
||
| # Ensure num_shards is at least 1 | ||
| num_shards = max(1, num_shards) | ||
|
|
||
| return LigerTiledMLPFunction.apply( | ||
| fn, | ||
| mlp_module, | ||
| x, | ||
| num_shards, | ||
| compute_params, | ||
| ) | ||
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.