Thunder makes optimizing PyTorch models easy, augmenting them with custom kernels, fusions, quantization, distributed strategies, and more.
For end users, Thunder comes with plugins that provide model speed-ups out of the box, for optimal utilization of last generation hardware.
For performance experts, Thunder is the most ergonomic framework for understanding, modifying, and optimizing AI models through composable transformations.
✅ Run PyTorch 40% faster ✅ Quantization ✅ Kernel fusion ✅ Training recipes ✅ FP4/FP6/FP8 precision ✅ Distributed TP/PP/DP ✅ Inference recipes ✅ Ready for NVIDIA Blackwell ✅ CUDA Graphs ✅ LLMs, non LLMs and more ✅ Custom Triton kernels ✅ Compose all the above
Install Thunder via pip (more options):
pip install torch==2.6.0 torchvision==0.21 nvfuser-cu124-torch26
pip install lightning-thunder
Advanced install options
For Blackwell you'll need CUDA 12.8
pip install --pre torch torchvision --index-url https://download.pytorch.org/whl/nightly/cu128
pip install --pre nvfuser-cu128 --extra-index-url https://pypi.nvidia.com
pip install lightning-thunder
These are optional, feel free to mix and match
# cuDNN SDPA
pip install nvidia-cudnn-frontend
# Float8 support (this will compile from source, be patient)
pip install "transformer_engine[pytorch]"
pip install git+https://github.com/Lightning-AI/lightning-thunder.git@main
git clone https://github.com/Lightning-AI/lightning-thunder.git
cd lightning-thunder
pip install -e .
Define a function or a torch module:
import torch.nn as nn
model = nn.Sequential(nn.Linear(2048, 4096), nn.ReLU(), nn.Linear(4096, 64))
Optimize it with thunder:
import thunder
thunder_model = thunder.compile(model)
x = torch.randn(64, 2048)
y = thunder_model(x)
assert y == model(x)
import thunder
import torch
import litgpt
with torch.device("cuda"):
model = litgpt.GPT.from_name("Llama-3.2-1B").to(torch.bfloat16)
thunder_model = thunder.compile(model)
inp = torch.ones((1, 2048), device="cuda", dtype=torch.int64)
out = thunder_model(inp)
out.sum().backward()
import thunder
import torch
import transformers
model_name = "bert-large-uncased"
tokenizer = transformers.AutoTokenizer.from_pretrained(model_name)
with torch.device("cuda"):
model = transformers.AutoModelForCausalLM.from_pretrained(
model_name, torch_dtype=torch.bfloat16
)
model.requires_grad_(False)
model.eval()
inp = tokenizer(["Hello world!"], return_tensors="pt")
thunder_model = thunder.compile(model, plugins="reduce-overhead")
out = thunder_model(**inp)
print(out)
import torch
import transformers
import thunder
model_name = "deepseek-ai/DeepSeek-R1-Distill-Llama-8B"
tokenizer = transformers.AutoTokenizer.from_pretrained(model_name)
with torch.device("cuda"):
model = transformers.AutoModelForCausalLM.from_pretrained(
model_name, torch_dtype=torch.bfloat16
)
model.requires_grad_(False)
model.eval()
inp = tokenizer(["Hello world! Here's a long story"], return_tensors="pt")
thunder_model = thunder.compile(
model, recipe="hf-transformers", plugins="reduce-overhead"
)
out = thunder_model.generate(
**inp, do_sample=False, cache_implementation="static", max_new_tokens=100
)
print(out)
import thunder
import torch
import torchvision as tv
with torch.device(device):
model = tv.models.vit_b_16()
model.requires_grad_(False)
model.eval()
inp = torch.randn(128, 3, 224, 224)
out = model(inp)
thunder_model = thunder.compile(model, plugins="reduce-overhead")
out = thunder_model(inp)
Plugins are a way to apply optimizations to a model, such as parallelism and quantization.
Thunder comes with a few plugins included of the box, but it's easy to write new ones.
- scale up with distributed strategies with DDP, FSDP, TP ()
- optimize numerical precision with FP8, MXFP8
- save memory with quantization
- reduce latency with CUDAGraphs
- debugging and profiling
Thunder works in three stages:
-
⚡️ It acquires your model by interpreting Python bytecode and producing a straight-line Python program
-
️⚡️ It transforms the computation trace to make it distributed, change precision
-
⚡️ It routes parts of the trace for execution
- fusion (
NVFuser
,torch.compile
) - specialized libraries (e.g.
cuDNN SDPA
,TransformerEngine
) - custom Triton and CUDA kernels
- PyTorch eager operations
- fusion (
This is how the trace looks like for a simple MLP:
import thunder
import torch.nn as nn
model = nn.Sequential(nn.Linear(1024, 2048), nn.ReLU(), nn.Linear(2048, 256))
thunder_model = thunder.compile(model)
y = thunder_model(torch.randn(4, 1024))
print(thunder.last_traces(thunder_model)[-1])
This is the acquired trace, ready to be transformed and executed:
def computation(input, t_0_bias, t_0_weight, t_2_bias, t_2_weight):
# input: "cuda:0 f32[4, 1024]"
# t_0_bias: "cuda:0 f32[2048]"
# t_0_weight: "cuda:0 f32[2048, 1024]"
# t_2_bias: "cuda:0 f32[256]"
# t_2_weight: "cuda:0 f32[256, 2048]"
t3 = ltorch.linear(input, t_0_weight, t_0_bias) # t3: "cuda:0 f32[4, 2048]"
t6 = ltorch.relu(t3, False) # t6: "cuda:0 f32[4, 2048]"
t10 = ltorch.linear(t6, t_2_weight, t_2_bias) # t10: "cuda:0 f32[4, 256]"
return (t10,)
Note how Thunder's intermediate representation is just (a subset of) Python!
Thunder is fast. Here are the speed-ups obtained on a pre-training task using LitGPT on H100 and B200 hardware, relative to PyTorch eager.
Thunder is an open source project, developed in collaboration with the community with significant contributions from NVIDIA.