diff --git a/hydragnn/models/Base.py b/hydragnn/models/Base.py index 186d0209f..f66d2a5a9 100644 --- a/hydragnn/models/Base.py +++ b/hydragnn/models/Base.py @@ -282,12 +282,15 @@ def _multihead(self): head_NN.append(self.convs_node_output[inode_feature]) head_NN.append(self.batch_norms_node_output[inode_feature]) inode_feature += 1 + else: raise ValueError( "Unknown head NN structure for node features" + self.node_NN_type + "; currently only support 'mlp', 'mlp_per_node' or 'conv' (can be set with config['NeuralNetwork']['Architecture']['output_heads']['node']['type'], e.g., ./examples/ci_multihead.json)" ) + elif self.head_type[ihead] == "pos": + head_NN = torch.nn.Identity() else: raise ValueError( "Unknown head type" @@ -304,15 +307,13 @@ def forward(self, data): x = data.x pos = data.pos + # print("data.x IN: ", x) + # print("data.pos IN", pos) + ### encoder part #### conv_args = self._conv_args(data) for conv, feat_layer in zip(self.graph_convs, self.feature_layers): - if not self.conv_checkpointing: - c, pos = conv(x=x, pos=pos, **conv_args) - else: - c, pos = checkpoint( - conv, use_reentrant=False, x=x, pos=pos, **conv_args - ) + c, pos = conv(x=x, pos=pos, **conv_args) x = self.activation_function(feat_layer(c)) #### multi-head decoder part#### @@ -322,29 +323,46 @@ def forward(self, data): else: x_graph = global_mean_pool(x, data.batch.to(x.device)) outputs = [] - outputs_var = [] for head_dim, headloc, type_head in zip( self.head_dims, self.heads_NN, self.head_type ): if type_head == "graph": x_graph_head = self.graph_shared(x_graph) - output_head = headloc(x_graph_head) - outputs.append(output_head[:, :head_dim]) - outputs_var.append(output_head[:, head_dim:] ** 2) - else: + outputs.append(headloc(x_graph_head)) + elif type_head == "node": if self.node_NN_type == "conv": for conv, batch_norm in zip(headloc[0::2], headloc[1::2]): - c, pos = conv(x=x, pos=pos, **conv_args) - c = batch_norm(c) - x = self.activation_function(c) - x_node = x + x_node = self.activation_function( + batch_norm(conv(x=x, edge_index=data.edge_index)) + ) else: x_node = headloc(x=x, batch=data.batch) - outputs.append(x_node[:, :head_dim]) - outputs_var.append(x_node[:, head_dim:] ** 2) - if self.var_output: - return outputs, outputs_var - return outputs + + # print("NODE OUT: ", x_node) + elif type_head == "pos": + # print("POS OUT: ", pos) + if self.equivariance: + x_node = pos - data.pos # following 3.2 The Dynamics in "Equivariant Diffusion for Molecule Generation in 3D" (Hoogeboom et al 2022) + # calculate the center of gravity for each subgraph + sg_num_nodes = [d.num_nodes for d in data.to_data_list()] # TODO - inefficient + com_ten = [] + # std_ten = [] + place = 0 + for sgnn in sg_num_nodes: + sg_x_node = x_node[place:place+sgnn] + com_ten.append(sg_x_node.mean(dim=0, keepdim=True).tile(sgnn, 1)) + # std_ten.append(sg_x_node.std() * torch.ones_like(sg_x_node)) + place += sgnn + com_ten = torch.cat(com_ten, dim=0) + # std_ten = torch.cat(std_ten, dim=0) + x_node = x_node - com_ten # subtract centers of mass + # x_node = x_node / std_ten # normalize output like GroupNorm + else: + x_node = pos + else: + raise NotImplementedError("Head type {} not recognized".format(type_head)) + outputs.append(x_node) + return outputs def loss(self, pred, value, head_index): var = None diff --git a/hydragnn/models/HybridEGCLStack.py b/hydragnn/models/HybridEGCLStack.py new file mode 100644 index 000000000..9a19d8444 --- /dev/null +++ b/hydragnn/models/HybridEGCLStack.py @@ -0,0 +1,69 @@ +############################################################################## +# Copyright (c) 2021, Oak Ridge National Laboratory # +# All rights reserved. # +# # +# This file is part of HydraGNN and is distributed under a BSD 3-clause # +# license. For the licensing terms see the LICENSE file in the top-level # +# directory. # +# # +# SPDX-License-Identifier: BSD-3-Clause # +############################################################################## +from typing import Optional + +import torch +import torch.nn as nn +from torch_geometric.nn import Sequential +import torch.nn.functional as F +from .EGCLStack import EGCLStack + +from hydragnn.utils.model import unsorted_segment_mean + + +class HybridEGCLStack(EGCLStack): + def __init__( + self, + *args, + **kwargs, + ): + # Initialize the parent class + super().__init__(*args, **kwargs) + + # Define new loss functions + self.cross_entropy = torch.nn.CrossEntropyLoss() + self.mse = torch.nn.MSELoss() + + def loss_hpweighted(self, pred, value, head_index, var=None): + """ + Overwrite this method to make split loss between + MSE (atom pos) and Cross Entropy (atom types). + """ + + # weights for different tasks as hyper-parameters + tot_loss = 0 + tasks_loss = [] + for ihead in range(self.num_heads): + head_pred = pred[ihead] + pred_shape = head_pred.shape + head_val = value[head_index[ihead]] + value_shape = head_val.shape + if pred_shape != value_shape: + head_val = torch.reshape(head_val, pred_shape) + + # Calculate loss depending on head + # Calculate cross entropy if atom types + if ihead == 0: + head_loss = self.cross_entropy(head_pred, head_val) + # Calculate MSE if position noise + elif ihead == 1: + head_loss = self.mse(head_pred, head_val) + + # Add loss to total loss and list of tasks loss + tot_loss += ( + head_loss * self.loss_weights[ihead] + ) + tasks_loss.append(head_loss) + + return tot_loss, tasks_loss + + def __str__(self): + return "HybridEGCLStack" \ No newline at end of file diff --git a/hydragnn/models/create.py b/hydragnn/models/create.py index 086bd0692..b88cab5dc 100644 --- a/hydragnn/models/create.py +++ b/hydragnn/models/create.py @@ -24,6 +24,7 @@ from hydragnn.models.SCFStack import SCFStack from hydragnn.models.DIMEStack import DIMEStack from hydragnn.models.EGCLStack import EGCLStack +from hydragnn.models.HybridEGCLStack import HybridEGCLStack from hydragnn.models.PNAEqStack import PNAEqStack from hydragnn.models.PAINNStack import PAINNStack from hydragnn.models.MACEStack import MACEStack @@ -345,7 +346,24 @@ def create_model( num_conv_layers=num_conv_layers, num_nodes=num_nodes, ) - + elif model_type == "HybridEGNN": + model = HybridEGCLStack( + edge_dim, + input_dim, + hidden_dim, + output_dim, + output_type, + output_heads, + activation_function, + loss_function_type, + equivariance, + max_neighbours=max_neighbours, + loss_weights=task_weights, + freeze_conv=freeze_conv, + initial_bias=initial_bias, + num_conv_layers=num_conv_layers, + num_nodes=num_nodes, + ) elif model_type == "PAINN": model = PAINNStack( # edge_dim, # To-do add edge_features diff --git a/hydragnn/preprocess/graph_samples_checks_and_updates.py b/hydragnn/preprocess/graph_samples_checks_and_updates.py index b4162d742..9fa45b181 100644 --- a/hydragnn/preprocess/graph_samples_checks_and_updates.py +++ b/hydragnn/preprocess/graph_samples_checks_and_updates.py @@ -271,6 +271,15 @@ def update_predicted_values( ], (-1, 1), ) + elif type[item] == "pos": + # index_counter_nodal_y = sum(node_feature_dim[: index[item]]) + feat_ = torch.reshape( + data.pos[ + :, + : node_feature_dim[index[item]] + ], + (-1, 1), + ) else: raise ValueError("Unknown output type", type[item]) output_feature.append(feat_) diff --git a/hydragnn/utils/input_config_parsing/config_utils.py b/hydragnn/utils/input_config_parsing/config_utils.py index a165ee7b7..85ccb8e83 100644 --- a/hydragnn/utils/input_config_parsing/config_utils.py +++ b/hydragnn/utils/input_config_parsing/config_utils.py @@ -136,7 +136,7 @@ def update_config(config, train_loader, val_loader, test_loader): def update_config_equivariance(config): - equivariant_models = ["EGNN", "SchNet", "PNAEq", "PAINN", "MACE"] + equivariant_models = ["EGNN", "SchNet", "PNAEq", "PAINN", "MACE", "HybridEGNN"] if "equivariance" in config and config["equivariance"]: assert ( config["model_type"] in equivariant_models @@ -188,7 +188,7 @@ def update_config_NN_outputs(config, data, graph_size_variable): for ihead in range(len(output_type)): if output_type[ihead] == "graph": dim_item = data.y_loc[0, ihead + 1].item() - data.y_loc[0, ihead].item() - elif output_type[ihead] == "node": + elif output_type[ihead] == "node" or output_type[ihead] == "pos": if ( graph_size_variable and config["Architecture"]["output_heads"]["node"]["type"] @@ -206,10 +206,13 @@ def update_config_NN_outputs(config, data, graph_size_variable): else: for ihead in range(len(output_type)): if output_type[ihead] != "graph": - raise ValueError( - "y_loc is needed for outputs that are not at graph levels", - output_type[ihead], - ) + if not "dynamic_target" in config["Variables_of_interest"] or\ + ("dynamic_target" in config["Variables_of_interest"] and\ + not config["Variables_of_interest"]["dynamic_target"]): # raise ValueError if yloc missing on non-graph, with "dynamic_target" set to false or missing + raise ValueError( + "y_loc is needed for outputs that are not at graph levels", + output_type[ihead], + ) dims_list = config["Variables_of_interest"]["output_dim"] config["Architecture"]["output_dim"] = dims_list diff --git a/hydragnn/utils/model/model.py b/hydragnn/utils/model/model.py index 7e3251e08..ed4113f5a 100644 --- a/hydragnn/utils/model/model.py +++ b/hydragnn/utils/model/model.py @@ -31,6 +31,8 @@ def activation_function_selection(activation_function_string: str): return torch.nn.ReLU() elif activation_function_string == "selu": return torch.nn.SELU() + elif activation_function_string == "silu": + return torch.nn.SiLU() elif activation_function_string == "prelu": return torch.nn.PReLU() elif activation_function_string == "elu":