diff --git a/.github/workflows/CI.yml b/.github/workflows/CI.yml index d148ee2bb..24d88444a 100644 --- a/.github/workflows/CI.yml +++ b/.github/workflows/CI.yml @@ -40,7 +40,7 @@ jobs: run: | python -m pip install --upgrade pip python -m pip install --upgrade -r requirements.txt -r requirements-dev.txt - python -m pip install --upgrade -r requirements-torch.txt --index-url https://download.pytorch.org/whl/cpu + python -m pip install --upgrade -r requirements-torch.txt --index-url https://download.pytorch.org/whl/cpu --extra-index-url https://pypi.org/simple python -m pip install --upgrade -r requirements-pyg.txt --find-links https://data.pyg.org/whl/torch-2.0.1+cpu.html python -m pip install --upgrade -r requirements-deepspeed.txt - name: Format black diff --git a/examples/LennardJones/LJ.json b/examples/LennardJones/LJ.json index a6b18f12b..942052003 100644 --- a/examples/LennardJones/LJ.json +++ b/examples/LennardJones/LJ.json @@ -30,6 +30,8 @@ "num_before_skip": 1, "num_after_skip": 1, "envelope_exponent": 5, + "max_ell": 1, + "node_max_ell": 1, "num_radial": 5, "num_spherical": 2, "hidden_dim": 20, diff --git a/hydragnn/models/MACEStack.py b/hydragnn/models/MACEStack.py new file mode 100644 index 000000000..d61696a6f --- /dev/null +++ b/hydragnn/models/MACEStack.py @@ -0,0 +1,741 @@ +############################################################################## +# Copyright (c) 2024, Oak Ridge National Laboratory # +# All rights reserved. # +# # +# This file is part of HydraGNN and is distributed under a BSD 3-clause # +# license. For the licensing terms see the LICENSE file in the top-level # +# directory. # +# # +# SPDX-License-Identifier: BSD-3-Clause # +############################################################################## +# Adapted From: +# GitHub: https://github.com/ACEsuit/mace +# ArXiV: https://arxiv.org/pdf/2206.07697 +# Date: August 27, 2024 | 12:37 (EST) +########################################################################################### +# Implementation of MACE models and other models based E(3)-Equivariant MPNNs +# Authors: Ilyes Batatia, Gregor Simm +# This program is distributed under the MIT License (see MIT.md) +########################################################################################### + +# NOTE MACE Architecture: +## There are two key ideas of MACE: +### (1) Message passing and interaction blocks are equivariant to the O(3) group. And invariant to the T(3) group (translations). +### (2) Predictions are made in an n-body expansion, where n is the numnber of convolutional layers+1. This is done by creating +### multi-body interactions, then decoding them. Decoding before anything else with 1-body interactions, Interaction Layer 1 +### will decode 2-body interactions, layer 2 will decode 3-body interactions,and so on. So, for a 3-convolutional-layer model +### predicting energy, there are 4 outputs for energy: 1 before convolution + 3*(1 after each layer). These outputs are summed +### at the end. This requires some adjustment to the behavior from Base.py + +# from typing import Any, Callable, Dict, List, Optional, Type, Union +import warnings + +# Torch +import torch +from torch.nn import ModuleList, Sequential +from torch.utils.checkpoint import checkpoint +from torch_scatter import scatter + +# Torch Geo +from torch_geometric.nn import ( + Sequential as PyGSequential, +) # This naming is because there is torch.nn.Sequential and torch_geometric.nn.Sequential +from torch_geometric.nn import global_mean_pool + +# MACE +from hydragnn.utils.model.mace_utils.modules.blocks import ( + EquivariantProductBasisBlock, + LinearNodeEmbeddingBlock, + RadialEmbeddingBlock, + RealAgnosticAttResidualInteractionBlock, +) +from hydragnn.utils.model.operations import ( + get_edge_vectors_and_lengths, +) + +# E3NN +from e3nn import nn, o3 +from e3nn.util.jit import compile_mode + + +# HydraGNN +from .Base import Base + +# Etc +import numpy as np +import math + + +@compile_mode("script") +class MACEStack(Base): + def __init__( + self, + r_max: float, # The cutoff radius for the radial basis functions and edge_index + radial_type: str, # The type of radial basis function to use + distance_transform: str, # The distance transform to use + num_bessel: int, # The number of radial bessel functions. This dictates the richness of radial information in message-passing. + edge_dim: int, # The dimension of HYDRA's optional edge attributes + max_ell: int, # Max l-type for CG-tensor product. Theoretically, there is no max l-type, but in practice, we need to truncate the CG-tensor product to keep tractible computation + node_max_ell: int, # Max l-type for node features + avg_num_neighbors: float, + num_polynomial_cutoff, # The polynomial cutoff function ensures that the function goes to zero at the cutoff radius smoothly. Same as envelope_exponent for DimeNet + correlation, # Used in the product basis block and *roughly* determines the richness of interaction in the n-body interaction of layer 'n'. + *args, + **kwargs, + ): + """Notes On MACEStack Arguments:""" + # MACE args that we have given definitions for and the reasons why: + ## Note: These can be changed in the future if the desired argument options change + ## interaction_cls / interaction_cls_first: The choice of interaction block type should not make much of a difference and would require more imports in create.py and/or string handling + ## Atomic Energies: This is not agnostic to what we're predicting, which is a requirement of HYDRA. We also don't have base atomic energies to load, so we simply one-hot encode the atomic numbers and train. + ## Atomic Numbers / num_elements: It's more robust in preventing errors to just cover the entire periodic table (1-118) + + # MACE args that we have dropped and the resons why: + ## pair repulsion, distance_transform, compute_virials, etc: HYDRA's framework is meant to compute based on graph or node type, so must be agnostic to these property specific types of computations + + # MACE args constructed by HYDRA args + ## Reasoning: Oftentimes, MACE arguments show similarity to HYDRA arguments, but are labelled differently + ## num_interactions is represented by num_conv_layers + ## radial_MLP uses ceil(hidden_dim/3) for its layer sizes + ## hidden_irreps and MLP_irreps are constructed from hidden_dim + ## - Note that this is a nontrivial choice... reconstructing irreps allows users to be unfamiliar with the e3nn library, and is more attached to the HYDRA framework, but limits customization slightly + ## - I use a hidden_max_ell argument to allow the user to set max ell in the hidden dimensions as well + """""" + + # Init Args + ## Passed + self.node_max_ell = node_max_ell + num_interactions = kwargs["num_conv_layers"] + self.edge_dim = edge_dim + self.avg_num_neighbors = avg_num_neighbors + ## Defined + self.interaction_cls = RealAgnosticAttResidualInteractionBlock + self.interaction_cls_first = RealAgnosticAttResidualInteractionBlock + atomic_numbers = list(range(1, 119)) # 118 elements in the periodic table + self.num_elements = len(atomic_numbers) + # Optional + num_polynomial_cutoff = ( + 5 if num_polynomial_cutoff is None else num_polynomial_cutoff + ) + self.correlation = [2] if correlation is None else correlation + radial_type = "bessel" if radial_type is None else radial_type + + # Making Irreps + self.sh_irreps = o3.Irreps.spherical_harmonics( + max_ell + ) # This makes the irreps string + self.edge_feats_irreps = o3.Irreps(f"{num_bessel}x0e") + + super().__init__(*args, **kwargs) + + self.spherical_harmonics = o3.SphericalHarmonics( + self.sh_irreps, + normalize=True, + normalization="component", # This makes the spherical harmonic class to be called with forward + ) + + # Register buffers are made when parameters need to be saved and transferred with the model, but not trained. + self.register_buffer( + "atomic_numbers", torch.tensor(atomic_numbers, dtype=torch.int64) + ) + self.register_buffer( + "r_max", torch.tensor(r_max, dtype=torch.get_default_dtype()) + ) + self.register_buffer( + "num_interactions", torch.tensor(num_interactions, dtype=torch.int64) + ) + if isinstance(correlation, int): + self.correlation = [self.correlation] * self.num_interactions + self.radial_embedding = RadialEmbeddingBlock( + r_max=r_max, + num_bessel=num_bessel, + num_polynomial_cutoff=num_polynomial_cutoff, + radial_type=radial_type, + distance_transform=distance_transform, + ) + self.node_embedding = LinearNodeEmbeddingBlock( + irreps_in=self.node_attr_irreps, + irreps_out=create_irreps_string( + self.hidden_dim, 0 + ), # Changed this to hidden_dim because no longer had node_feats_irreps + ) + + def _init_conv(self): + # Multihead Decoders + ## This integrates HYDRA multihead nature with MACE's layer-wise readouts + ## NOTE Norm techniques (feature_layers in HYDRA) are not advised for use in equivariant models as it can break equivariance + self.multihead_decoders = ModuleList() + # attr_irreps for node and edges are created here because we need input_dim, which requires super(base) to be called, which calls _init_conv + self.node_attr_irreps = o3.Irreps([(self.num_elements, (0, 1))]) + # Edge Attributes are by default the spherical harmoncis but should be extended to include HYDRA's edge_attr is desired + if self.use_edge_attr: + self.edge_attrs_irreps = ( + o3.Irreps(f"{self.edge_dim}x0e") + self.sh_irreps + ).simplify() # Simplify combines irreps of the same type + else: + self.edge_attrs_irreps = self.sh_irreps + hidden_irreps = o3.Irreps( + create_irreps_string(self.hidden_dim, self.node_max_ell) + ) + final_hidden_irreps = o3.Irreps( + create_irreps_string(self.hidden_dim, 0) + ) # Only scalars are outputted in the last layer + + last_layer = 1 == self.num_conv_layers + + self.multihead_decoders.append( + MultiheadDecoderBlock( + self.node_attr_irreps, + self.node_max_ell, + self.config_heads, + self.head_dims, + self.head_type, + self.num_heads, + self.activation_function, + self.num_nodes, + nonlinear=True, + ) + ) # For base-node traits + self.graph_convs.append( + self.get_conv(self.input_dim, self.hidden_dim, first_layer=True) + ) + irreps = hidden_irreps if not last_layer else final_hidden_irreps + self.multihead_decoders.append( + MultiheadDecoderBlock( + irreps, + self.node_max_ell, + self.config_heads, + self.head_dims, + self.head_type, + self.num_heads, + self.activation_function, + self.num_nodes, + nonlinear=last_layer, + ) + ) + for i in range(self.num_conv_layers - 1): + last_layer = i == self.num_conv_layers - 2 + conv = self.get_conv( + self.hidden_dim, self.hidden_dim, last_layer=last_layer + ) + self.graph_convs.append(conv) + irreps = hidden_irreps if not last_layer else final_hidden_irreps + self.multihead_decoders.append( + MultiheadDecoderBlock( + irreps, + self.node_max_ell, + self.config_heads, + self.head_dims, + self.head_type, + self.num_heads, + self.activation_function, + self.num_nodes, + nonlinear=last_layer, + ) + ) # Last layer will be nonlinear node decoding + + def get_conv(self, input_dim, output_dim, first_layer=False, last_layer=False): + hidden_dim = output_dim if input_dim == 1 else input_dim + + # All of these should be constructed with HYDRA dimensional arguments + ## Radial + radial_MLP_dim = math.ceil( + float(hidden_dim) / 3 + ) # Go based off hidden_dim for radial_MLP + radial_MLP = [radial_MLP_dim, radial_MLP_dim, radial_MLP_dim] + ## Input, Hidden, and Output irreps sizing (this is usually just hidden in MACE) + ### Input dimensions are handled implicitly + ### Hidden + hidden_irreps = create_irreps_string(hidden_dim, self.node_max_ell) + hidden_irreps = o3.Irreps(hidden_irreps) + node_feats_irreps = o3.Irreps([(hidden_irreps.count(o3.Irrep(0, 1)), (0, 1))]) + num_features = hidden_irreps.count( + o3.Irrep(0, 1) + ) # Multiple copies of spherical harmonics for multiple interactions. They are 'combined' in a certain way during .simplify() ## This makes it a requirement that hidden irreps all have the same number of channels + interaction_irreps = ( + (self.sh_irreps * num_features) + .sort()[0] + .simplify() # Kept as sh_irreps for the output of reshape irreps, whether or not edge_attr irreps are added from HYDRA functionality + ) # .sort() is a tuple, so we need the [0] element for the sorted result + ### Output + output_irreps = create_irreps_string(output_dim, self.node_max_ell) + output_irreps = o3.Irreps(output_irreps) + + # Constructing convolutional layers + if first_layer: + hidden_irreps_out = hidden_irreps + inter = self.interaction_cls_first( + node_attrs_irreps=self.node_attr_irreps, + node_feats_irreps=node_feats_irreps, + edge_attrs_irreps=self.edge_attrs_irreps, + edge_feats_irreps=self.edge_feats_irreps, + target_irreps=interaction_irreps, # Replace with output? + hidden_irreps=hidden_irreps_out, + avg_num_neighbors=self.avg_num_neighbors, + radial_MLP=radial_MLP, + ) + # Use the appropriate self connection at the first layer for proper E0 + use_sc_first = False + if "Residual" in str(self.interaction_cls_first): + use_sc_first = True + prod = EquivariantProductBasisBlock( + node_feats_irreps=interaction_irreps, + target_irreps=hidden_irreps, + correlation=self.correlation[0], + num_elements=self.num_elements, + use_sc=use_sc_first, + ) + sizing = o3.Linear( + hidden_irreps_out, output_irreps + ) # Change sizing to output_irreps + elif last_layer: + # Select only scalars output for last layer + hidden_irreps_out = str(hidden_irreps[0]) + output_irreps = str(output_irreps[0]) + inter = self.interaction_cls( + node_attrs_irreps=self.node_attr_irreps, + node_feats_irreps=hidden_irreps, + edge_attrs_irreps=self.edge_attrs_irreps, + edge_feats_irreps=self.edge_feats_irreps, + target_irreps=interaction_irreps, + hidden_irreps=hidden_irreps_out, + avg_num_neighbors=self.avg_num_neighbors, + radial_MLP=radial_MLP, + ) + prod = EquivariantProductBasisBlock( + node_feats_irreps=interaction_irreps, + target_irreps=hidden_irreps_out, + correlation=self.correlation[0], + num_elements=self.num_elements, + use_sc=True, + ) + sizing = o3.Linear( + hidden_irreps_out, output_irreps + ) # Change sizing to output_irreps + else: + hidden_irreps_out = hidden_irreps + inter = self.interaction_cls( + node_attrs_irreps=self.node_attr_irreps, + node_feats_irreps=hidden_irreps, + edge_attrs_irreps=self.edge_attrs_irreps, + edge_feats_irreps=self.edge_feats_irreps, + target_irreps=interaction_irreps, + hidden_irreps=hidden_irreps_out, + avg_num_neighbors=self.avg_num_neighbors, + radial_MLP=radial_MLP, + ) + prod = EquivariantProductBasisBlock( + node_feats_irreps=interaction_irreps, + target_irreps=hidden_irreps_out, + correlation=self.correlation[0], # Should this be i+1? + num_elements=self.num_elements, + use_sc=True, + ) + sizing = o3.Linear( + hidden_irreps_out, output_irreps + ) # Change sizing to output_irreps + + input_args = "node_attributes, pos, node_features, edge_attributes, edge_features, edge_index" + conv_args = "node_attributes, edge_attributes, edge_features, edge_index" # node_features is not used here because it's passed through in the forward + + if not last_layer: + return PyGSequential( + input_args, + [ + (inter, "node_features, " + conv_args + " -> node_features, sc"), + (prod, "node_features, sc, node_attributes -> node_features"), + (sizing, "node_features -> node_features"), + ( + lambda node_features, pos: [node_features, pos], + "node_features, pos -> node_features, pos", + ), + ], + ) + else: + return PyGSequential( + input_args, + [ + (inter, "node_features, " + conv_args + " -> node_features, sc"), + (prod, "node_features, sc, node_attributes -> node_features"), + (sizing, "node_features -> node_features"), + ( + lambda node_features, pos: [node_features, pos], + "node_features, pos -> node_features, pos", + ), + ], + ) + + def forward(self, data): + data, conv_args = self._conv_args(data) + node_features = data.node_features + node_attributes = data.node_attributes + pos = data.pos + + ### encoder / decoder part #### + ## NOTE Norm techniques (feature_layers in HYDRA) are not advised for use in equivariant models as it can break equivariance + + ### There is a readout before the first convolution layer ### + outputs = [] + output = self.multihead_decoders[0]( + data, node_attributes + ) # [index][n_output, size_output] + # Create outputs first + outputs = output + + ### Do conv --> readout --> repeat for each convolution layer ### + for conv, readout in zip(self.graph_convs, self.multihead_decoders[1:]): + if not self.conv_checkpointing: + node_features, pos = conv( + node_features=node_features, pos=pos, **conv_args + ) + output = readout(data, node_features) # [index][n_output, size_output] + else: + node_features, pos = checkpoint( + conv, + use_reentrant=False, + node_features=node_features, + pos=pos, + **conv_args, + ) + output = readout( + data, node_features + ) # output is a list of tensors with [index][n_output, size_output] + # Sum predictions for each index, taking care of size differences + for idx, prediction in enumerate(output): + outputs[idx] = outputs[idx] + prediction + + return outputs + + def _conv_args(self, data): + assert ( + data.pos is not None + ), "MACE requires node positions (data.pos) to be set." + + # Center positions at 0 per graph. This is a requirement for equivariant models that + # initialize the spherical harmonics, since the initial spherical harmonic projection + # uses the nodal position vector x/||x|| as the input to the spherical harmonics. + # If we didn't center at 0, these models wouldn't even be invariant to translation. + mean_pos = scatter(data.pos, data.batch, dim=0, reduce="mean") + data.pos = data.pos - mean_pos[data.batch] + + # Create node_attrs from atomic numbers. Later on it may contain more information + ## Node attrs are intrinsic properties of the atoms. Currently, MACE only supports atomic number node attributes + ## data.node_attrs is already used in another place, so has been renamed to data.node_attributes from MACE and same with other data variable names + data.node_attributes = process_node_attributes(data["x"], self.num_elements) + data.shifts = torch.zeros( + (data.edge_index.shape[1], 3), dtype=data.pos.dtype, device=data.pos.device + ) # Shifts takes into account pbc conditions, but I believe we already generate data.pos to take it into account + + # Embeddings + node_feats = self.node_embedding(data["node_attributes"]) + vectors, lengths = get_edge_vectors_and_lengths( + positions=data["pos"], + edge_index=data["edge_index"], + shifts=data["shifts"], + ) + edge_attributes = self.spherical_harmonics(vectors) + if self.use_edge_attr: + edge_attributes = torch.cat([data.edge_attr, edge_attributes], dim=1) + edge_features = self.radial_embedding( + lengths, data["node_attributes"], data["edge_index"], self.atomic_numbers + ) + + # Variable names + data.node_features = node_feats + data.edge_attributes = edge_attributes + data.edge_features = edge_features + + conv_args = { + "node_attributes": data.node_attributes, + "edge_attributes": data.edge_attributes, + "edge_features": data.edge_features, + "edge_index": data.edge_index, + } + + return data, conv_args + + def _multihead(self): + # NOTE Multihead is skipped as it's an integral part of MACE's architecture to have a decoder after every layer, + # and a convolutional layer in decoding is not supported. Therefore, this final step is not necessary for MACE. + # However, various parts of multihead are applied in the MultiheadLinearBlock and MultiheadNonLinearBlock classes. + pass + + def __str__(self): + return "MACEStack" + + +def create_irreps_string( + n: int, ell: int +): # Custom function to allow for use of HYDRA arguments in creating irreps + irreps = [f"{n}x{ell}{'e' if ell % 2 == 0 else 'o'}" for ell in range(ell + 1)] + return " + ".join(irreps) + + +def process_node_attributes(node_attributes, num_elements): + # Check that node attributes are atomic numbers and process accordingly + node_attributes = node_attributes.squeeze() # Squeeze all unnecessary dimensions + assert ( + node_attributes.dim() == 1 + ), "MACE only supports raw atomic numbers as node_attributes. Your data.x \ + isn't a 1D tensor after squeezing, are you using vector features?" + + # Check that all elements are integers or integer-like (e.g., 1.0, 2.0), not floats like 1.1 + # This is only a warning so that we don't enforce this requirement on the tests. + if not torch.all(node_attributes == node_attributes.round()): + warnings.warn( + "MACE only supports raw atomic numbers as node_attributes. Your data.x \ + contains floats, which does not align with atomic numbers." + ) + + # Check that all atomic numbers are within the valid range (1 to num_elements) + # This is only a warning so that we don't enforce this requirement on the tests. + if not torch.all((node_attributes >= 1) & (node_attributes <= num_elements)): + warnings.warn( + "MACE only supports raw atomic numbers as node_attributes. Your data.x \ + is not in the range 1-118, which does not align with atomic numbers." + ) + node_attributes = torch.clamp(node_attributes, min=1, max=118) + + # Perform one-hot encoding + one_hot = torch.nn.functional.one_hot( + (node_attributes - 1).long(), + num_classes=num_elements, # Subtract 1 to make atomic numbers 0-indexed for one-hot encoding + ).float() # [n_atoms, 118] + + return one_hot + + +@compile_mode("script") +class MultiheadDecoderBlock(torch.nn.Module): + def __init__( + self, + input_irreps, + node_max_ell, + config_heads, + head_dims, + head_type, + num_heads, + activation_function, + num_nodes, + nonlinear=False, + ): + super(MultiheadDecoderBlock, self).__init__() + self.input_irreps = input_irreps + self.node_max_ell = node_max_ell if not nonlinear else 0 + self.config_heads = config_heads + self.head_dims = head_dims + self.head_type = head_type + self.num_heads = num_heads + self.activation_function = activation_function + self.num_nodes = num_nodes + + self.graph_shared = None + self.node_NN_type = None + self.heads = ModuleList() + + # Create shared dense layers for graph-level output if applicable + if "graph" in self.config_heads: + graph_input_irreps = o3.Irreps( + f"{self.input_irreps.count(o3.Irrep(0, 1))}x0e" + ) + dim_sharedlayers = self.config_heads["graph"]["dim_sharedlayers"] + sharedlayers_irreps = o3.Irreps(f"{dim_sharedlayers}x0e") + denselayers = [] + denselayers.append(o3.Linear(graph_input_irreps, sharedlayers_irreps)) + denselayers.append( + nn.Activation( + irreps_in=sharedlayers_irreps, acts=[self.activation_function] + ) + ) + for _ in range(self.config_heads["graph"]["num_sharedlayers"] - 1): + denselayers.append(o3.Linear(sharedlayers_irreps, sharedlayers_irreps)) + denselayers.append( + nn.Activation( + irreps_in=sharedlayers_irreps, acts=[self.activation_function] + ) + ) + self.graph_shared = Sequential(*denselayers) + + # Create layers for each head + for ihead in range(self.num_heads): + if self.head_type[ihead] == "graph": + num_layers_graph = self.config_heads["graph"]["num_headlayers"] + hidden_dim_graph = self.config_heads["graph"]["dim_headlayers"] + denselayers = [] + head_hidden_irreps = o3.Irreps(f"{hidden_dim_graph[0]}x0e") + denselayers.append(o3.Linear(sharedlayers_irreps, head_hidden_irreps)) + denselayers.append( + nn.Activation( + irreps_in=head_hidden_irreps, acts=[self.activation_function] + ) + ) + for ilayer in range(num_layers_graph - 1): + input_irreps = o3.Irreps(f"{hidden_dim_graph[ilayer]}x0e") + output_irreps = o3.Irreps(f"{hidden_dim_graph[ilayer + 1]}x0e") + denselayers.append(o3.Linear(input_irreps, output_irreps)) + denselayers.append( + nn.Activation( + irreps_in=output_irreps, acts=[self.activation_function] + ) + ) + input_irreps = o3.Irreps(f"{hidden_dim_graph[-1]}x0e") + output_irreps = o3.Irreps(f"{self.head_dims[ihead]}x0e") + denselayers.append(o3.Linear(input_irreps, output_irreps)) + self.heads.append(Sequential(*denselayers)) + elif self.head_type[ihead] == "node": + self.node_NN_type = self.config_heads["node"]["type"] + head = ModuleList() + if self.node_NN_type == "mlp" or self.node_NN_type == "mlp_per_node": + self.num_mlp = 1 if self.node_NN_type == "mlp" else self.num_nodes + assert ( + self.num_nodes is not None + ), "num_nodes must be a positive integer for MLP" + num_layers_node = self.config_heads["node"]["num_headlayers"] + hidden_dim_node = self.config_heads["node"]["dim_headlayers"] + head = MLPNode( + self.input_irreps, + self.node_max_ell, + self.config_heads, + num_layers_node, + hidden_dim_node, + self.head_dims[ihead], + self.num_mlp, + self.num_nodes, + self.config_heads["node"]["type"], + self.activation_function, + nonlinear=nonlinear, + ) + self.heads.append(head) + else: + raise ValueError( + f"Unknown head NN structure for node features: {self.node_NN_type}" + ) + else: + raise ValueError( + f"Unknown head type: {self.head_type[ihead]}; supported types are 'graph' or 'node'" + ) + + def forward(self, data, node_features): + if data.batch is None: + graph_features = node_features[:, : self.hidden_dim].mean( + dim=0, keepdim=True + ) # Need to take only the type-0 irreps for aggregation + else: + graph_features = global_mean_pool( + node_features[:, : self.input_irreps.count(o3.Irrep(0, 1))], + data.batch.to(node_features.device), + ) + outputs = [] + for headloc, type_head in zip(self.heads, self.head_type): + if type_head == "graph": + x_graph_head = self.graph_shared(graph_features) + outputs.append(headloc(x_graph_head)) + else: # Node-level output + if self.node_NN_type == "conv": + raise ValueError( + "Node-level convolutional layers are not supported in MACE" + ) + else: + x_node = headloc(node_features, data.batch) + outputs.append(x_node) + return outputs + + +@compile_mode("script") +class MLPNode(torch.nn.Module): + def __init__( + self, + input_irreps, + node_max_ell, + config_heads, + num_layers, + hidden_dims, + output_dim, + num_mlp, + num_nodes, + node_type, + activation_function, + nonlinear=False, + ): + super().__init__() + self.input_irreps = input_irreps + self.hidden_dims = hidden_dims + self.output_dim = output_dim + self.node_max_ell = node_max_ell if not nonlinear else 0 + self.config_heads = config_heads + self.num_layers = num_layers + self.node_type = node_type + self.num_mlp = num_mlp + self.num_nodes = num_nodes + self.activation_function = activation_function + + self.mlp = ModuleList() + + # Create dense layers for each MLP based on node_type ("mlp" or "mlp_per_node") + for _ in range(self.num_mlp): + denselayers = [] + + # Input and hidden irreps for each MLP layer + input_irreps = input_irreps + hidden_irreps = o3.Irreps(f"{hidden_dims[0]}x0e") # Hidden irreps + + denselayers.append(o3.Linear(input_irreps, hidden_irreps)) + denselayers.append( + nn.Activation(irreps_in=hidden_irreps, acts=[self.activation_function]) + ) + + # Add intermediate layers + for ilayer in range(self.num_layers - 1): + input_irreps = o3.Irreps(f"{hidden_dims[ilayer]}x0e") + hidden_irreps = o3.Irreps(f"{hidden_dims[ilayer + 1]}x0e") + denselayers.append(o3.Linear(input_irreps, hidden_irreps)) + denselayers.append( + nn.Activation( + irreps_in=hidden_irreps, acts=[self.activation_function] + ) + ) + + # Last layer + hidden_irreps = o3.Irreps(f"{hidden_dims[-1]}x0e") + output_irreps = o3.Irreps( + f"{self.output_dim}x0e" + ) # Assuming head_dims has been passed for the final output + denselayers.append(o3.Linear(hidden_irreps, output_irreps)) + + # Append to MLP + self.mlp.append(Sequential(*denselayers)) + + def node_features_reshape(self, node_features, batch): + """Reshape node_features from [batch_size*num_nodes, num_features] to [batch_size, num_features, num_nodes]""" + num_features = node_features.shape[1] + batch_size = batch.max() + 1 + out = torch.zeros( + (batch_size, num_features, self.num_nodes), + dtype=node_features.dtype, + device=node_features.device, + ) + for inode in range(self.num_nodes): + inode_index = [i for i in range(inode, batch.shape[0], self.num_nodes)] + out[:, :, inode] = node_features[inode_index, :] + return out + + def forward(self, node_features: torch.Tensor, batch: torch.Tensor): + if self.node_type == "mlp": + outs = self.mlp[0](node_features) + else: + outs = torch.zeros( + ( + node_features.shape[0], + self.head_dims[0], + ), # Assuming `head_dims` defines the final output dimension + dtype=node_features.dtype, + device=node_features.device, + ) + x_nodes = self.node_features_reshape(x, batch) + for inode in range(self.num_nodes): + inode_index = [i for i in range(inode, batch.shape[0], self.num_nodes)] + outs[inode_index, :] = self.mlp[inode](x_nodes[:, :, inode]) + return outs + + def __str__(self): + return "MLPNode" diff --git a/hydragnn/models/create.py b/hydragnn/models/create.py index b3fcbb1ea..086bd0692 100644 --- a/hydragnn/models/create.py +++ b/hydragnn/models/create.py @@ -12,6 +12,7 @@ import os import torch from torch_geometric.data import Data +from typing import List, Union from hydragnn.models.GINStack import GINStack from hydragnn.models.PNAStack import PNAStack @@ -25,6 +26,7 @@ from hydragnn.models.EGCLStack import EGCLStack from hydragnn.models.PNAEqStack import PNAEqStack from hydragnn.models.PAINNStack import PAINNStack +from hydragnn.models.MACEStack import MACEStack from hydragnn.utils.distributed import get_device from hydragnn.utils.profiling_and_tracing.time_utils import Timer @@ -55,6 +57,8 @@ def create_model_config( config["Architecture"]["num_before_skip"], config["Architecture"]["num_after_skip"], config["Architecture"]["num_radial"], + config["Architecture"]["radial_type"], + config["Architecture"]["distance_transform"], config["Architecture"]["basis_emb_size"], config["Architecture"]["int_emb_size"], config["Architecture"]["out_emb_size"], @@ -64,6 +68,10 @@ def create_model_config( config["Architecture"]["num_filters"], config["Architecture"]["radius"], config["Architecture"]["equivariance"], + config["Architecture"]["correlation"], + config["Architecture"]["max_ell"], + config["Architecture"]["node_max_ell"], + config["Architecture"]["avg_num_neighbors"], config["Training"]["conv_checkpointing"], verbosity, use_gpu, @@ -91,6 +99,8 @@ def create_model( num_before_skip: int = None, num_after_skip: int = None, num_radial: int = None, + radial_type: str = None, + distance_transform: str = None, basis_emb_size: int = None, int_emb_size: int = None, out_emb_size: int = None, @@ -100,6 +110,10 @@ def create_model( num_filters: int = None, radius: float = None, equivariance: bool = False, + correlation: Union[int, List[int]] = None, + max_ell: int = None, + node_max_ell: int = None, + avg_num_neighbors: int = None, conv_checkpointing: bool = False, verbosity: int = 0, use_gpu: bool = True, @@ -371,6 +385,39 @@ def create_model( num_conv_layers=num_conv_layers, num_nodes=num_nodes, ) + + elif model_type == "MACE": + assert radius is not None, "MACE requires radius input." + assert num_radial is not None, "MACE requires num_radial input." + assert max_ell is not None, "MACE requires max_ell input." + assert node_max_ell is not None, "MACE requires node_max_ell input." + assert max_ell >= 1, "MACE requires max_ell >= 1." + assert node_max_ell >= 1, "MACE requires node_max_ell >= 1." + model = MACEStack( + radius, + radial_type, + distance_transform, + num_radial, + edge_dim, + max_ell, + node_max_ell, + avg_num_neighbors, + envelope_exponent, + correlation, + input_dim, + hidden_dim, + output_dim, + output_type, + output_heads, + activation_function, + loss_function_type, + equivariance, + loss_weights=task_weights, + freeze_conv=freeze_conv, + initial_bias=initial_bias, + num_conv_layers=num_conv_layers, + num_nodes=num_nodes, + ) else: raise ValueError("Unknown model_type: {0}".format(model_type)) diff --git a/hydragnn/utils/input_config_parsing/config_utils.py b/hydragnn/utils/input_config_parsing/config_utils.py index 2ac7bf392..a165ee7b7 100644 --- a/hydragnn/utils/input_config_parsing/config_utils.py +++ b/hydragnn/utils/input_config_parsing/config_utils.py @@ -14,6 +14,7 @@ check_if_graph_size_variable, gather_deg, ) +from hydragnn.utils.model.model import calculate_avg_deg from hydragnn.utils.distributed import get_comm_size_and_rank from copy import deepcopy import json @@ -56,8 +57,22 @@ def update_config(config, train_loader, val_loader, test_loader): else: config["NeuralNetwork"]["Architecture"]["pna_deg"] = None + if config["NeuralNetwork"]["Architecture"]["model_type"] == "MACE": + if hasattr(train_loader.dataset, "avg_num_neighbors"): + ## Use avg neighbours used in the dataset. + avg_num_neighbors = torch.tensor(train_loader.dataset.avg_num_neighbors) + else: + avg_num_neighbors = float(calculate_avg_deg(train_loader.dataset)) + config["NeuralNetwork"]["Architecture"]["avg_num_neighbors"] = avg_num_neighbors + else: + config["NeuralNetwork"]["Architecture"]["avg_num_neighbors"] = None + if "radius" not in config["NeuralNetwork"]["Architecture"]: config["NeuralNetwork"]["Architecture"]["radius"] = None + if "radial_type" not in config["NeuralNetwork"]["Architecture"]: + config["NeuralNetwork"]["Architecture"]["radial_type"] = None + if "distance_transform" not in config["NeuralNetwork"]["Architecture"]: + config["NeuralNetwork"]["Architecture"]["distance_transform"] = None if "num_gaussians" not in config["NeuralNetwork"]["Architecture"]: config["NeuralNetwork"]["Architecture"]["num_gaussians"] = None if "num_filters" not in config["NeuralNetwork"]["Architecture"]: @@ -78,6 +93,14 @@ def update_config(config, train_loader, val_loader, test_loader): config["NeuralNetwork"]["Architecture"]["num_radial"] = None if "num_spherical" not in config["NeuralNetwork"]["Architecture"]: config["NeuralNetwork"]["Architecture"]["num_spherical"] = None + if "radial_type" not in config["NeuralNetwork"]["Architecture"]: + config["NeuralNetwork"]["Architecture"]["radial_type"] = None + if "correlation" not in config["NeuralNetwork"]["Architecture"]: + config["NeuralNetwork"]["Architecture"]["correlation"] = None + if "max_ell" not in config["NeuralNetwork"]["Architecture"]: + config["NeuralNetwork"]["Architecture"]["max_ell"] = None + if "node_max_ell" not in config["NeuralNetwork"]["Architecture"]: + config["NeuralNetwork"]["Architecture"]["node_max_ell"] = None config["NeuralNetwork"]["Architecture"] = update_config_edge_dim( config["NeuralNetwork"]["Architecture"] @@ -113,11 +136,11 @@ def update_config(config, train_loader, val_loader, test_loader): def update_config_equivariance(config): - equivariant_models = ["EGNN", "SchNet", "PNAEq", "PAINN"] + equivariant_models = ["EGNN", "SchNet", "PNAEq", "PAINN", "MACE"] if "equivariance" in config and config["equivariance"]: assert ( config["model_type"] in equivariant_models - ), "E(3) equivariance can only be ensured for EGNN and SchNet." + ), "E(3) equivariance can only be ensured for EGNN, SchNet, and MACE." elif "equivariance" not in config: config["equivariance"] = False return config @@ -125,11 +148,11 @@ def update_config_equivariance(config): def update_config_edge_dim(config): config["edge_dim"] = None - edge_models = ["PNAPlus", "PNA", "CGCNN", "SchNet", "EGNN", "DimeNet"] + edge_models = ["PNAPlus", "PNA", "CGCNN", "SchNet", "EGNN", "DimeNet", "MACE"] if "edge_features" in config and config["edge_features"]: assert ( config["model_type"] in edge_models - ), "Edge features can only be used with DimeNet EGNN, SchNet, PNA, PNAPlus, and CGCNN." + ), "Edge features can only be used with DimeNet, MACE, EGNN, SchNet, PNA, PNAPlus, and CGCNN." config["edge_dim"] = len(config["edge_features"]) elif config["model_type"] == "CGCNN": # CG always needs an integer edge_dim diff --git a/hydragnn/utils/model/irreps_tools.py b/hydragnn/utils/model/irreps_tools.py new file mode 100644 index 000000000..71d74e6ff --- /dev/null +++ b/hydragnn/utils/model/irreps_tools.py @@ -0,0 +1,102 @@ +########################################################################################### +# Elementary tools for handling irreducible representations +# Authors: Ilyes Batatia, Gregor Simm +# This program is distributed under the MIT License (see MIT.md) +########################################################################################### + +from typing import List, Tuple + +import torch +from e3nn import o3 +from e3nn.util.jit import compile_mode + + +# Based on mir-group/nequip +def tp_out_irreps_with_instructions( + irreps1: o3.Irreps, irreps2: o3.Irreps, target_irreps: o3.Irreps +) -> Tuple[o3.Irreps, List]: + trainable = True + + # Collect possible irreps and their instructions + irreps_out_list: List[Tuple[int, o3.Irreps]] = [] + instructions = [] + for i, (mul, ir_in) in enumerate(irreps1): + for j, (_, ir_edge) in enumerate(irreps2): + for ir_out in ir_in * ir_edge: # | l1 - l2 | <= l <= l1 + l2 + if ir_out in target_irreps: + k = len(irreps_out_list) # instruction index + irreps_out_list.append((mul, ir_out)) + instructions.append((i, j, k, "uvu", trainable)) + + # We sort the output irreps of the tensor product so that we can simplify them + # when they are provided to the second o3.Linear + irreps_out = o3.Irreps(irreps_out_list) + irreps_out, permut, _ = irreps_out.sort() + + # Permute the output indexes of the instructions to match the sorted irreps: + instructions = [ + (i_in1, i_in2, permut[i_out], mode, train) + for i_in1, i_in2, i_out, mode, train in instructions + ] + + instructions = sorted(instructions, key=lambda x: x[2]) + + return irreps_out, instructions + + +def linear_out_irreps(irreps: o3.Irreps, target_irreps: o3.Irreps) -> o3.Irreps: + # Assuming simplified irreps + irreps_mid = [] + for _, ir_in in irreps: + found = False + + for mul, ir_out in target_irreps: + if ir_in == ir_out: + irreps_mid.append((mul, ir_out)) + found = True + break + + if not found: + raise RuntimeError(f"{ir_in} not in {target_irreps}") + + return o3.Irreps(irreps_mid) + + +@compile_mode("script") +class reshape_irreps(torch.nn.Module): + def __init__(self, irreps: o3.Irreps) -> None: + super().__init__() + self.irreps = o3.Irreps(irreps) + self.dims = [] + self.muls = [] + for mul, ir in self.irreps: + d = ir.dim + self.dims.append(d) + self.muls.append(mul) + + def forward(self, tensor: torch.Tensor) -> torch.Tensor: + ix = 0 + out = [] + batch, _ = tensor.shape + for mul, d in zip(self.muls, self.dims): + field = tensor[:, ix : ix + mul * d] # [batch, sample, mul * repr] + ix += mul * d + field = field.reshape(batch, mul, d) + out.append(field) + return torch.cat(out, dim=-1) + + +def extract_invariant(x: torch.Tensor, num_layers: int, num_features: int, l_max: int): + out = [] + for i in range(num_layers - 1): + out.append( + x[ + :, + i + * (l_max + 1) ** 2 + * num_features : (i * (l_max + 1) ** 2 + 1) + * num_features, + ] + ) + out.append(x[:, -num_features:]) + return torch.cat(out, dim=-1) diff --git a/hydragnn/utils/model/mace_utils/modules/__init__.py b/hydragnn/utils/model/mace_utils/modules/__init__.py new file mode 100644 index 000000000..3f6c5d027 --- /dev/null +++ b/hydragnn/utils/model/mace_utils/modules/__init__.py @@ -0,0 +1,51 @@ +########################################################################################### +# __init__ file for Modules +# Authors: Ilyes Batatia, Gregor Simm +# This program is distributed under the MIT License (see MIT.md) +########################################################################################### +# Taken From: +# GitHub: https://github.com/ACEsuit/mace +# ArXiV: https://arxiv.org/pdf/2206.07697 +# Date: August 27, 2024 | 12:37 (EST) +########################################################################################### + +from typing import Callable, Dict, Optional, Type + +import torch + +from .blocks import ( + AtomicEnergiesBlock, + EquivariantProductBasisBlock, + InteractionBlock, + LinearNodeEmbeddingBlock, + LinearReadoutBlock, + NonLinearReadoutBlock, + RadialEmbeddingBlock, + RealAgnosticAttResidualInteractionBlock, + ScaleShiftBlock, +) + +from .radial import BesselBasis, GaussianBasis, PolynomialCutoff +from .symmetric_contraction import SymmetricContraction + +interaction_classes: Dict[str, Type[InteractionBlock]] = { + "RealAgnosticAttResidualInteractionBlock": RealAgnosticAttResidualInteractionBlock, +} + +__all__ = [ + "AtomicEnergiesBlock", + "RadialEmbeddingBlock", + "LinearNodeEmbeddingBlock", + "LinearReadoutBlock", + "EquivariantProductBasisBlock", + "ScaleShiftBlock", + "LinearDipoleReadoutBlock", + "NonLinearDipoleReadoutBlock", + "InteractionBlock", + "NonLinearReadoutBlock", + "PolynomialCutoff", + "BesselBasis", + "GaussianBasis", + "SymmetricContraction", + "interaction_classes", +] diff --git a/hydragnn/utils/model/mace_utils/modules/blocks.py b/hydragnn/utils/model/mace_utils/modules/blocks.py new file mode 100644 index 000000000..27574828d --- /dev/null +++ b/hydragnn/utils/model/mace_utils/modules/blocks.py @@ -0,0 +1,404 @@ +########################################################################################### +# Elementary Block for Building O(3) Equivariant Higher Order Message Passing Neural Network +# Authors: Ilyes Batatia, Gregor Simm +# This program is distributed under the MIT License (see MIT.md) +########################################################################################### +# Taken From: +# GitHub: https://github.com/ACEsuit/mace +# ArXiV: https://arxiv.org/pdf/2206.07697 +# Date: August 27, 2024 | 12:37 (EST) +########################################################################################### + +from abc import abstractmethod +from typing import Callable, List, Optional, Tuple, Union + +import numpy as np +import torch +import torch.nn.functional +from torch_scatter import scatter +from e3nn import nn, o3 +from e3nn.util.jit import compile_mode + +from hydragnn.utils.model.mace_utils.tools.compile import simplify_if_compile +from hydragnn.utils.model.irreps_tools import ( + reshape_irreps, + tp_out_irreps_with_instructions, +) + +from .radial import ( + AgnesiTransform, + BesselBasis, + ChebychevBasis, + GaussianBasis, + PolynomialCutoff, + SoftTransform, +) +from .symmetric_contraction import SymmetricContraction + + +@compile_mode("script") +class LinearNodeEmbeddingBlock(torch.nn.Module): + def __init__(self, irreps_in: o3.Irreps, irreps_out: o3.Irreps): + super().__init__() + self.linear = o3.Linear(irreps_in=irreps_in, irreps_out=irreps_out) + + def forward( + self, + node_attrs: torch.Tensor, + ) -> torch.Tensor: # [n_nodes, irreps] + return self.linear(node_attrs) + + +@compile_mode("script") +class LinearReadoutBlock(torch.nn.Module): + def __init__(self, irreps_in: o3.Irreps): + super().__init__() + self.linear = o3.Linear(irreps_in=irreps_in, irreps_out=o3.Irreps("0e")) + + def forward(self, x: torch.Tensor) -> torch.Tensor: # [n_nodes, irreps] # [..., ] + return self.linear(x) # [n_nodes, 1] + + +@simplify_if_compile +@compile_mode("script") +class NonLinearReadoutBlock(torch.nn.Module): + def __init__( + self, irreps_in: o3.Irreps, MLP_irreps: o3.Irreps, gate: Optional[Callable] + ): + super().__init__() + self.hidden_irreps = MLP_irreps + self.linear_1 = o3.Linear(irreps_in=irreps_in, irreps_out=self.hidden_irreps) + self.non_linearity = nn.Activation( + irreps_in=self.hidden_irreps, acts=[gate] + ) # Need to adjust this to actually use the gate + self.linear_2 = o3.Linear( + irreps_in=self.hidden_irreps, irreps_out=o3.Irreps("0e") + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: # [n_nodes, irreps] # [..., ] + x = self.non_linearity(self.linear_1(x)) + return self.linear_2(x) # [n_nodes, 1] + + +@compile_mode("script") +class AtomicEnergiesBlock(torch.nn.Module): + atomic_energies: torch.Tensor + + def __init__(self, atomic_energies: Union[np.ndarray, torch.Tensor]): + super().__init__() + assert len(atomic_energies.shape) == 1 + + self.register_buffer( + "atomic_energies", + torch.tensor(atomic_energies, dtype=torch.get_default_dtype()), + ) # [n_elements, ] + + def forward( + self, x: torch.Tensor # one-hot of elements [..., n_elements] + ) -> torch.Tensor: # [..., ] + return torch.matmul(x, self.atomic_energies) + + def __repr__(self): + formatted_energies = ", ".join([f"{x:.4f}" for x in self.atomic_energies]) + return f"{self.__class__.__name__}(energies=[{formatted_energies}])" + + +@compile_mode("script") +class AtomicBlock(torch.nn.Module): + def __init__(self, output_dim): + super().__init__() + # Initialize the atomic energies as a trainable parameter + self.atomic_energies = torch.nn.Parameter( + torch.randn(118, output_dim) + ) # There are 118 known elements + + def forward(self, atomic_numbers): + # Perform the linear multiplication (no bias) + return ( + atomic_numbers @ self.atomic_energies + ) # Output will now have shape [batch_size, output_dim] + + +@compile_mode("script") +class RadialEmbeddingBlock(torch.nn.Module): + def __init__( + self, + r_max: float, + num_bessel: int, + num_polynomial_cutoff: int, + radial_type: str = "bessel", + distance_transform: str = "None", + ): + super().__init__() + if radial_type == "bessel": + self.bessel_fn = BesselBasis(r_max=r_max, num_basis=num_bessel) + elif radial_type == "gaussian": + self.bessel_fn = GaussianBasis(r_max=r_max, num_basis=num_bessel) + elif radial_type == "chebyshev": + self.bessel_fn = ChebychevBasis(r_max=r_max, num_basis=num_bessel) + if distance_transform == "Agnesi": + self.distance_transform = AgnesiTransform() + elif distance_transform == "Soft": + self.distance_transform = SoftTransform() + self.cutoff_fn = PolynomialCutoff(r_max=r_max, p=num_polynomial_cutoff) + self.out_dim = num_bessel + + def forward( + self, + edge_lengths: torch.Tensor, # [n_edges, 1] + node_attrs: torch.Tensor, + edge_index: torch.Tensor, + atomic_numbers: torch.Tensor, + ): + cutoff = self.cutoff_fn(edge_lengths) # [n_edges, 1] + if hasattr(self, "distance_transform"): + edge_lengths = self.distance_transform( + edge_lengths, node_attrs, edge_index, atomic_numbers + ) + radial = self.bessel_fn(edge_lengths) # [n_edges, n_basis] + return radial * cutoff # [n_edges, n_basis] + + +@compile_mode("script") +class EquivariantProductBasisBlock(torch.nn.Module): + def __init__( + self, + node_feats_irreps: o3.Irreps, + target_irreps: o3.Irreps, + correlation: int, + use_sc: bool = True, + num_elements: Optional[int] = None, + ) -> None: + super().__init__() + + self.use_sc = use_sc + self.symmetric_contractions = SymmetricContraction( + irreps_in=node_feats_irreps, + irreps_out=target_irreps, + correlation=correlation, + num_elements=num_elements, + ) + # Update linear + self.linear = o3.Linear( + target_irreps, + target_irreps, + internal_weights=True, + shared_weights=True, + ) + + def forward( + self, + node_feats: torch.Tensor, + sc: Optional[torch.Tensor], + node_attrs: torch.Tensor, + ) -> torch.Tensor: + node_feats = self.symmetric_contractions(node_feats, node_attrs) + if self.use_sc and sc is not None: + return self.linear(node_feats) + sc + return self.linear(node_feats) + + +@compile_mode("script") +class InteractionBlock(torch.nn.Module): + def __init__( + self, + node_attrs_irreps: o3.Irreps, + node_feats_irreps: o3.Irreps, + edge_attrs_irreps: o3.Irreps, + edge_feats_irreps: o3.Irreps, + target_irreps: o3.Irreps, + hidden_irreps: o3.Irreps, + avg_num_neighbors: float, + radial_MLP: Optional[List[int]] = None, + ) -> None: + super().__init__() + self.node_attrs_irreps = node_attrs_irreps + self.node_feats_irreps = node_feats_irreps + self.edge_attrs_irreps = edge_attrs_irreps + self.edge_feats_irreps = edge_feats_irreps + self.target_irreps = target_irreps + self.hidden_irreps = hidden_irreps + self.avg_num_neighbors = avg_num_neighbors + if radial_MLP is None: + radial_MLP = [64, 64, 64] + self.radial_MLP = radial_MLP + + self._setup() + + @abstractmethod + def _setup(self) -> None: + raise NotImplementedError + + @abstractmethod + def forward( + self, + node_attrs: torch.Tensor, + node_feats: torch.Tensor, + edge_attrs: torch.Tensor, + edge_feats: torch.Tensor, + edge_index: torch.Tensor, + ) -> torch.Tensor: + raise NotImplementedError + + +nonlinearities = {1: torch.nn.functional.silu, -1: torch.tanh} + + +@compile_mode("script") +class TensorProductWeightsBlock(torch.nn.Module): + def __init__(self, num_elements: int, num_edge_feats: int, num_feats_out: int): + super().__init__() + + weights = torch.empty( + (num_elements, num_edge_feats, num_feats_out), + dtype=torch.get_default_dtype(), + ) + torch.nn.init.xavier_uniform_(weights) + self.weights = torch.nn.Parameter(weights) + + def forward( + self, + sender_or_receiver_node_attrs: torch.Tensor, # assumes that the node attributes are one-hot encoded + edge_feats: torch.Tensor, + ): + return torch.einsum( + "be, ba, aek -> bk", edge_feats, sender_or_receiver_node_attrs, self.weights + ) + + def __repr__(self): + return ( + f'{self.__class__.__name__}(shape=({", ".join(str(s) for s in self.weights.shape)}), ' + f"weights={np.prod(self.weights.shape)})" + ) + + +########################################################################################### +# NOTE: Below is one of the many possible Interaction Blocks in the MACE architecture. +# Since there are adaptations to the original code in order to be integrated with +# the HydraGNN framework, and the changes between blocks are relatively minor, we've +# elected to adapt one general-purpose block here. Users can access the other blocks +# and adapt similarly from the original MACE code (linked with date at the top). +########################################################################################### +@compile_mode("script") +class RealAgnosticAttResidualInteractionBlock(InteractionBlock): + def _setup(self) -> None: + self.node_feats_down_irreps = o3.Irreps( + [(o3.Irreps(self.hidden_irreps).count(o3.Irrep(0, 1)), (0, 1))] + ) + # First linear + self.linear_up = o3.Linear( + self.node_feats_irreps, + self.node_feats_irreps, + internal_weights=True, + shared_weights=True, + ) + # TensorProduct + irreps_mid, instructions = tp_out_irreps_with_instructions( + self.node_feats_irreps, + self.edge_attrs_irreps, + self.target_irreps, + ) + self.conv_tp = o3.TensorProduct( + self.node_feats_irreps, + self.edge_attrs_irreps, + irreps_mid, + instructions=instructions, + shared_weights=False, + internal_weights=False, + ) + + # Convolution weights + self.linear_down = o3.Linear( + self.node_feats_irreps, + self.node_feats_down_irreps, + internal_weights=True, + shared_weights=True, + ) + input_dim = ( + self.edge_feats_irreps.num_irreps # The irreps here should be scalars because they are fed through an activation in self.conv_tp_weights + + 2 * self.node_feats_down_irreps.num_irreps + ) + # The following specifies the network architecture for embedding l=0 (scalar) irreps + ## It is worth double-checking, but I believe this means that type 0 (scalar) irreps + # are being embedded by 3 layers of size self.hidden_dim (scalar irreps) and the + # output dim, then activated. + self.conv_tp_weights = nn.FullyConnectedNet( + [input_dim] + + 3 * [o3.Irreps(self.hidden_irreps).count(o3.Irrep(0, 1))] + + [self.conv_tp.weight_numel], + torch.nn.functional.silu, + ) + + # Linear + irreps_mid = ( + irreps_mid.simplify() + ) # .simplify() essentially combines irreps of the same type so that normalization is done across them all. The site has an in-depth explanation + self.irreps_out = self.target_irreps + self.linear = o3.Linear( + irreps_mid, + self.irreps_out, + internal_weights=True, + shared_weights=True, + ) + + self.reshape = reshape_irreps(self.irreps_out) + + # Skip connection. + self.skip_linear = o3.Linear( + self.node_feats_irreps, self.hidden_irreps + ) # This will be size (num_nodes, 64*9) when there are 64 channels and irreps, 0, 1, 2 (1+3+5=9) ## This becomes sc + + def forward( + self, + node_feats: torch.Tensor, + node_attrs: torch.Tensor, + edge_attrs: torch.Tensor, + edge_feats: torch.Tensor, + edge_index: torch.Tensor, + ) -> Tuple[torch.Tensor, None]: + sender = edge_index[0] + receiver = edge_index[1] + num_nodes = node_feats.shape[0] + sc = self.skip_linear(node_feats) + node_feats_up = self.linear_up(node_feats) + node_feats_down = self.linear_down(node_feats) + augmented_edge_feats = torch.cat( + [ + edge_feats, + node_feats_down[sender], + node_feats_down[receiver], + ], + dim=-1, + ) + tp_weights = self.conv_tp_weights(augmented_edge_feats) + mji = self.conv_tp( + node_feats_up[sender], edge_attrs, tp_weights + ) # [n_edges, irreps] + message = scatter( + src=mji, index=receiver, dim=0, dim_size=num_nodes, reduce="sum" + ) # [n_nodes, irreps] + message = self.linear(message) / self.avg_num_neighbors + return ( + self.reshape(message), + sc, + ) # [n_nodes, channels, (lmax + 1)**2] + + +@compile_mode("script") +class ScaleShiftBlock(torch.nn.Module): + def __init__(self, scale: float, shift: float): + super().__init__() + self.register_buffer( + "scale", torch.tensor(scale, dtype=torch.get_default_dtype()) + ) + self.register_buffer( + "shift", torch.tensor(shift, dtype=torch.get_default_dtype()) + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return self.scale * x + self.shift + + def __repr__(self): + return ( + f"{self.__class__.__name__}(scale={self.scale:.6f}, shift={self.shift:.6f})" + ) diff --git a/hydragnn/utils/model/mace_utils/modules/radial.py b/hydragnn/utils/model/mace_utils/modules/radial.py new file mode 100644 index 000000000..f53896c49 --- /dev/null +++ b/hydragnn/utils/model/mace_utils/modules/radial.py @@ -0,0 +1,248 @@ +########################################################################################### +# Radial basis and cutoff +# Authors: Ilyes Batatia, Gregor Simm +# This program is distributed under the MIT License (see MIT.md) +########################################################################################### +# Taken From: +# GitHub: https://github.com/ACEsuit/mace +# ArXiV: https://arxiv.org/pdf/2206.07697 +# Date: August 27, 2024 | 12:37 (EST) +########################################################################################### + +import ase + +import numpy as np +import torch +from torch_scatter import scatter +from e3nn.util.jit import compile_mode + +from hydragnn.utils.model.mace_utils.tools.compile import simplify_if_compile + + +@compile_mode("script") +class BesselBasis(torch.nn.Module): + """ + Equation (7) + """ + + def __init__(self, r_max: float, num_basis=8, trainable=False): + super().__init__() + + bessel_weights = ( + np.pi + / r_max + * torch.linspace( + start=1.0, + end=num_basis, + steps=num_basis, + dtype=torch.get_default_dtype(), + ) + ) + if trainable: + self.bessel_weights = torch.nn.Parameter(bessel_weights) + else: + self.register_buffer("bessel_weights", bessel_weights) + + self.register_buffer( + "r_max", torch.tensor(r_max, dtype=torch.get_default_dtype()) + ) + self.register_buffer( + "prefactor", + torch.tensor(np.sqrt(2.0 / r_max), dtype=torch.get_default_dtype()), + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: # [..., 1] + numerator = torch.sin(self.bessel_weights * x) # [..., num_basis] + return self.prefactor * (numerator / x) + + def __repr__(self): + return ( + f"{self.__class__.__name__}(r_max={self.r_max}, num_basis={len(self.bessel_weights)}, " + f"trainable={self.bessel_weights.requires_grad})" + ) + + +@compile_mode("script") +class ChebychevBasis(torch.nn.Module): + """ + Equation (7) + """ + + def __init__(self, r_max: float, num_basis=8): + super().__init__() + self.register_buffer( + "n", + torch.arange(1, num_basis + 1, dtype=torch.get_default_dtype()).unsqueeze( + 0 + ), + ) + self.num_basis = num_basis + self.r_max = r_max + + def forward(self, x: torch.Tensor) -> torch.Tensor: # [..., 1] + x = x.repeat(1, self.num_basis) + n = self.n.repeat(len(x), 1) + return torch.special.chebyshev_polynomial_t(x, n) + + def __repr__(self): + return ( + f"{self.__class__.__name__}(r_max={self.r_max}, num_basis={self.num_basis}," + ) + + +@compile_mode("script") +class GaussianBasis(torch.nn.Module): + """ + Gaussian basis functions + """ + + def __init__(self, r_max: float, num_basis=128, trainable=False): + super().__init__() + gaussian_weights = torch.linspace( + start=0.0, end=r_max, steps=num_basis, dtype=torch.get_default_dtype() + ) + if trainable: + self.gaussian_weights = torch.nn.Parameter( + gaussian_weights, requires_grad=True + ) + else: + self.register_buffer("gaussian_weights", gaussian_weights) + self.coeff = -0.5 / (r_max / (num_basis - 1)) ** 2 + + def forward(self, x: torch.Tensor) -> torch.Tensor: # [..., 1] + x = x - self.gaussian_weights + return torch.exp(self.coeff * torch.pow(x, 2)) + + +@compile_mode("script") +class PolynomialCutoff(torch.nn.Module): + """ + Equation (8) + """ + + p: torch.Tensor + r_max: torch.Tensor + + def __init__(self, r_max: float, p=6): + super().__init__() + self.register_buffer("p", torch.tensor(p, dtype=torch.get_default_dtype())) + self.register_buffer( + "r_max", torch.tensor(r_max, dtype=torch.get_default_dtype()) + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + # yapf: disable + envelope = ( + 1.0 + - ((self.p + 1.0) * (self.p + 2.0) / 2.0) * torch.pow(x / self.r_max, self.p) + + self.p * (self.p + 2.0) * torch.pow(x / self.r_max, self.p + 1) + - (self.p * (self.p + 1.0) / 2) * torch.pow(x / self.r_max, self.p + 2) + ) + # yapf: enable + + # noinspection PyUnresolvedReferences + return envelope * (x < self.r_max) + + def __repr__(self): + return f"{self.__class__.__name__}(p={self.p}, r_max={self.r_max})" + + +@compile_mode("script") +class AgnesiTransform(torch.nn.Module): + """ + Agnesi transform see ACEpotentials.jl, JCP 2023, p. 160 + """ + + def __init__( + self, + q: float = 0.9183, + p: float = 4.5791, + a: float = 1.0805, + trainable=False, + ): + super().__init__() + self.register_buffer("q", torch.tensor(q, dtype=torch.get_default_dtype())) + self.register_buffer("p", torch.tensor(p, dtype=torch.get_default_dtype())) + self.register_buffer("a", torch.tensor(a, dtype=torch.get_default_dtype())) + self.register_buffer( + "covalent_radii", + torch.tensor( + ase.data.covalent_radii, + dtype=torch.get_default_dtype(), + ), + ) + if trainable: + self.a = torch.nn.Parameter(torch.tensor(1.0805, requires_grad=True)) + self.q = torch.nn.Parameter(torch.tensor(0.9183, requires_grad=True)) + self.p = torch.nn.Parameter(torch.tensor(4.5791, requires_grad=True)) + + def forward( + self, + x: torch.Tensor, + node_attrs: torch.Tensor, + edge_index: torch.Tensor, + atomic_numbers: torch.Tensor, + ) -> torch.Tensor: + sender = edge_index[0] + receiver = edge_index[1] + node_atomic_numbers = atomic_numbers[torch.argmax(node_attrs, dim=1)].unsqueeze( + -1 + ) + Z_u = node_atomic_numbers[sender] + Z_v = node_atomic_numbers[receiver] + r_0 = 0.5 * (self.covalent_radii[Z_u] + self.covalent_radii[Z_v]) + return ( + 1 + (self.a * ((x / r_0) ** self.q) / (1 + (x / r_0) ** (self.q - self.p))) + ) ** (-1) + + def __repr__(self): + return f"{self.__class__.__name__}(a={self.a}, q={self.q}, p={self.p})" + + +@simplify_if_compile +@compile_mode("script") +class SoftTransform(torch.nn.Module): + """ + Soft Transform + """ + + def __init__(self, a: float = 0.2, b: float = 3.0, trainable=False): + super().__init__() + self.register_buffer( + "covalent_radii", + torch.tensor( + ase.data.covalent_radii, + dtype=torch.get_default_dtype(), + ), + ) + if trainable: + self.a = torch.nn.Parameter(torch.tensor(a, requires_grad=True)) + self.b = torch.nn.Parameter(torch.tensor(b, requires_grad=True)) + else: + self.register_buffer("a", torch.tensor(a)) + self.register_buffer("b", torch.tensor(b)) + + def forward( + self, + x: torch.Tensor, + node_attrs: torch.Tensor, + edge_index: torch.Tensor, + atomic_numbers: torch.Tensor, + ) -> torch.Tensor: + sender = edge_index[0] + receiver = edge_index[1] + node_atomic_numbers = atomic_numbers[torch.argmax(node_attrs, dim=1)].unsqueeze( + -1 + ) + Z_u = node_atomic_numbers[sender] + Z_v = node_atomic_numbers[receiver] + r_0 = (self.covalent_radii[Z_u] + self.covalent_radii[Z_v]) / 4 + y = ( + x + + (1 / 2) * torch.tanh(-(x / r_0) - self.a * ((x / r_0) ** self.b)) + + 1 / 2 + ) + return y + + def __repr__(self): + return f"{self.__class__.__name__}(a={self.a.item()}, b={self.b.item()})" diff --git a/hydragnn/utils/model/mace_utils/modules/symmetric_contraction.py b/hydragnn/utils/model/mace_utils/modules/symmetric_contraction.py new file mode 100644 index 000000000..465d8fa9e --- /dev/null +++ b/hydragnn/utils/model/mace_utils/modules/symmetric_contraction.py @@ -0,0 +1,238 @@ +########################################################################################### +# Implementation of the symmetric contraction algorithm presented in the MACE paper +# (Batatia et al, MACE: Higher Order Equivariant Message Passing Neural Networks for Fast and Accurate Force Fields , Eq.10 and 11) +# Authors: Ilyes Batatia +# This program is distributed under the MIT License (see MIT.md) +########################################################################################### +# Taken From: +# GitHub: https://github.com/ACEsuit/mace +# ArXiV: https://arxiv.org/pdf/2206.07697 +# Date: August 27, 2024 | 12:37 (EST) +########################################################################################### + +from typing import Dict, Optional, Union + +import opt_einsum_fx +import torch +import torch.fx +from e3nn import o3 +from e3nn.util.codegen import CodeGenMixin +from e3nn.util.jit import compile_mode + +from hydragnn.utils.model.mace_utils.tools.cg import U_matrix_real + +BATCH_EXAMPLE = 10 +ALPHABET = ["w", "x", "v", "n", "z", "r", "t", "y", "u", "o", "p", "s"] + + +@compile_mode("script") +class SymmetricContraction(CodeGenMixin, torch.nn.Module): + def __init__( + self, + irreps_in: o3.Irreps, + irreps_out: o3.Irreps, + correlation: Union[int, Dict[str, int]], + irrep_normalization: str = "component", + path_normalization: str = "element", + internal_weights: Optional[bool] = None, + shared_weights: Optional[bool] = None, + num_elements: Optional[int] = None, + ) -> None: + super().__init__() + + if irrep_normalization is None: + irrep_normalization = "component" + + if path_normalization is None: + path_normalization = "element" + + assert irrep_normalization in ["component", "norm", "none"] + assert path_normalization in ["element", "path", "none"] + + self.irreps_in = o3.Irreps(irreps_in) + self.irreps_out = o3.Irreps(irreps_out) + + del irreps_in, irreps_out + + if not isinstance(correlation, tuple): + corr = correlation + correlation = {} + for irrep_out in self.irreps_out: + correlation[irrep_out] = corr + + assert shared_weights or not internal_weights + + if internal_weights is None: + internal_weights = True + + self.internal_weights = internal_weights + self.shared_weights = shared_weights + + del internal_weights, shared_weights + + self.contractions = torch.nn.ModuleList() + for irrep_out in self.irreps_out: + self.contractions.append( + Contraction( + irreps_in=self.irreps_in, + irrep_out=o3.Irreps(str(irrep_out.ir)), + correlation=correlation[irrep_out], + internal_weights=self.internal_weights, + num_elements=num_elements, + weights=self.shared_weights, + ) + ) + + def forward(self, x: torch.Tensor, y: torch.Tensor): + outs = [contraction(x, y) for contraction in self.contractions] + return torch.cat(outs, dim=-1) + + +@compile_mode("script") +class Contraction(torch.nn.Module): + def __init__( + self, + irreps_in: o3.Irreps, + irrep_out: o3.Irreps, + correlation: int, + internal_weights: bool = True, + num_elements: Optional[int] = None, + weights: Optional[torch.Tensor] = None, + ) -> None: + super().__init__() + + self.num_features = irreps_in.count((0, 1)) + self.coupling_irreps = o3.Irreps([irrep.ir for irrep in irreps_in]) + self.correlation = correlation + dtype = torch.get_default_dtype() + for nu in range(1, correlation + 1): + U_matrix = U_matrix_real( + irreps_in=self.coupling_irreps, + irreps_out=irrep_out, + correlation=nu, + dtype=dtype, + )[-1] + self.register_buffer(f"U_matrix_{nu}", U_matrix) + + # Tensor contraction equations + self.contractions_weighting = torch.nn.ModuleList() + self.contractions_features = torch.nn.ModuleList() + + # Create weight for product basis + self.weights = torch.nn.ParameterList([]) + + for i in range(correlation, 0, -1): + # Shapes definying + num_params = self.U_tensors(i).size()[-1] + num_equivariance = 2 * irrep_out.lmax + 1 + num_ell = self.U_tensors(i).size()[-2] + + if i == correlation: + parse_subscript_main = ( + [ALPHABET[j] for j in range(i + min(irrep_out.lmax, 1) - 1)] + + ["ik,ekc,bci,be -> bc"] + + [ALPHABET[j] for j in range(i + min(irrep_out.lmax, 1) - 1)] + ) + graph_module_main = torch.fx.symbolic_trace( + lambda x, y, w, z: torch.einsum( + "".join(parse_subscript_main), x, y, w, z + ) + ) + + # Optimizing the contractions + self.graph_opt_main = opt_einsum_fx.optimize_einsums_full( + model=graph_module_main, + example_inputs=( + torch.randn( + [num_equivariance] + [num_ell] * i + [num_params] + ).squeeze(0), + torch.randn((num_elements, num_params, self.num_features)), + torch.randn((BATCH_EXAMPLE, self.num_features, num_ell)), + torch.randn((BATCH_EXAMPLE, num_elements)), + ), + ) + # Parameters for the product basis + w = torch.nn.Parameter( + torch.randn((num_elements, num_params, self.num_features)) + / num_params + ) + self.weights_max = w + else: + # Generate optimized contractions equations + parse_subscript_weighting = ( + [ALPHABET[j] for j in range(i + min(irrep_out.lmax, 1))] + + ["k,ekc,be->bc"] + + [ALPHABET[j] for j in range(i + min(irrep_out.lmax, 1))] + ) + parse_subscript_features = ( + ["bc"] + + [ALPHABET[j] for j in range(i - 1 + min(irrep_out.lmax, 1))] + + ["i,bci->bc"] + + [ALPHABET[j] for j in range(i - 1 + min(irrep_out.lmax, 1))] + ) + + # Symbolic tracing of contractions + graph_module_weighting = torch.fx.symbolic_trace( + lambda x, y, z: torch.einsum( + "".join(parse_subscript_weighting), x, y, z + ) + ) + graph_module_features = torch.fx.symbolic_trace( + lambda x, y: torch.einsum("".join(parse_subscript_features), x, y) + ) + + # Optimizing the contractions + graph_opt_weighting = opt_einsum_fx.optimize_einsums_full( + model=graph_module_weighting, + example_inputs=( + torch.randn( + [num_equivariance] + [num_ell] * i + [num_params] + ).squeeze(0), + torch.randn((num_elements, num_params, self.num_features)), + torch.randn((BATCH_EXAMPLE, num_elements)), + ), + ) + graph_opt_features = opt_einsum_fx.optimize_einsums_full( + model=graph_module_features, + example_inputs=( + torch.randn( + [BATCH_EXAMPLE, self.num_features, num_equivariance] + + [num_ell] * i + ).squeeze(2), + torch.randn((BATCH_EXAMPLE, self.num_features, num_ell)), + ), + ) + self.contractions_weighting.append(graph_opt_weighting) + self.contractions_features.append(graph_opt_features) + # Parameters for the product basis + w = torch.nn.Parameter( + torch.randn((num_elements, num_params, self.num_features)) + / num_params + ) + self.weights.append(w) + if not internal_weights: + self.weights = weights[:-1] + self.weights_max = weights[-1] + + def forward(self, x: torch.Tensor, y: torch.Tensor): + out = self.graph_opt_main( + self.U_tensors(self.correlation), + self.weights_max, + x, + y, + ) + for i, (weight, contract_weights, contract_features) in enumerate( + zip(self.weights, self.contractions_weighting, self.contractions_features) + ): + c_tensor = contract_weights( + self.U_tensors(self.correlation - i - 1), + weight, + y, + ) + c_tensor = c_tensor + out + out = contract_features(c_tensor, x) + + return out.view(out.shape[0], -1) + + def U_tensors(self, nu: int): + return dict(self.named_buffers())[f"U_matrix_{nu}"] diff --git a/hydragnn/utils/model/mace_utils/tools/__init__.py b/hydragnn/utils/model/mace_utils/tools/__init__.py new file mode 100644 index 000000000..26207ecbe --- /dev/null +++ b/hydragnn/utils/model/mace_utils/tools/__init__.py @@ -0,0 +1,16 @@ +########################################################################################### +# __init__ file for Tools +# Authors: Ilyes Batatia, Gregor Simm +# This program is distributed under the MIT License (see MIT.md) +########################################################################################### +# Taken From: +# GitHub: https://github.com/ACEsuit/mace +# ArXiV: https://arxiv.org/pdf/2206.07697 +# Date: August 27, 2024 | 12:37 (EST) +########################################################################################### + +from .cg import U_matrix_real + +__all__ = [ + "U_matrix_real", +] diff --git a/hydragnn/utils/model/mace_utils/tools/cg.py b/hydragnn/utils/model/mace_utils/tools/cg.py new file mode 100644 index 000000000..f0349998b --- /dev/null +++ b/hydragnn/utils/model/mace_utils/tools/cg.py @@ -0,0 +1,136 @@ +########################################################################################### +# Higher Order Real Clebsch Gordan (based on e3nn by Mario Geiger) +# Authors: Ilyes Batatia +# This program is distributed under the MIT License (see MIT.md) +########################################################################################### +# Taken From: +# GitHub: https://github.com/ACEsuit/mace +# ArXiV: https://arxiv.org/pdf/2206.07697 +# Date: August 27, 2024 | 12:37 (EST) +########################################################################################### + +import collections +from typing import List, Union + +import torch +from e3nn import o3 + +_TP = collections.namedtuple("_TP", "op, args") +_INPUT = collections.namedtuple("_INPUT", "tensor, start, stop") + + +def _wigner_nj( + irrepss: List[o3.Irreps], + normalization: str = "component", + filter_ir_mid=None, + dtype=None, +): + irrepss = [o3.Irreps(irreps) for irreps in irrepss] + if filter_ir_mid is not None: + filter_ir_mid = [o3.Irrep(ir) for ir in filter_ir_mid] + + if len(irrepss) == 1: + (irreps,) = irrepss + ret = [] + e = torch.eye(irreps.dim, dtype=dtype) + i = 0 + for mul, ir in irreps: + for _ in range(mul): + sl = slice(i, i + ir.dim) + ret += [(ir, _INPUT(0, sl.start, sl.stop), e[sl])] + i += ir.dim + return ret + + *irrepss_left, irreps_right = irrepss + ret = [] + for ir_left, path_left, C_left in _wigner_nj( + irrepss_left, + normalization=normalization, + filter_ir_mid=filter_ir_mid, + dtype=dtype, + ): + i = 0 + for mul, ir in irreps_right: + for ir_out in ir_left * ir: + if filter_ir_mid is not None and ir_out not in filter_ir_mid: + continue + + C = o3.wigner_3j(ir_out.l, ir_left.l, ir.l, dtype=dtype) + if normalization == "component": + C *= ir_out.dim ** 0.5 + if normalization == "norm": + C *= ir_left.dim ** 0.5 * ir.dim ** 0.5 + + C = torch.einsum("jk,ijl->ikl", C_left.flatten(1), C) + C = C.reshape( + ir_out.dim, *(irreps.dim for irreps in irrepss_left), ir.dim + ) + for u in range(mul): + E = torch.zeros( + ir_out.dim, + *(irreps.dim for irreps in irrepss_left), + irreps_right.dim, + dtype=dtype, + ) + sl = slice(i + u * ir.dim, i + (u + 1) * ir.dim) + E[..., sl] = C + ret += [ + ( + ir_out, + _TP( + op=(ir_left, ir, ir_out), + args=( + path_left, + _INPUT(len(irrepss_left), sl.start, sl.stop), + ), + ), + E, + ) + ] + i += mul * ir.dim + return sorted(ret, key=lambda x: x[0]) + + +def U_matrix_real( + irreps_in: Union[str, o3.Irreps], + irreps_out: Union[str, o3.Irreps], + correlation: int, + normalization: str = "component", + filter_ir_mid=None, + dtype=None, +): + irreps_out = o3.Irreps(irreps_out) + irrepss = [o3.Irreps(irreps_in)] * correlation + if correlation == 4: + filter_ir_mid = [ + (0, 1), + (1, -1), + (2, 1), + (3, -1), + (4, 1), + (5, -1), + (6, 1), + (7, -1), + (8, 1), + (9, -1), + (10, 1), + (11, -1), + ] + wigners = _wigner_nj(irrepss, normalization, filter_ir_mid, dtype) + current_ir = wigners[0][0] + out = [] + stack = torch.tensor([]) + + for ir, _, base_o3 in wigners: + if ir in irreps_out and ir == current_ir: + stack = torch.cat((stack, base_o3.squeeze().unsqueeze(-1)), dim=-1) + last_ir = current_ir + elif ir in irreps_out and ir != current_ir: + if len(stack) != 0: + out += [last_ir, stack] + stack = base_o3.squeeze().unsqueeze(-1) + current_ir, last_ir = ir, ir + else: + current_ir = ir + out += [last_ir, stack] + return out diff --git a/hydragnn/utils/model/mace_utils/tools/compile.py b/hydragnn/utils/model/mace_utils/tools/compile.py new file mode 100644 index 000000000..7179676c6 --- /dev/null +++ b/hydragnn/utils/model/mace_utils/tools/compile.py @@ -0,0 +1,112 @@ +########################################################################################### +# Compilation utilities for MACE +# Authors: Ilyes Batatia, Gregor Simm and David Kovacs +# This code is pulled from the MACE repository, which is distributed under the MIT License +########################################################################################### +# Taken From: +# GitHub: https://github.com/ACEsuit/mace +# ArXiV: https://arxiv.org/pdf/2206.07697 +# Date: August 27, 2024 | 12:37 (EST) +########################################################################################### +# NOTE: This file relates heavily to speedups in the MACE architecture which can be done by +# compiling correctly. As a result, much of this code can be commented and MACE will +# still function correctly. Specifically, only simplify_if_compile() and it's imports +# are required. However, for optimal performance, it is recommended to keep the entire +# file as-is. +########################################################################################### + +from contextlib import contextmanager +from functools import wraps +from typing import Callable, Tuple + +try: + import torch._dynamo as dynamo +except ImportError: + dynamo = None +from e3nn import get_optimization_defaults, set_optimization_defaults +from torch import autograd, nn +from torch.fx import symbolic_trace + +ModuleFactory = Callable[..., nn.Module] +TypeTuple = Tuple[type, ...] + + +@contextmanager +def disable_e3nn_codegen(): + """Context manager that disables the legacy PyTorch code generation used in e3nn.""" + init_val = get_optimization_defaults()["jit_script_fx"] + set_optimization_defaults(jit_script_fx=False) + yield + set_optimization_defaults(jit_script_fx=init_val) + + +def prepare(func: ModuleFactory, allow_autograd: bool = True) -> ModuleFactory: + """Function transform that prepares a MACE module for torch.compile + Args: + func (ModuleFactory): A function that creates an nn.Module + allow_autograd (bool, optional): Force inductor compiler to inline call to + `torch.autograd.grad`. Defaults to True. + Returns: + ModuleFactory: Decorated function that creates a torch.compile compatible module + """ + if allow_autograd: + dynamo.allow_in_graph(autograd.grad) + elif dynamo.allowed_functions.is_allowed(autograd.grad): + dynamo.disallow_in_graph(autograd.grad) + + @wraps(func) + def wrapper(*args, **kwargs): + with disable_e3nn_codegen(): + model = func(*args, **kwargs) + + model = simplify(model) + return model + + return wrapper + + +_SIMPLIFY_REGISTRY = set() + + +def simplify_if_compile(module: nn.Module) -> nn.Module: + """Decorator to register a module for symbolic simplification + + The decorated module will be simplifed using `torch.fx.symbolic_trace`. + This constrains the module to not have any dynamic control flow, see: + + https://pytorch.org/docs/stable/fx.html#limitations-of-symbolic-tracing + + Args: + module (nn.Module): the module to register + + Returns: + nn.Module: registered module + """ + _SIMPLIFY_REGISTRY.add(module) + return module + + +# This is a different type of simplify() function than provided in e3nn. In e3nn, simplify() +# combines irreps, while this one simplifies the module for symbolic tracing. +def simplify(module: nn.Module) -> nn.Module: + """Recursively searches for registered modules to simplify with + `torch.fx.symbolic_trace` to support compiling with the PyTorch Dynamo compiler. + + Modules are registered with the `simplify_if_compile` decorator and + + Args: + module (nn.Module): the module to simplify + + Returns: + nn.Module: the simplified module + """ + simplify_types = tuple(_SIMPLIFY_REGISTRY) + + for name, child in module.named_children(): + if isinstance(child, simplify_types): + traced = symbolic_trace(child) + setattr(module, name, traced) + else: + simplify(child) + + return module diff --git a/hydragnn/utils/model/model.py b/hydragnn/utils/model/model.py index 6b6d3eb56..7e3251e08 100644 --- a/hydragnn/utils/model/model.py +++ b/hydragnn/utils/model/model.py @@ -122,9 +122,9 @@ def load_existing_model( model.load_checkpoint(os.path.join(path, model_name), model_name) -## This function may cause OOM if datasets is too large +## These functions may cause OOM if dataset is too large ## to fit in a single GPU (i.e., with DDP). Use with caution. -## Recommend to use calculate_PNA_degree_dist +## Recommend to use calculate_PNA_degree_dist or calculate_avg_deg_dist def calculate_PNA_degree(loader, max_neighbours): backend = os.getenv("HYDRAGNN_AGGR_BACKEND", "torch") if backend == "torch": @@ -139,6 +139,22 @@ def calculate_PNA_degree(loader, max_neighbours): return deg +def calculate_avg_deg(loader): + backend = os.getenv("HYDRAGNN_AGGR_BACKEND", "torch") + if backend == "torch": + return calculate_avg_deg_dist(loader) + elif backend == "mpi": + return calculate_avg_deg_mpi(loader) + else: + deg = 0 + counter = 0 + for data in iterate_tqdm(loader, 2, desc="Calculate avg degree"): + d = degree(data.edge_index[1], num_nodes=data.num_nodes, dtype=torch.long) + deg += d.sum() + counter += d.size(0) + return deg / counter + + def calculate_PNA_degree_dist(loader, max_neighbours): assert dist.is_initialized() deg = torch.zeros(max_neighbours + 1, dtype=torch.long) @@ -151,6 +167,23 @@ def calculate_PNA_degree_dist(loader, max_neighbours): return deg +def calculate_avg_deg_dist(loader): + assert dist.is_initialized() + deg = 0 + counter = 0 + for data in iterate_tqdm(loader, 2, desc="Calculate avg degree"): + d = degree(data.edge_index[1], num_nodes=data.num_nodes, dtype=torch.long) + deg += d.sum() + counter += d.size(0) + deg = torch.tensor(deg) + counter = torch.tensor(counter) + dist.all_reduce(deg, op=dist.ReduceOp.SUM) + dist.all_reduce(counter, op=dist.ReduceOp.SUM) + deg = deg.detach().cpu() + counter = counter.detach().cpu() + return deg / counter + + def calculate_PNA_degree_mpi(loader, max_neighbours): assert dist.is_initialized() deg = torch.zeros(max_neighbours + 1, dtype=torch.long) @@ -163,6 +196,21 @@ def calculate_PNA_degree_mpi(loader, max_neighbours): return torch.tensor(deg) +def calculate_avg_deg_mpi(loader): + assert dist.is_initialized() + deg = 0 + counter = 0 + for data in iterate_tqdm(loader, 2, desc="Calculate avg degree"): + d = degree(data.edge_index[1], num_nodes=data.num_nodes, dtype=torch.long) + deg += d.sum() + counter += d.size(0) + from mpi4py import MPI + + deg = MPI.COMM_WORLD.allreduce(deg, op=MPI.SUM) + counter = MPI.COMM_WORLD.allreduce(counter, op=MPI.SUM) + return deg / counter + + def unsorted_segment_mean(data, segment_ids, num_segments): result_shape = (num_segments, data.size(1)) segment_ids = segment_ids.unsqueeze(-1).expand(-1, data.size(1)) @@ -233,7 +281,6 @@ def __init__( self.use_deepspeed = use_deepspeed def __call__(self, model, optimizer, perf_metric): - if (perf_metric > self.min_perf_metric + self.min_delta) or ( self.count < self.warmup ): diff --git a/hydragnn/utils/model/operations.py b/hydragnn/utils/model/operations.py new file mode 100644 index 000000000..5101f70ac --- /dev/null +++ b/hydragnn/utils/model/operations.py @@ -0,0 +1,35 @@ +############################################################################## +# Copyright (c) 2024, Oak Ridge National Laboratory # +# All rights reserved. # +# # +# This file is part of HydraGNN and is distributed under a BSD 3-clause # +# license. For the licensing terms see the LICENSE file in the top-level # +# directory. # +# # +# SPDX-License-Identifier: BSD-3-Clause # +############################################################################## +from typing import List, Tuple +import numpy as np +import torch + + +########################################################################################### +# Function for the computation of edge vectors and lengths (MIT License (see MIT.md)) +# Authors: Ilyes Batatia, Gregor Simm and David Kovacs +########################################################################################### +def get_edge_vectors_and_lengths( + positions: torch.Tensor, # [n_nodes, 3] + edge_index: torch.Tensor, # [2, n_edges] + shifts: torch.Tensor, # [n_edges, 3] + normalize: bool = False, + eps: float = 1e-9, +) -> Tuple[torch.Tensor, torch.Tensor]: + sender = edge_index[0] + receiver = edge_index[1] + vectors = positions[receiver] - positions[sender] + shifts # [n_edges, 3] + lengths = torch.linalg.norm(vectors, dim=-1, keepdim=True) # [n_edges, 1] + if normalize: + vectors_normed = vectors / (lengths + eps) + return vectors_normed, lengths + + return vectors, lengths diff --git a/requirements-torch.txt b/requirements-torch.txt index d0673bd15..7bbb88c33 100644 --- a/requirements-torch.txt +++ b/requirements-torch.txt @@ -1,4 +1,6 @@ torch==2.0.1 torchvision torchaudio - +e3nn==0.5.1 +torch-ema==0.3 +torchmetrics==1.4.0 diff --git a/requirements.txt b/requirements.txt index 2e9bfef74..b8107fd54 100644 --- a/requirements.txt +++ b/requirements.txt @@ -5,3 +5,4 @@ tqdm tensorboard psutil sympy +matscipy diff --git a/tests/inputs/ci.json b/tests/inputs/ci.json index 36613eaa8..14ebd43af 100644 --- a/tests/inputs/ci.json +++ b/tests/inputs/ci.json @@ -28,6 +28,7 @@ "model_type": "PNA", "radius": 2.0, "max_neighbours": 100, + "radial_type": "bessel", "num_gaussians": 50, "envelope_exponent": 5, "int_emb_size": 64, @@ -38,6 +39,8 @@ "num_radial": 6, "num_spherical": 7, "num_filters": 126, + "max_ell": 1, + "node_max_ell": 1, "periodic_boundary_conditions": false, "hidden_dim": 8, "num_conv_layers": 2, diff --git a/tests/inputs/ci_equivariant.json b/tests/inputs/ci_equivariant.json index 1afc93c21..51175e9e7 100644 --- a/tests/inputs/ci_equivariant.json +++ b/tests/inputs/ci_equivariant.json @@ -39,6 +39,8 @@ "num_radial": 6, "num_spherical": 7, "num_filters": 126, + "max_ell": 1, + "node_max_ell": 1, "periodic_boundary_conditions": false, "hidden_dim": 8, "num_conv_layers": 2, diff --git a/tests/inputs/ci_multihead.json b/tests/inputs/ci_multihead.json index 4d0a743e4..f408c4aa8 100644 --- a/tests/inputs/ci_multihead.json +++ b/tests/inputs/ci_multihead.json @@ -36,6 +36,8 @@ "num_radial": 6, "num_spherical": 7, "num_filters": 126, + "max_ell": 1, + "node_max_ell": 1, "periodic_boundary_conditions": false, "hidden_dim": 8, "num_conv_layers": 2, diff --git a/tests/inputs/ci_vectoroutput.json b/tests/inputs/ci_vectoroutput.json index 238e1c793..ddb616615 100644 --- a/tests/inputs/ci_vectoroutput.json +++ b/tests/inputs/ci_vectoroutput.json @@ -28,6 +28,8 @@ "max_neighbours": 100, "envelope_exponent": 5, "num_radial": 6, + "max_ell": 1, + "node_max_ell": 1, "periodic_boundary_conditions": false, "hidden_dim": 8, "num_conv_layers": 2, diff --git a/tests/test_forces_equivariant.py b/tests/test_forces_equivariant.py index 4e9d8bb53..d6df7e20d 100644 --- a/tests/test_forces_equivariant.py +++ b/tests/test_forces_equivariant.py @@ -17,7 +17,7 @@ @pytest.mark.parametrize("example", ["LennardJones"]) @pytest.mark.parametrize( - "model_type", ["SchNet", "EGNN", "DimeNet", "PAINN", "PNAPlus"] + "model_type", ["SchNet", "EGNN", "DimeNet", "PAINN", "PNAPlus", "MACE"] ) @pytest.mark.mpi_skip() def pytest_examples(example, model_type): diff --git a/tests/test_graphs.py b/tests/test_graphs.py index 1e187eb16..177cd11c8 100755 --- a/tests/test_graphs.py +++ b/tests/test_graphs.py @@ -149,6 +149,7 @@ def unittest_train_model( "EGNN": [0.20, 0.20], "PNAEq": [0.60, 0.60], "PAINN": [0.60, 0.60], + "MACE": [0.60, 0.70], } if use_lengths and ("vector" not in ci_input): thresholds["CGCNN"] = [0.175, 0.175] @@ -210,6 +211,7 @@ def unittest_train_model( "EGNN", "PNAEq", "PAINN", + "MACE", ], ) @pytest.mark.parametrize("ci_input", ["ci.json", "ci_multihead.json"]) @@ -218,19 +220,21 @@ def pytest_train_model(model_type, ci_input, overwrite_data=False): # Test only models -@pytest.mark.parametrize("model_type", ["PNA", "PNAPlus", "CGCNN", "SchNet", "EGNN"]) +@pytest.mark.parametrize( + "model_type", ["PNA", "PNAPlus", "CGCNN", "SchNet", "EGNN", "MACE"] +) def pytest_train_model_lengths(model_type, overwrite_data=False): unittest_train_model(model_type, "ci.json", True, overwrite_data) # Test across equivariant models -@pytest.mark.parametrize("model_type", ["EGNN", "SchNet", "PNAEq", "PAINN"]) +@pytest.mark.parametrize("model_type", ["EGNN", "SchNet", "PNAEq", "PAINN", "MACE"]) def pytest_train_equivariant_model(model_type, overwrite_data=False): unittest_train_model(model_type, "ci_equivariant.json", False, overwrite_data) # Test vector output -@pytest.mark.parametrize("model_type", ["PNA", "PNAPlus"]) +@pytest.mark.parametrize("model_type", ["PNA", "PNAPlus", "MACE"]) def pytest_train_model_vectoroutput(model_type, overwrite_data=False): unittest_train_model(model_type, "ci_vectoroutput.json", True, overwrite_data) @@ -253,7 +257,3 @@ def pytest_train_model_vectoroutput(model_type, overwrite_data=False): ) def pytest_train_model_conv_head(model_type, overwrite_data=False): unittest_train_model(model_type, "ci_conv_head.json", False, overwrite_data) - - -def train_model_conv_head(model_type, overwrite_data=False): - unittest_train_model(model_type, "ci_conv_head.json", False, overwrite_data) diff --git a/tests/test_model_loadpred.py b/tests/test_model_loadpred.py index a8d650b43..8b3617959 100755 --- a/tests/test_model_loadpred.py +++ b/tests/test_model_loadpred.py @@ -12,7 +12,8 @@ import torch import random import hydragnn -from .test_graphs import unittest_train_model +from tests.test_graphs import unittest_train_model +from hydragnn.utils.input_config_parsing.config_utils import update_config def unittest_model_prediction(config): @@ -23,6 +24,7 @@ def unittest_model_prediction(config): val_loader, test_loader, ) = hydragnn.preprocess.load_data.dataset_loading_and_splitting(config=config) + config = update_config(config, train_loader, val_loader, test_loader) model = hydragnn.models.create.create_model_config( config=config["NeuralNetwork"], diff --git a/tests/test_radial_transforms.py b/tests/test_radial_transforms.py new file mode 100755 index 000000000..9238288c1 --- /dev/null +++ b/tests/test_radial_transforms.py @@ -0,0 +1,208 @@ +############################################################################## +# Copyright (c) 2024, Oak Ridge National Laboratory # +# All rights reserved. # +# # +# This file is part of HydraGNN and is distributed under a BSD 3-clause # +# license. For the licensing terms see the LICENSE file in the top-level # +# directory. # +# # +# SPDX-License-Identifier: BSD-3-Clause # +############################################################################## + +import sys, os, json +import pytest + +import torch + +torch.manual_seed(97) +import shutil + +import hydragnn, tests +from hydragnn.utils.input_config_parsing.config_utils import merge_config + + +# Main unit test function called by pytest wrappers. +## Adapted from test_graphs.py ... Currently, only the single head model json is tested, although the multihead functionality remains. +def unittest_train_model( + model_type, + radial_type, + distance_transform, + ci_input, + use_lengths=True, + overwrite_data=False, + use_deepspeed=False, + overwrite_config=None, +): + world_size, rank = hydragnn.utils.distributed.get_comm_size_and_rank() + + os.environ["SERIALIZED_DATA_PATH"] = os.getcwd() + + # Read in config settings and override model type. + config_file = os.path.join(os.getcwd(), "tests/inputs", ci_input) + with open(config_file, "r") as f: + config = json.load(f) + config["NeuralNetwork"]["Architecture"]["model_type"] = model_type + config["NeuralNetwork"]["Architecture"]["radial_type"] = radial_type + config["NeuralNetwork"]["Architecture"]["distance_transform"] = distance_transform + + # Overwrite config settings if provided + if overwrite_config: + config = merge_config(config, overwrite_config) + + """ + to test this locally, set ci.json as + "Dataset": { + ... + "path": { + "train": "serialized_dataset/unit_test_singlehead_train.pkl", + "test": "serialized_dataset/unit_test_singlehead_test.pkl", + "validate": "serialized_dataset/unit_test_singlehead_validate.pkl"} + ... + """ + # use pkl files if exist by default + for dataset_name in config["Dataset"]["path"].keys(): + if dataset_name == "total": + pkl_file = ( + os.environ["SERIALIZED_DATA_PATH"] + + "/serialized_dataset/" + + config["Dataset"]["name"] + + ".pkl" + ) + else: + pkl_file = ( + os.environ["SERIALIZED_DATA_PATH"] + + "/serialized_dataset/" + + config["Dataset"]["name"] + + "_" + + dataset_name + + ".pkl" + ) + if os.path.exists(pkl_file): + config["Dataset"]["path"][dataset_name] = pkl_file + + # In the unit test runs, it is found MFC favors graph-level features over node-level features, compared with other models; + # hence here we decrease the loss weight coefficient for graph-level head in MFC. + if model_type == "MFC" and ci_input == "ci_multihead.json": + config["NeuralNetwork"]["Architecture"]["task_weights"][0] = 2 + + # Only run with edge lengths for models that support them. + if use_lengths: + config["NeuralNetwork"]["Architecture"]["edge_features"] = ["lengths"] + + if rank == 0: + num_samples_tot = 500 + # check if serialized pickle files or folders for raw files provided + pkl_input = False + if list(config["Dataset"]["path"].values())[0].endswith(".pkl"): + pkl_input = True + # only generate new datasets, if not pkl + if not pkl_input: + for dataset_name, data_path in config["Dataset"]["path"].items(): + if overwrite_data: + shutil.rmtree(data_path) + if not os.path.exists(data_path): + os.makedirs(data_path) + if dataset_name == "total": + num_samples = num_samples_tot + elif dataset_name == "train": + num_samples = int( + num_samples_tot + * config["NeuralNetwork"]["Training"]["perc_train"] + ) + elif dataset_name == "test": + num_samples = int( + num_samples_tot + * (1 - config["NeuralNetwork"]["Training"]["perc_train"]) + * 0.5 + ) + elif dataset_name == "validate": + num_samples = int( + num_samples_tot + * (1 - config["NeuralNetwork"]["Training"]["perc_train"]) + * 0.5 + ) + if not os.listdir(data_path): + tests.deterministic_graph_data( + data_path, number_configurations=num_samples + ) + + # Run Training + hydragnn.run_training(config, use_deepspeed) + + ( + error, + error_mse_task, + true_values, + predicted_values, + ) = hydragnn.run_prediction(config, use_deepspeed) + + # Set RMSE and sample MAE error thresholds + thresholds = { + "SAGE": [0.20, 0.20], + "PNA": [0.10, 0.10], + "PNAPlus": [0.10, 0.10], + "MFC": [0.20, 0.30], + "GIN": [0.25, 0.20], + "GAT": [0.60, 0.70], + "CGCNN": [0.175, 0.175], + "SchNet": [0.20, 0.20], + "DimeNet": [0.50, 0.50], + "EGNN": [0.20, 0.20], + "MACE": [0.60, 0.70], + } + + verbosity = 2 + + for ihead in range(len(true_values)): + error_head_mse = error_mse_task[ihead] + error_str = ( + str("{:.6f}".format(error_head_mse)) + + " < " + + str(thresholds[model_type][0]) + ) + hydragnn.utils.print.print_distributed(verbosity, "head: " + error_str) + assert ( + error_head_mse < thresholds[model_type][0] + ), "Head RMSE checking failed for " + str(ihead) + + head_true = true_values[ihead] + head_pred = predicted_values[ihead] + # Check individual samples + mae = torch.nn.L1Loss() + sample_mean_abs_error = mae(head_true, head_pred) + error_str = ( + "{:.6f}".format(sample_mean_abs_error) + + " < " + + str(thresholds[model_type][1]) + ) + assert ( + sample_mean_abs_error < thresholds[model_type][1] + ), "MAE sample checking failed!" + + # Check RMSE error + error_str = str("{:.6f}".format(error)) + " < " + str(thresholds[model_type][0]) + hydragnn.utils.print.print_distributed(verbosity, "total: " + error_str) + assert error < thresholds[model_type][0], "Total RMSE checking failed!" + str(error) + + +@pytest.mark.parametrize( + "model_type", + ["MACE"], +) +@pytest.mark.parametrize("basis_function", ["bessel", "gaussian", "chebyshev"]) +@pytest.mark.parametrize("distance_transform", ["None", "Agnesi", "Soft"]) +def pytest_train_model_transforms( + model_type, + basis_function, + distance_transform, + use_lengths=True, + overwrite_data=False, +): + unittest_train_model( + model_type, + basis_function, + distance_transform, + "ci.json", + use_lengths, + overwrite_data, + )