Skip to content
120 changes: 120 additions & 0 deletions benchmark/data/all_benchmark_data.csv

Large diffs are not rendered by default.

244 changes: 244 additions & 0 deletions benchmark/scripts/benchmark_tiled_mlp.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,244 @@
import torch
import triton

from transformers.models.llama.configuration_llama import LlamaConfig
from utils import QUANTILES
from utils import SingleBenchmarkRunInput
from utils import SingleBenchmarkRunOutput
from utils import _test_memory
from utils import parse_benchmark_script_args
from utils import run_benchmarks

from liger_kernel.transformers.geglu import LigerGEGLUMLP
from liger_kernel.transformers.swiglu import LigerSwiGLUMLP
from liger_kernel.transformers.tiled_mlp import LigerTiledGEGLUMLP
from liger_kernel.transformers.tiled_mlp import LigerTiledSwiGLUMLP
from liger_kernel.utils import infer_device

device = infer_device()


def bench_speed_tiled_mlp(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOutput:
seq_len = input.x
bsz = input.extra_benchmark_config["bsz"]
hidden_size = input.extra_benchmark_config["hidden_size"]
intermediate_size = input.extra_benchmark_config["intermediate_size"]
hidden_act = input.extra_benchmark_config["hidden_act"]
dtype = input.extra_benchmark_config["dtype"]
num_shards = input.extra_benchmark_config.get("num_shards", None)
activation_type = input.extra_benchmark_config["activation_type"]
provider = input.kernel_provider
mode = input.kernel_operation_mode

llama_config = LlamaConfig(
hidden_size=hidden_size,
intermediate_size=intermediate_size,
hidden_act=hidden_act,
)

x_shape = (bsz, seq_len, hidden_size)

# initialize input
x = torch.randn(*x_shape, device=device, dtype=dtype, requires_grad=True)

if activation_type == "geglu":
if provider == "liger":
layer = LigerGEGLUMLP(config=llama_config).to(device).to(dtype)
elif provider == "liger_tiled":
layer = LigerTiledGEGLUMLP(config=llama_config, num_shards=num_shards).to(device).to(dtype)
else:
raise ValueError(f"Invalid provider: {provider} for GEGLU")
elif activation_type == "swiglu":
if provider == "liger":
layer = LigerSwiGLUMLP(config=llama_config).to(device).to(dtype)
elif provider == "liger_tiled":
layer = LigerTiledSwiGLUMLP(config=llama_config, num_shards=num_shards).to(device).to(dtype)
else:
raise ValueError(f"Invalid provider: {provider} for SwiGLU")
else:
raise ValueError(f"Invalid activation_type: {activation_type}")

def fwd():
return layer(x)

if mode == "forward":
ms_50, ms_20, ms_80 = triton.testing.do_bench(
fwd,
grad_to_none=[x],
rep=10,
quantiles=QUANTILES,
)
elif mode == "backward":
do = torch.randn_like(x)
y = fwd()
ms_50, ms_20, ms_80 = triton.testing.do_bench(
lambda: y.backward(do, retain_graph=True),
grad_to_none=[x],
rep=10,
quantiles=QUANTILES,
)
else:

def full():
y = fwd()
y.backward(torch.randn_like(y), retain_graph=True)

ms_50, ms_20, ms_80 = triton.testing.do_bench(
full,
grad_to_none=[x],
rep=10,
quantiles=QUANTILES,
)

return SingleBenchmarkRunOutput(
y_20=ms_20,
y_50=ms_50,
y_80=ms_80,
)


def bench_memory_tiled_mlp(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOutput:
seq_len = input.x
bsz = input.extra_benchmark_config["bsz"]
hidden_size = input.extra_benchmark_config["hidden_size"]
intermediate_size = input.extra_benchmark_config["intermediate_size"]
hidden_act = input.extra_benchmark_config["hidden_act"]
dtype = input.extra_benchmark_config["dtype"]
num_shards = input.extra_benchmark_config.get("num_shards", None)
activation_type = input.extra_benchmark_config["activation_type"]
provider = input.kernel_provider
mode = input.kernel_operation_mode

llama_config = LlamaConfig(
hidden_size=hidden_size,
intermediate_size=intermediate_size,
hidden_act=hidden_act,
)

x_shape = (bsz, seq_len, hidden_size)
# initialize input
x = torch.randn(*x_shape, device=device, dtype=dtype, requires_grad=True)

if activation_type == "geglu":
if provider == "liger":
layer = LigerGEGLUMLP(config=llama_config).to(device).to(dtype)
elif provider == "liger_tiled":
layer = LigerTiledGEGLUMLP(config=llama_config, num_shards=num_shards).to(device).to(dtype)
else:
raise ValueError(f"Invalid provider: {provider} for GEGLU")
elif activation_type == "swiglu":
if provider == "liger":
layer = LigerSwiGLUMLP(config=llama_config).to(device).to(dtype)
elif provider == "liger_tiled":
layer = LigerTiledSwiGLUMLP(config=llama_config, num_shards=num_shards).to(device).to(dtype)
else:
raise ValueError(f"Invalid provider: {provider} for SwiGLU")
else:
raise ValueError(f"Invalid activation_type: {activation_type}")

def fwd():
return layer(x)

def full():
y = fwd()
y.backward(torch.randn_like(y), retain_graph=True)

if mode == "forward":
mem_50, mem_20, mem_80 = _test_memory(
fwd,
quantiles=QUANTILES,
)
elif mode == "backward":
do = torch.randn_like(x)
y = fwd()
mem_50, mem_20, mem_80 = _test_memory(
lambda: y.backward(do, retain_graph=True),
quantiles=QUANTILES,
)
else:
mem_50, mem_20, mem_80 = _test_memory(
full,
quantiles=QUANTILES,
)

return SingleBenchmarkRunOutput(
y_20=mem_20,
y_50=mem_50,
y_80=mem_80,
)


if __name__ == "__main__":
args = parse_benchmark_script_args()

# Benchmark GEGLU variants
common_configs_geglu = {
"kernel_name": "tiled_geglu",
"x_name": "T",
"x_label": "sequence length",
"x_values": [2**i for i in range(10, 15)], # 1024 to 16384
"kernel_providers": ["liger", "liger_tiled"],
"extra_benchmark_configs": [
{
"bsz": 2,
"hidden_size": 2048,
"intermediate_size": 4096,
"hidden_act": "gelu_pytorch_tanh",
"activation_type": "geglu",
"num_shards": 4,
"dtype": torch.bfloat16,
}
],
"overwrite": args.overwrite,
}

run_benchmarks(
bench_test_fn=bench_speed_tiled_mlp,
kernel_operation_modes=["full", "forward", "backward"],
metric_name="speed",
metric_unit="ms",
**common_configs_geglu,
)
run_benchmarks(
bench_test_fn=bench_memory_tiled_mlp,
kernel_operation_modes=["full", "forward", "backward"],
metric_name="memory",
metric_unit="MB",
**common_configs_geglu,
)

# Benchmark SwiGLU variants
common_configs_swiglu = {
"kernel_name": "tiled_swiglu",
"x_name": "T",
"x_label": "sequence length",
"x_values": [2**i for i in range(10, 15)], # 1024 to 16384
"kernel_providers": ["liger", "liger_tiled"],
"extra_benchmark_configs": [
{
"bsz": 2,
"hidden_size": 2048,
"intermediate_size": 4096,
"hidden_act": "silu",
"activation_type": "swiglu",
"num_shards": 4,
"dtype": torch.bfloat16,
}
],
"overwrite": args.overwrite,
}

run_benchmarks(
bench_test_fn=bench_speed_tiled_mlp,
kernel_operation_modes=["full", "forward", "backward"],
metric_name="speed",
metric_unit="ms",
**common_configs_swiglu,
)
run_benchmarks(
bench_test_fn=bench_memory_tiled_mlp,
kernel_operation_modes=["full", "forward", "backward"],
metric_name="memory",
metric_unit="MB",
**common_configs_swiglu,
)
147 changes: 147 additions & 0 deletions src/liger_kernel/ops/tiled_mlp.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,147 @@
"""
Based on DeepSpeed's TiledMLP:
https://github.com/microsoft/DeepSpeed/blob/master/deepspeed/runtime/sequence_parallel/ulysses_sp.py
"""

import math
from typing import Callable, List, Optional

import torch

from liger_kernel.ops.utils import ensure_contiguous


class LigerTiledMLPFunction(torch.autograd.Function):
"""
Perform a tiled MLP computation to massively reduce memory usage needed to compute MLP
when using very long sequence lengths.

This module re-computes `forward` in the `backward`. So the `forward` occurs twice each iteration.
And if you're using activation checkpointing it then occurs thrice.

Args:
fn: the function to call on sharded inputs (e.g., mlp.forward)
mlp_module: the MLP nn.Module object
x: the input to MLP.forward (hidden_states)
shards: how many shards to use
compute_params: a list of weights engaged in the compute (only needed when using DeepSpeed ZeRO)

Returns:
the computed hidden_states
"""

@staticmethod
@ensure_contiguous
def forward(
ctx,
fn: Callable,
mlp_module: torch.nn.Module,
x: torch.Tensor,
shards: int,
compute_params: Optional[List[torch.nn.Parameter]] = None,
) -> torch.Tensor:
ctx.fn = fn
ctx.mlp_module = mlp_module
ctx.shards = shards
ctx.compute_params = [p for p in compute_params if p.requires_grad] if compute_params else []
ctx.save_for_backward(x)

# x.shape could be [bs, seqlen, hidden_size] or [seqlen, hidden_size] (moe experts)
x_shards = list(torch.chunk(x, chunks=shards, dim=-2))
with torch.no_grad():
output_shards = [fn(mlp_module, x_shard) for x_shard in x_shards]
output_unsharded = torch.cat(output_shards, dim=-2)

return output_unsharded

@staticmethod
@ensure_contiguous
def backward(ctx, *grads) -> tuple:
fn = ctx.fn
(x,) = ctx.saved_tensors
mlp_module = ctx.mlp_module
shards = ctx.shards
compute_params = ctx.compute_params

x_requires_grad = x.requires_grad
x = x.detach()
# detach() unsets x.requires_grad, so restore it
x.requires_grad_(x_requires_grad)

# x.shape could be [bs, seqlen, hidden_size] or [seqlen, hidden_size] (moe experts)
hidden_size = x.shape[-1]
x_shape_orig = x.shape

# flatten bs+seqlen to avoid having stride issues when narrowing into seqlen w/ bs>1
x = x.view(-1, hidden_size)
incoming_grad = grads[0].view(-1, hidden_size)
x_grad = torch.zeros_like(x)

x_shards = list(torch.chunk(x, chunks=shards, dim=0))

for i, x_shard in enumerate(x_shards):
# Tell deepspeed not to add a new grad to its ipg bucket until the last shard is run
# XXX: DDP, FSDP will need something similar to make it work
if compute_params:
if i + 1 < shards:
for param in compute_params:
param.ds_grad_is_ready = False
else:
# last shard, can add the grad
for param in compute_params:
param.ds_grad_is_ready = True

x_shard.requires_grad_(x_requires_grad)

# if seqlen is not exactly divisible by shards the last step will be shorter than shard_step
shard_step = x_shards[i].shape[0]
shard_offset = i * x_shards[0].shape[0]

x_shard.grad = x_grad.narrow(0, shard_offset, shard_step).view_as(x_shard)
incoming_grad_shard = incoming_grad.narrow(0, shard_offset, shard_step).view_as(x_shard)
with torch.enable_grad():
output = fn(mlp_module, x_shard)
torch.autograd.backward(output, incoming_grad_shard)

# unflatten
x_grad = x_grad.view(x_shape_orig)

return (None, None, x_grad, None, None)


def apply_tiled_mlp(
fn: Callable,
mlp_module: torch.nn.Module,
x: torch.Tensor,
num_shards: Optional[int] = None,
compute_params: Optional[List[torch.nn.Parameter]] = None,
) -> torch.Tensor:
"""
Apply tiled MLP computation for memory efficiency.

Args:
fn: the function to call on sharded inputs (e.g., lambda module, x: module(x))
mlp_module: the MLP nn.Module object
x: the input tensor with shape [bs, seqlen, hidden_size] or [seqlen, hidden_size]
num_shards: number of shards to use. If None, automatically calculated as ceil(seqlen / hidden_size)
compute_params: list of parameters for DeepSpeed ZeRO optimization

Returns:
output tensor with the same shape as input
"""
if num_shards is None:
# x.shape could be [bs, seqlen, hidden_size] or [seqlen, hidden_size]
hidden_size = x.shape[-1]
seqlen = x.shape[-2]
num_shards = math.ceil(seqlen / hidden_size)

# Ensure num_shards is at least 1
num_shards = max(1, num_shards)

return LigerTiledMLPFunction.apply(
fn,
mlp_module,
x,
num_shards,
compute_params,
)
Loading