Official implementation for paper "Efficient Length-Generalizable Attention via Causal Retrieval for Long-Context Language Modeling" (ICML 2025)
When generating the current chunk (c7), GCA (Grouped CA) retrieves past chunks using the landmark representation of c6 to assist in token prediction for the next chunk. The key to GCA's length generalization lies in an end-to-end differentiable retrieval mechanism, which is achieved through a two-stage attention mechanism. After selecting the top-k chunks:In the first stage, each token in c7 performs attention with the tokens within the retrieved chunk respectively to obtain information from that chunk. Taking the example in the diagram,
In the second stage, the softmax-normalized retrieval scores of the chunks are used as weights to perform a weighted summation of
During backpropagation (BP), the weights of past chunks that better facilitate token prediction for the next chunk will be enhanced, enabling end-to-end causal retrieval learning.
All models were pre-trained on contexts of no more than 16K tokens, and all attention spans are limited to no more than 728 tokens. Our model (DRT) achieves 1000x extrapolation on the needle-in-a-haystack task, maintaining high accuracy even with 16M context length.torch==2.4.0, transformers>=4.36.0, triton==3.0.0
pip install requirements.txt
ArXiv-math, PG19, XSUM, CNN/DailyMail
Before pre-training, ensure that the corpus is indexed. Pre-processing script:
PG19: python preprocess/pg19_prepare.py
ArXiv: python preprocess/arxiv_math_prepare.py
Summarization: python preprocess/summarization_preprocess.py
Test triton kernel:
pytest ltriton/gca.py
Test DRT generation:
python -m unittest tests/generation_unittest.py
sh scripts/pretrain_pg19.sh
Summarization tasks
sh scripts/xsum_ft.sh
NIAH tests
sh scripts/niah_ft.sh
Please note that we have observed whether to enable softmax_off_by_one
has an impact on the results. Therefore, when fine-tuning for the NIAH task, we use vanilla softmax by setting enable_softmax_one
to false in the config.
Eval perplexity:
python slidewin_eval.py \
--config_path PATH_TO_YOUR_CONFIG \
--vocab_dir config/gpt2-small \
--corpus_path PATH_TO_VALID_SET \
--max_seq_len MAX_SEQ_LEN \
--stride -1 \
--checkpoint_path PATH_TO_YOUR_CHECKPOINT \
--model_type MODEL_TYPE(DRT/slide_window_lm/rpt_contriever/blk_rec_tfm/llama_with_landmark)
Eval passkey-retrieval:
python slidewin_eval.py \
--config_path PATH_TO_YOUR_CONFIG \
--vocab_dir config/gpt2-small \
--corpus_path PATH_TO_VALID_SET \
--max_seq_len MAX_SEQ_LEN \
--passkey_retrieval single/multihop/multi \
--stride -1 \
--checkpoint_path PATH_TO_YOUR_CHECKPOINT \
--model_type MODEL_TYPE(DRT/slide_window_lm/rpt_contriever/blk_rec_tfm/llama_with_landmark)
To evaluate the summarization task, you need to first generate the summary and then evaluate the generated results. The script for the generation part is as follows:
python eval/gen_summarization.py \
--model_path PATH_TO_FINETUNED_MODEL \
--model_type MODEL_TYPE \
--config_path PATH_TO_YOUR_CONFIG \
--vocab_dir config/gpt2-small \
--corpus_path /PATH_TO_PREPROCESSED_CORPUS/test.pkl \
--output_path /PATH_TO_OUTPUT/OUTPUT_NAME.pkl
If you encounter any problems, please feel free to contact us: aaron.hx AT antgroup.com