Skip to content

[ICLR 2025] Official PyTorch Implementation of Gated Delta Networks: Improving Mamba2 with Delta Rule

License

Notifications You must be signed in to change notification settings

NVlabs/GatedDeltaNet

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

24 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

Gated Delta Networks: Improving Mamba2 with Delta Rule

nvidia-deltanet-badge

Official PyTorch implementation of Gated Delta Networks: Improving Mamba2 with Delta Rule (ICLR '25).

Star on GitHub

Songlin Yang, Jan Kautz and Ali Hatamizadeh.

For business inquiries, please visit our website and submit the form: NVIDIA Research Licensing

For additional functionalities, such as varlen training and inference support, see FLA implementation.

📢 Latest Updates

  • 02/23/2024: 🔥🔥 Check out the optimized Gated DeltaNet FLA kernels with significantly faster speed.
  • 02/22/2024: 🔥 Gated DeltaNet is available in FLA !
  • 01/22/2024: 🔥🔥 Gated DeltaNet has been accepted to ICLR '25.
  • 12/09/2024: Code Release: Train your own Gated DeltaNet on Slimpajama dataset
  • Watch this space for more exciting updates!

❓ Frequently Asked Questions (FAQ)

1️⃣ Can I use Gated DeltaNet directly from FLA?

Yes! You can import the Gated DeltaNet block directly from FLA. The following script demonstrates how to do so using either FLA or our repository:

>>> USE_FLA = True
>>> import torch
>>> if USE_FLA:
...     from fla.layers import GatedDeltaNet
>>> else:
...     from .gated_delta_net import GatedDeltaNet
>>> 
>>> bs, num_heads, seq_len, hidden_size = 16, 4, 2048, 512
>>> gated_deltanet = GatedDeltaNet(hidden_size=hidden_size, num_heads=num_heads, mode='chunk').bfloat16().cuda()
>>> gated_deltanet
GatedDeltaNet(
  (silu): SiLU()
  (q_proj): Linear(in_features=512, out_features=1024, bias=False)
  (k_proj): Linear(in_features=512, out_features=1024, bias=False)
  (v_proj): Linear(in_features=512, out_features=2048, bias=False)
  (b_proj): Linear(in_features=512, out_features=4, bias=False)
  (a_proj): Linear(in_features=512, out_features=4, bias=False)
  (q_conv1d): ShortConvolution(1024, 1024, kernel_size=(4,), stride=(1,), padding=(3,), groups=1024, bias=False, activation=silu)
  (k_conv1d): ShortConvolution(1024, 1024, kernel_size=(4,), stride=(1,), padding=(3,), groups=1024, bias=False, activation=silu)
  (v_conv1d): ShortConvolution(2048, 2048, kernel_size=(4,), stride=(1,), padding=(3,), groups=2048, bias=False, activation=silu)
  (g_proj): Linear(in_features=512, out_features=2048, bias=False)
  (o_norm): FusedRMSNormSwishGate(512, eps=1e-05)
  (o_proj): Linear(in_features=2048, out_features=512, bias=False)
)
>>> x = torch.randn(bs, seq_len, hidden_size).bfloat16().cuda()
>>> y, _, _ = gated_deltanet(x)
>>> y.shape
torch.Size([16, 2048, 512])

2️⃣ What is the difference between the FLA Gated DeltaNet kernels and the NVLabs implementation?

FLA kernels are faster and also support variable-length (varlen) training. We strongly recommend using FLA for better performance.

For reference, we also provide FLA-based kernels in this repository. You can find the optimized FLA Gated DeltaNet kernels here.


3️⃣ Will you release the pretrained model weights?

No, we only provide code implementations.


4️⃣ The dataloader in this repository is designed for SlimPajama-672B, but your models were trained on FineWeb-Edu. Why is that, and should I expect similar results?

For the code release, we used the original Samba repository and included the SlimPajama-672B dataloader to maintain consistency.

Our experiments confirm that SlimPajama-672B produces similar results and trends to those reported in our paper. You can expect comparable performance.


5️⃣ Any guidance for evaluating the models?

Since this codebase is primarily adapted from the Samba codebase, which is designed mainly for training, evaluation can be inconvenient. Notably, Samba codebase lacks generation utilities required for many generation-based evaluation tasks.

We recommend first converting your trained model weights to Hugging Face format provided in the FLA repo. Once converted, you can leverage FLA for streamlined evaluation.

  • For Single Needle in a Haystack (S-NIAH) tasks:
    Please install NVIDIA/RULER. The installation process can be challenging; we suggest installing any missing dependencies individually to ensure success. S-NIAH tasks are zero-shot tasks, and since RULER supports Hugging Face format models, you can easily evaluate your converted FLA models in this case.

  • For zero-shot commonsense reasoning tasks (Table 3):
    Follow the FLA instructions for evaluation details.

  • For zero-shot, in-context recall-intensive tasks (Table 4):
    Use the official evaluation script from their repository.
    ⚠️ Important: Avoid directly using lm-eval-harness with the task name alone, as this can lead to significant performance differences. These retrieval tasks are highly prompt-sensitive for instruction-untuned models in zero-shot settings.

🌟 Why Gated DeltaNet?

Gated DeltaNet introduces a novel approach to linear transformers by combining:

  • 🧠 Smart Memory Management: Intelligent memory management that knows what to keep and what to forget
  • Precise Updates: Targeted memory updates that enhance model efficiency
  • 💻 Hardware Efficiency: Optimized implementation for real-world deployment

Architecture Overview

Efficiency

Gated DeltaNet shows exceptional performance in terms of training throughput compared to models like Mamba2 and Samba:

Language Modeling and Reasoning

Our model outperforms competitors of various types(e.g. Transformer, RNN, hybrid) in terms of perplexity and zero-shot accuracy on reasoning benchmarks:

Long-context

Gated DeltaNet also achieves favorable perplexity scores on long-context benchmarks:

🚀 Getting Started

Training Your Model

Launch your training with our streamlined command:

python ../pretrain.py \
--train_data_dir ${TRAIN_DATA} \
--val_data_dir ${VALIDATION_DATA} \
--output_root ${SAVE_DIR} \
--exp_name ${NAME} \
--model_name ${MODEL} \
--train_config ${CONFIG} \
--eval_iters ${EVAL_ITERS} \
--learning_rate ${LR} \
--micro_batch_size ${MICRO_BATCH_SIZE}

💡 Pro Tip: Add --interactive_job --debug for interactive debugging sessions!

Please see this slurm script for training the GatedDeltaNet_H1 model with 0.4B parameters on 15B tokens. The training requires 4 nodes and can be finished in approximately 4 hours. For this run, the validation loss and perplexitty curves (1x & 2x for lengh extrapolation) are expected as follows:

curves

📜 License

Copyright © 2025, NVIDIA Corporation. All rights reserved.

Licensed under the NVIDIA Source Code License-NC. See LICENSE for details.

🙏 Acknowledgements

Built on the shoulders of giants:

⭐ Support Us

If you find this work useful, please consider:

  • Starring the repository
  • Citing our paper
  • Contributing to the codebase

Join us in pushing the boundaries of linear transformers! 🚀

Citation

If you find Gated DeltaNet to be useful for your work, please consider citing our paper:

@inproceedings{yang2025gated,
title={Gated Delta Networks: Improving Mamba2 with Delta Rule},
author={Songlin Yang and Jan Kautz and Ali Hatamizadeh},
booktitle={The Thirteenth International Conference on Learning Representations},
year={2025},
url={https://openreview.net/forum?id=r8H7xhYPwz}
}

Star History

Stargazers repo roster for @NVlabs/GatedDeltaNet

Star History Chart

About

[ICLR 2025] Official PyTorch Implementation of Gated Delta Networks: Improving Mamba2 with Delta Rule

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Contributors 3

  •  
  •  
  •  

Languages