From ca29030e6b68e4a76442c99cd34d786f1f9b7458 Mon Sep 17 00:00:00 2001 From: Massimiliano Lupo Pasini Date: Sat, 9 Nov 2024 11:18:15 -0500 Subject: [PATCH] omat24 example (#295) * omat24 example * formatting fixed * remove unused parsed arguments * exists set to True for data download * exists set to True for data download * updated download script to move tar files into subdirectories * bug fixed in normalization of the energy * fixes on omat24 example * black formatting fixed * Update examples/omat24/train.py * formatting fixed --------- Co-authored-by: Massimiliano Lupo Pasini Co-authored-by: Massimiliano Lupo Pasini --- examples/omat24/download_dataset.py | 85 +++++ examples/omat24/example_requirements.txt | 1 + examples/omat24/omat24_energy.json | 58 +++ examples/omat24/omat24_forces.json | 58 +++ examples/omat24/train.py | 454 +++++++++++++++++++++++ 5 files changed, 656 insertions(+) create mode 100644 examples/omat24/download_dataset.py create mode 100644 examples/omat24/example_requirements.txt create mode 100644 examples/omat24/omat24_energy.json create mode 100644 examples/omat24/omat24_forces.json create mode 100644 examples/omat24/train.py diff --git a/examples/omat24/download_dataset.py b/examples/omat24/download_dataset.py new file mode 100644 index 000000000..521aa0567 --- /dev/null +++ b/examples/omat24/download_dataset.py @@ -0,0 +1,85 @@ +import argparse +import glob +import logging +import os +import shutil + + +DOWNLOAD_LINKS = { + "train": { + "rattled-1000": "wget https://dl.fbaipublicfiles.com/opencatalystproject/data/omat/241018/omat/train/rattled-1000.tar.gz", + "rattled-1000-subsampled": "wget https://dl.fbaipublicfiles.com/opencatalystproject/data/omat/241018/omat/train/rattled-1000-subsampled.tar.gz", + "rattled-500": "wget https://dl.fbaipublicfiles.com/opencatalystproject/data/omat/241018/omat/train/rattled-500.tar.gz", + "rattled-500-subsampled": "wget https://dl.fbaipublicfiles.com/opencatalystproject/data/omat/241018/omat/train/rattled-500-subsampled.tar.gz", + "rattled-300": "wget https://dl.fbaipublicfiles.com/opencatalystproject/data/omat/241018/omat/train/rattled-300.tar.gz", + "rattled-300-subsampled": "wget https://dl.fbaipublicfiles.com/opencatalystproject/data/omat/241018/omat/train/rattled-300-subsampled.tar.gz", + "aimd-from-PBE-1000-npt": "wget https://dl.fbaipublicfiles.com/opencatalystproject/data/omat/241018/omat/train/aimd-from-PBE-1000-npt.tar.gz", + "aimd-from-PBE-1000-nvt": "wget https://dl.fbaipublicfiles.com/opencatalystproject/data/omat/241018/omat/train/aimd-from-PBE-1000-nvt.tar.gz", + "aimd-from-PBE-3000-npt": "wget https://dl.fbaipublicfiles.com/opencatalystproject/data/omat/241018/omat/train/aimd-from-PBE-3000-npt.tar.gz", + "aimd-from-PBE-3000-nvt": "wget https://dl.fbaipublicfiles.com/opencatalystproject/data/omat/241018/omat/train/aimd-from-PBE-3000-nvt.tar.gz", + "rattled-relax": "wget https://dl.fbaipublicfiles.com/opencatalystproject/data/omat/241018/omat/train/rattled-relax.tar.gz", + }, + "val": { + "rattled-1000": "wget https://dl.fbaipublicfiles.com/opencatalystproject/data/omat/241018/omat/val/rattled-1000.tar.gz", + "rattled-1000-subsampled": "wget https://dl.fbaipublicfiles.com/opencatalystproject/data/omat/241018/omat/val/rattled-1000-subsampled.tar.gz", + "rattled-500": "wget https://dl.fbaipublicfiles.com/opencatalystproject/data/omat/241018/omat/val/rattled-500.tar.gz", + "rattled-500-subsampled": "wget https://dl.fbaipublicfiles.com/opencatalystproject/data/omat/241018/omat/val/rattled-500-subsampled.tar.gz", + "rattled-300": "wget https://dl.fbaipublicfiles.com/opencatalystproject/data/omat/241018/omat/val/rattled-300.tar.gz", + "rattled-300-subsampled": "wget https://dl.fbaipublicfiles.com/opencatalystproject/data/omat/241018/omat/val/rattled-300-subsampled.tar.gz", + "aimd-from-PBE-1000-npt": "wget https://dl.fbaipublicfiles.com/opencatalystproject/data/omat/241018/omat/val/aimd-from-PBE-1000-npt.tar.gz", + "aimd-from-PBE-1000-nvt": "wget https://dl.fbaipublicfiles.com/opencatalystproject/data/omat/241018/omat/val/aimd-from-PBE-1000-nvt.tar.gz", + "aimd-from-PBE-3000-npt": "wget https://dl.fbaipublicfiles.com/opencatalystproject/data/omat/241018/omat/val/aimd-from-PBE-3000-npt.tar.gz", + "aimd-from-PBE-3000-nvt": "wget https://dl.fbaipublicfiles.com/opencatalystproject/data/omat/241018/omat/val/aimd-from-PBE-3000-nvt.tar.gz", + "rattled-relax": "wget https://dl.fbaipublicfiles.com/opencatalystproject/data/omat/241018/omat/val/rattled-relax.tar.gz", + }, +} + + +assert ( + DOWNLOAD_LINKS["train"].keys() == DOWNLOAD_LINKS["val"].keys() +), "data partition names in train do not match with equivalent names in val" +dataset_names = list(DOWNLOAD_LINKS["train"].keys()) + + +def get_data(datadir, task, split): + os.makedirs(datadir, exist_ok=True) + + if (task == "train" or task == "val") and split is None: + raise NotImplementedError(f"{task} requires a split to be defined.") + + assert ( + split in DOWNLOAD_LINKS[task] + ), f'{task}/{split}" split not defined, please specify one of the following: {list(DOWNLOAD_LINKS[task].keys())}' + download_link = DOWNLOAD_LINKS[task][split] + + os.makedirs(os.path.join(datadir, task), exist_ok=True) + + os.system(f"wget {download_link} -P {datadir}") + filename = os.path.join(datadir, os.path.basename(download_link)) + + # Move the directory + new_filename = os.path.join(datadir, task, os.path.basename(download_link)) + shutil.move(filename, new_filename) + + logging.info("Extracting contents...") + os.system(f"tar -xvf {new_filename} -C {os.path.join(datadir, task)}") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument( + "--data-path", + type=str, + default="./dataset", + help="Specify path to save datasets. Defaults to './dataset'", + ) + + args, _ = parser.parse_known_args() + + for task in ["train", "val"]: + for split in dataset_names: + get_data( + datadir=args.data_path, + task=task, + split=split, + ) diff --git a/examples/omat24/example_requirements.txt b/examples/omat24/example_requirements.txt new file mode 100644 index 000000000..5d9c15106 --- /dev/null +++ b/examples/omat24/example_requirements.txt @@ -0,0 +1 @@ +fairchem-core \ No newline at end of file diff --git a/examples/omat24/omat24_energy.json b/examples/omat24/omat24_energy.json new file mode 100644 index 000000000..378a01cd7 --- /dev/null +++ b/examples/omat24/omat24_energy.json @@ -0,0 +1,58 @@ +{ + "Verbosity": { + "level": 2 + }, + "NeuralNetwork": { + "Architecture": { + "model_type": "EGNN", + "equivariance": true, + "radius": 5.0, + "max_neighbours": 100000, + "num_gaussians": 50, + "envelope_exponent": 5, + "int_emb_size": 64, + "basis_emb_size": 8, + "out_emb_size": 128, + "num_after_skip": 2, + "num_before_skip": 1, + "num_radial": 6, + "num_spherical": 7, + "num_filters": 126, + "edge_features": ["length"], + "hidden_dim": 50, + "num_conv_layers": 3, + "output_heads": { + "graph":{ + "num_sharedlayers": 2, + "dim_sharedlayers": 50, + "num_headlayers": 2, + "dim_headlayers": [50,25] + } + }, + "task_weights": [1.0] + }, + "Variables_of_interest": { + "input_node_features": [0, 1, 2, 3], + "output_names": ["energy"], + "output_index": [0], + "output_dim": [1], + "type": ["graph"] + }, + "Training": { + "num_epoch": 50, + "perc_train": 0.8, + "loss_function_type": "mae", + "batch_size": 32, + "continue": 0, + "Optimizer": { + "type": "AdamW", + "learning_rate": 1e-3 + } + } + }, + "Visualization": { + "plot_init_solution": true, + "plot_hist_solution": false, + "create_plots": true + } +} diff --git a/examples/omat24/omat24_forces.json b/examples/omat24/omat24_forces.json new file mode 100644 index 000000000..6eec50b28 --- /dev/null +++ b/examples/omat24/omat24_forces.json @@ -0,0 +1,58 @@ +{ + "Verbosity": { + "level": 2 + }, + "NeuralNetwork": { + "Architecture": { + "model_type": "EGNN", + "equivariance": true, + "radius": 5.0, + "max_neighbours": 100000, + "num_gaussians": 50, + "envelope_exponent": 5, + "int_emb_size": 64, + "basis_emb_size": 8, + "out_emb_size": 128, + "num_after_skip": 2, + "num_before_skip": 1, + "num_radial": 6, + "num_spherical": 7, + "num_filters": 126, + "edge_features": ["length"], + "hidden_dim": 50, + "num_conv_layers": 3, + "output_heads": { + "node": { + "num_headlayers": 2, + "dim_headlayers": [200,200], + "type": "mlp" + } + }, + "task_weights": [1.0] + }, + "Variables_of_interest": { + "input_node_features": [0, 1, 2, 3], + "output_names": ["forces"], + "output_index": [2], + "output_dim": [3], + "type": ["node"] + }, + "Training": { + "num_epoch": 50, + "EarlyStopping": true, + "perc_train": 0.9, + "loss_function_type": "mae", + "batch_size": 32, + "continue": 0, + "Optimizer": { + "type": "AdamW", + "learning_rate": 1e-3 + } + } + }, + "Visualization": { + "plot_init_solution": true, + "plot_hist_solution": false, + "create_plots": true + } +} diff --git a/examples/omat24/train.py b/examples/omat24/train.py new file mode 100644 index 000000000..ac32973b7 --- /dev/null +++ b/examples/omat24/train.py @@ -0,0 +1,454 @@ +import os, re, json +import logging +import sys +from mpi4py import MPI +import argparse + +import random + +import torch + +# FIX random seed +random_state = 0 +torch.manual_seed(random_state) + +from torch_geometric.data import Data +from torch_geometric.transforms import RadiusGraph, Distance + +import hydragnn +from hydragnn.utils.profiling_and_tracing.time_utils import Timer +from hydragnn.utils.model import print_model +from hydragnn.utils.datasets.abstractbasedataset import AbstractBaseDataset +from hydragnn.utils.datasets.distdataset import DistDataset +from hydragnn.utils.datasets.pickledataset import ( + SimplePickleWriter, + SimplePickleDataset, +) +from hydragnn.preprocess.graph_samples_checks_and_updates import gather_deg +from hydragnn.preprocess.load_data import split_dataset + +import hydragnn.utils.profiling_and_tracing.tracer as tr + +from hydragnn.utils.print.print_utils import iterate_tqdm, log + +try: + from hydragnn.utils.datasets.adiosdataset import AdiosWriter, AdiosDataset +except ImportError: + pass + +import subprocess +from hydragnn.utils.distributed import nsplit + +## FIMME +torch.backends.cudnn.enabled = False + +from fairchem.core.datasets import AseDBDataset + + +def info(*args, logtype="info", sep=" "): + getattr(logging, logtype)(sep.join(map(str, args))) + + +# FIXME: this radis cutoff overwrites the radius cutoff currently written in the JSON file +create_graph_fromXYZ = RadiusGraph(r=5.0) # radius cutoff in angstrom +compute_edge_lengths = Distance(norm=False, cat=True) + + +dataset_names = [ + "rattled-1000", + "rattled-1000-subsampled", + "rattled-500", + "rattled-500-subsampled", + "rattled-300", + "rattled-300-subsampled", + "aimd-from-PBE-1000-npt", + "aimd-from-PBE-1000-nvt", + "aimd-from-PBE-3000-npt", + "aimd-from-PBE-3000-nvt", + "rattled-relax", +] + + +class OMat2024(AbstractBaseDataset): + def __init__( + self, dirpath, var_config, data_type, energy_per_atom=True, dist=False + ): + super().__init__() + + assert (data_type == "train") or ( + data_type == "val" + ), "data_type must be a string either equal to 'train' or to 'val'" + + self.var_config = var_config + self.data_path = os.path.join(dirpath, data_type) + self.energy_per_atom = energy_per_atom + + config_kwargs = {} # see tutorial on additional configuration + + # Threshold for atomic forces in eV/angstrom + self.forces_norm_threshold = 1000.0 + + self.dist = dist + if self.dist: + assert torch.distributed.is_initialized() + self.world_size = torch.distributed.get_world_size() + self.rank = torch.distributed.get_rank() + + torch.distributed.barrier() + + for dataname in dataset_names: + + print(f"Rank {self.rank} reading {data_type}/{dataname} ... ", flush=True) + + dataset = AseDBDataset( + config=dict( + src=os.path.join(dirpath, data_type, dataname), **config_kwargs + ) + ) + + rx = list(nsplit(list(range(dataset.num_samples)), self.world_size))[ + self.rank + ] + + for index in iterate_tqdm(rx, verbosity_level=2): + try: + xyz = torch.tensor( + dataset.get_atoms(index).get_positions(), dtype=torch.float32 + ) + natoms = torch.IntTensor([xyz.shape[0]]) + Z = torch.tensor( + dataset.get_atoms(index).get_atomic_numbers(), + dtype=torch.float32, + ).unsqueeze(1) + energy = torch.tensor( + dataset.get_atoms(index).get_total_energy(), dtype=torch.float32 + ).unsqueeze(0) + forces = torch.tensor( + dataset.get_atoms(index).get_forces(), dtype=torch.float32 + ) + chemical_formula = dataset.get_atoms(index).get_chemical_formula() + + if self.energy_per_atom: + energy /= natoms.item() + + data = Data(pos=xyz, x=Z, force=forces, energy=energy, y=energy) + data.x = torch.cat((data.x, xyz, forces), dim=1) + data = create_graph_fromXYZ(data) + + # Add edge length as edge feature + data = compute_edge_lengths(data) + data.edge_attr = data.edge_attr.to(torch.float32) + if self.check_forces_values(data.force): + self.dataset.append(data) + else: + print( + f"L2-norm of force tensor is {data.force.norm()} and exceeds threshold {self.forces_norm_threshold} - atomistic structure: {chemical_formula}", + flush=True, + ) + + except Exception as e: + print(f"Rank {self.rank} reading - exception: ", e) + + torch.distributed.barrier() + + random.shuffle(self.dataset) + + def check_forces_values(self, forces): + + # Calculate the L2 norm for each row + norms = torch.norm(forces, p=2, dim=1) + # Check if all norms are less than the threshold + + return torch.all(norms < self.forces_norm_threshold).item() + + def len(self): + return len(self.dataset) + + def get(self, idx): + return self.dataset[idx] + + +if __name__ == "__main__": + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + parser.add_argument("--sampling", type=float, help="sampling ratio", default=None) + parser.add_argument( + "--preonly", + action="store_true", + help="preprocess only (no training)", + ) + parser.add_argument( + "--inputfile", help="input file", type=str, default="omat24_energy.json" + ) + parser.add_argument( + "--energy_per_atom", + help="option to normalize energy by number of atoms", + type=bool, + default=True, + ) + parser.add_argument("--ddstore", action="store_true", help="ddstore dataset") + parser.add_argument("--ddstore_width", type=int, help="ddstore width", default=None) + parser.add_argument("--shmem", action="store_true", help="shmem") + parser.add_argument("--log", help="log name") + parser.add_argument("--batch_size", type=int, help="batch_size", default=None) + parser.add_argument("--num_epoch", type=int, help="num_epoch", default=None) + parser.add_argument("--everyone", action="store_true", help="gptimer") + parser.add_argument("--modelname", help="model name") + + group = parser.add_mutually_exclusive_group() + group.add_argument( + "--adios", + help="Adios dataset", + action="store_const", + dest="format", + const="adios", + ) + group.add_argument( + "--pickle", + help="Pickle dataset", + action="store_const", + dest="format", + const="pickle", + ) + parser.set_defaults(format="adios") + args = parser.parse_args() + + graph_feature_names = ["energy"] + graph_feature_dims = [1] + node_feature_names = ["atomic_number", "cartesian_coordinates", "forces"] + node_feature_dims = [1, 3, 3] + dirpwd = os.path.dirname(os.path.abspath(__file__)) + datadir = os.path.join(dirpwd, "dataset") + ################################################################################################################## + input_filename = os.path.join(dirpwd, args.inputfile) + ################################################################################################################## + # Configurable run choices (JSON file that accompanies this example script). + with open(input_filename, "r") as f: + config = json.load(f) + verbosity = config["Verbosity"]["level"] + var_config = config["NeuralNetwork"]["Variables_of_interest"] + var_config["graph_feature_names"] = graph_feature_names + var_config["graph_feature_dims"] = graph_feature_dims + var_config["node_feature_names"] = node_feature_names + var_config["node_feature_dims"] = node_feature_dims + + if args.batch_size is not None: + config["NeuralNetwork"]["Training"]["batch_size"] = args.batch_size + + if args.num_epoch is not None: + config["NeuralNetwork"]["Training"]["num_epoch"] = args.num_epoch + + ################################################################################################################## + # Always initialize for multi-rank training. + comm_size, rank = hydragnn.utils.distributed.setup_ddp() + ################################################################################################################## + + comm = MPI.COMM_WORLD + + ## Set up logging + logging.basicConfig( + level=logging.INFO, + format="%%(levelname)s (rank %d): %%(message)s" % (rank), + datefmt="%H:%M:%S", + ) + + log_name = "OMat24" if args.log is None else args.log + hydragnn.utils.print.setup_log(log_name) + writer = hydragnn.utils.model.get_summary_writer(log_name) + + log("Command: {0}\n".format(" ".join([x for x in sys.argv])), rank=0) + + modelname = "OMat24" if args.modelname is None else args.modelname + if args.preonly: + ## local data + trainset = OMat2024( + os.path.join(datadir), + var_config, + data_type="train", + energy_per_atom=args.energy_per_atom, + dist=True, + ) + ## This is a local split + trainset, valset1, valset2 = split_dataset( + dataset=trainset, + perc_train=0.9, + stratify_splitting=False, + ) + valset = [*valset1, *valset2] + testset = OMat2024( + os.path.join(datadir), var_config, data_type="val", dist=True + ) + ## Need as a list + testset = testset[:] + print(rank, "Local splitting: ", len(trainset), len(valset), len(testset)) + + deg = gather_deg(trainset) + config["pna_deg"] = deg + + setnames = ["trainset", "valset", "testset"] + + ## adios + if args.format == "adios": + fname = os.path.join( + os.path.dirname(__file__), "./dataset/%s.bp" % modelname + ) + adwriter = AdiosWriter(fname, comm) + adwriter.add("trainset", trainset) + adwriter.add("valset", valset) + adwriter.add("testset", testset) + # adwriter.add_global("minmax_node_feature", total.minmax_node_feature) + # adwriter.add_global("minmax_graph_feature", total.minmax_graph_feature) + adwriter.add_global("pna_deg", deg) + adwriter.save() + + ## pickle + elif args.format == "pickle": + basedir = os.path.join( + os.path.dirname(__file__), "dataset", "%s.pickle" % modelname + ) + attrs = dict() + attrs["pna_deg"] = deg + SimplePickleWriter( + trainset, + basedir, + "trainset", + # minmax_node_feature=total.minmax_node_feature, + # minmax_graph_feature=total.minmax_graph_feature, + use_subdir=True, + attrs=attrs, + ) + SimplePickleWriter( + valset, + basedir, + "valset", + # minmax_node_feature=total.minmax_node_feature, + # minmax_graph_feature=total.minmax_graph_feature, + use_subdir=True, + ) + SimplePickleWriter( + testset, + basedir, + "testset", + # minmax_node_feature=total.minmax_node_feature, + # minmax_graph_feature=total.minmax_graph_feature, + use_subdir=True, + ) + sys.exit(0) + + tr.initialize() + tr.disable() + timer = Timer("load_data") + timer.start() + + if args.format == "adios": + info("Adios load") + assert not (args.shmem and args.ddstore), "Cannot use both ddstore and shmem" + opt = { + "preload": False, + "shmem": args.shmem, + "ddstore": args.ddstore, + "ddstore_width": args.ddstore_width, + } + fname = os.path.join(os.path.dirname(__file__), "./dataset/%s.bp" % modelname) + trainset = AdiosDataset(fname, "trainset", comm, **opt, var_config=var_config) + valset = AdiosDataset(fname, "valset", comm, **opt, var_config=var_config) + testset = AdiosDataset(fname, "testset", comm, **opt, var_config=var_config) + elif args.format == "pickle": + info("Pickle load") + basedir = os.path.join( + os.path.dirname(__file__), "dataset", "%s.pickle" % modelname + ) + trainset = SimplePickleDataset( + basedir=basedir, label="trainset", var_config=var_config + ) + valset = SimplePickleDataset( + basedir=basedir, label="valset", var_config=var_config + ) + testset = SimplePickleDataset( + basedir=basedir, label="testset", var_config=var_config + ) + # minmax_node_feature = trainset.minmax_node_feature + # minmax_graph_feature = trainset.minmax_graph_feature + pna_deg = trainset.pna_deg + if args.ddstore: + opt = {"ddstore_width": args.ddstore_width} + trainset = DistDataset(trainset, "trainset", comm, **opt) + valset = DistDataset(valset, "valset", comm, **opt) + testset = DistDataset(testset, "testset", comm, **opt) + # trainset.minmax_node_feature = minmax_node_feature + # trainset.minmax_graph_feature = minmax_graph_feature + trainset.pna_deg = pna_deg + else: + raise NotImplementedError("No supported format: %s" % (args.format)) + + info( + "trainset,valset,testset size: %d %d %d" + % (len(trainset), len(valset), len(testset)) + ) + + if args.ddstore: + os.environ["HYDRAGNN_AGGR_BACKEND"] = "mpi" + os.environ["HYDRAGNN_USE_ddstore"] = "1" + + (train_loader, val_loader, test_loader,) = hydragnn.preprocess.create_dataloaders( + trainset, valset, testset, config["NeuralNetwork"]["Training"]["batch_size"] + ) + + config = hydragnn.utils.input_config_parsing.update_config( + config, train_loader, val_loader, test_loader + ) + ## Good to sync with everyone right after DDStore setup + comm.Barrier() + + hydragnn.utils.input_config_parsing.save_config(config, log_name) + + timer.stop() + + model = hydragnn.models.create_model_config( + config=config["NeuralNetwork"], + verbosity=verbosity, + ) + model = hydragnn.utils.distributed.get_distributed_model(model, verbosity) + + # Print details of neural network architecture + print_model(model) + + learning_rate = config["NeuralNetwork"]["Training"]["Optimizer"]["learning_rate"] + optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate) + scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau( + optimizer, mode="min", factor=0.5, patience=5, min_lr=0.00001 + ) + + hydragnn.utils.model.load_existing_model_config( + model, config["NeuralNetwork"]["Training"], optimizer=optimizer + ) + + ################################################################################################################## + + hydragnn.train.train_validate_test( + model, + optimizer, + train_loader, + val_loader, + test_loader, + writer, + scheduler, + config["NeuralNetwork"], + log_name, + verbosity, + create_plots=False, + ) + + hydragnn.utils.model.save_model(model, optimizer, log_name) + hydragnn.utils.profiling_and_tracing.print_timers(verbosity) + + if tr.has("GPTLTracer"): + import gptl4py as gp + + eligible = rank if args.everyone else 0 + if rank == eligible: + gp.pr_file(os.path.join("logs", log_name, "gp_timing.p%d" % rank)) + gp.pr_summary_file(os.path.join("logs", log_name, "gp_timing.summary")) + gp.finalize() + sys.exit(0)