diff --git a/examples/multidataset/dataset_histogram_plot.py b/examples/multidataset/dataset_histogram_plot.py new file mode 100644 index 000000000..44972ebd5 --- /dev/null +++ b/examples/multidataset/dataset_histogram_plot.py @@ -0,0 +1,368 @@ +import adios2 as ad2 +import numpy as np +import pickle +import os +from tqdm import tqdm + +import matplotlib +import matplotlib.pyplot as plt +from mpl_toolkits.axes_grid1.inset_locator import inset_axes + +font = {"size": 12} +matplotlib.rc("font", **font) + + +def histplot(dataset_list): + for dataname in dataname_list: + x = np.concatenate(dataset_list[dataname]) + if len(x) > 0: + # print(dataname, x.min(), x.max(), x.mean(), x.std()) + h, bins = np.histogram(x, bins=50) + plt.figure(figsize=[6, 3]) + plt.hist(x, bins=50, density=True, log=True) + plt.title(dataname) + plt.close() + else: + print(dataname, "no data") + + +def histplot2(dataset_list, name): + datasetname = ["trainset", "valset", "testset"] + for dataname in tqdm(dataname_list, desc="hist2"): + fname = f"hist-3set-{name}-h-{dataname}.npz" + if not os.path.exists(fname): + xa = np.concatenate(dataset_list[dataname]) + h, bins = np.histogram(xa, bins=50) + np.savez(fname, h=h, bins=bins) + else: + with np.load(fname) as f: + h = f["h"] + bins = f["bins"] + + plt.figure(figsize=[6, 3]) + for i in range(3): + x = dataset_list[dataname][i] + if len(x) > 0: + # print(dataname, x.min(), x.max(), x.mean(), x.std()) + fname = f"hist-3set-{name}-h-{dataname}-{i}.npz" + if not os.path.exists(fname): + h, _ = np.histogram(x, bins=bins, density=True) + np.savez(fname, h=h) + else: + with np.load(fname) as f: + h = f["h"] + plt.bar( + 0.5 * bins[:-1] + 0.5 * bins[1:], + h, + width=bins[1] - bins[0], + alpha=0.2, + label="_Hidden", + ) + # h, _, _ = plt.hist(x, bins=bins, alpha=0.2, label="_Hidden") + xs = list() + ys = list() + xs.append(bins[0]) + ys.append(0) + for k in range(len(h)): + xs.append(bins[k]) + xs.append(bins[k + 1]) + ys.append(h[k]) + ys.append(h[k]) + xs.append(bins[-1]) + ys.append(0) + plt.plot(xs, ys, label=datasetname[i]) + else: + print(dataname, "no data") + plt.yscale("log") + plt.title(dataname.replace("-v2", "")) + plt.legend() + plt.tight_layout() + plt.savefig(f"hist_3set-{name}-{dataname}.pdf") + plt.close() + + +def histplot3(dataset_list, name): + datasetname = ["trainset", "valset", "testset"] + fig, ax = plt.subplots(1, 5, sharey=True, figsize=[16, 3]) + for p, dataname in tqdm( + enumerate(dataname_list), desc="hist3", total=len(dataname_list) + ): + fname = f"hist-3set-{name}-h-{dataname}.npz" + if not os.path.exists(fname): + xa = np.concatenate(dataset_list[dataname]) + h, bins = np.histogram(xa, bins=50) + np.savez(fname, h=h, bins=bins) + else: + with np.load(fname) as f: + h = f["h"] + bins = f["bins"] + + for i in range(3): + x = dataset_list[dataname][i] + if len(x) > 0: + # print(dataname, x.min(), x.max(), x.mean(), x.std()) + fname = f"hist-3set-{name}-h-{dataname}-{i}.npz" + if not os.path.exists(fname): + h, _ = np.histogram(x, bins=bins, density=True) + np.savez(fname, h=h) + else: + with np.load(fname) as f: + h = f["h"] + ax[p].bar( + 0.5 * bins[:-1] + 0.5 * bins[1:], + h, + width=bins[1] - bins[0], + alpha=0.2, + label="_Hidden", + ) + xs = list() + ys = list() + xs.append(bins[0]) + ys.append(0) + for k in range(len(h)): + xs.append(bins[k]) + xs.append(bins[k + 1]) + ys.append(h[k]) + ys.append(h[k]) + xs.append(bins[-1]) + ys.append(0) + ax[p].plot(xs, ys, label=datasetname[i]) + else: + print(dataname, "no data") + ax[p].set_yscale("log") + ax[p].set_title(dataname.replace("-v2", "")) + ax[p].tick_params(axis="x", labelrotation=30) + fig.subplots_adjust(wspace=0, hspace=0) + plt.legend(loc=1, prop={"size": 10}) + plt.savefig(f"hist_3set-{name}-all.pdf") + plt.close() + + +if __name__ == "__main__": + dirpwd = os.path.dirname(os.path.abspath(__file__)) + prefix = os.path.join(dirpwd, "dataset") + # dataname_list = ["ANI1x", "MPTrj", "qm7x", "OC2022", "OC2020", "OC2020-20M"] + # dataname_list = ["ANI1x-v2", "MPTrj-v2", "qm7x-v2", "OC2022-v2", "OC2020-v2", "OC2020-20M-v2"] + dataname_list = ["ANI1x-v2", "MPTrj-v2", "qm7x-v2", "OC2022-v2", "OC2020-v2"] + suffix = "-v2" + + ## atoms + natom_list = dict() + for dataname in tqdm(dataname_list, desc="atom"): + natom_list[dataname] = list() + for label in ["trainset", "valset", "testset"]: + with ad2.open(os.path.join(prefix, dataname + ".bp"), "r") as f: + f.__next__() + natom = f.read(f"{label}/pos/variable_count") + natom_list[dataname].append(natom) + + for dataname in dataname_list: + x = np.concatenate(natom_list[dataname]) + # print(dataname, x.min(), x.max(), x.mean(), x.std()) + h, bins = np.histogram(x, bins=50) + plt.figure(figsize=[6, 3]) + plt.hist(x, bins=50, density=True, log=True) + plt.title(dataname) + plt.close() + + plt.figure(figsize=[6, 3]) + for dataname in dataname_list: + x = np.concatenate(natom_list[dataname]) + h, bins = np.histogram(x, bins=50, density=True) + plt.plot( + bins[:-1], h * (bins[1] - bins[0]) * 100, label=dataname.replace("-v2", "") + ) + plt.fill_between( + bins[:-1], h * (bins[1] - bins[0]) * 100, alpha=0.5, label="_nolegend_" + ) + plt.xlabel("Num. of atoms") + plt.ylabel("Ratio (%)") + plt.legend() + plt.tight_layout() + plt.savefig(f"hist-atoms{suffix}.pdf") + plt.close() + + ## edges + edge_list = dict() + for dataname in tqdm(dataname_list, desc="edge"): + edge_list[dataname] = list() + for label in ["trainset", "valset", "testset"]: + with ad2.open(os.path.join(prefix, dataname + ".bp"), "r") as f: + f.__next__() + nedge = f.read(f"{label}/edge_attr/variable_count") + edge_list[dataname].append(nedge) + + for dataname in dataname_list: + x = np.concatenate(edge_list[dataname]) + # print(dataname, x.min(), x.max(), x.mean(), x.std()) + h, bins = np.histogram(x, bins=50) + plt.figure(figsize=[6, 3]) + plt.hist(x, bins=50, density=True, log=True) + plt.title(dataname) + plt.close() + + plt.figure(figsize=[6, 3]) + for dataname in dataname_list: + x = np.concatenate(edge_list[dataname]) + h, bins = np.histogram(x, bins=50, density=True) + plt.plot( + bins[:-1], h * (bins[1] - bins[0]) * 100, label=dataname.replace("-v2", "") + ) + plt.fill_between( + bins[:-1], h * (bins[1] - bins[0]) * 100, alpha=0.5, label="_nolegend_" + ) + plt.xlabel("Num. of edges") + plt.ylabel("Ratio (%)") + plt.legend() + plt.tight_layout() + plt.savefig(f"hist-edges{suffix}.pdf") + plt.close() + + ## energy + energy_list = dict() + for dataname in tqdm(dataname_list, desc="energy"): + energy_list[dataname] = list() + for label in ["trainset", "valset", "testset"]: + with ad2.open(os.path.join(prefix, dataname + ".bp"), "r") as f: + f.__next__() + energy = f.read(f"{label}/energy") + energy_list[dataname].append(energy) + + for dataname in dataname_list: + x = np.concatenate(energy_list[dataname]) + if len(x) > 0: + # print(dataname, x.min(), x.max(), x.mean(), x.std()) + h, bins = np.histogram(x, bins=50) + plt.figure(figsize=[6, 3]) + plt.hist(x, bins=50, density=True, log=True) + plt.title(dataname) + plt.close() + else: + print(dataname, "no data") + + plt.figure(figsize=[6, 3]) + min_list = list() + max_list = list() + for dataname in dataname_list: + min_list.append(energy_list[dataname][0].min()) + min_list.append(energy_list[dataname][1].min()) + min_list.append(energy_list[dataname][2].min()) + max_list.append(energy_list[dataname][0].max()) + max_list.append(energy_list[dataname][1].max()) + max_list.append(energy_list[dataname][2].max()) + mn, mx = min(min_list), max(max_list) + bins = np.arange(mn, mx, 0.2) + + for dataname in dataname_list: + x = np.concatenate(energy_list[dataname]) + h, bins = np.histogram(x, bins=bins, density=True) + plt.plot( + bins[:-1], h * (bins[1] - bins[0]) * 100, label=dataname.replace("-v2", "") + ) + plt.fill_between( + bins[:-1], h * (bins[1] - bins[0]) * 100, alpha=0.5, label="_nolegend_" + ) + plt.xlabel("Energy") + plt.ylabel("Ratio (%)") + plt.xlim([-1000, +10]) + plt.legend(loc="upper right", bbox_to_anchor=(0.95, 1.0)) + + ax = plt.gca() + ax = inset_axes(ax, width="40%", height="40%", loc="upper left") + for dataname in dataname_list: + x = np.concatenate(energy_list[dataname]) + h, bins = np.histogram(x, bins=bins, density=True) + ax.plot( + bins[:-1], h * (bins[1] - bins[0]) * 100, label=dataname.replace("-v2", "") + ) + ax.fill_between( + bins[:-1], h * (bins[1] - bins[0]) * 100, alpha=0.5, label="_nolegend_" + ) + ax.set_yticks([]) + + plt.tight_layout() + plt.savefig(f"hist-energy{suffix}.pdf") + plt.close() + + ## force + force_list = dict() + for dataname in tqdm(dataname_list, desc="force"): + force_list[dataname] = list() + for label in ["trainset", "valset", "testset"]: + with ad2.open(os.path.join(prefix, dataname + ".bp"), "r") as f: + f.__next__() + force = f.read(f"{label}/force") + force_list[dataname].append(force) + + # for dataname in dataname_list: + # x = np.concatenate(force_list[dataname]) + # x = np.linalg.norm(x, axis=-1) + # # print(dataname, x.min(), x.max(), x.mean(), x.std()) + # h, bins = np.histogram(x, bins=50) + # plt.figure(figsize=[6, 3]) + # plt.hist(x, bins=50, density=True, log=True) + # plt.title(dataname) + # plt.close() + + min_list = list() + max_list = list() + for dataname in dataname_list: + min_list.append(force_list[dataname][0].min()) + min_list.append(force_list[dataname][1].min()) + min_list.append(force_list[dataname][2].min()) + max_list.append(force_list[dataname][0].max()) + max_list.append(force_list[dataname][1].max()) + max_list.append(force_list[dataname][2].max()) + mn, mx = min(min_list), max(max_list) + bins = np.arange(mn, mx, 0.2) + + h_list = dict() + for dataname in tqdm(dataname_list, desc="hist"): + fname = f"hist-h-{dataname}.npz" + if not os.path.exists(fname): + x = np.concatenate(force_list[dataname]) + x = np.linalg.norm(x, axis=-1) + h, bins = np.histogram(x, bins=bins, density=True) + np.savez(f"hist-h-{dataname}.npz", h=h) + h_list[dataname] = h + else: + with np.load(fname) as f: + h = np.load(fname)["h"] + h_list[dataname] = h + + plt.figure(figsize=[6, 3]) + for dataname in dataname_list: + h = h_list[dataname] + plt.plot( + bins[:-1], h * (bins[1] - bins[0]) * 100, label=dataname.replace("-v2", "") + ) + plt.fill_between( + bins[:-1], h * (bins[1] - bins[0]) * 100, alpha=0.5, label="_nolegend_" + ) + plt.xlabel("Force") + plt.ylabel("Ratio (%)") + plt.xlim([-0.5, 10]) + plt.legend(loc="upper right") + plt.tight_layout() + plt.close() + + # Create an inset axis within the main plot + ax = plt.gca() + ax = inset_axes(ax, width="40%", height="40%", loc="upper center") + for dataname in dataname_list: + h = h_list[dataname] + ax.plot( + bins[:-1], h * (bins[1] - bins[0]) * 100, label=dataname.replace("-v2", "") + ) + ax.fill_between( + bins[:-1], h * (bins[1] - bins[0]) * 100, alpha=0.5, label="_nolegend_" + ) + + plt.savefig(f"hist-force{suffix}.pdf") + plt.close() + + histplot2(energy_list, "energy") + histplot2(force_list, "force") + + histplot3(energy_list, "energy") + histplot3(force_list, "force") diff --git a/hydragnn/models/PAINNStack.py b/hydragnn/models/PAINNStack.py new file mode 100644 index 000000000..dd9f9ebf2 --- /dev/null +++ b/hydragnn/models/PAINNStack.py @@ -0,0 +1,311 @@ +############################################################################## +# Copyright (c) 2024, Oak Ridge National Laboratory # +# All rights reserved. # +# # +# This file is part of HydraGNN and is distributed under a BSD 3-clause # +# license. For the licensing terms see the LICENSE file in the top-level # +# directory. # +# # +# SPDX-License-Identifier: BSD-3-Clause # +############################################################################## + +# Adapted From the Following: +# Github: https://github.com/nityasagarjena/PaiNN-model/blob/main/PaiNN/model.py +# Paper: https://arxiv.org/pdf/2102.03150 + + +import torch +from torch import nn +from torch_geometric import nn as geom_nn +from torch.utils.checkpoint import checkpoint + +from .Base import Base + + +class PAINNStack(Base): + """ + Generates angles, distances, to/from indices, radial basis + functions and spherical basis functions for learning. + """ + + def __init__( + self, + # edge_dim: int, # To-Do: Add edge_features + num_radial: int, + radius: float, + *args, + **kwargs + ): + # self.edge_dim = edge_dim + self.num_radial = num_radial + self.radius = radius + + super().__init__(*args, **kwargs) + + def _init_conv(self): + last_layer = 1 == self.num_conv_layers + self.graph_convs.append(self.get_conv(self.input_dim, self.hidden_dim)) + self.feature_layers.append(nn.Identity()) + for i in range(self.num_conv_layers - 1): + last_layer = i == self.num_conv_layers - 2 + conv = self.get_conv(self.hidden_dim, self.hidden_dim, last_layer) + self.graph_convs.append(conv) + self.feature_layers.append(nn.Identity()) + + def get_conv(self, input_dim, output_dim, last_layer=False): + hidden_dim = output_dim if input_dim == 1 else input_dim + assert ( + hidden_dim > 1 + ), "PainnNet requires more than one hidden dimension between input_dim and output_dim." + self_inter = PainnMessage( + node_size=input_dim, edge_size=self.num_radial, cutoff=self.radius + ) + cross_inter = PainnUpdate(node_size=input_dim, last_layer=last_layer) + """ + The following linear layers are to get the correct sizing of embeddings. This is + necessary to use the hidden_dim, output_dim of HYDRAGNN's stacked conv layers correctly + because node_scalar and node-vector are updated through a sum. + """ + node_embed_out = nn.Sequential( + nn.Linear(input_dim, output_dim), + nn.Tanh(), + nn.Linear(output_dim, output_dim), + ) # Tanh activation is necessary to prevent exploding gradients when learning from random signals in test_graphs.py + vec_embed_out = nn.Linear(input_dim, output_dim) if not last_layer else None + + if not last_layer: + return geom_nn.Sequential( + "x, v, pos, edge_index, diff, dist", + [ + (self_inter, "x, v, edge_index, diff, dist -> x, v"), + (cross_inter, "x, v -> x, v"), + (node_embed_out, "x -> x"), + (vec_embed_out, "v -> v"), + (lambda x, v, pos: [x, v, pos], "x, v, pos -> x, v, pos"), + ], + ) + else: + return geom_nn.Sequential( + "x, v, pos, edge_index, diff, dist", + [ + (self_inter, "x, v, edge_index, diff, dist -> x, v"), + ( + cross_inter, + "x, v -> x", + ), # v is not updated in the last layer to avoid hanging gradients + ( + node_embed_out, + "x -> x", + ), # No need to embed down v because it's not used anymore + (lambda x, v, pos: [x, v, pos], "x, v, pos -> x, v, pos"), + ], + ) + + def forward(self, data): + data, conv_args = self._conv_args( + data + ) # Added v to data here (necessary for PAINN Stack) + x = data.x + v = data.v + pos = data.pos + + ### encoder part #### + for conv, feat_layer in zip(self.graph_convs, self.feature_layers): + if not self.conv_checkpointing: + c, v, pos = conv(x=x, v=v, pos=pos, **conv_args) # Added v here + else: + c, v, pos = checkpoint( # Added v here + conv, use_reentrant=False, x=x, v=v, pos=pos, **conv_args + ) + x = self.activation_function(feat_layer(c)) + + #### multi-head decoder part#### + # shared dense layers for graph level output + if data.batch is None: + x_graph = x.mean(dim=0, keepdim=True) + else: + x_graph = geom_nn.global_mean_pool(x, data.batch.to(x.device)) + outputs = [] + outputs_var = [] + for head_dim, headloc, type_head in zip( + self.head_dims, self.heads_NN, self.head_type + ): + if type_head == "graph": + x_graph_head = self.graph_shared(x_graph) + output_head = headloc(x_graph_head) + outputs.append(output_head[:, :head_dim]) + outputs_var.append(output_head[:, head_dim:] ** 2) + else: + if self.node_NN_type == "conv": + for conv, batch_norm in zip(headloc[0::2], headloc[1::2]): + c, v, pos = conv(x=x, v=v, pos=pos, **conv_args) + c = batch_norm(c) + x = self.activation_function(c) + x_node = x + else: + x_node = headloc(x=x, batch=data.batch) + outputs.append(x_node[:, :head_dim]) + outputs_var.append(x_node[:, head_dim:] ** 2) + if self.var_output: + return outputs, outputs_var + return outputs + + def _conv_args(self, data): + assert ( + data.pos is not None + ), "PAINNNet requires node positions (data.pos) to be set." + + # Calculate relative vectors and distances + i, j = data.edge_index[0], data.edge_index[1] + diff = data.pos[i] - data.pos[j] + dist = diff.pow(2).sum(dim=-1).sqrt() + norm_diff = diff / dist.unsqueeze(-1) + + # Instantiate tensor to hold equivariant traits + v = torch.zeros(data.x.size(0), 3, data.x.size(1), device=data.x.device) + data.v = v + + conv_args = { + "edge_index": data.edge_index.t().to(torch.long), + "diff": norm_diff, + "dist": dist, + } + + return data, conv_args + + +class PainnMessage(nn.Module): + """Message function""" + + def __init__(self, node_size: int, edge_size: int, cutoff: float): + super().__init__() + + self.node_size = node_size + self.edge_size = edge_size + self.cutoff = cutoff + + self.scalar_message_mlp = nn.Sequential( + nn.Linear(node_size, node_size), + nn.SiLU(), + nn.Linear(node_size, node_size * 3), + ) + + self.filter_layer = nn.Linear(edge_size, node_size * 3) + + def forward(self, node_scalar, node_vector, edge, edge_diff, edge_dist): + # remember to use v_j, s_j but not v_i, s_i + filter_weight = self.filter_layer( + sinc_expansion(edge_dist, self.edge_size, self.cutoff) + ) + filter_weight = filter_weight * cosine_cutoff(edge_dist, self.cutoff).unsqueeze( + -1 + ) + scalar_out = self.scalar_message_mlp(node_scalar) + filter_out = filter_weight * scalar_out[edge[:, 1]] + + gate_state_vector, gate_edge_vector, message_scalar = torch.split( + filter_out, + self.node_size, + dim=1, + ) + + # num_pairs * 3 * node_size, num_pairs * node_size + message_vector = node_vector[edge[:, 1]] * gate_state_vector.unsqueeze(1) + edge_vector = gate_edge_vector.unsqueeze(1) * ( + edge_diff / edge_dist.unsqueeze(-1) + ).unsqueeze(-1) + message_vector = message_vector + edge_vector + + # sum message + residual_scalar = torch.zeros_like(node_scalar) + residual_vector = torch.zeros_like(node_vector) + residual_scalar.index_add_(0, edge[:, 0], message_scalar) + residual_vector.index_add_(0, edge[:, 0], message_vector) + + # new node state + new_node_scalar = node_scalar + residual_scalar + new_node_vector = node_vector + residual_vector + + return new_node_scalar, new_node_vector + + +class PainnUpdate(nn.Module): + """Update function""" + + def __init__(self, node_size: int, last_layer=False): + super().__init__() + + self.update_U = nn.Linear(node_size, node_size) + self.update_V = nn.Linear(node_size, node_size) + self.last_layer = last_layer + + if not self.last_layer: + self.update_mlp = nn.Sequential( + nn.Linear(node_size * 2, node_size), + nn.SiLU(), + nn.Linear(node_size, node_size * 3), + ) + else: + self.update_mlp = nn.Sequential( + nn.Linear(node_size * 2, node_size), + nn.SiLU(), + nn.Linear(node_size, node_size * 2), + ) + + def forward(self, node_scalar, node_vector): + Uv = self.update_U(node_vector) + Vv = self.update_V(node_vector) + + Vv_norm = torch.linalg.norm(Vv, dim=1) + mlp_input = torch.cat((Vv_norm, node_scalar), dim=1) + mlp_output = self.update_mlp(mlp_input) + + if not self.last_layer: + a_vv, a_sv, a_ss = torch.split( + mlp_output, + node_vector.shape[-1], + dim=1, + ) + + delta_v = a_vv.unsqueeze(1) * Uv + inner_prod = torch.sum(Uv * Vv, dim=1) + delta_s = a_sv * inner_prod + a_ss + + return node_scalar + delta_s, node_vector + delta_v + else: + a_sv, a_ss = torch.split( + mlp_output, + node_vector.shape[-1], + dim=1, + ) + + inner_prod = torch.sum(Uv * Vv, dim=1) + delta_s = a_sv * inner_prod + a_ss + + return node_scalar + delta_s + + +def sinc_expansion(edge_dist: torch.Tensor, edge_size: int, cutoff: float): + """ + Calculate sinc radial basis function: + + sin(n * pi * d / d_cut) / d + """ + n = torch.arange(edge_size, device=edge_dist.device) + 1 + return torch.sin( + edge_dist.unsqueeze(-1) * n * torch.pi / cutoff + ) / edge_dist.unsqueeze(-1) + + +def cosine_cutoff(edge_dist: torch.Tensor, cutoff: float): + """ + Calculate cutoff value based on distance. + This uses the cosine Behler-Parinello cutoff function: + + f(d) = 0.5 * (cos(pi * d / d_cut) + 1) for d < d_cut and 0 otherwise + """ + return torch.where( + edge_dist < cutoff, + 0.5 * (torch.cos(torch.pi * edge_dist / cutoff) + 1), + torch.tensor(0.0, device=edge_dist.device, dtype=edge_dist.dtype), + ) diff --git a/hydragnn/models/create.py b/hydragnn/models/create.py index 08d21b77e..086bd0692 100644 --- a/hydragnn/models/create.py +++ b/hydragnn/models/create.py @@ -25,6 +25,7 @@ from hydragnn.models.DIMEStack import DIMEStack from hydragnn.models.EGCLStack import EGCLStack from hydragnn.models.PNAEqStack import PNAEqStack +from hydragnn.models.PAINNStack import PAINNStack from hydragnn.models.MACEStack import MACEStack from hydragnn.utils.distributed import get_device @@ -345,6 +346,25 @@ def create_model( num_nodes=num_nodes, ) + elif model_type == "PAINN": + model = PAINNStack( + # edge_dim, # To-do add edge_features + num_radial, + radius, + input_dim, + hidden_dim, + output_dim, + output_type, + output_heads, + activation_function, + loss_function_type, + equivariance, + loss_weights=task_weights, + freeze_conv=freeze_conv, + num_conv_layers=num_conv_layers, + num_nodes=num_nodes, + ) + elif model_type == "PNAEq": assert pna_deg is not None, "PNAEq requires degree input." model = PNAEqStack( diff --git a/hydragnn/utils/input_config_parsing/config_utils.py b/hydragnn/utils/input_config_parsing/config_utils.py index 1be58afeb..a165ee7b7 100644 --- a/hydragnn/utils/input_config_parsing/config_utils.py +++ b/hydragnn/utils/input_config_parsing/config_utils.py @@ -136,7 +136,7 @@ def update_config(config, train_loader, val_loader, test_loader): def update_config_equivariance(config): - equivariant_models = ["EGNN", "SchNet", "PNAEq", "MACE"] + equivariant_models = ["EGNN", "SchNet", "PNAEq", "PAINN", "MACE"] if "equivariance" in config and config["equivariance"]: assert ( config["model_type"] in equivariant_models diff --git a/tests/test_forces_equivariant.py b/tests/test_forces_equivariant.py index cd0bc8364..d6df7e20d 100644 --- a/tests/test_forces_equivariant.py +++ b/tests/test_forces_equivariant.py @@ -17,7 +17,7 @@ @pytest.mark.parametrize("example", ["LennardJones"]) @pytest.mark.parametrize( - "model_type", ["SchNet", "EGNN", "DimeNet", "PNAPlus", "PNAEq", "MACE"] + "model_type", ["SchNet", "EGNN", "DimeNet", "PAINN", "PNAPlus", "MACE"] ) @pytest.mark.mpi_skip() def pytest_examples(example, model_type): diff --git a/tests/test_graphs.py b/tests/test_graphs.py index 2aec7fc2d..177cd11c8 100755 --- a/tests/test_graphs.py +++ b/tests/test_graphs.py @@ -148,6 +148,7 @@ def unittest_train_model( "DimeNet": [0.50, 0.50], "EGNN": [0.20, 0.20], "PNAEq": [0.60, 0.60], + "PAINN": [0.60, 0.60], "MACE": [0.60, 0.70], } if use_lengths and ("vector" not in ci_input): @@ -209,6 +210,7 @@ def unittest_train_model( "DimeNet", "EGNN", "PNAEq", + "PAINN", "MACE", ], ) @@ -226,7 +228,7 @@ def pytest_train_model_lengths(model_type, overwrite_data=False): # Test across equivariant models -@pytest.mark.parametrize("model_type", ["EGNN", "SchNet", "PNAEq", "MACE"]) +@pytest.mark.parametrize("model_type", ["EGNN", "SchNet", "PNAEq", "PAINN", "MACE"]) def pytest_train_equivariant_model(model_type, overwrite_data=False): unittest_train_model(model_type, "ci_equivariant.json", False, overwrite_data) @@ -250,6 +252,7 @@ def pytest_train_model_vectoroutput(model_type, overwrite_data=False): "DimeNet", "EGNN", "PNAEq", + "PAINN", ], ) def pytest_train_model_conv_head(model_type, overwrite_data=False):