Official PyTorch implementation of Gated Delta Networks: Improving Mamba2 with Delta Rule (ICLR '25).
Songlin Yang, Jan Kautz and Ali Hatamizadeh.
For business inquiries, please visit our website and submit the form: NVIDIA Research Licensing
For additional functionalities, such as varlen training and inference support, see FLA implementation.
02/23/2024
: 🔥🔥 Check out the optimized Gated DeltaNet FLA kernels with significantly faster speed.02/22/2024
: 🔥 Gated DeltaNet is available in FLA !01/22/2024
: 🔥🔥 Gated DeltaNet has been accepted to ICLR '25.12/09/2024
: Code Release: Train your own Gated DeltaNet on Slimpajama dataset- Watch this space for more exciting updates!
Yes! You can import the Gated DeltaNet block directly from FLA. The following script demonstrates how to do so using either FLA or our repository:
>>> USE_FLA = True
>>> import torch
>>> if USE_FLA:
... from fla.layers import GatedDeltaNet
>>> else:
... from .gated_delta_net import GatedDeltaNet
>>>
>>> bs, num_heads, seq_len, hidden_size = 16, 4, 2048, 512
>>> gated_deltanet = GatedDeltaNet(hidden_size=hidden_size, num_heads=num_heads, mode='chunk').bfloat16().cuda()
>>> gated_deltanet
GatedDeltaNet(
(silu): SiLU()
(q_proj): Linear(in_features=512, out_features=1024, bias=False)
(k_proj): Linear(in_features=512, out_features=1024, bias=False)
(v_proj): Linear(in_features=512, out_features=2048, bias=False)
(b_proj): Linear(in_features=512, out_features=4, bias=False)
(a_proj): Linear(in_features=512, out_features=4, bias=False)
(q_conv1d): ShortConvolution(1024, 1024, kernel_size=(4,), stride=(1,), padding=(3,), groups=1024, bias=False, activation=silu)
(k_conv1d): ShortConvolution(1024, 1024, kernel_size=(4,), stride=(1,), padding=(3,), groups=1024, bias=False, activation=silu)
(v_conv1d): ShortConvolution(2048, 2048, kernel_size=(4,), stride=(1,), padding=(3,), groups=2048, bias=False, activation=silu)
(g_proj): Linear(in_features=512, out_features=2048, bias=False)
(o_norm): FusedRMSNormSwishGate(512, eps=1e-05)
(o_proj): Linear(in_features=2048, out_features=512, bias=False)
)
>>> x = torch.randn(bs, seq_len, hidden_size).bfloat16().cuda()
>>> y, _, _ = gated_deltanet(x)
>>> y.shape
torch.Size([16, 2048, 512])
FLA kernels are faster and also support variable-length (varlen) training. We strongly recommend using FLA for better performance.
For reference, we also provide FLA-based kernels in this repository. You can find the optimized FLA Gated DeltaNet kernels here.
No, we only provide code implementations.
4️⃣ The dataloader in this repository is designed for SlimPajama-672B, but your models were trained on FineWeb-Edu. Why is that, and should I expect similar results?
For the code release, we used the original Samba repository and included the SlimPajama-672B dataloader to maintain consistency.
Our experiments confirm that SlimPajama-672B produces similar results and trends to those reported in our paper. You can expect comparable performance.
Since this codebase is primarily adapted from the Samba codebase, which is designed mainly for training, evaluation can be inconvenient. Notably, Samba codebase lacks generation utilities required for many generation-based evaluation tasks.
We recommend first converting your trained model weights to Hugging Face format provided in the FLA repo. Once converted, you can leverage FLA for streamlined evaluation.
-
For Single Needle in a Haystack (S-NIAH) tasks:
Please install NVIDIA/RULER. The installation process can be challenging; we suggest installing any missing dependencies individually to ensure success. S-NIAH tasks are zero-shot tasks, and since RULER supports Hugging Face format models, you can easily evaluate your converted FLA models in this case. -
For zero-shot commonsense reasoning tasks (Table 3):
Follow the FLA instructions for evaluation details. -
For zero-shot, in-context recall-intensive tasks (Table 4):
Use the official evaluation script from their repository.
⚠️ Important: Avoid directly usinglm-eval-harness
with the task name alone, as this can lead to significant performance differences. These retrieval tasks are highly prompt-sensitive for instruction-untuned models in zero-shot settings.
Gated DeltaNet introduces a novel approach to linear transformers by combining:
- 🧠 Smart Memory Management: Intelligent memory management that knows what to keep and what to forget
- ⚡ Precise Updates: Targeted memory updates that enhance model efficiency
- 💻 Hardware Efficiency: Optimized implementation for real-world deployment
Gated DeltaNet shows exceptional performance in terms of training throughput compared to models like Mamba2 and Samba:
Our model outperforms competitors of various types(e.g. Transformer, RNN, hybrid) in terms of perplexity and zero-shot accuracy on reasoning benchmarks:
Gated DeltaNet also achieves favorable perplexity scores on long-context benchmarks:
Launch your training with our streamlined command:
python ../pretrain.py \
--train_data_dir ${TRAIN_DATA} \
--val_data_dir ${VALIDATION_DATA} \
--output_root ${SAVE_DIR} \
--exp_name ${NAME} \
--model_name ${MODEL} \
--train_config ${CONFIG} \
--eval_iters ${EVAL_ITERS} \
--learning_rate ${LR} \
--micro_batch_size ${MICRO_BATCH_SIZE}
💡 Pro Tip: Add --interactive_job --debug
for interactive debugging sessions!
Please see this slurm script for training the GatedDeltaNet_H1 model with 0.4B parameters on 15B tokens. The training requires 4 nodes and can be finished in approximately 4 hours. For this run, the validation loss and perplexitty curves (1x & 2x for lengh extrapolation) are expected as follows:
Copyright © 2025, NVIDIA Corporation. All rights reserved.
Licensed under the NVIDIA Source Code License-NC. See LICENSE for details.
Built on the shoulders of giants:
If you find this work useful, please consider:
- Starring the repository
- Citing our paper
- Contributing to the codebase
Join us in pushing the boundaries of linear transformers! 🚀
If you find Gated DeltaNet to be useful for your work, please consider citing our paper:
@inproceedings{yang2025gated,
title={Gated Delta Networks: Improving Mamba2 with Delta Rule},
author={Songlin Yang and Jan Kautz and Ali Hatamizadeh},
booktitle={The Thirteenth International Conference on Learning Representations},
year={2025},
url={https://openreview.net/forum?id=r8H7xhYPwz}
}