diff --git a/hydragnn/models/Base.py b/hydragnn/models/Base.py index dfbe5f35e..e51a96516 100644 --- a/hydragnn/models/Base.py +++ b/hydragnn/models/Base.py @@ -20,6 +20,7 @@ import sys from hydragnn.utils.distributed import get_device from hydragnn.utils.print.print_utils import print_master +from hydragnn.utils.model.operations import get_edge_vectors_and_lengths import inspect @@ -136,6 +137,10 @@ def _init_conv(self): self.feature_layers.append(BatchNorm(self.hidden_dim)) def _embedding(self, data): + if not hasattr(data, "edge_shifts"): + data.edge_shifts = torch.zeros( + (data.edge_index.size(1), 3), device=data.edge_index.device + ) conv_args = {"edge_index": data.edge_index.to(torch.long)} if self.use_edge_attr: assert ( diff --git a/hydragnn/models/MACEStack.py b/hydragnn/models/MACEStack.py index 66eed27ee..8746bc47c 100644 --- a/hydragnn/models/MACEStack.py +++ b/hydragnn/models/MACEStack.py @@ -32,7 +32,7 @@ # Torch import torch -from torch.nn import ModuleList, Sequential +from torch.nn import ModuleList, Sequential, Linear from torch.utils.checkpoint import checkpoint from torch_scatter import scatter @@ -49,17 +49,21 @@ RadialEmbeddingBlock, RealAgnosticAttResidualInteractionBlock, ) -from hydragnn.utils.model.operations import ( - get_edge_vectors_and_lengths, -) # E3NN from e3nn import nn, o3 from e3nn.util.jit import compile_mode - # HydraGNN from .Base import Base +from hydragnn.utils.model.operations import get_edge_vectors_and_lengths +from hydragnn.utils.model.irreps_tools import create_irreps_string +from hydragnn.utils.model.mace_utils.modules.blocks import ( + CombineBlock, + SplitBlock, + NonLinearMultiheadDecoderBlock, + LinearMultiheadDecoderBlock, +) # Etc import numpy as np @@ -104,8 +108,10 @@ def __init__( ## - I use a hidden_max_ell argument to allow the user to set max ell in the hidden dimensions as well """""" + ############################ Prior to Inheritance ############################ # Init Args ## Passed + self.max_ell = max_ell self.node_max_ell = node_max_ell num_interactions = kwargs["num_conv_layers"] self.edge_dim = edge_dim @@ -113,9 +119,11 @@ def __init__( ## Defined self.interaction_cls = RealAgnosticAttResidualInteractionBlock self.interaction_cls_first = RealAgnosticAttResidualInteractionBlock - atomic_numbers = list(range(1, 119)) # 118 elements in the periodic table + atomic_numbers = list( + range(1, 119) + ) # 118 elements in the periodic table. Simpler to not expose this to the user self.num_elements = len(atomic_numbers) - # Optional + ## Optional num_polynomial_cutoff = ( 5 if num_polynomial_cutoff is None else num_polynomial_cutoff ) @@ -123,18 +131,21 @@ def __init__( radial_type = "bessel" if radial_type is None else radial_type # Making Irreps - self.sh_irreps = o3.Irreps.spherical_harmonics( - max_ell - ) # This makes the irreps string self.edge_feats_irreps = o3.Irreps(f"{num_bessel}x0e") + self.node_attr_irreps = o3.Irreps([(self.num_elements, (0, 1))]) + ############################################################################## + # NOTE the super() call is done at this point because some of the arguments are needed for the initialization of the + # Base class. For example, _init_ calls _init_conv, which requires self.edge_attr_irreps, self.node_attr_irreps, etc. + # Other arguments such as the radial type may be moved before the super() call just for streamlining the code. super().__init__(input_args, conv_args, *args, **kwargs) + ############################ Post Inheritance ############################ self.spherical_harmonics = o3.SphericalHarmonics( self.sh_irreps, normalize=True, - normalization="component", # This makes the spherical harmonic class to be called with forward - ) + normalization="component", + ) # Called to embed the edge_vectors into spherical harmonics # Register buffers are made when parameters need to be saved and transferred with the model, but not trained. self.register_buffer( @@ -159,232 +170,189 @@ def __init__( irreps_in=self.node_attr_irreps, irreps_out=create_irreps_string( self.hidden_dim, 0 - ), # Changed this to hidden_dim because no longer had node_feats_irreps + ), # Going from one-hot to hidden_dim ) + ############################################################################## def _init_conv(self): # Multihead Decoders ## This integrates HYDRA multihead nature with MACE's layer-wise readouts - ## NOTE Norm techniques (feature_layers in HYDRA) are not advised for use in equivariant models as it can break equivariance + ## NOTE Norm techniques (called feature_layers in HYDRA) such as BatchNorm are + ## not advised for use in equivariant models as it can break equivariance self.multihead_decoders = ModuleList() - # attr_irreps for node and edges are created here because we need input_dim, which requires super(base) to be called, which calls _init_conv - self.node_attr_irreps = o3.Irreps([(self.num_elements, (0, 1))]) - # Edge Attributes are by default the spherical harmoncis but should be extended to include HYDRA's edge_attr is desired + + # Making Irreps + ## Edge attributes (must be done here in order to have self.use_edge_attr) + self.sh_irreps = o3.Irreps.spherical_harmonics( + self.max_ell + ) # This makes the irreps string if self.use_edge_attr: self.edge_attrs_irreps = ( o3.Irreps(f"{self.edge_dim}x0e") + self.sh_irreps - ).simplify() # Simplify combines irreps of the same type + ).simplify() # Simplify combines irreps of the same type (e.g., 2x0e + 2x0e = 4x0e) else: self.edge_attrs_irreps = self.sh_irreps + ## Node features after convolution hidden_irreps = o3.Irreps( create_irreps_string(self.hidden_dim, self.node_max_ell) ) final_hidden_irreps = o3.Irreps( create_irreps_string(self.hidden_dim, 0) - ) # Only scalars are outputted in the last layer + ) # Only scalars are output in the last layer last_layer = 1 == self.num_conv_layers + # Decoder before convolutions based on node_attributes self.multihead_decoders.append( - MultiheadDecoderBlock( - self.node_attr_irreps, - self.node_max_ell, - self.config_heads, - self.head_dims, - self.head_type, - self.num_heads, - self.activation_function, - self.num_nodes, - nonlinear=True, + get_multihead_decoder( + nonlinear=last_layer, + input_irreps=self.node_attr_irreps, + config_heads=self.config_heads, + head_dims=self.head_dims, + head_type=self.head_type, + num_heads=self.num_heads, + activation_function=self.activation_function, + num_nodes=self.num_nodes, ) - ) # For base-node traits + ) + + # First Conv and Decoder self.graph_convs.append( - self.get_conv(self.input_dim, self.hidden_dim, first_layer=True) + self.get_conv( + self.hidden_dim, + self.hidden_dim, + first_layer=True, + last_layer=last_layer, + ) # Node features are already converted to hidden_dim via one-hot embedding ) irreps = hidden_irreps if not last_layer else final_hidden_irreps self.multihead_decoders.append( - MultiheadDecoderBlock( - irreps, - self.node_max_ell, - self.config_heads, - self.head_dims, - self.head_type, - self.num_heads, - self.activation_function, - self.num_nodes, + get_multihead_decoder( nonlinear=last_layer, + input_irreps=irreps, + config_heads=self.config_heads, + head_dims=self.head_dims, + head_type=self.head_type, + num_heads=self.num_heads, + activation_function=self.activation_function, + num_nodes=self.num_nodes, ) ) + + # Variable number of convolutions and decoders for i in range(self.num_conv_layers - 1): last_layer = i == self.num_conv_layers - 2 - conv = self.get_conv( - self.hidden_dim, self.hidden_dim, last_layer=last_layer + self.graph_convs.append( + self.get_conv(self.hidden_dim, self.hidden_dim, last_layer=last_layer) ) - self.graph_convs.append(conv) irreps = hidden_irreps if not last_layer else final_hidden_irreps self.multihead_decoders.append( - MultiheadDecoderBlock( - irreps, - self.node_max_ell, - self.config_heads, - self.head_dims, - self.head_type, - self.num_heads, - self.activation_function, - self.num_nodes, + get_multihead_decoder( nonlinear=last_layer, + input_irreps=irreps, + config_heads=self.config_heads, + head_dims=self.head_dims, + head_type=self.head_type, + num_heads=self.num_heads, + activation_function=self.activation_function, + num_nodes=self.num_nodes, ) ) # Last layer will be nonlinear node decoding def get_conv(self, input_dim, output_dim, first_layer=False, last_layer=False): hidden_dim = output_dim if input_dim == 1 else input_dim - # All of these should be constructed with HYDRA dimensional arguments - ## Radial + # NOTE All of these should be constructed with HYDRA dimensional arguments + + # Radial radial_MLP_dim = math.ceil( float(hidden_dim) / 3 ) # Go based off hidden_dim for radial_MLP radial_MLP = [radial_MLP_dim, radial_MLP_dim, radial_MLP_dim] - ## Input, Hidden, and Output irreps sizing (this is usually just hidden in MACE) - ### Input dimensions are handled implicitly - ### Hidden - hidden_irreps = create_irreps_string(hidden_dim, self.node_max_ell) - hidden_irreps = o3.Irreps(hidden_irreps) - node_feats_irreps = o3.Irreps([(hidden_irreps.count(o3.Irrep(0, 1)), (0, 1))]) - num_features = hidden_irreps.count( - o3.Irrep(0, 1) - ) # Multiple copies of spherical harmonics for multiple interactions. They are 'combined' in a certain way during .simplify() ## This makes it a requirement that hidden irreps all have the same number of channels + + # Input, Hidden, and Output irreps sizing (this is usually just hidden in MACE) + ## Input + if first_layer: + node_feats_irreps = o3.Irreps(create_irreps_string(input_dim, 0)) + else: + node_feats_irreps = o3.Irreps( + create_irreps_string(input_dim, self.node_max_ell) + ) + ## Hidden + hidden_irreps = o3.Irreps(create_irreps_string(hidden_dim, self.node_max_ell)) + num_features = hidden_dim # Multiple copies of spherical harmonics for multiple interactions. They are 'combined' during .simplify() ## This makes it a requirement that different irrep types in hidden irreps all have the same number of channels. interaction_irreps = ( (self.sh_irreps * num_features) .sort()[0] .simplify() # Kept as sh_irreps for the output of reshape irreps, whether or not edge_attr irreps are added from HYDRA functionality ) # .sort() is a tuple, so we need the [0] element for the sorted result - ### Output - output_irreps = create_irreps_string(output_dim, self.node_max_ell) - output_irreps = o3.Irreps(output_irreps) + ## Output + output_irreps = o3.Irreps(create_irreps_string(output_dim, self.node_max_ell)) + + # Combine the inv_node_feat and equiv_node_feat into irreps + combine = CombineBlock() - # Constructing convolutional layers + # Scalars output for last layer + if last_layer: + # Convert to irreps here for countability in the splitblock + hidden_irreps = o3.Irreps(str(hidden_irreps[0])) + output_irreps = o3.Irreps(str(output_irreps[0])) + + # Interaction if first_layer: - hidden_irreps_out = hidden_irreps - combine = CombineBlock() - inter = self.interaction_cls_first( - node_attrs_irreps=self.node_attr_irreps, - node_feats_irreps=node_feats_irreps, - edge_attrs_irreps=self.edge_attrs_irreps, - edge_feats_irreps=self.edge_feats_irreps, - target_irreps=interaction_irreps, # Replace with output? - hidden_irreps=hidden_irreps_out, - avg_num_neighbors=self.avg_num_neighbors, - radial_MLP=radial_MLP, - ) - # Use the appropriate self connection at the first layer for proper E0 - use_sc_first = False - if "Residual" in str(self.interaction_cls_first): - use_sc_first = True - prod = EquivariantProductBasisBlock( - node_feats_irreps=interaction_irreps, - target_irreps=hidden_irreps, - correlation=self.correlation[0], - num_elements=self.num_elements, - use_sc=use_sc_first, - ) - sizing = o3.Linear( - hidden_irreps_out, output_irreps - ) # Change sizing to output_irreps - split = SplitBlock(hidden_irreps) - elif last_layer: - # Select only scalars output for last layer - hidden_irreps_out = str(hidden_irreps[0]) - output_irreps = str(output_irreps[0]) - combine = CombineBlock() - inter = self.interaction_cls( - node_attrs_irreps=self.node_attr_irreps, - node_feats_irreps=hidden_irreps, - edge_attrs_irreps=self.edge_attrs_irreps, - edge_feats_irreps=self.edge_feats_irreps, - target_irreps=interaction_irreps, - hidden_irreps=hidden_irreps_out, - avg_num_neighbors=self.avg_num_neighbors, - radial_MLP=radial_MLP, - ) - prod = EquivariantProductBasisBlock( - node_feats_irreps=interaction_irreps, - target_irreps=hidden_irreps_out, - correlation=self.correlation[0], - num_elements=self.num_elements, - use_sc=True, - ) - sizing = o3.Linear( - hidden_irreps_out, output_irreps - ) # Change sizing to output_irreps - split = SplitBlock(hidden_irreps) + interaction_cls = self.interaction_cls_first else: - hidden_irreps_out = hidden_irreps - combine = CombineBlock() - inter = self.interaction_cls( - node_attrs_irreps=self.node_attr_irreps, - node_feats_irreps=hidden_irreps, - edge_attrs_irreps=self.edge_attrs_irreps, - edge_feats_irreps=self.edge_feats_irreps, - target_irreps=interaction_irreps, - hidden_irreps=hidden_irreps_out, - avg_num_neighbors=self.avg_num_neighbors, - radial_MLP=radial_MLP, - ) - prod = EquivariantProductBasisBlock( - node_feats_irreps=interaction_irreps, - target_irreps=hidden_irreps_out, - correlation=self.correlation[0], # Should this be i+1? - num_elements=self.num_elements, - use_sc=True, - ) - sizing = o3.Linear( - hidden_irreps_out, output_irreps - ) # Change sizing to output_irreps - split = SplitBlock(hidden_irreps) - - if not last_layer: - return PyGSequential( - self.input_args, - [ - (combine, "inv_node_feat, equiv_node_feat -> node_features"), - ( - inter, - "node_features, " + self.conv_args + " -> node_features, sc", - ), - (prod, "node_features, sc, node_attributes -> node_features"), - (sizing, "node_features -> node_features"), - ( - lambda node_features, equiv_node_feat: [ - node_features, - equiv_node_feat, - ], - "node_features, equiv_node_feat -> node_features, equiv_node_feat", - ), - (split, "node_features -> inv_node_feat, equiv_node_feat"), - ], - ) + interaction_cls = self.interaction_cls + inter = interaction_cls( + node_attrs_irreps=self.node_attr_irreps, + node_feats_irreps=node_feats_irreps, + edge_attrs_irreps=self.edge_attrs_irreps, + edge_feats_irreps=self.edge_feats_irreps, + target_irreps=interaction_irreps, + hidden_irreps=hidden_irreps, + avg_num_neighbors=self.avg_num_neighbors, + radial_MLP=radial_MLP, + ) + + # Product + if first_layer: + use_sc = "Residual" in str(self.interaction_cls_first) else: - return PyGSequential( - self.input_args, - [ - (combine, "inv_node_feat, equiv_node_feat -> node_features"), - ( - inter, - "node_features, " + self.conv_args + " -> node_features, sc", - ), - (prod, "node_features, sc, node_attributes -> node_features"), - (sizing, "node_features -> node_features"), - ( - lambda node_features, equiv_node_feat: [ - node_features, - equiv_node_feat, - ], - "node_features, equiv_node_feat -> node_features, equiv_node_feat", - ), - (split, "node_features -> inv_node_feat, equiv_node_feat"), - ], - ) + use_sc = True # True for non-first layers + prod = EquivariantProductBasisBlock( + node_feats_irreps=interaction_irreps, + target_irreps=hidden_irreps, + correlation=self.correlation[0], # Currently uniform across all layers + num_elements=self.num_elements, + use_sc=use_sc, + ) + + # Post-processing for conv layer output size + sizing = o3.Linear(hidden_irreps, output_irreps) + + # Split irreps into inv_node_feat and equiv_node_feat + split = SplitBlock(output_irreps) + + return PyGSequential( + self.input_args, + [ + (combine, "inv_node_feat, equiv_node_feat -> node_features"), + ( + inter, + "node_features, " + self.conv_args + " -> node_features, sc", + ), + (prod, "node_features, sc, node_attributes -> node_features"), + (sizing, "node_features -> node_features"), + ( + lambda node_features, equiv_node_feat: [ + node_features, + equiv_node_feat, + ], + "node_features, equiv_node_feat -> node_features, equiv_node_feat", + ), + (split, "node_features -> inv_node_feat, equiv_node_feat"), + ], + ) # NOTE An if/else is not needed here because MACE's interaction layers already contract down to purely scalars in the last layer def forward(self, data): inv_node_feat, equiv_node_feat, conv_args = self._embedding(data) @@ -424,6 +392,8 @@ def forward(self, data): return outputs def _embedding(self, data): + super()._embedding(data) + assert ( data.pos is not None ), "MACE requires node positions (data.pos) to be set." @@ -432,29 +402,30 @@ def _embedding(self, data): # initialize the spherical harmonics, since the initial spherical harmonic projection # uses the nodal position vector x/||x|| as the input to the spherical harmonics. # If we didn't center at 0, these models wouldn't even be invariant to translation. - mean_pos = scatter(data.pos, data.batch, dim=0, reduce="mean") - data.pos = data.pos - mean_pos[data.batch] + if data.batch is None: + mean_pos = data.pos.mean(dim=0, keepdim=True) + data.pos = data.pos - mean_pos + else: + mean_pos = scatter(data.pos, data.batch, dim=0, reduce="mean") + data.pos = data.pos - mean_pos[data.batch] + + # Get edge vectors and distances + edge_vec, edge_dist = get_edge_vectors_and_lengths( + data.pos, data.edge_index, data.edge_shifts + ) # Create node_attrs from atomic numbers. Later on it may contain more information ## Node attrs are intrinsic properties of the atoms. Currently, MACE only supports atomic number node attributes ## data.node_attrs is already used in another place, so has been renamed to data.node_attributes from MACE and same with other data variable names data.node_attributes = process_node_attributes(data["x"], self.num_elements) - data.shifts = torch.zeros( - (data.edge_index.shape[1], 3), dtype=data.pos.dtype, device=data.pos.device - ) # Shifts takes into account pbc conditions, but I believe we already generate data.pos to take it into account # Embeddings node_feats = self.node_embedding(data["node_attributes"]) - vectors, lengths = get_edge_vectors_and_lengths( - positions=data["pos"], - edge_index=data["edge_index"], - shifts=data["shifts"], - ) - edge_attributes = self.spherical_harmonics(vectors) + edge_attributes = self.spherical_harmonics(edge_vec) if self.use_edge_attr: edge_attributes = torch.cat([data.edge_attr, edge_attributes], dim=1) edge_features = self.radial_embedding( - lengths, data["node_attributes"], data["edge_index"], self.atomic_numbers + edge_dist, data["node_attributes"], data["edge_index"], self.atomic_numbers ) # Variable names @@ -478,20 +449,13 @@ def _embedding(self, data): def _multihead(self): # NOTE Multihead is skipped as it's an integral part of MACE's architecture to have a decoder after every layer, # and a convolutional layer in decoding is not supported. Therefore, this final step is not necessary for MACE. - # However, various parts of multihead are applied in the MultiheadLinearBlock and MultiheadNonLinearBlock classes. + # However, various parts of multihead are applied in the LinearMultiheadBlock and NonLinearMultiheadBlock classes. pass def __str__(self): return "MACEStack" -def create_irreps_string( - n: int, ell: int -): # Custom function to allow for use of HYDRA arguments in creating irreps - irreps = [f"{n}x{ell}{'e' if ell % 2 == 0 else 'o'}" for ell in range(ell + 1)] - return " + ".join(irreps) - - def process_node_attributes(node_attributes, num_elements): # Check that node attributes are atomic numbers and process accordingly node_attributes = node_attributes.squeeze() # Squeeze all unnecessary dimensions @@ -526,256 +490,33 @@ def process_node_attributes(node_attributes, num_elements): return one_hot -@compile_mode("script") -class CombineBlock(torch.nn.Module): - def __init__(self): - super(CombineBlock, self).__init__() - - def forward(self, inv_node_features, equiv_node_features): - return torch.cat([inv_node_features, equiv_node_features], dim=1) - - -@compile_mode("script") -class SplitBlock(torch.nn.Module): - def __init__(self, irreps): - super(SplitBlock, self).__init__() - self.dim = irreps.count(o3.Irrep(0, 1)) - - def forward(self, node_features): - return node_features[:, : self.dim], node_features[:, self.dim :] - - -@compile_mode("script") -class MultiheadDecoderBlock(torch.nn.Module): - def __init__( - self, - input_irreps, - node_max_ell, - config_heads, - head_dims, - head_type, - num_heads, - activation_function, - num_nodes, - nonlinear=False, - ): - super(MultiheadDecoderBlock, self).__init__() - self.input_irreps = input_irreps - self.node_max_ell = node_max_ell if not nonlinear else 0 - self.config_heads = config_heads - self.head_dims = head_dims - self.head_type = head_type - self.num_heads = num_heads - self.activation_function = activation_function - self.num_nodes = num_nodes - - self.graph_shared = None - self.node_NN_type = None - self.heads = ModuleList() - - # Create shared dense layers for graph-level output if applicable - if "graph" in self.config_heads: - graph_input_irreps = o3.Irreps( - f"{self.input_irreps.count(o3.Irrep(0, 1))}x0e" - ) - dim_sharedlayers = self.config_heads["graph"]["dim_sharedlayers"] - sharedlayers_irreps = o3.Irreps(f"{dim_sharedlayers}x0e") - denselayers = [] - denselayers.append(o3.Linear(graph_input_irreps, sharedlayers_irreps)) - denselayers.append( - nn.Activation( - irreps_in=sharedlayers_irreps, acts=[self.activation_function] - ) - ) - for _ in range(self.config_heads["graph"]["num_sharedlayers"] - 1): - denselayers.append(o3.Linear(sharedlayers_irreps, sharedlayers_irreps)) - denselayers.append( - nn.Activation( - irreps_in=sharedlayers_irreps, acts=[self.activation_function] - ) - ) - self.graph_shared = Sequential(*denselayers) - - # Create layers for each head - for ihead in range(self.num_heads): - if self.head_type[ihead] == "graph": - num_layers_graph = self.config_heads["graph"]["num_headlayers"] - hidden_dim_graph = self.config_heads["graph"]["dim_headlayers"] - denselayers = [] - head_hidden_irreps = o3.Irreps(f"{hidden_dim_graph[0]}x0e") - denselayers.append(o3.Linear(sharedlayers_irreps, head_hidden_irreps)) - denselayers.append( - nn.Activation( - irreps_in=head_hidden_irreps, acts=[self.activation_function] - ) - ) - for ilayer in range(num_layers_graph - 1): - input_irreps = o3.Irreps(f"{hidden_dim_graph[ilayer]}x0e") - output_irreps = o3.Irreps(f"{hidden_dim_graph[ilayer + 1]}x0e") - denselayers.append(o3.Linear(input_irreps, output_irreps)) - denselayers.append( - nn.Activation( - irreps_in=output_irreps, acts=[self.activation_function] - ) - ) - input_irreps = o3.Irreps(f"{hidden_dim_graph[-1]}x0e") - output_irreps = o3.Irreps(f"{self.head_dims[ihead]}x0e") - denselayers.append(o3.Linear(input_irreps, output_irreps)) - self.heads.append(Sequential(*denselayers)) - elif self.head_type[ihead] == "node": - self.node_NN_type = self.config_heads["node"]["type"] - head = ModuleList() - if self.node_NN_type == "mlp" or self.node_NN_type == "mlp_per_node": - self.num_mlp = 1 if self.node_NN_type == "mlp" else self.num_nodes - assert ( - self.num_nodes is not None - ), "num_nodes must be a positive integer for MLP" - num_layers_node = self.config_heads["node"]["num_headlayers"] - hidden_dim_node = self.config_heads["node"]["dim_headlayers"] - head = MLPNode( - self.input_irreps, - self.node_max_ell, - self.config_heads, - num_layers_node, - hidden_dim_node, - self.head_dims[ihead], - self.num_mlp, - self.num_nodes, - self.config_heads["node"]["type"], - self.activation_function, - nonlinear=nonlinear, - ) - self.heads.append(head) - else: - raise ValueError( - f"Unknown head NN structure for node features: {self.node_NN_type}" - ) - else: - raise ValueError( - f"Unknown head type: {self.head_type[ihead]}; supported types are 'graph' or 'node'" - ) - - def forward(self, data, node_features): - if data.batch is None: - graph_features = node_features[:, : self.hidden_dim].mean( - dim=0, keepdim=True - ) # Need to take only the type-0 irreps for aggregation - else: - graph_features = global_mean_pool( - node_features[:, : self.input_irreps.count(o3.Irrep(0, 1))], - data.batch.to(node_features.device), - ) - outputs = [] - for headloc, type_head in zip(self.heads, self.head_type): - if type_head == "graph": - x_graph_head = self.graph_shared(graph_features) - outputs.append(headloc(x_graph_head)) - else: # Node-level output - if self.node_NN_type == "conv": - raise ValueError( - "Node-level convolutional layers are not supported in MACE" - ) - else: - x_node = headloc(node_features, data.batch) - outputs.append(x_node) - return outputs - - -@compile_mode("script") -class MLPNode(torch.nn.Module): - def __init__( - self, - input_irreps, - node_max_ell, - config_heads, - num_layers, - hidden_dims, - output_dim, - num_mlp, - num_nodes, - node_type, - activation_function, - nonlinear=False, - ): - super().__init__() - self.input_irreps = input_irreps - self.hidden_dims = hidden_dims - self.output_dim = output_dim - self.node_max_ell = node_max_ell if not nonlinear else 0 - self.config_heads = config_heads - self.num_layers = num_layers - self.node_type = node_type - self.num_mlp = num_mlp - self.num_nodes = num_nodes - self.activation_function = activation_function - - self.mlp = ModuleList() - - # Create dense layers for each MLP based on node_type ("mlp" or "mlp_per_node") - for _ in range(self.num_mlp): - denselayers = [] - - # Input and hidden irreps for each MLP layer - input_irreps = input_irreps - hidden_irreps = o3.Irreps(f"{hidden_dims[0]}x0e") # Hidden irreps - - denselayers.append(o3.Linear(input_irreps, hidden_irreps)) - denselayers.append( - nn.Activation(irreps_in=hidden_irreps, acts=[self.activation_function]) - ) - - # Add intermediate layers - for ilayer in range(self.num_layers - 1): - input_irreps = o3.Irreps(f"{hidden_dims[ilayer]}x0e") - hidden_irreps = o3.Irreps(f"{hidden_dims[ilayer + 1]}x0e") - denselayers.append(o3.Linear(input_irreps, hidden_irreps)) - denselayers.append( - nn.Activation( - irreps_in=hidden_irreps, acts=[self.activation_function] - ) - ) - - # Last layer - hidden_irreps = o3.Irreps(f"{hidden_dims[-1]}x0e") - output_irreps = o3.Irreps( - f"{self.output_dim}x0e" - ) # Assuming head_dims has been passed for the final output - denselayers.append(o3.Linear(hidden_irreps, output_irreps)) - - # Append to MLP - self.mlp.append(Sequential(*denselayers)) - - def node_features_reshape(self, node_features, batch): - """Reshape node_features from [batch_size*num_nodes, num_features] to [batch_size, num_features, num_nodes]""" - num_features = node_features.shape[1] - batch_size = batch.max() + 1 - out = torch.zeros( - (batch_size, num_features, self.num_nodes), - dtype=node_features.dtype, - device=node_features.device, +def get_multihead_decoder( + nonlinear: bool, + input_irreps, + config_heads, + head_dims, + head_type, + num_heads, + activation_function, + num_nodes, +): + if nonlinear: + return NonLinearMultiheadDecoderBlock( + input_irreps, + config_heads, + head_dims, + head_type, + num_heads, + activation_function, + num_nodes, + ) + else: + return LinearMultiheadDecoderBlock( + input_irreps, + config_heads, + head_dims, + head_type, + num_heads, + activation_function, + num_nodes, ) - for inode in range(self.num_nodes): - inode_index = [i for i in range(inode, batch.shape[0], self.num_nodes)] - out[:, :, inode] = node_features[inode_index, :] - return out - - def forward(self, node_features: torch.Tensor, batch: torch.Tensor): - if self.node_type == "mlp": - outs = self.mlp[0](node_features) - else: - outs = torch.zeros( - ( - node_features.shape[0], - self.head_dims[0], - ), # Assuming `head_dims` defines the final output dimension - dtype=node_features.dtype, - device=node_features.device, - ) - x_nodes = self.node_features_reshape(x, batch) - for inode in range(self.num_nodes): - inode_index = [i for i in range(inode, batch.shape[0], self.num_nodes)] - outs[inode_index, :] = self.mlp[inode](x_nodes[:, :, inode]) - return outs - - def __str__(self): - return "MLPNode" diff --git a/hydragnn/utils/model/irreps_tools.py b/hydragnn/utils/model/irreps_tools.py index 71d74e6ff..0fcc633cf 100644 --- a/hydragnn/utils/model/irreps_tools.py +++ b/hydragnn/utils/model/irreps_tools.py @@ -100,3 +100,10 @@ def extract_invariant(x: torch.Tensor, num_layers: int, num_features: int, l_max ) out.append(x[:, -num_features:]) return torch.cat(out, dim=-1) + + +# Added by HydraGNN group: +def create_irreps_string(n: int, ell: int): # Custom function to create irreps easily + # By choice and somewhat regular convention, the default parity is for all even ell values and odd for all odd ell values + irreps = [f"{n}x{ell}{'e' if ell % 2 == 0 else 'o'}" for ell in range(ell + 1)] + return " + ".join(irreps) diff --git a/hydragnn/utils/model/mace_utils/modules/blocks.py b/hydragnn/utils/model/mace_utils/modules/blocks.py index 27574828d..68178832e 100644 --- a/hydragnn/utils/model/mace_utils/modules/blocks.py +++ b/hydragnn/utils/model/mace_utils/modules/blocks.py @@ -15,7 +15,9 @@ import numpy as np import torch import torch.nn.functional +from torch.nn import ModuleList, Sequential, Linear from torch_scatter import scatter +from torch_geometric.nn import global_mean_pool from e3nn import nn, o3 from e3nn.util.jit import compile_mode @@ -23,6 +25,7 @@ from hydragnn.utils.model.irreps_tools import ( reshape_irreps, tp_out_irreps_with_instructions, + create_irreps_string, ) from .radial import ( @@ -402,3 +405,404 @@ def __repr__(self): return ( f"{self.__class__.__name__}(scale={self.scale:.6f}, shift={self.shift:.6f})" ) + + +########################################################################################### +# The following block are created by the HydraGNN team, mainly for compatibility between +# HydraGNN's Multihead architecture and MACE's Decoding architecture. +########################################################################################### + + +@compile_mode("script") +class LinearMultiheadDecoderBlock(torch.nn.Module): + def __init__( + self, + input_irreps, + config_heads, + head_dims, + head_type, + num_heads, + activation_function, + num_nodes, + ): + # NOTE The readouts of MACE take in irreps of higher order than just scalars. This is fed through o3.Linear + # to reduce to scalars. To implement this in HYDRAGNN, the first layer of the node output head + # will be such a layer, then all further layers will operate on scalars. Graph-level output heads, on + # the other hand, will always operate on the scalar part of the irreps, because pooling may break + # equivariance. (To-Do: Check for equivariant pooling methods) + + # NOTE It's a key point of the MACE architecture that all decoders before the last layer are linear. In order + # to avoid numerical instability from many stacked linear layers without activation, the MultiheadDecoderBlock + # class will be split into linear and nonlinear versions. The nonlinear version stacks layers in the same way + # that HYDRAGNN normally would, but the linear version ignores many parameters to have only one layer. + + super(LinearMultiheadDecoderBlock, self).__init__() + self.input_irreps = input_irreps + self.config_heads = config_heads + self.head_dims = head_dims + self.head_type = head_type + self.num_heads = num_heads + self.activation_function = activation_function + self.num_nodes = num_nodes + + self.graph_shared = None + self.node_NN_type = None + self.heads_NN = ModuleList() + + self.input_scalar_dim = input_irreps.count(o3.Irrep(0, 1)) + + for ihead in range(self.num_heads): + # mlp for each head output + if self.head_type[ihead] == "graph": + denselayers = [] + denselayers.append( + Linear( + self.input_scalar_dim, + self.head_dims[ihead], + ) + ) + head_NN = Sequential(*denselayers) + elif self.head_type[ihead] == "node": + self.node_NN_type = self.config_heads["node"]["type"] + head_NN = ModuleList() + if self.node_NN_type == "mlp" or self.node_NN_type == "mlp_per_node": + self.num_mlp = 1 if self.node_NN_type == "mlp" else self.num_nodes + assert ( + self.num_nodes is not None + ), "num_nodes must be positive integer for MLP" + # """if different graphs in the datasets have different size, one MLP is shared across all nodes """ + head_NN = LinearMLPNode( + input_irreps, + self.head_dims[ihead], + self.num_mlp, + self.config_heads["node"]["type"], + self.activation_function, + self.num_nodes, + ) + elif self.node_NN_type == "conv": + raise ValueError( + "Node-level convolutional layers are not supported in MACE" + ) + else: + raise ValueError( + "Unknown head NN structure for node features" + + self.node_NN_type + + "; currently only support 'mlp', 'mlp_per_node' or 'conv' (can be set with config['NeuralNetwork']['Architecture']['output_heads']['node']['type'], e.g., ./examples/ci_multihead.json)" + ) + else: + raise ValueError( + "Unknown head type" + + self.head_type[ihead] + + "; currently only support 'graph' or 'node'" + ) + self.heads_NN.append(head_NN) + + def forward(self, data, node_features): + # Take only the type-0 irreps for graph aggregation + if data.batch is None: + graph_features = node_features[:, : self.input_scalar_dim].mean( + dim=0, keepdim=True + ) + else: + graph_features = global_mean_pool( + node_features[:, : self.input_scalar_dim], + data.batch.to(node_features.device), + ) + outputs = [] + for headloc, type_head in zip(self.heads_NN, self.head_type): + if type_head == "graph": + outputs.append(headloc(graph_features)) + else: # Node-level output + if self.node_NN_type == "conv": + raise ValueError( + "Node-level convolutional layers are not supported in MACE" + ) + else: + x_node = headloc(node_features, data.batch) + outputs.append(x_node) + return outputs + + +@compile_mode("script") +class NonLinearMultiheadDecoderBlock(torch.nn.Module): + def __init__( + self, + input_irreps, + config_heads, + head_dims, + head_type, + num_heads, + activation_function, + num_nodes, + ): + # NOTE The readouts of MACE take in irreps of higher order than just scalars. This is fed through o3.Linear + # to reduce to scalars. To implement this in HYDRAGNN, the first layer of the node output head + # will be such a layer, then all further layers will operate on scalars. Graph-level output heads, on + # the other hand, will always operate on the scalar part of the irreps, because pooling may break + # equivariance. (To-Do: Check for equivariant pooling methods) + + # NOTE It's a key point of the MACE architecture that all decoders before the last layer are linear. In order + # to avoid numerical instability from many stacked linear layers without activation, the MultiheadDecoderBlock + # class will be split into linear and nonlinear versions. The nonlinear version stacks layers in the same way + # that HYDRAGNN normally would, but the linear version ignores many parameters to have only one layer. + + super(NonLinearMultiheadDecoderBlock, self).__init__() + self.input_irreps = input_irreps + self.config_heads = config_heads + self.head_dims = head_dims + self.head_type = head_type + self.num_heads = num_heads + self.activation_function = activation_function + self.num_nodes = num_nodes + + self.graph_shared = None + self.node_NN_type = None + self.heads_NN = ModuleList() + + self.input_scalar_dim = input_irreps.count(o3.Irrep(0, 1)) + + # Create shared dense layers for graph-level output if applicable + if "graph" in self.config_heads: + denselayers = [] + dim_sharedlayers = self.config_heads["graph"]["dim_sharedlayers"] + denselayers.append( + Linear(self.input_scalar_dim, dim_sharedlayers) + ) # Count scalar irreps for input + denselayers.append(self.activation_function) + for ishare in range(self.config_heads["graph"]["num_sharedlayers"] - 1): + denselayers.append(Linear(dim_sharedlayers, dim_sharedlayers)) + denselayers.append(self.activation_function) + self.graph_shared = Sequential(*denselayers) + + for ihead in range(self.num_heads): + # mlp for each head output + if self.head_type[ihead] == "graph": + num_head_hidden = self.config_heads["graph"]["num_headlayers"] + dim_head_hidden = self.config_heads["graph"]["dim_headlayers"] + denselayers = [] + denselayers.append(Linear(dim_sharedlayers, dim_head_hidden[0])) + denselayers.append(self.activation_function) + for ilayer in range(num_head_hidden - 1): + denselayers.append( + Linear(dim_head_hidden[ilayer], dim_head_hidden[ilayer + 1]) + ) + denselayers.append(self.activation_function) + denselayers.append( + Linear( + dim_head_hidden[-1], + self.head_dims[ihead], + ) + ) + head_NN = Sequential(*denselayers) + elif self.head_type[ihead] == "node": + self.node_NN_type = self.config_heads["node"]["type"] + head_NN = ModuleList() + if self.node_NN_type == "mlp" or self.node_NN_type == "mlp_per_node": + self.num_mlp = 1 if self.node_NN_type == "mlp" else self.num_nodes + assert ( + self.num_nodes is not None + ), "num_nodes must be positive integer for MLP" + # """if different graphs in the datasets have different size, one MLP is shared across all nodes """ + hidden_dim_node = self.config_heads["node"]["dim_headlayers"] + head_NN = NonLinearMLPNode( + input_irreps, + self.head_dims[ihead], + self.num_mlp, + hidden_dim_node, + self.config_heads["node"]["type"], + self.activation_function, + self.num_nodes, + ) + elif self.node_NN_type == "conv": + raise ValueError( + "Node-level convolutional layers are not supported in MACE" + ) + else: + raise ValueError( + "Unknown head NN structure for node features" + + self.node_NN_type + + "; currently only support 'mlp', 'mlp_per_node' or 'conv' (can be set with config['NeuralNetwork']['Architecture']['output_heads']['node']['type'], e.g., ./examples/ci_multihead.json)" + ) + else: + raise ValueError( + "Unknown head type" + + self.head_type[ihead] + + "; currently only support 'graph' or 'node'" + ) + self.heads_NN.append(head_NN) + + def forward(self, data, node_features): + # Take only the type-0 irreps for graph aggregation + if data.batch is None: + graph_features = node_features[:, : self.input_scalar_dim].mean( + dim=0, keepdim=True + ) + else: + graph_features = global_mean_pool( + node_features[:, : self.input_scalar_dim], + data.batch.to(node_features.device), + ) + outputs = [] + for headloc, type_head in zip(self.heads_NN, self.head_type): + if type_head == "graph": + x_graph_head = self.graph_shared(graph_features) + outputs.append(headloc(x_graph_head)) + else: # Node-level output + if self.node_NN_type == "conv": + raise ValueError( + "Node-level convolutional layers are not supported in MACE" + ) + else: + x_node = headloc(node_features, data.batch) + outputs.append(x_node) + return outputs + + +@compile_mode("script") +class LinearMLPNode(torch.nn.Module): + def __init__( + self, + input_irreps, + output_dim, + # No longer need hidden_dim_node because there is only one layer + num_mlp, + node_type, + activation_function, + num_nodes, + ): + super().__init__() + self.input_irreps = input_irreps + self.output_dim = output_dim + self.num_mlp = num_mlp + self.node_type = node_type + self.activation_function = activation_function + self.num_nodes = num_nodes + + self.mlp = ModuleList() + for _ in range(self.num_mlp): + denselayers = [] + output_irreps = o3.Irreps(create_irreps_string(output_dim, 0)) + denselayers.append( + o3.Linear(input_irreps, output_irreps) + ) # First layer is o3.Linear and takes all irreps down to scalars + self.mlp.append(Sequential(*denselayers)) + + def node_features_reshape(self, x, batch): + """reshape x from [batch_size*num_nodes, num_features] to [batch_size, num_features, num_nodes]""" + num_features = x.shape[1] + batch_size = batch.max() + 1 + out = torch.zeros( + (batch_size, num_features, self.num_nodes), + dtype=x.dtype, + device=x.device, + ) + for inode in range(self.num_nodes): + inode_index = [i for i in range(inode, batch.shape[0], self.num_nodes)] + out[:, :, inode] = x[inode_index, :] + return out + + def forward(self, x: torch.Tensor, batch: torch.Tensor): + if self.node_type == "mlp": + outs = self.mlp[0](x) + else: + outs = torch.zeros( + (x.shape[0], self.output_dim), + dtype=x.dtype, + device=x.device, + ) + x_nodes = self.node_features_reshape(x, batch) + for inode in range(self.num_nodes): + inode_index = [i for i in range(inode, batch.shape[0], self.num_nodes)] + outs[inode_index, :] = self.mlp[inode](x_nodes[:, :, inode]) + return outs + + def __str__(self): + return "MLPNode" + + +@compile_mode("script") +class NonLinearMLPNode(torch.nn.Module): + def __init__( + self, + input_irreps, + output_dim, + num_mlp, + hidden_dim_node, + node_type, + activation_function, + num_nodes, + ): + super().__init__() + self.input_irreps = input_irreps + self.output_dim = output_dim + self.num_mlp = num_mlp + self.node_type = node_type + self.activation_function = activation_function + self.num_nodes = num_nodes + + self.mlp = ModuleList() + for _ in range(self.num_mlp): + denselayers = [] + hidden_irreps = o3.Irreps(create_irreps_string(hidden_dim_node[0], 0)) + denselayers.append( + o3.Linear(input_irreps, hidden_irreps) + ) # First layer is o3.Linear and takes all irreps down to scalars + denselayers.append(self.activation_function) + for ilayer in range(len(hidden_dim_node) - 1): + denselayers.append( + Linear(hidden_dim_node[ilayer], hidden_dim_node[ilayer + 1]) + ) + denselayers.append(self.activation_function) + denselayers.append(Linear(hidden_dim_node[-1], output_dim)) + self.mlp.append(Sequential(*denselayers)) + + def node_features_reshape(self, x, batch): + """reshape x from [batch_size*num_nodes, num_features] to [batch_size, num_features, num_nodes]""" + num_features = x.shape[1] + batch_size = batch.max() + 1 + out = torch.zeros( + (batch_size, num_features, self.num_nodes), + dtype=x.dtype, + device=x.device, + ) + for inode in range(self.num_nodes): + inode_index = [i for i in range(inode, batch.shape[0], self.num_nodes)] + out[:, :, inode] = x[inode_index, :] + return out + + def forward(self, x: torch.Tensor, batch: torch.Tensor): + if self.node_type == "mlp": + outs = self.mlp[0](x) + else: + outs = torch.zeros( + (x.shape[0], self.output_dim), + dtype=x.dtype, + device=x.device, + ) + x_nodes = self.node_features_reshape(x, batch) + for inode in range(self.num_nodes): + inode_index = [i for i in range(inode, batch.shape[0], self.num_nodes)] + outs[inode_index, :] = self.mlp[inode](x_nodes[:, :, inode]) + return outs + + def __str__(self): + return "MLPNode" + + +@compile_mode("script") +class CombineBlock(torch.nn.Module): + def __init__(self): + super(CombineBlock, self).__init__() + + def forward(self, inv_node_features, equiv_node_features): + return torch.cat([inv_node_features, equiv_node_features], dim=1) + + +@compile_mode("script") +class SplitBlock(torch.nn.Module): + def __init__(self, irreps): + super(SplitBlock, self).__init__() + self.dim = irreps.count(o3.Irrep(0, 1)) + + def forward(self, node_features): + return node_features[:, : self.dim], node_features[:, self.dim :]