diff --git a/hydragnn/models/Base.py b/hydragnn/models/Base.py index 186d0209..dfbe5f35 100644 --- a/hydragnn/models/Base.py +++ b/hydragnn/models/Base.py @@ -27,6 +27,8 @@ class Base(Module): def __init__( self, + input_args: str, + conv_args: str, input_dim: int, hidden_dim: int, output_dim: list, @@ -46,6 +48,8 @@ def __init__( ): super().__init__() self.device = get_device() + self.input_args = input_args + self.conv_args = conv_args self.input_dim = input_dim self.hidden_dim = hidden_dim self.dropout = dropout @@ -104,6 +108,10 @@ def __init__( and self.edge_dim > 0 ): self.use_edge_attr = True + if "edge_attr" not in self.input_args: + self.input_args += ", edge_attr" + if "edge_attr" not in self.conv_args: + self.conv_args += ", edge_attr" # Option to only train final property layers. self.freeze_conv = freeze_conv @@ -127,14 +135,14 @@ def _init_conv(self): self.graph_convs.append(conv) self.feature_layers.append(BatchNorm(self.hidden_dim)) - def _conv_args(self, data): + def _embedding(self, data): conv_args = {"edge_index": data.edge_index.to(torch.long)} if self.use_edge_attr: assert ( data.edge_attr is not None ), "Data must have edge attributes if use_edge_attributes is set." conv_args.update({"edge_attr": data.edge_attr}) - return conv_args + return data.x, data.pos, conv_args def _freeze_conv(self): for module in [self.graph_convs, self.feature_layers]: @@ -301,19 +309,27 @@ def enable_conv_checkpointing(self): self.conv_checkpointing = True def forward(self, data): - x = data.x - pos = data.pos - ### encoder part #### - conv_args = self._conv_args(data) + inv_node_feat, equiv_node_feat, conv_args = self._embedding(data) + for conv, feat_layer in zip(self.graph_convs, self.feature_layers): if not self.conv_checkpointing: - c, pos = conv(x=x, pos=pos, **conv_args) + inv_node_feat, equiv_node_feat = conv( + inv_node_feat=inv_node_feat, + equiv_node_feat=equiv_node_feat, + **conv_args + ) else: - c, pos = checkpoint( - conv, use_reentrant=False, x=x, pos=pos, **conv_args + inv_node_feat, equiv_node_feat = checkpoint( + conv, + use_reentrant=False, + inv_node_feat=inv_node_feat, + equiv_node_feat=equiv_node_feat, + **conv_args ) - x = self.activation_function(feat_layer(c)) + inv_node_feat = self.activation_function(feat_layer(inv_node_feat)) + + x = inv_node_feat #### multi-head decoder part#### # shared dense layers for graph level output @@ -333,11 +349,17 @@ def forward(self, data): outputs_var.append(output_head[:, head_dim:] ** 2) else: if self.node_NN_type == "conv": + inv_node_feat = x for conv, batch_norm in zip(headloc[0::2], headloc[1::2]): - c, pos = conv(x=x, pos=pos, **conv_args) - c = batch_norm(c) - x = self.activation_function(c) - x_node = x + inv_node_feat, equiv_node_feat = conv( + inv_node_feat=inv_node_feat, + equiv_node_feat=equiv_node_feat, + **conv_args + ) + inv_node_feat = batch_norm(inv_node_feat) + inv_node_feat = self.activation_function(inv_node_feat) + x_node = inv_node_feat + x = inv_node_feat else: x_node = headloc(x=x, batch=data.batch) outputs.append(x_node[:, :head_dim]) diff --git a/hydragnn/models/CGCNNStack.py b/hydragnn/models/CGCNNStack.py index 69e08f21..35f3d106 100644 --- a/hydragnn/models/CGCNNStack.py +++ b/hydragnn/models/CGCNNStack.py @@ -19,6 +19,8 @@ class CGCNNStack(Base): def __init__( self, + input_args, + conv_args, edge_dim: int, input_dim, output_dim, @@ -32,6 +34,8 @@ def __init__( # also as hidden dimension (second argument of base constructor) # We therefore pass all required args explicitly. super().__init__( + input_args, + conv_args, input_dim, input_dim, output_dim, @@ -39,6 +43,16 @@ def __init__( **kwargs, ) + if self.use_edge_attr: + assert ( + self.input_args + == "inv_node_feat, equiv_node_feat, edge_index, edge_attr" + ) + assert self.conv_args == "inv_node_feat, edge_index, edge_attr" + else: + assert self.input_args == "inv_node_feat, equiv_node_feat, edge_index" + assert self.conv_args == "inv_node_feat, edge_index" + def get_conv(self, input_dim, _): cgcnn = CGConv( channels=input_dim, @@ -48,18 +62,17 @@ def get_conv(self, input_dim, _): bias=True, ) - input_args = "x, pos, edge_index" - conv_args = "x, edge_index" - - if self.use_edge_attr: - input_args += ", edge_attr" - conv_args += ", edge_attr" - return Sequential( - input_args, + self.input_args, [ - (cgcnn, conv_args + " -> x"), - (lambda x, pos: [x, pos], "x, pos -> x, pos"), + (cgcnn, self.conv_args + " -> inv_node_feat"), + ( + lambda inv_node_feat, equiv_node_feat: [ + inv_node_feat, + equiv_node_feat, + ], + "inv_node_feat, equiv_node_feat -> inv_node_feat, equiv_node_feat", + ), ], ) diff --git a/hydragnn/models/DIMEStack.py b/hydragnn/models/DIMEStack.py index ab7f9edb..715dd565 100644 --- a/hydragnn/models/DIMEStack.py +++ b/hydragnn/models/DIMEStack.py @@ -36,6 +36,8 @@ class DIMEStack(Base): def __init__( self, + input_args, + conv_args, basis_emb_size, envelope_exponent, int_emb_size, @@ -60,7 +62,7 @@ def __init__( self.edge_dim = edge_dim self.radius = radius - super().__init__(*args, **kwargs) + super().__init__(input_args, conv_args, *args, **kwargs) self.rbf = BesselBasisLayer(num_radial, radius, envelope_exponent) self.sbf = SphericalBasisLayer( @@ -111,28 +113,37 @@ def get_conv(self, input_dim, output_dim): if self.use_edge_attr: return Sequential( - "x, pos, rbf, edge_attr, sbf, i, j, idx_kj, idx_ji", + self.input_args, [ - (lin, "x -> x"), - (emb, "x, rbf, i, j, edge_attr -> x1"), + (lin, "inv_node_feat -> inv_node_feat"), + (emb, "inv_node_feat, rbf, i, j, edge_attr -> x1"), (inter, "x1, rbf, sbf, idx_kj, idx_ji -> x2"), - (dec, "x2, rbf, i -> c"), - (lambda x, pos: [x, pos], "c, pos -> c, pos"), + (dec, "x2, rbf, i -> inv_node_feat"), + ( + lambda inv_node_feat, equiv_node_feat: [ + inv_node_feat, + equiv_node_feat, + ], + "inv_node_feat, equiv_node_feat -> inv_node_feat, equiv_node_feat", + ), ], ) else: return Sequential( - "x, pos, rbf, sbf, i, j, idx_kj, idx_ji", + self.input_args, [ - (lin, "x -> x"), - (emb, "x, rbf, i, j -> x1"), + (lin, "inv_node_feat -> inv_node_feat"), + (emb, "inv_node_feat, rbf, i, j -> x1"), (inter, "x1, rbf, sbf, idx_kj, idx_ji -> x2"), - (dec, "x2, rbf, i -> c"), - (lambda x, pos: [x, pos], "c, pos -> c, pos"), + (dec, "x2, rbf, i -> inv_node_feat"), + ( + lambda x, pos: [x, pos], + "inv_node_feat, equiv_node_feat -> inv_node_feat, equiv_node_feat", + ), ], ) - def _conv_args(self, data): + def _embedding(self, data): assert ( data.pos is not None ), "DimeNet requires node positions (data.pos) to be set." @@ -166,7 +177,7 @@ def _conv_args(self, data): ), "Data must have edge attributes if use_edge_attributes is set." conv_args.update({"edge_attr": data.edge_attr}) - return conv_args + return data.x, data.pos, conv_args """ diff --git a/hydragnn/models/EGCLStack.py b/hydragnn/models/EGCLStack.py index 7109d0fc..3a9a7492 100644 --- a/hydragnn/models/EGCLStack.py +++ b/hydragnn/models/EGCLStack.py @@ -21,6 +21,8 @@ class EGCLStack(Base): def __init__( self, + input_args, + conv_args, edge_attr_dim: int, *args, max_neighbours: Optional[int] = None, @@ -30,7 +32,11 @@ def __init__( self.edge_dim = ( 0 if edge_attr_dim is None else edge_attr_dim ) # Must be named edge_dim to trigger use by Base - super().__init__(*args, **kwargs) + super().__init__(input_args, conv_args, *args, **kwargs) + + assert ( + self.input_args == "inv_node_feat, equiv_node_feat, edge_index, edge_attr" + ) pass def _init_conv(self): @@ -56,21 +62,33 @@ def get_conv(self, input_dim, output_dim, last_layer=False): if self.equivariance and not last_layer: return Sequential( - "x, pos, edge_index, edge_attr", + self.input_args, [ - (egcl, "x, pos, edge_index, edge_attr -> x, pos"), + ( + egcl, + "inv_node_feat, equiv_node_feat, edge_index, edge_attr -> inv_node_feat, equiv_node_feat", + ), ], ) else: return Sequential( - "x, pos, edge_index, edge_attr", + self.input_args, [ - (egcl, "x, pos, edge_index, edge_attr -> x"), - (lambda x, pos: [x, pos], "x, pos -> x, pos"), + ( + egcl, + "inv_node_feat, equiv_node_feat, edge_index, edge_attr -> inv_node_feat", + ), + ( + lambda inv_node_feat, equiv_node_feat: [ + inv_node_feat, + equiv_node_feat, + ], + "inv_node_feat, equiv_node_feat -> inv_node_feat, equiv_node_feat", + ), ], ) - def _conv_args(self, data): + def _embedding(self, data): if self.edge_dim > 0: conv_args = { "edge_index": data.edge_index, @@ -82,7 +100,7 @@ def _conv_args(self, data): "edge_attr": None, } - return conv_args + return data.x, data.pos, conv_args def __str__(self): return "EGCLStack" diff --git a/hydragnn/models/GATStack.py b/hydragnn/models/GATStack.py index f29696b2..adeb0ad7 100644 --- a/hydragnn/models/GATStack.py +++ b/hydragnn/models/GATStack.py @@ -21,6 +21,8 @@ class GATStack(Base): def __init__( self, + input_args, + conv_args, heads: int, negative_slope: float, *args, @@ -30,7 +32,7 @@ def __init__( self.heads = heads self.negative_slope = negative_slope - super().__init__(*args, **kwargs) + super().__init__(input_args, conv_args, *args, **kwargs) def _init_conv(self): """Here this function overwrites _init_conv() in Base since it has different implementation @@ -99,18 +101,17 @@ def get_conv(self, input_dim, output_dim, concat): concat=concat, ) - input_args = "x, pos, edge_index" - conv_args = "x, edge_index" - - if self.use_edge_attr: - input_args += ", edge_attr" - conv_args += ", edge_attr" - return Sequential( - input_args, + self.input_args, [ - (gat, conv_args + " -> x"), - (lambda x, pos: [x, pos], "x, pos -> x, pos"), + (gat, self.conv_args + " -> inv_node_feat"), + ( + lambda inv_node_feat, equiv_node_feat: [ + inv_node_feat, + equiv_node_feat, + ], + "inv_node_feat, equiv_node_feat -> inv_node_feat, equiv_node_feat", + ), ], ) diff --git a/hydragnn/models/GINStack.py b/hydragnn/models/GINStack.py index fda18bad..193f46c0 100644 --- a/hydragnn/models/GINStack.py +++ b/hydragnn/models/GINStack.py @@ -33,14 +33,14 @@ def get_conv(self, input_dim, output_dim): train_eps=True, ) - input_args = "x, pos, edge_index" - conv_args = "x, edge_index" - return Sequential( - input_args, + self.input_args, [ - (gin, conv_args + " -> x"), - (lambda x, pos: [x, pos], "x, pos -> x, pos"), + (gin, self.conv_args + " -> inv_node_feat"), + ( + lambda x, equiv_node_feat: [x, equiv_node_feat], + "inv_node_feat, equiv_node_feat -> inv_node_feat, equiv_node_feat", + ), ], ) diff --git a/hydragnn/models/MACEStack.py b/hydragnn/models/MACEStack.py index d61696a6..66eed27e 100644 --- a/hydragnn/models/MACEStack.py +++ b/hydragnn/models/MACEStack.py @@ -70,6 +70,8 @@ class MACEStack(Base): def __init__( self, + input_args, + conv_args, r_max: float, # The cutoff radius for the radial basis functions and edge_index radial_type: str, # The type of radial basis function to use distance_transform: str, # The distance transform to use @@ -126,7 +128,7 @@ def __init__( ) # This makes the irreps string self.edge_feats_irreps = o3.Irreps(f"{num_bessel}x0e") - super().__init__(*args, **kwargs) + super().__init__(input_args, conv_args, *args, **kwargs) self.spherical_harmonics = o3.SphericalHarmonics( self.sh_irreps, @@ -264,6 +266,7 @@ def get_conv(self, input_dim, output_dim, first_layer=False, last_layer=False): # Constructing convolutional layers if first_layer: hidden_irreps_out = hidden_irreps + combine = CombineBlock() inter = self.interaction_cls_first( node_attrs_irreps=self.node_attr_irreps, node_feats_irreps=node_feats_irreps, @@ -288,10 +291,12 @@ def get_conv(self, input_dim, output_dim, first_layer=False, last_layer=False): sizing = o3.Linear( hidden_irreps_out, output_irreps ) # Change sizing to output_irreps + split = SplitBlock(hidden_irreps) elif last_layer: # Select only scalars output for last layer hidden_irreps_out = str(hidden_irreps[0]) output_irreps = str(output_irreps[0]) + combine = CombineBlock() inter = self.interaction_cls( node_attrs_irreps=self.node_attr_irreps, node_feats_irreps=hidden_irreps, @@ -312,8 +317,10 @@ def get_conv(self, input_dim, output_dim, first_layer=False, last_layer=False): sizing = o3.Linear( hidden_irreps_out, output_irreps ) # Change sizing to output_irreps + split = SplitBlock(hidden_irreps) else: hidden_irreps_out = hidden_irreps + combine = CombineBlock() inter = self.interaction_cls( node_attrs_irreps=self.node_attr_irreps, node_feats_irreps=hidden_irreps, @@ -334,71 +341,81 @@ def get_conv(self, input_dim, output_dim, first_layer=False, last_layer=False): sizing = o3.Linear( hidden_irreps_out, output_irreps ) # Change sizing to output_irreps - - input_args = "node_attributes, pos, node_features, edge_attributes, edge_features, edge_index" - conv_args = "node_attributes, edge_attributes, edge_features, edge_index" # node_features is not used here because it's passed through in the forward + split = SplitBlock(hidden_irreps) if not last_layer: return PyGSequential( - input_args, + self.input_args, [ - (inter, "node_features, " + conv_args + " -> node_features, sc"), + (combine, "inv_node_feat, equiv_node_feat -> node_features"), + ( + inter, + "node_features, " + self.conv_args + " -> node_features, sc", + ), (prod, "node_features, sc, node_attributes -> node_features"), (sizing, "node_features -> node_features"), ( - lambda node_features, pos: [node_features, pos], - "node_features, pos -> node_features, pos", + lambda node_features, equiv_node_feat: [ + node_features, + equiv_node_feat, + ], + "node_features, equiv_node_feat -> node_features, equiv_node_feat", ), + (split, "node_features -> inv_node_feat, equiv_node_feat"), ], ) else: return PyGSequential( - input_args, + self.input_args, [ - (inter, "node_features, " + conv_args + " -> node_features, sc"), + (combine, "inv_node_feat, equiv_node_feat -> node_features"), + ( + inter, + "node_features, " + self.conv_args + " -> node_features, sc", + ), (prod, "node_features, sc, node_attributes -> node_features"), (sizing, "node_features -> node_features"), ( - lambda node_features, pos: [node_features, pos], - "node_features, pos -> node_features, pos", + lambda node_features, equiv_node_feat: [ + node_features, + equiv_node_feat, + ], + "node_features, equiv_node_feat -> node_features, equiv_node_feat", ), + (split, "node_features -> inv_node_feat, equiv_node_feat"), ], ) def forward(self, data): - data, conv_args = self._conv_args(data) - node_features = data.node_features - node_attributes = data.node_attributes - pos = data.pos + inv_node_feat, equiv_node_feat, conv_args = self._embedding(data) - ### encoder / decoder part #### - ## NOTE Norm techniques (feature_layers in HYDRA) are not advised for use in equivariant models as it can break equivariance - - ### There is a readout before the first convolution layer ### - outputs = [] + ### MACE has a readout block before convolutions ### output = self.multihead_decoders[0]( - data, node_attributes + data, data.node_attributes ) # [index][n_output, size_output] - # Create outputs first outputs = output ### Do conv --> readout --> repeat for each convolution layer ### for conv, readout in zip(self.graph_convs, self.multihead_decoders[1:]): if not self.conv_checkpointing: - node_features, pos = conv( - node_features=node_features, pos=pos, **conv_args + inv_node_feat, equiv_node_feat = conv( + inv_node_feat=inv_node_feat, + equiv_node_feat=equiv_node_feat, + **conv_args, ) - output = readout(data, node_features) # [index][n_output, size_output] + output = readout( + data, torch.cat([inv_node_feat, equiv_node_feat], dim=1) + ) # [index][n_output, size_output] else: - node_features, pos = checkpoint( + inv_node_feat, equiv_node_feat = checkpoint( conv, use_reentrant=False, - node_features=node_features, - pos=pos, + inv_node_feat=inv_node_feat, + equiv_node_feat=equiv_node_feat, **conv_args, ) output = readout( - data, node_features + data, torch.cat([inv_node_feat, equiv_node_feat], dim=1) ) # output is a list of tensors with [index][n_output, size_output] # Sum predictions for each index, taking care of size differences for idx, prediction in enumerate(output): @@ -406,7 +423,7 @@ def forward(self, data): return outputs - def _conv_args(self, data): + def _embedding(self, data): assert ( data.pos is not None ), "MACE requires node positions (data.pos) to be set." @@ -452,7 +469,11 @@ def _conv_args(self, data): "edge_index": data.edge_index, } - return data, conv_args + return ( + data.node_features[:, : self.hidden_dim], + data.node_features[:, self.hidden_dim :], + conv_args, + ) def _multihead(self): # NOTE Multihead is skipped as it's an integral part of MACE's architecture to have a decoder after every layer, @@ -505,6 +526,25 @@ def process_node_attributes(node_attributes, num_elements): return one_hot +@compile_mode("script") +class CombineBlock(torch.nn.Module): + def __init__(self): + super(CombineBlock, self).__init__() + + def forward(self, inv_node_features, equiv_node_features): + return torch.cat([inv_node_features, equiv_node_features], dim=1) + + +@compile_mode("script") +class SplitBlock(torch.nn.Module): + def __init__(self, irreps): + super(SplitBlock, self).__init__() + self.dim = irreps.count(o3.Irrep(0, 1)) + + def forward(self, node_features): + return node_features[:, : self.dim], node_features[:, self.dim :] + + @compile_mode("script") class MultiheadDecoderBlock(torch.nn.Module): def __init__( diff --git a/hydragnn/models/MFCStack.py b/hydragnn/models/MFCStack.py index dc2b78c6..963010b8 100644 --- a/hydragnn/models/MFCStack.py +++ b/hydragnn/models/MFCStack.py @@ -21,13 +21,15 @@ class MFCStack(Base): def __init__( self, + input_args, + conv_args, max_degree: int, *args, **kwargs, ): self.max_degree = max_degree - super().__init__(*args, **kwargs) + super().__init__(input_args, conv_args, *args, **kwargs) def get_conv(self, input_dim, output_dim): mfc = MFConv( @@ -36,14 +38,14 @@ def get_conv(self, input_dim, output_dim): max_degree=self.max_degree, ) - input_args = "x, pos, edge_index" - conv_args = "x, edge_index" - return Sequential( - input_args, + self.input_args, [ - (mfc, conv_args + " -> x"), - (lambda x, pos: [x, pos], "x, pos -> x, pos"), + (mfc, self.conv_args + " -> inv_node_feat"), + ( + lambda x, pos: [x, pos], + "inv_node_feat, equiv_node_feat -> inv_node_feat, equiv_node_feat", + ), ], ) diff --git a/hydragnn/models/PAINNStack.py b/hydragnn/models/PAINNStack.py index 11f06246..01cb4381 100644 --- a/hydragnn/models/PAINNStack.py +++ b/hydragnn/models/PAINNStack.py @@ -31,6 +31,8 @@ class PAINNStack(Base): def __init__( self, # edge_dim: int, # To-Do: Add edge_features + input_args, + conv_args, num_radial: int, radius: float, *args, @@ -40,13 +42,11 @@ def __init__( self.num_radial = num_radial self.radius = radius - super().__init__(*args, **kwargs) + super().__init__(input_args, conv_args, *args, **kwargs) def _init_conv(self): last_layer = 1 == self.num_conv_layers - self.graph_convs.append( - self.get_conv(self.input_dim, self.hidden_dim, last_layer) - ) + self.graph_convs.append(self.get_conv(self.input_dim, self.hidden_dim)) self.feature_layers.append(nn.Identity()) for i in range(self.num_conv_layers - 1): last_layer = i == self.num_conv_layers - 2 @@ -64,8 +64,8 @@ def get_conv(self, input_dim, output_dim, last_layer=False): ) cross_inter = PainnUpdate(node_size=input_dim, last_layer=last_layer) """ - The following linear layers are to get the correct sizing of embeddings. This is - necessary to use the hidden_dim, output_dim of HYDRAGNN's stacked conv layers correctly + The following linear layers are to get the correct sizing of embeddings. This is + necessary to use the hidden_dim, output_dim of HYDRAGNN's stacked conv layers correctly because node_scalar and node-vector are updated through a sum. """ node_embed_out = nn.Sequential( @@ -77,82 +77,54 @@ def get_conv(self, input_dim, output_dim, last_layer=False): if not last_layer: return geom_nn.Sequential( - "x, v, pos, edge_index, diff, dist", + self.input_args, [ - (self_inter, "x, v, edge_index, diff, dist -> x, v"), - (cross_inter, "x, v -> x, v"), - (node_embed_out, "x -> x"), - (vec_embed_out, "v -> v"), - (lambda x, v, pos: [x, v, pos], "x, v, pos -> x, v, pos"), + ( + self_inter, + "inv_node_feat, equiv_node_feat, edge_index, diff, dist -> inv_node_feat, equiv_node_feat", + ), + ( + cross_inter, + "inv_node_feat, equiv_node_feat -> inv_node_feat, equiv_node_feat", + ), + (node_embed_out, "inv_node_feat -> inv_node_feat"), + (vec_embed_out, "equiv_node_feat -> equiv_node_feat"), + ( + lambda inv_node_feat, equiv_node_feat: [ + inv_node_feat, + equiv_node_feat, + ], + "inv_node_feat, equiv_node_feat -> inv_node_feat, equiv_node_feat", + ), ], ) else: return geom_nn.Sequential( - "x, v, pos, edge_index, diff, dist", + self.input_args, [ - (self_inter, "x, v, edge_index, diff, dist -> x, v"), + ( + self_inter, + "inv_node_feat, equiv_node_feat, edge_index, diff, dist -> inv_node_feat, equiv_node_feat", + ), ( cross_inter, - "x, v -> x", + "inv_node_feat, equiv_node_feat -> inv_node_feat", ), # v is not updated in the last layer to avoid hanging gradients ( node_embed_out, - "x -> x", + "inv_node_feat -> inv_node_feat", ), # No need to embed down v because it's not used anymore - (lambda x, v, pos: [x, v, pos], "x, v, pos -> x, v, pos"), + ( + lambda inv_node_feat, equiv_node_feat: [ + inv_node_feat, + equiv_node_feat, + ], + "inv_node_feat, equiv_node_feat -> inv_node_feat, equiv_node_feat", + ), ], ) - def forward(self, data): - data, conv_args = self._conv_args( - data - ) # Added v to data here (necessary for PAINN Stack) - x = data.x - v = data.v - pos = data.pos - - ### encoder part #### - for conv, feat_layer in zip(self.graph_convs, self.feature_layers): - if not self.conv_checkpointing: - c, v, pos = conv(x=x, v=v, pos=pos, **conv_args) # Added v here - else: - c, v, pos = checkpoint( # Added v here - conv, use_reentrant=False, x=x, v=v, pos=pos, **conv_args - ) - x = self.activation_function(feat_layer(c)) - - #### multi-head decoder part#### - # shared dense layers for graph level output - if data.batch is None: - x_graph = x.mean(dim=0, keepdim=True) - else: - x_graph = geom_nn.global_mean_pool(x, data.batch.to(x.device)) - outputs = [] - outputs_var = [] - for head_dim, headloc, type_head in zip( - self.head_dims, self.heads_NN, self.head_type - ): - if type_head == "graph": - x_graph_head = self.graph_shared(x_graph) - output_head = headloc(x_graph_head) - outputs.append(output_head[:, :head_dim]) - outputs_var.append(output_head[:, head_dim:] ** 2) - else: - if self.node_NN_type == "conv": - for conv, batch_norm in zip(headloc[0::2], headloc[1::2]): - c, v, pos = conv(x=x, v=v, pos=pos, **conv_args) - c = batch_norm(c) - x = self.activation_function(c) - x_node = x - else: - x_node = headloc(x=x, batch=data.batch) - outputs.append(x_node[:, :head_dim]) - outputs_var.append(x_node[:, head_dim:] ** 2) - if self.var_output: - return outputs, outputs_var - return outputs - - def _conv_args(self, data): + def _embedding(self, data): assert ( data.pos is not None ), "PAINNNet requires node positions (data.pos) to be set." @@ -173,7 +145,7 @@ def _conv_args(self, data): "dist": dist, } - return data, conv_args + return data.x, data.v, conv_args class PainnMessage(nn.Module): diff --git a/hydragnn/models/PNAEqStack.py b/hydragnn/models/PNAEqStack.py index 3eee774f..22946e14 100644 --- a/hydragnn/models/PNAEqStack.py +++ b/hydragnn/models/PNAEqStack.py @@ -42,7 +42,15 @@ class PNAEqStack(Base): """ def __init__( - self, deg: list, edge_dim: int, num_radial: int, radius: float, *args, **kwargs + self, + input_args, + conv_args, + deg: list, + edge_dim: int, + num_radial: int, + radius: float, + *args, + **kwargs, ): self.x_aggregators = ["mean", "min", "max", "std"] @@ -58,7 +66,7 @@ def __init__( self.num_radial = num_radial self.radius = radius - super().__init__(*args, **kwargs) + super().__init__(input_args, conv_args, *args, **kwargs) self.rbf = rbf_BasisLayer(self.num_radial, self.radius) @@ -90,8 +98,8 @@ def get_conv(self, input_dim, output_dim, last_layer=False): ) update = PainnUpdate(node_size=input_dim, last_layer=last_layer) """ - The following linear layers are to get the correct sizing of embeddings. This is - necessary to use the hidden_dim, output_dim of HYDRAGNN's stacked conv layers correctly + The following linear layers are to get the correct sizing of embeddings. This is + necessary to use the hidden_dim, output_dim of HYDRAGNN's stacked conv layers correctly because node_scalar and node-vector are updated through an additive skip connection. """ # Embed down to output size @@ -104,91 +112,50 @@ def get_conv(self, input_dim, output_dim, last_layer=False): geom_nn.Linear(input_dim, output_dim) if not last_layer else None ) - input_args = "x, v, pos, edge_index, edge_rbf, edge_vec" - conv_args = "x, v, edge_index, edge_rbf, edge_vec" - - if self.use_edge_attr: - input_args += ", edge_attr" - conv_args += ", edge_attr" - if not last_layer: return geom_nn.Sequential( - input_args, + self.input_args, [ - (message, conv_args + " -> x, v"), - (update, "x, v -> x, v"), - (node_embed_out, "x -> x"), - (vec_embed_out, "v -> v"), - (lambda x, v, pos: [x, v, pos], "x, v, pos -> x, v, pos"), + (message, self.conv_args + " -> inv_node_feat, equiv_node_feat"), + ( + update, + "inv_node_feat, equiv_node_feat -> inv_node_feat, equiv_node_feat", + ), + (node_embed_out, "inv_node_feat -> inv_node_feat"), + (vec_embed_out, "equiv_node_feat -> equiv_node_feat"), + ( + lambda inv_node_feat, equiv_node_feat: [ + inv_node_feat, + equiv_node_feat, + ], + "inv_node_feat, equiv_node_feat -> inv_node_feat, equiv_node_feat", + ), ], ) else: return geom_nn.Sequential( - input_args, + self.input_args, [ - (message, conv_args + " -> x, v"), + (message, self.conv_args + " -> inv_node_feat, equiv_node_feat"), ( update, - "x, v -> x", + "inv_node_feat, equiv_node_feat -> inv_node_feat", ), # v is not updated in the last layer to avoid hanging gradients ( node_embed_out, - "x -> x", + "inv_node_feat -> inv_node_feat", ), # No need to embed down v because it's not used anymore - (lambda x, v, pos: [x, v, pos], "x, v, pos -> x, v, pos"), + ( + lambda inv_node_feat, equiv_node_feat: [ + inv_node_feat, + equiv_node_feat, + ], + "inv_node_feat, equiv_node_feat -> inv_node_feat, equiv_node_feat", + ), ], ) - def forward(self, data): - data, conv_args = self._conv_args( - data - ) # Added v to data here (necessary for PNAEq Stack) - x = data.x - v = data.v - pos = data.pos - - ### encoder part #### - for conv, feat_layer in zip(self.graph_convs, self.feature_layers): - if not self.conv_checkpointing: - c, v, pos = conv(x=x, v=v, pos=pos, **conv_args) # Added v here - else: - c, v, pos = checkpoint( # Added v here - conv, use_reentrant=False, x=x, v=v, pos=pos, **conv_args - ) - x = self.activation_function(feat_layer(c)) - - #### multi-head decoder part#### - # shared dense layers for graph level output - if data.batch is None: - x_graph = x.mean(dim=0, keepdim=True) - else: - x_graph = geom_nn.global_mean_pool(x, data.batch.to(x.device)) - outputs = [] - outputs_var = [] - for head_dim, headloc, type_head in zip( - self.head_dims, self.heads_NN, self.head_type - ): - if type_head == "graph": - x_graph_head = self.graph_shared(x_graph) - output_head = headloc(x_graph_head) - outputs.append(output_head[:, :head_dim]) - outputs_var.append(output_head[:, head_dim:] ** 2) - else: - if self.node_NN_type == "conv": - for conv, batch_norm in zip(headloc[0::2], headloc[1::2]): - c, v, pos = conv(x=x, v=v, pos=pos, **conv_args) - c = batch_norm(c) - x = self.activation_function(c) - x_node = x - else: - x_node = headloc(x=x, batch=data.batch) - outputs.append(x_node[:, :head_dim]) - outputs_var.append(x_node[:, head_dim:] ** 2) - if self.var_output: - return outputs, outputs_var - return outputs - - def _conv_args(self, data): + def _embedding(self, data): assert ( data.pos is not None ), "PNAEq requires node positions (data.pos) to be set." @@ -210,7 +177,7 @@ def _conv_args(self, data): "edge_vec": norm_diff, } - return data, conv_args + return data.x, data.v, conv_args class PainnMessage(MessagePassing): diff --git a/hydragnn/models/PNAPlusStack.py b/hydragnn/models/PNAPlusStack.py index 06561d6d..0a104ba0 100644 --- a/hydragnn/models/PNAPlusStack.py +++ b/hydragnn/models/PNAPlusStack.py @@ -39,6 +39,8 @@ class PNAPlusStack(Base): def __init__( self, + input_args, + conv_args, deg: list, edge_dim: int, envelope_exponent: int, @@ -61,7 +63,7 @@ def __init__( self.num_radial = num_radial self.radius = radius - super().__init__(*args, **kwargs) + super().__init__(input_args, conv_args, *args, **kwargs) self.rbf = BesselBasisLayer( self.num_radial, self.radius, self.envelope_exponent @@ -81,22 +83,21 @@ def get_conv(self, input_dim, output_dim): divide_input=False, ) - input_args = "x, pos, edge_index, rbf" - conv_args = "x, edge_index, rbf" - - if self.use_edge_attr: - input_args += ", edge_attr" - conv_args += ", edge_attr" - return PyGSequential( - input_args, + self.input_args, [ - (pna, conv_args + " -> x"), - (lambda x, pos: [x, pos], "x, pos -> x, pos"), + (pna, self.conv_args + " -> inv_node_feat"), + ( + lambda inv_node_feat, equiv_node_feat: [ + inv_node_feat, + equiv_node_feat, + ], + "inv_node_feat, equiv_node_feat -> inv_node_feat, equiv_node_feat", + ), ], ) - def _conv_args(self, data): + def _embedding(self, data): assert ( data.pos is not None ), "PNA+ requires node positions (data.pos) to be set." @@ -113,7 +114,7 @@ def _conv_args(self, data): ), "Data must have edge attributes if use_edge_attributes is set." conv_args.update({"edge_attr": data.edge_attr}) - return conv_args + return data.x, data.pos, conv_args def __str__(self): return "PNAStack" diff --git a/hydragnn/models/PNAStack.py b/hydragnn/models/PNAStack.py index 8f8e98b3..fac7570a 100644 --- a/hydragnn/models/PNAStack.py +++ b/hydragnn/models/PNAStack.py @@ -19,6 +19,8 @@ class PNAStack(Base): def __init__( self, + input_args, + conv_args, deg: list, edge_dim: int, *args, @@ -35,7 +37,7 @@ def __init__( self.deg = torch.Tensor(deg) self.edge_dim = edge_dim - super().__init__(*args, **kwargs) + super().__init__(input_args, conv_args, *args, **kwargs) def get_conv(self, input_dim, output_dim): pna = PNAConv( @@ -50,18 +52,17 @@ def get_conv(self, input_dim, output_dim): divide_input=False, ) - input_args = "x, pos, edge_index" - conv_args = "x, edge_index" - - if self.use_edge_attr: - input_args += ", edge_attr" - conv_args += ", edge_attr" - return Sequential( - input_args, + self.input_args, [ - (pna, conv_args + " -> x"), - (lambda x, pos: [x, pos], "x, pos -> x, pos"), + (pna, self.conv_args + " -> inv_node_feat"), + ( + lambda inv_node_feat, equiv_node_feat: [ + inv_node_feat, + equiv_node_feat, + ], + "inv_node_feat, equiv_node_feat -> inv_node_feat, equiv_node_feat", + ), ], ) diff --git a/hydragnn/models/SAGEStack.py b/hydragnn/models/SAGEStack.py index 79a99a43..40c150ed 100644 --- a/hydragnn/models/SAGEStack.py +++ b/hydragnn/models/SAGEStack.py @@ -28,14 +28,17 @@ def get_conv(self, input_dim, output_dim): out_channels=output_dim, ) - input_args = "x, pos, edge_index" - conv_args = "x, edge_index" - return Sequential( - input_args, + self.input_args, [ - (sage, conv_args + " -> x"), - (lambda x, pos: [x, pos], "x, pos -> x, pos"), + (sage, self.conv_args + " -> inv_node_feat"), + ( + lambda inv_node_feat, equiv_node_feat: [ + inv_node_feat, + equiv_node_feat, + ], + "inv_node_feat, equiv_node_feat -> inv_node_feat, equiv_node_feat", + ), ], ) diff --git a/hydragnn/models/SCFStack.py b/hydragnn/models/SCFStack.py index 4f66ae6b..c132c8ec 100644 --- a/hydragnn/models/SCFStack.py +++ b/hydragnn/models/SCFStack.py @@ -32,6 +32,8 @@ class SCFStack(Base): def __init__( self, + input_args, + conv_args, num_filters: int, num_gaussians: list, radius: float, @@ -44,7 +46,7 @@ def __init__( self.num_filters = num_filters self.num_gaussians = num_gaussians - super().__init__(*args, **kwargs) + super().__init__(input_args, conv_args, *args, **kwargs) pass @@ -83,39 +85,56 @@ def get_conv(self, input_dim, output_dim, last_layer): equivariant=self.equivariance and not last_layer, ) - conv_args = "x, edge_index, edge_weight, edge_attr, pos" if self.use_edge_attr: - input_args = "x, pos, edge_index, edge_weight, edge_attr" return PyGSeq( - input_args, + self.input_args, [ - (interaction, conv_args + " -> x"), - (lambda x, pos: [x, pos], "x, pos -> x, pos"), + (interaction, self.conv_args + " -> inv_node_feat"), + ( + lambda inv_node_feat, equiv_node_feat: [ + inv_node_feat, + equiv_node_feat, + ], + "inv_node_feat, equiv_node_feat -> inv_node_feat, equiv_node_feat", + ), ], ) elif self.equivariance and not last_layer: - input_args = "x, pos, batch" return PyGSeq( - input_args, + self.input_args, [ - (self.interaction_graph, "pos, batch -> edge_index, edge_weight"), + ( + self.interaction_graph, + "equiv_node_feat, batch -> edge_index, edge_weight", + ), (self.distance_expansion, "edge_weight -> edge_attr"), - (interaction, conv_args + " -> x, pos"), + ( + interaction, + self.conv_args + " -> inv_node_feat, equiv_node_feat", + ), ], ) else: - input_args = "x, pos, batch" return PyGSeq( - input_args, + self.input_args, [ - (self.interaction_graph, "pos, batch -> edge_index, edge_weight"), + ( + self.interaction_graph, + "equiv_node_feat, batch -> edge_index, edge_weight", + ), (self.distance_expansion, "edge_weight -> edge_attr"), - (interaction, conv_args + " -> x"), - (lambda x, pos: [x, pos], "x, pos -> x, pos"), + (interaction, self.conv_args + " -> inv_node_feat"), + ( + lambda inv_node_feat, equiv_node_feat: [ + inv_node_feat, + equiv_node_feat, + ], + "inv_node_feat, equiv_node_feat -> inv_node_feat, equiv_node_feat", + ), ], ) - def _conv_args(self, data): + def _embedding(self, data): if (self.use_edge_attr) and (self.equivariance): raise Exception( "For SchNet if using edge attributes, then E(3)-equivariance cannot be ensured. Please disable equivariance or edge attributes." @@ -134,7 +153,7 @@ def _conv_args(self, data): "batch": data.batch, } - return conv_args + return data.x, data.pos, conv_args def __str__(self): return "SCFStack" @@ -188,10 +207,10 @@ def reset_parameters(self): def forward( self, x: Tensor, + pos: Tensor, edge_index: Tensor, edge_weight: Tensor, edge_attr: Tensor, - pos: Tensor, ) -> Tensor: C = 0.5 * (torch.cos(edge_weight * PI / self.cutoff) + 1.0) W = self.nn(edge_attr) * C.view(-1, 1) diff --git a/hydragnn/models/create.py b/hydragnn/models/create.py index 086bd069..7c6949d5 100644 --- a/hydragnn/models/create.py +++ b/hydragnn/models/create.py @@ -127,6 +127,8 @@ def create_model( # Note: model-specific inputs must come first. if model_type == "GIN": model = GINStack( + "inv_node_feat, equiv_node_feat, edge_index", + "inv_node_feat, edge_index", input_dim, hidden_dim, output_dim, @@ -145,6 +147,8 @@ def create_model( elif model_type == "PNA": assert pna_deg is not None, "PNA requires degree input." model = PNAStack( + "inv_node_feat, equiv_node_feat, edge_index", + "inv_node_feat, edge_index", pna_deg, edge_dim, input_dim, @@ -170,6 +174,8 @@ def create_model( assert num_radial is not None, "PNAPlus requires num_radial input." assert radius is not None, "PNAPlus requires radius input." model = PNAPlusStack( + "inv_node_feat, equiv_node_feat, edge_index, rbf", + "inv_node_feat, edge_index, rbf", pna_deg, edge_dim, envelope_exponent, @@ -195,6 +201,8 @@ def create_model( heads = 6 negative_slope = 0.05 model = GATStack( + "inv_node_feat, equiv_node_feat, edge_index", + "inv_node_feat, edge_index", heads, negative_slope, input_dim, @@ -215,6 +223,8 @@ def create_model( elif model_type == "MFC": assert max_neighbours is not None, "MFC requires max_neighbours input." model = MFCStack( + "inv_node_feat, equiv_node_feat, edge_index", + "inv_node_feat, edge_index", max_neighbours, input_dim, hidden_dim, @@ -233,6 +243,8 @@ def create_model( elif model_type == "CGCNN": model = CGCNNStack( + "inv_node_feat, equiv_node_feat, edge_index", # input_args + "inv_node_feat, edge_index", # conv_args edge_dim, input_dim, output_dim, @@ -250,6 +262,8 @@ def create_model( elif model_type == "SAGE": model = SAGEStack( + "inv_node_feat, equiv_node_feat, edge_index", # input_args + "inv_node_feat, edge_index", # conv_args input_dim, hidden_dim, output_dim, @@ -269,6 +283,8 @@ def create_model( assert num_filters is not None, "SchNet requires num_filters input." assert radius is not None, "SchNet requires radius input." model = SCFStack( + "inv_node_feat, equiv_node_feat, batch", + "inv_node_feat, equiv_node_feat, edge_index, edge_weight, edge_attr", num_gaussians, num_filters, radius, @@ -301,6 +317,8 @@ def create_model( assert num_spherical is not None, "DimeNet requires num_spherical input." assert radius is not None, "DimeNet requires radius input." model = DIMEStack( + "inv_node_feat, equiv_node_feat, rbf, sbf, i, j, idx_kj, idx_ji", # input_args + "", # conv_args basis_emb_size, envelope_exponent, int_emb_size, @@ -329,6 +347,8 @@ def create_model( elif model_type == "EGNN": model = EGCLStack( + "inv_node_feat, equiv_node_feat, edge_index, edge_attr", # input_args + "", # conv_args edge_dim, input_dim, hidden_dim, @@ -349,6 +369,8 @@ def create_model( elif model_type == "PAINN": model = PAINNStack( # edge_dim, # To-do add edge_features + "inv_node_feat, equiv_node_feat, edge_index, diff, dist", + "", num_radial, radius, input_dim, @@ -368,6 +390,8 @@ def create_model( elif model_type == "PNAEq": assert pna_deg is not None, "PNAEq requires degree input." model = PNAEqStack( + "inv_node_feat, equiv_node_feat, edge_index, edge_rbf, edge_vec", + "inv_node_feat, equiv_node_feat, edge_index, edge_rbf, edge_vec", pna_deg, edge_dim, num_radial, @@ -394,6 +418,8 @@ def create_model( assert max_ell >= 1, "MACE requires max_ell >= 1." assert node_max_ell >= 1, "MACE requires node_max_ell >= 1." model = MACEStack( + "node_attributes, equiv_node_feat, inv_node_feat, edge_attributes, edge_features, edge_index", + "node_attributes, edge_attributes, edge_features, edge_index", radius, radial_type, distance_transform,