Skip to content

Commit

Permalink
omat24 example (#295)
Browse files Browse the repository at this point in the history
* omat24 example

* formatting fixed

* remove unused parsed arguments

* exists set to True for data download

* exists set to True for data download

* updated download script to move tar files into subdirectories

* bug fixed in normalization of the energy

* fixes on omat24 example

* black formatting fixed

* Update examples/omat24/train.py

* formatting fixed

---------

Co-authored-by: Massimiliano Lupo Pasini <[email protected]>
Co-authored-by: Massimiliano Lupo Pasini <[email protected]>
  • Loading branch information
3 people authored Nov 9, 2024
1 parent 98e8bc3 commit ca29030
Show file tree
Hide file tree
Showing 5 changed files with 656 additions and 0 deletions.
85 changes: 85 additions & 0 deletions examples/omat24/download_dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
import argparse
import glob
import logging
import os
import shutil


DOWNLOAD_LINKS = {
"train": {
"rattled-1000": "wget https://dl.fbaipublicfiles.com/opencatalystproject/data/omat/241018/omat/train/rattled-1000.tar.gz",
"rattled-1000-subsampled": "wget https://dl.fbaipublicfiles.com/opencatalystproject/data/omat/241018/omat/train/rattled-1000-subsampled.tar.gz",
"rattled-500": "wget https://dl.fbaipublicfiles.com/opencatalystproject/data/omat/241018/omat/train/rattled-500.tar.gz",
"rattled-500-subsampled": "wget https://dl.fbaipublicfiles.com/opencatalystproject/data/omat/241018/omat/train/rattled-500-subsampled.tar.gz",
"rattled-300": "wget https://dl.fbaipublicfiles.com/opencatalystproject/data/omat/241018/omat/train/rattled-300.tar.gz",
"rattled-300-subsampled": "wget https://dl.fbaipublicfiles.com/opencatalystproject/data/omat/241018/omat/train/rattled-300-subsampled.tar.gz",
"aimd-from-PBE-1000-npt": "wget https://dl.fbaipublicfiles.com/opencatalystproject/data/omat/241018/omat/train/aimd-from-PBE-1000-npt.tar.gz",
"aimd-from-PBE-1000-nvt": "wget https://dl.fbaipublicfiles.com/opencatalystproject/data/omat/241018/omat/train/aimd-from-PBE-1000-nvt.tar.gz",
"aimd-from-PBE-3000-npt": "wget https://dl.fbaipublicfiles.com/opencatalystproject/data/omat/241018/omat/train/aimd-from-PBE-3000-npt.tar.gz",
"aimd-from-PBE-3000-nvt": "wget https://dl.fbaipublicfiles.com/opencatalystproject/data/omat/241018/omat/train/aimd-from-PBE-3000-nvt.tar.gz",
"rattled-relax": "wget https://dl.fbaipublicfiles.com/opencatalystproject/data/omat/241018/omat/train/rattled-relax.tar.gz",
},
"val": {
"rattled-1000": "wget https://dl.fbaipublicfiles.com/opencatalystproject/data/omat/241018/omat/val/rattled-1000.tar.gz",
"rattled-1000-subsampled": "wget https://dl.fbaipublicfiles.com/opencatalystproject/data/omat/241018/omat/val/rattled-1000-subsampled.tar.gz",
"rattled-500": "wget https://dl.fbaipublicfiles.com/opencatalystproject/data/omat/241018/omat/val/rattled-500.tar.gz",
"rattled-500-subsampled": "wget https://dl.fbaipublicfiles.com/opencatalystproject/data/omat/241018/omat/val/rattled-500-subsampled.tar.gz",
"rattled-300": "wget https://dl.fbaipublicfiles.com/opencatalystproject/data/omat/241018/omat/val/rattled-300.tar.gz",
"rattled-300-subsampled": "wget https://dl.fbaipublicfiles.com/opencatalystproject/data/omat/241018/omat/val/rattled-300-subsampled.tar.gz",
"aimd-from-PBE-1000-npt": "wget https://dl.fbaipublicfiles.com/opencatalystproject/data/omat/241018/omat/val/aimd-from-PBE-1000-npt.tar.gz",
"aimd-from-PBE-1000-nvt": "wget https://dl.fbaipublicfiles.com/opencatalystproject/data/omat/241018/omat/val/aimd-from-PBE-1000-nvt.tar.gz",
"aimd-from-PBE-3000-npt": "wget https://dl.fbaipublicfiles.com/opencatalystproject/data/omat/241018/omat/val/aimd-from-PBE-3000-npt.tar.gz",
"aimd-from-PBE-3000-nvt": "wget https://dl.fbaipublicfiles.com/opencatalystproject/data/omat/241018/omat/val/aimd-from-PBE-3000-nvt.tar.gz",
"rattled-relax": "wget https://dl.fbaipublicfiles.com/opencatalystproject/data/omat/241018/omat/val/rattled-relax.tar.gz",
},
}


assert (
DOWNLOAD_LINKS["train"].keys() == DOWNLOAD_LINKS["val"].keys()
), "data partition names in train do not match with equivalent names in val"
dataset_names = list(DOWNLOAD_LINKS["train"].keys())


def get_data(datadir, task, split):
os.makedirs(datadir, exist_ok=True)

if (task == "train" or task == "val") and split is None:
raise NotImplementedError(f"{task} requires a split to be defined.")

assert (
split in DOWNLOAD_LINKS[task]
), f'{task}/{split}" split not defined, please specify one of the following: {list(DOWNLOAD_LINKS[task].keys())}'
download_link = DOWNLOAD_LINKS[task][split]

os.makedirs(os.path.join(datadir, task), exist_ok=True)

os.system(f"wget {download_link} -P {datadir}")
filename = os.path.join(datadir, os.path.basename(download_link))

# Move the directory
new_filename = os.path.join(datadir, task, os.path.basename(download_link))
shutil.move(filename, new_filename)

logging.info("Extracting contents...")
os.system(f"tar -xvf {new_filename} -C {os.path.join(datadir, task)}")


if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--data-path",
type=str,
default="./dataset",
help="Specify path to save datasets. Defaults to './dataset'",
)

args, _ = parser.parse_known_args()

for task in ["train", "val"]:
for split in dataset_names:
get_data(
datadir=args.data_path,
task=task,
split=split,
)
1 change: 1 addition & 0 deletions examples/omat24/example_requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
fairchem-core
58 changes: 58 additions & 0 deletions examples/omat24/omat24_energy.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
{
"Verbosity": {
"level": 2
},
"NeuralNetwork": {
"Architecture": {
"model_type": "EGNN",
"equivariance": true,
"radius": 5.0,
"max_neighbours": 100000,
"num_gaussians": 50,
"envelope_exponent": 5,
"int_emb_size": 64,
"basis_emb_size": 8,
"out_emb_size": 128,
"num_after_skip": 2,
"num_before_skip": 1,
"num_radial": 6,
"num_spherical": 7,
"num_filters": 126,
"edge_features": ["length"],
"hidden_dim": 50,
"num_conv_layers": 3,
"output_heads": {
"graph":{
"num_sharedlayers": 2,
"dim_sharedlayers": 50,
"num_headlayers": 2,
"dim_headlayers": [50,25]
}
},
"task_weights": [1.0]
},
"Variables_of_interest": {
"input_node_features": [0, 1, 2, 3],
"output_names": ["energy"],
"output_index": [0],
"output_dim": [1],
"type": ["graph"]
},
"Training": {
"num_epoch": 50,
"perc_train": 0.8,
"loss_function_type": "mae",
"batch_size": 32,
"continue": 0,
"Optimizer": {
"type": "AdamW",
"learning_rate": 1e-3
}
}
},
"Visualization": {
"plot_init_solution": true,
"plot_hist_solution": false,
"create_plots": true
}
}
58 changes: 58 additions & 0 deletions examples/omat24/omat24_forces.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
{
"Verbosity": {
"level": 2
},
"NeuralNetwork": {
"Architecture": {
"model_type": "EGNN",
"equivariance": true,
"radius": 5.0,
"max_neighbours": 100000,
"num_gaussians": 50,
"envelope_exponent": 5,
"int_emb_size": 64,
"basis_emb_size": 8,
"out_emb_size": 128,
"num_after_skip": 2,
"num_before_skip": 1,
"num_radial": 6,
"num_spherical": 7,
"num_filters": 126,
"edge_features": ["length"],
"hidden_dim": 50,
"num_conv_layers": 3,
"output_heads": {
"node": {
"num_headlayers": 2,
"dim_headlayers": [200,200],
"type": "mlp"
}
},
"task_weights": [1.0]
},
"Variables_of_interest": {
"input_node_features": [0, 1, 2, 3],
"output_names": ["forces"],
"output_index": [2],
"output_dim": [3],
"type": ["node"]
},
"Training": {
"num_epoch": 50,
"EarlyStopping": true,
"perc_train": 0.9,
"loss_function_type": "mae",
"batch_size": 32,
"continue": 0,
"Optimizer": {
"type": "AdamW",
"learning_rate": 1e-3
}
}
},
"Visualization": {
"plot_init_solution": true,
"plot_hist_solution": false,
"create_plots": true
}
}
Loading

0 comments on commit ca29030

Please sign in to comment.