Skip to content

Commit

Permalink
added Bareesh and Eric's changes to HydraGNN
Browse files Browse the repository at this point in the history
  • Loading branch information
zachfox committed Oct 25, 2024
1 parent 0fafea7 commit b58c053
Show file tree
Hide file tree
Showing 6 changed files with 146 additions and 27 deletions.
58 changes: 38 additions & 20 deletions hydragnn/models/Base.py
Original file line number Diff line number Diff line change
Expand Up @@ -282,12 +282,15 @@ def _multihead(self):
head_NN.append(self.convs_node_output[inode_feature])
head_NN.append(self.batch_norms_node_output[inode_feature])
inode_feature += 1

else:
raise ValueError(
"Unknown head NN structure for node features"
+ self.node_NN_type
+ "; currently only support 'mlp', 'mlp_per_node' or 'conv' (can be set with config['NeuralNetwork']['Architecture']['output_heads']['node']['type'], e.g., ./examples/ci_multihead.json)"
)
elif self.head_type[ihead] == "pos":
head_NN = torch.nn.Identity()
else:
raise ValueError(
"Unknown head type"
Expand All @@ -304,15 +307,13 @@ def forward(self, data):
x = data.x
pos = data.pos

# print("data.x IN: ", x)
# print("data.pos IN", pos)

### encoder part ####
conv_args = self._conv_args(data)
for conv, feat_layer in zip(self.graph_convs, self.feature_layers):
if not self.conv_checkpointing:
c, pos = conv(x=x, pos=pos, **conv_args)
else:
c, pos = checkpoint(
conv, use_reentrant=False, x=x, pos=pos, **conv_args
)
c, pos = conv(x=x, pos=pos, **conv_args)
x = self.activation_function(feat_layer(c))

#### multi-head decoder part####
Expand All @@ -322,29 +323,46 @@ def forward(self, data):
else:
x_graph = global_mean_pool(x, data.batch.to(x.device))
outputs = []
outputs_var = []
for head_dim, headloc, type_head in zip(
self.head_dims, self.heads_NN, self.head_type
):
if type_head == "graph":
x_graph_head = self.graph_shared(x_graph)
output_head = headloc(x_graph_head)
outputs.append(output_head[:, :head_dim])
outputs_var.append(output_head[:, head_dim:] ** 2)
else:
outputs.append(headloc(x_graph_head))
elif type_head == "node":
if self.node_NN_type == "conv":
for conv, batch_norm in zip(headloc[0::2], headloc[1::2]):
c, pos = conv(x=x, pos=pos, **conv_args)
c = batch_norm(c)
x = self.activation_function(c)
x_node = x
x_node = self.activation_function(
batch_norm(conv(x=x, edge_index=data.edge_index))
)
else:
x_node = headloc(x=x, batch=data.batch)
outputs.append(x_node[:, :head_dim])
outputs_var.append(x_node[:, head_dim:] ** 2)
if self.var_output:
return outputs, outputs_var
return outputs

# print("NODE OUT: ", x_node)
elif type_head == "pos":
# print("POS OUT: ", pos)
if self.equivariance:
x_node = pos - data.pos # following 3.2 The Dynamics in "Equivariant Diffusion for Molecule Generation in 3D" (Hoogeboom et al 2022)
# calculate the center of gravity for each subgraph
sg_num_nodes = [d.num_nodes for d in data.to_data_list()] # TODO - inefficient
com_ten = []
# std_ten = []
place = 0
for sgnn in sg_num_nodes:
sg_x_node = x_node[place:place+sgnn]
com_ten.append(sg_x_node.mean(dim=0, keepdim=True).tile(sgnn, 1))
# std_ten.append(sg_x_node.std() * torch.ones_like(sg_x_node))
place += sgnn
com_ten = torch.cat(com_ten, dim=0)
# std_ten = torch.cat(std_ten, dim=0)
x_node = x_node - com_ten # subtract centers of mass
# x_node = x_node / std_ten # normalize output like GroupNorm
else:
x_node = pos
else:
raise NotImplementedError("Head type {} not recognized".format(type_head))
outputs.append(x_node)
return outputs

def loss(self, pred, value, head_index):
var = None
Expand Down
69 changes: 69 additions & 0 deletions hydragnn/models/HybridEGCLStack.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
##############################################################################
# Copyright (c) 2021, Oak Ridge National Laboratory #
# All rights reserved. #
# #
# This file is part of HydraGNN and is distributed under a BSD 3-clause #
# license. For the licensing terms see the LICENSE file in the top-level #
# directory. #
# #
# SPDX-License-Identifier: BSD-3-Clause #
##############################################################################
from typing import Optional

import torch
import torch.nn as nn
from torch_geometric.nn import Sequential
import torch.nn.functional as F
from .EGCLStack import EGCLStack

from hydragnn.utils.model import unsorted_segment_mean


class HybridEGCLStack(EGCLStack):
def __init__(
self,
*args,
**kwargs,
):
# Initialize the parent class
super().__init__(*args, **kwargs)

# Define new loss functions
self.cross_entropy = torch.nn.CrossEntropyLoss()
self.mse = torch.nn.MSELoss()

def loss_hpweighted(self, pred, value, head_index, var=None):
"""
Overwrite this method to make split loss between
MSE (atom pos) and Cross Entropy (atom types).
"""

# weights for different tasks as hyper-parameters
tot_loss = 0
tasks_loss = []
for ihead in range(self.num_heads):
head_pred = pred[ihead]
pred_shape = head_pred.shape
head_val = value[head_index[ihead]]
value_shape = head_val.shape
if pred_shape != value_shape:
head_val = torch.reshape(head_val, pred_shape)

# Calculate loss depending on head
# Calculate cross entropy if atom types
if ihead == 0:
head_loss = self.cross_entropy(head_pred, head_val)
# Calculate MSE if position noise
elif ihead == 1:
head_loss = self.mse(head_pred, head_val)

# Add loss to total loss and list of tasks loss
tot_loss += (
head_loss * self.loss_weights[ihead]
)
tasks_loss.append(head_loss)

return tot_loss, tasks_loss

def __str__(self):
return "HybridEGCLStack"
20 changes: 19 additions & 1 deletion hydragnn/models/create.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from hydragnn.models.SCFStack import SCFStack
from hydragnn.models.DIMEStack import DIMEStack
from hydragnn.models.EGCLStack import EGCLStack
from hydragnn.models.HybridEGCLStack import HybridEGCLStack
from hydragnn.models.PNAEqStack import PNAEqStack
from hydragnn.models.PAINNStack import PAINNStack
from hydragnn.models.MACEStack import MACEStack
Expand Down Expand Up @@ -345,7 +346,24 @@ def create_model(
num_conv_layers=num_conv_layers,
num_nodes=num_nodes,
)

elif model_type == "HybridEGNN":
model = HybridEGCLStack(
edge_dim,
input_dim,
hidden_dim,
output_dim,
output_type,
output_heads,
activation_function,
loss_function_type,
equivariance,
max_neighbours=max_neighbours,
loss_weights=task_weights,
freeze_conv=freeze_conv,
initial_bias=initial_bias,
num_conv_layers=num_conv_layers,
num_nodes=num_nodes,
)
elif model_type == "PAINN":
model = PAINNStack(
# edge_dim, # To-do add edge_features
Expand Down
9 changes: 9 additions & 0 deletions hydragnn/preprocess/graph_samples_checks_and_updates.py
Original file line number Diff line number Diff line change
Expand Up @@ -271,6 +271,15 @@ def update_predicted_values(
],
(-1, 1),
)
elif type[item] == "pos":
# index_counter_nodal_y = sum(node_feature_dim[: index[item]])
feat_ = torch.reshape(
data.pos[
:,
: node_feature_dim[index[item]]
],
(-1, 1),
)
else:
raise ValueError("Unknown output type", type[item])
output_feature.append(feat_)
Expand Down
15 changes: 9 additions & 6 deletions hydragnn/utils/input_config_parsing/config_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,7 @@ def update_config(config, train_loader, val_loader, test_loader):


def update_config_equivariance(config):
equivariant_models = ["EGNN", "SchNet", "PNAEq", "PAINN", "MACE"]
equivariant_models = ["EGNN", "SchNet", "PNAEq", "PAINN", "MACE", "HybridEGNN"]
if "equivariance" in config and config["equivariance"]:
assert (
config["model_type"] in equivariant_models
Expand Down Expand Up @@ -188,7 +188,7 @@ def update_config_NN_outputs(config, data, graph_size_variable):
for ihead in range(len(output_type)):
if output_type[ihead] == "graph":
dim_item = data.y_loc[0, ihead + 1].item() - data.y_loc[0, ihead].item()
elif output_type[ihead] == "node":
elif output_type[ihead] == "node" or output_type[ihead] == "pos":
if (
graph_size_variable
and config["Architecture"]["output_heads"]["node"]["type"]
Expand All @@ -206,10 +206,13 @@ def update_config_NN_outputs(config, data, graph_size_variable):
else:
for ihead in range(len(output_type)):
if output_type[ihead] != "graph":
raise ValueError(
"y_loc is needed for outputs that are not at graph levels",
output_type[ihead],
)
if not "dynamic_target" in config["Variables_of_interest"] or\
("dynamic_target" in config["Variables_of_interest"] and\
not config["Variables_of_interest"]["dynamic_target"]): # raise ValueError if yloc missing on non-graph, with "dynamic_target" set to false or missing
raise ValueError(
"y_loc is needed for outputs that are not at graph levels",
output_type[ihead],
)
dims_list = config["Variables_of_interest"]["output_dim"]

config["Architecture"]["output_dim"] = dims_list
Expand Down
2 changes: 2 additions & 0 deletions hydragnn/utils/model/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,8 @@ def activation_function_selection(activation_function_string: str):
return torch.nn.ReLU()
elif activation_function_string == "selu":
return torch.nn.SELU()
elif activation_function_string == "silu":
return torch.nn.SiLU()
elif activation_function_string == "prelu":
return torch.nn.PReLU()
elif activation_function_string == "elu":
Expand Down

0 comments on commit b58c053

Please sign in to comment.