The repository contains code to:
- Train the embedding model
- Build nearest neighbor index on the Pile
- Run distributed servers on top of the index
- Query the servers
- Evaluate TTT-NN on the Pile
- Run baselines
This project is based on the paper Test-Time Training on Nearest Neighbors for Large Language Models by Moritz Hardt and Yu Sun, in ICLR 2024. Please cite as:
@inproceedings{hardt2024test,
title={Test-time training on nearest neighbors for large language models},
author={Hardt, Moritz and Sun, Yu},
booktitle={International Conference on Learning Representations},
year={2024}
}
To evaluate TTT-NN you ultimately need the following directory structure:
indexes/
roberta-large/
00.jsonl.index
01.jsonl.index
...
29.jsonl.index
models/
roberta-large-pile-lr2e-5-bs16-8gpu/
checkpoint-1700000/
pile/
train/
00.jsonl
01.jsonl
...
29.jsonl
val.jsonl
test.jsonl
servers/
addresses.txt
Download the dataset here and place the files in the pile/
subdirectory.
You can download the pretrained embedding model from HuggingFace. Place the files in the directory models/roberta-large-pile-lr2e-5-bs16-8gpu/checkpoint-1700000
.
To train the embedding model yourself, see code/trainer_lm.py
. This is a standard HuggingFace training setup.
This code was used to produce the model checkpoint checkpoint-1700000
in the models
directory.
The model trained for approximately one month on 8 A100 GPUs, making one pass over the data.
Make sure to have the checkpoint models/roberta-large-pile-lr2e-5-bs16-8gpu/checkpoint-1700000
before you proceed.
You can download the index files here. The download size is approximately 800GB. Place the files in the directory indexes/roberta-large
.
To build the index yourself, use the code in code/build_database.py
to build an index on top of the Pile dataset. This is a time consuming operation. Specify --data_file
to build index for given data file.
python3 code/build_database.py \
--data_file pile/train/00.jsonl \
--output_dir indexes/roberta-large
Make sure you have all index files in indexes/roberta-large
before you proceed.
The following command will launch a server with 6 replicas each serving one split of the data. This will append 6 ip addresses and ports to the file specified as address_path
.
python3 code/pile_server.py \
--address_path servers/addresses.txt \
--data_file pile/train/00.jsonl \
--num_servers 6
To serve from all Pile data files, start one server for each data file. We recommend starting 30 servers with 6 replicas each, resulting in 180 instances running.
Make sure servers are up and running before launching evaluation.
Use code/pile_client.py
to query the server. Specify --address_path
to indicate which servers to query. The client will query all servers it finds under the address path and query each. The client then builds a local nearest neighbors structure to find the nearest neighbors among all the retrieved results.
The client code can be used as a standalone client, but will also be called from the evaluation code.
To evaluate on GPTNeo with default parameters:
python3 code/eval_tttlm.py \
--address_path servers/addresses.txt \
--results_dir results/
To evaluate on GPT2:
python3 code/eval_tttlm.py \
--model gpt2-large \
--tokenizer gpt2-large \
--embedding_model_checkpoint models/roberta-large-pile-lr2e-5-bs16-8gpu/checkpoint-1700000
--max_length 1024 \
--stride 1024 \
--learning_rate 2e-5 \
--address_path servers/addresses.txt \
--results_dir results/
Replace gpt2
with gpt2-large
to evaluate on GPT2Large.
Use code/process_results.py
to merge results for distributed evaluation with multiple processes, aggregate statistics, and make plots.
The evaluation code uses modified parts of the lm_evaluation-harness
package by Eleuther-AI.
The folder lm_eval
contains the modified parts (also lm_eval_interpolation
for the interpolation baseline), as well as the unmodified parts to ensure compatibility.
- See
code/baseline_context.py
for in-context baseline. Alsocode/process_results_context.py
. - See
code/baseline_interpolation.py
for interpolation baseline. Alsocode/process_results_interpolation.py
. - Run
code/eval_tttlm.py
with option--dynamic_eval
for dynamic evaluation baseline.