Skip to content
240 changes: 240 additions & 0 deletions benchmark/data/all_benchmark_data.csv

Large diffs are not rendered by default.

397 changes: 397 additions & 0 deletions benchmark/scripts/benchmark_tiled_mlp.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,397 @@
import math

import torch
import torch.nn as nn
import triton

from transformers.models.llama.configuration_llama import LlamaConfig
from transformers.models.llama.modeling_llama import LlamaMLP
from utils import QUANTILES
from utils import SingleBenchmarkRunInput
from utils import SingleBenchmarkRunOutput
from utils import _test_memory
from utils import parse_benchmark_script_args
from utils import run_benchmarks

from liger_kernel.transformers.geglu import LigerGEGLUMLP
from liger_kernel.transformers.swiglu import LigerSwiGLUMLP
from liger_kernel.transformers.tiled_mlp import LigerTiledGEGLUMLP
from liger_kernel.transformers.tiled_mlp import LigerTiledSwiGLUMLP
from liger_kernel.utils import infer_device

device = infer_device()


# DeepSpeed TiledMLP implementation
# Based on: https://github.com/deepspeedai/DeepSpeed/blob/v0.18.2/deepspeed/runtime/sequence_parallel/ulysses_sp.py#L838
class DeepSpeedTiledMLP(torch.autograd.Function):
"""
DeepSpeed's TiledMLP implementation for fair comparison.
This is the actual DeepSpeed algorithm that performs tiled MLP computation
to massively reduce memory usage with very long sequence lengths.

This module re-computes forward in the backward, so forward occurs twice per iteration.
"""

@staticmethod
def forward(ctx, fn, self, x, shards, compute_params) -> torch.Tensor:
ctx.fn = fn
ctx.self = self
ctx.shards = shards
ctx.compute_params = [p for p in compute_params if p.requires_grad] if compute_params else []
ctx.save_for_backward(x)

# x.shape could be [bs, seqlen, hidden_size] or [seqlen, hidden_size] (moe experts)
x_shards = list(torch.chunk(x, chunks=shards, dim=-2))
with torch.no_grad():
output_shards = [fn(self, x_shard) for x_shard in x_shards]
output_unsharded = torch.cat(output_shards, dim=-2)

return output_unsharded

@staticmethod
def backward(ctx, *grads):
fn = ctx.fn
(x,) = ctx.saved_tensors
self = ctx.self
shards = ctx.shards
compute_params = ctx.compute_params

x_requires_grad = x.requires_grad
x = x.detach()
# detach() unsets x.requires_grad, so restore it
x.requires_grad_(x_requires_grad)

# x.shape could be [bs, seqlen, hidden_size] or [seqlen, hidden_size] (moe experts)
hidden_size = x.shape[-1]
x_shape_orig = x.shape

# flatten bs+seqlen to avoid having stride issues when narrowing into seqlen w/ bs>1
x = x.view(-1, hidden_size)
incoming_grad = grads[0].view(-1, hidden_size)
x_grad = torch.zeros_like(x)

x_shards = list(torch.chunk(x, chunks=shards, dim=0))

for i, x_shard in enumerate(x_shards):
# Tell deepspeed not to add a new grad to its ipg bucket until the last shard is run
# XXX: DDP, FSDP will need something similar to make it work
if compute_params:
if i + 1 < shards:
for param in compute_params:
if hasattr(param, "ds_grad_is_ready"):
param.ds_grad_is_ready = False
else:
# last shard, can add the grad
for param in compute_params:
if hasattr(param, "ds_grad_is_ready"):
param.ds_grad_is_ready = True

x_shard.requires_grad_(x_requires_grad)

# if seqlen is not exactly divisible by shards the last step will be shorter than shard_step
shard_step = x_shards[i].shape[0]
shard_offset = i * x_shards[0].shape[0]

x_shard.grad = x_grad.narrow(0, shard_offset, shard_step).view_as(x_shard)
incoming_grad_shard = incoming_grad.narrow(0, shard_offset, shard_step).view_as(x_shard)
with torch.enable_grad():
output = fn(self, x_shard)
torch.autograd.backward(output, incoming_grad_shard)

# unflatten
x_grad = x_grad.view(x_shape_orig)

return (None, None, x_grad, None, None)


# DeepSpeed TiledMLP wrapper to match our interface
class DeepSpeedTiledMLPWrapper(nn.Module):
"""
Wrapper for DeepSpeed's TiledMLP to match the interface used in benchmarks.
Uses the DeepSpeed TiledMLP algorithm for memory-efficient MLP computation.
"""

def __init__(self, config, num_shards=None):
super().__init__()
self.config = config
self.hidden_size = config.hidden_size
self.intermediate_size = config.intermediate_size
self.num_shards = num_shards

self.mlp = LlamaMLP(config=config)

def forward(self, x):
# Calculate num_shards if not provided
num_shards = self.num_shards
if num_shards is None:
hidden_size = x.shape[-1]
seqlen = x.shape[-2]
num_shards = math.ceil(seqlen / hidden_size)
num_shards = max(1, num_shards)

# Collect compute parameters for DeepSpeed ZeRO compatibility
compute_params = [
self.mlp.down_proj.weight,
self.mlp.gate_proj.weight,
self.mlp.up_proj.weight,
]

# Define the MLP forward function for DeepSpeed TiledMLP
def mlp_forward(mlp_module, x_input):
return mlp_module.down_proj(mlp_module.act_fn(mlp_module.gate_proj(x_input)) * mlp_module.up_proj(x_input))

# Use DeepSpeed's TiledMLP implementation
return DeepSpeedTiledMLP.apply(
mlp_forward,
self.mlp,
x,
num_shards,
compute_params,
)


def bench_speed_tiled_mlp(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOutput:
seq_len = input.x
bsz = input.extra_benchmark_config["bsz"]
hidden_size = input.extra_benchmark_config["hidden_size"]
intermediate_size = input.extra_benchmark_config["intermediate_size"]
hidden_act = input.extra_benchmark_config["hidden_act"]
dtype = input.extra_benchmark_config["dtype"]
num_shards = input.extra_benchmark_config.get("num_shards", None)
activation_type = input.extra_benchmark_config["activation_type"]
provider = input.kernel_provider
mode = input.kernel_operation_mode

llama_config = LlamaConfig(
hidden_size=hidden_size,
intermediate_size=intermediate_size,
hidden_act=hidden_act,
)

x_shape = (bsz, seq_len, hidden_size)

# initialize input
x = torch.randn(*x_shape, device=device, dtype=dtype, requires_grad=True)

if activation_type == "geglu":
if provider == "huggingface":
layer = LlamaMLP(config=llama_config).to(device).to(dtype)
elif provider == "liger":
layer = LigerGEGLUMLP(config=llama_config).to(device).to(dtype)
elif provider == "liger_tiled":
layer = LigerTiledGEGLUMLP(config=llama_config, num_shards=num_shards).to(device).to(dtype)
elif provider == "deepspeed_tiled":
layer = DeepSpeedTiledMLPWrapper(config=llama_config, num_shards=num_shards).to(device).to(dtype)
else:
raise ValueError(f"Invalid provider: {provider} for GEGLU")
elif activation_type == "swiglu":
if provider == "huggingface":
layer = LlamaMLP(config=llama_config).to(device).to(dtype)
elif provider == "liger":
layer = LigerSwiGLUMLP(config=llama_config).to(device).to(dtype)
elif provider == "liger_tiled":
layer = LigerTiledSwiGLUMLP(config=llama_config, num_shards=num_shards).to(device).to(dtype)
elif provider == "deepspeed_tiled":
layer = DeepSpeedTiledMLPWrapper(config=llama_config, num_shards=num_shards).to(device).to(dtype)
else:
raise ValueError(f"Invalid provider: {provider} for SwiGLU")
else:
raise ValueError(f"Invalid activation_type: {activation_type}")

def fwd():
return layer(x)

if mode == "forward":
ms_50, ms_20, ms_80 = triton.testing.do_bench(
fwd,
grad_to_none=[x],
rep=10,
quantiles=QUANTILES,
)
elif mode == "backward":
do = torch.randn_like(x)
y = fwd()
ms_50, ms_20, ms_80 = triton.testing.do_bench(
lambda: y.backward(do, retain_graph=True),
grad_to_none=[x],
rep=10,
quantiles=QUANTILES,
)
else:

def full():
y = fwd()
y.backward(torch.randn_like(y), retain_graph=True)

ms_50, ms_20, ms_80 = triton.testing.do_bench(
full,
grad_to_none=[x],
rep=10,
quantiles=QUANTILES,
)

return SingleBenchmarkRunOutput(
y_20=ms_20,
y_50=ms_50,
y_80=ms_80,
)


def bench_memory_tiled_mlp(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOutput:
seq_len = input.x
bsz = input.extra_benchmark_config["bsz"]
hidden_size = input.extra_benchmark_config["hidden_size"]
intermediate_size = input.extra_benchmark_config["intermediate_size"]
hidden_act = input.extra_benchmark_config["hidden_act"]
dtype = input.extra_benchmark_config["dtype"]
num_shards = input.extra_benchmark_config.get("num_shards", None)
activation_type = input.extra_benchmark_config["activation_type"]
provider = input.kernel_provider
mode = input.kernel_operation_mode

llama_config = LlamaConfig(
hidden_size=hidden_size,
intermediate_size=intermediate_size,
hidden_act=hidden_act,
)

x_shape = (bsz, seq_len, hidden_size)
# initialize input
x = torch.randn(*x_shape, device=device, dtype=dtype, requires_grad=True)

if activation_type == "geglu":
if provider == "huggingface":
layer = LlamaMLP(config=llama_config).to(device).to(dtype)
elif provider == "liger":
layer = LigerGEGLUMLP(config=llama_config).to(device).to(dtype)
elif provider == "liger_tiled":
layer = LigerTiledGEGLUMLP(config=llama_config, num_shards=num_shards).to(device).to(dtype)
elif provider == "deepspeed_tiled":
layer = DeepSpeedTiledMLPWrapper(config=llama_config, num_shards=num_shards).to(device).to(dtype)
else:
raise ValueError(f"Invalid provider: {provider} for GEGLU")
elif activation_type == "swiglu":
if provider == "huggingface":
layer = LlamaMLP(config=llama_config).to(device).to(dtype)
elif provider == "liger":
layer = LigerSwiGLUMLP(config=llama_config).to(device).to(dtype)
elif provider == "liger_tiled":
layer = LigerTiledSwiGLUMLP(config=llama_config, num_shards=num_shards).to(device).to(dtype)
elif provider == "deepspeed_tiled":
layer = DeepSpeedTiledMLPWrapper(config=llama_config, num_shards=num_shards).to(device).to(dtype)
else:
raise ValueError(f"Invalid provider: {provider} for SwiGLU")
else:
raise ValueError(f"Invalid activation_type: {activation_type}")

def fwd():
return layer(x)

def full():
y = fwd()
y.backward(torch.randn_like(y), retain_graph=True)

if mode == "forward":
mem_50, mem_20, mem_80 = _test_memory(
fwd,
quantiles=QUANTILES,
)
elif mode == "backward":
do = torch.randn_like(x)
y = fwd()
mem_50, mem_20, mem_80 = _test_memory(
lambda: y.backward(do, retain_graph=True),
quantiles=QUANTILES,
)
else:
mem_50, mem_20, mem_80 = _test_memory(
full,
quantiles=QUANTILES,
)

return SingleBenchmarkRunOutput(
y_20=mem_20,
y_50=mem_50,
y_80=mem_80,
)


if __name__ == "__main__":
args = parse_benchmark_script_args()

# Benchmark GEGLU variants
kernel_providers_geglu = ["huggingface", "liger", "liger_tiled", "deepspeed_tiled"]

common_configs_geglu = {
"kernel_name": "tiled_geglu",
"x_name": "T",
"x_label": "sequence length",
"x_values": [2**i for i in range(10, 15)], # 1024 to 16384
"kernel_providers": kernel_providers_geglu,
"extra_benchmark_configs": [
{
"bsz": 2,
"hidden_size": 2048,
"intermediate_size": 4096,
"hidden_act": "gelu_pytorch_tanh",
"activation_type": "geglu",
"num_shards": 4,
"dtype": torch.bfloat16,
}
],
"overwrite": args.overwrite,
}

run_benchmarks(
bench_test_fn=bench_speed_tiled_mlp,
kernel_operation_modes=["full", "forward", "backward"],
metric_name="speed",
metric_unit="ms",
**common_configs_geglu,
)
run_benchmarks(
bench_test_fn=bench_memory_tiled_mlp,
kernel_operation_modes=["full", "forward", "backward"],
metric_name="memory",
metric_unit="MB",
**common_configs_geglu,
)

# Benchmark SwiGLU variants
kernel_providers_swiglu = ["huggingface", "liger", "liger_tiled", "deepspeed_tiled"]

common_configs_swiglu = {
"kernel_name": "tiled_swiglu",
"x_name": "T",
"x_label": "sequence length",
"x_values": [2**i for i in range(10, 15)], # 1024 to 16384
"kernel_providers": kernel_providers_swiglu,
"extra_benchmark_configs": [
{
"bsz": 2,
"hidden_size": 2048,
"intermediate_size": 4096,
"hidden_act": "silu",
"activation_type": "swiglu",
"num_shards": 4,
"dtype": torch.bfloat16,
}
],
"overwrite": args.overwrite,
}

run_benchmarks(
bench_test_fn=bench_speed_tiled_mlp,
kernel_operation_modes=["full", "forward", "backward"],
metric_name="speed",
metric_unit="ms",
**common_configs_swiglu,
)
run_benchmarks(
bench_test_fn=bench_memory_tiled_mlp,
kernel_operation_modes=["full", "forward", "backward"],
metric_name="memory",
metric_unit="MB",
**common_configs_swiglu,
)
Loading
Loading