CrystaLLM has been evaluated on the Perov-5, Carbon-24, MP-20 and MPTS-52 benchmarks. This document describes how to reproduce the benchmark evaluation experiments described in the paper. The Perov-5, Carbon-24 and MP-20 datasets are from the CDVAE repository, while the MPTS-52 dataset is from the DiffCSP repository.
To assess CrystaLLM on a benchmark, a sequence of steps is required to:
- pre-process the original benchmark dataset
- tokenize the pre-processed CIF files
- train the model
- generate CIF files
- post-process the generated CIF files
- compute the benchmark metrics.
Below, we will provide instructions using the Perov-5 benchmark as an example.
Prepare the benchmark CSV files:
python bin/prepare_csv_benchmark.py resources/benchmarks/perov_5/train.csv perov_5_train_orig.tar.gz
python bin/prepare_csv_benchmark.py resources/benchmarks/perov_5/val.csv perov_5_val_orig.tar.gz
python bin/prepare_csv_benchmark.py resources/benchmarks/perov_5/test.csv perov_5_test_orig.tar.gz
Convert the .tar.gz files to .pkl.gz files for more efficient processing:
python bin/tar_to_pickle.py perov_5_train_orig.tar.gz perov_5_train_orig.pkl.gz
python bin/tar_to_pickle.py perov_5_val_orig.tar.gz perov_5_val_orig.pkl.gz
python bin/tar_to_pickle.py perov_5_test_orig.tar.gz perov_5_test_orig.pkl.gz
Pre-process the benchmark CIF files:
python bin/preprocess.py perov_5_train_orig.pkl.gz --out perov_5_train_prep.pkl.gz --workers 4
python bin/preprocess.py perov_5_val_orig.pkl.gz --out perov_5_val_prep.pkl.gz --workers 4
python bin/preprocess.py perov_5_test_orig.pkl.gz --out perov_5_test_prep.pkl.gz --workers 4
Tokenize the benchmark training and validation sets:
python bin/tokenize_cifs.py \
--train_fname perov_5_train_prep.pkl.gz \
--val_fname perov_5_val_prep.pkl.gz \
--out_dir tokens_perov_5/ \
--workers 4
Train a CrystaLLM model from scratch using only the benchmark training set:
python bin/train.py --config=config/crystallm_perov_5_small.yaml device=cuda dtype=float16
Generate the prompts from the CIF files of the test set:
python bin/make_prompts.py perov_5_test_prep.pkl.gz -o prompts_perov_5_test.tar.gz
Generate the CIF files, performing 20 generation attempts from each of the prompts:
python bin/generate_cifs.py \
--model crystallm_perov_5_small \
--prompts prompts_perov_5_test.tar.gz \
--out gen_perov_5_small_raw.tar.gz \
--device cuda \
--num-gens 20
Post-process the generated CIF files:
python bin/postprocess.py gen_perov_5_small_raw.tar.gz gen_perov_5_small.tar.gz
To compute the benchmark metrics using all available generations:
python bin/benchmark_metrics.py gen_perov_5_small.tar.gz perov_5_test_orig.tar.gz
To compute the benchmark metrics for the first generation attempt only (i.e. the n=1 case):
python bin/benchmark_metrics.py gen_perov_5_small.tar.gz perov_5_test_orig.tar.gz --num-gens 1
The benchmark_metrics.py
script will process the given collection of CIF files, and compare them to the original CIF
files. When all the processing is complete, the match rate and RMSE will be printed to the console.
The steps above can be performed with any of the other benchmark datasets as well, by simply substituting perov_5
with carbon_24
, mp_20
, or mpts_52
. For example, to apply this pipeline to Carbon-24, use
carbon_24_train_orig.tar.gz
in place of perov_5_train_orig.tar.gz
, and so forth. The benchmark datasets are located
in the resources/benchmarks folder.
Since the benchmarking pipeline has previously been applied to all the benchmark datasets, we've uploaded the generated
artifacts for convenience and reproducibility. All artifacts generated by the pipeline can therefore be downloaded
directly using bin/download.py
. For example, the tokenized Perov-5 dataset can be downloaded directly using
python bin/download.py tokens_perov_5.tar.gz
Also, the generated CIF files for a benchmark can be downloaded directly:
python bin/download.py gen_perov_5_small.tar.gz
All models come in two sizes: small
and large
. One exception is the model trained on the full 2.3M-compound dataset
minus the MPTS-52 validation and training sets, crystallm_v1_minus_mpts_52_small.tar.gz
, which is available in the
small size only. Its generated CIF files can be found in gen_v1_minus_mpts_52_small.tar.gz
(and the tokens are in
tokens_v1_minus_mpts_52.tar.gz
).