Skip to content

Commit

Permalink
Light refactor to improve development (ORNL#305)
Browse files Browse the repository at this point in the history
* light refactor to improve development

* removed printing

* testing my commit ability

* reformat MACE

* remove test change

* fix error relating to readout block dimension input

---------

Co-authored-by: Justin <[email protected]>
Co-authored-by: Rylie Weaver <[email protected]>
  • Loading branch information
3 people committed Nov 12, 2024
1 parent ad20eaf commit 8b47801
Show file tree
Hide file tree
Showing 15 changed files with 386 additions and 290 deletions.
50 changes: 36 additions & 14 deletions hydragnn/models/Base.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@
class Base(Module):
def __init__(
self,
input_args: str,
conv_args: str,
input_dim: int,
hidden_dim: int,
output_dim: list,
Expand All @@ -46,6 +48,8 @@ def __init__(
):
super().__init__()
self.device = get_device()
self.input_args = input_args
self.conv_args = conv_args
self.input_dim = input_dim
self.hidden_dim = hidden_dim
self.dropout = dropout
Expand Down Expand Up @@ -104,6 +108,10 @@ def __init__(
and self.edge_dim > 0
):
self.use_edge_attr = True
if "edge_attr" not in self.input_args:
self.input_args += ", edge_attr"
if "edge_attr" not in self.conv_args:
self.conv_args += ", edge_attr"

# Option to only train final property layers.
self.freeze_conv = freeze_conv
Expand All @@ -127,14 +135,14 @@ def _init_conv(self):
self.graph_convs.append(conv)
self.feature_layers.append(BatchNorm(self.hidden_dim))

def _conv_args(self, data):
def _embedding(self, data):
conv_args = {"edge_index": data.edge_index.to(torch.long)}
if self.use_edge_attr:
assert (
data.edge_attr is not None
), "Data must have edge attributes if use_edge_attributes is set."
conv_args.update({"edge_attr": data.edge_attr})
return conv_args
return data.x, data.pos, conv_args

def _freeze_conv(self):
for module in [self.graph_convs, self.feature_layers]:
Expand Down Expand Up @@ -301,19 +309,27 @@ def enable_conv_checkpointing(self):
self.conv_checkpointing = True

def forward(self, data):
x = data.x
pos = data.pos

### encoder part ####
conv_args = self._conv_args(data)
inv_node_feat, equiv_node_feat, conv_args = self._embedding(data)

for conv, feat_layer in zip(self.graph_convs, self.feature_layers):
if not self.conv_checkpointing:
c, pos = conv(x=x, pos=pos, **conv_args)
inv_node_feat, equiv_node_feat = conv(
inv_node_feat=inv_node_feat,
equiv_node_feat=equiv_node_feat,
**conv_args
)
else:
c, pos = checkpoint(
conv, use_reentrant=False, x=x, pos=pos, **conv_args
inv_node_feat, equiv_node_feat = checkpoint(
conv,
use_reentrant=False,
inv_node_feat=inv_node_feat,
equiv_node_feat=equiv_node_feat,
**conv_args
)
x = self.activation_function(feat_layer(c))
inv_node_feat = self.activation_function(feat_layer(inv_node_feat))

x = inv_node_feat

#### multi-head decoder part####
# shared dense layers for graph level output
Expand All @@ -333,11 +349,17 @@ def forward(self, data):
outputs_var.append(output_head[:, head_dim:] ** 2)
else:
if self.node_NN_type == "conv":
inv_node_feat = x
for conv, batch_norm in zip(headloc[0::2], headloc[1::2]):
c, pos = conv(x=x, pos=pos, **conv_args)
c = batch_norm(c)
x = self.activation_function(c)
x_node = x
inv_node_feat, equiv_node_feat = conv(
inv_node_feat=inv_node_feat,
equiv_node_feat=equiv_node_feat,
**conv_args
)
inv_node_feat = batch_norm(inv_node_feat)
inv_node_feat = self.activation_function(inv_node_feat)
x_node = inv_node_feat
x = inv_node_feat
else:
x_node = headloc(x=x, batch=data.batch)
outputs.append(x_node[:, :head_dim])
Expand Down
33 changes: 23 additions & 10 deletions hydragnn/models/CGCNNStack.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@
class CGCNNStack(Base):
def __init__(
self,
input_args,
conv_args,
edge_dim: int,
input_dim,
output_dim,
Expand All @@ -32,13 +34,25 @@ def __init__(
# also as hidden dimension (second argument of base constructor)
# We therefore pass all required args explicitly.
super().__init__(
input_args,
conv_args,
input_dim,
input_dim,
output_dim,
*args,
**kwargs,
)

if self.use_edge_attr:
assert (
self.input_args
== "inv_node_feat, equiv_node_feat, edge_index, edge_attr"
)
assert self.conv_args == "inv_node_feat, edge_index, edge_attr"
else:
assert self.input_args == "inv_node_feat, equiv_node_feat, edge_index"
assert self.conv_args == "inv_node_feat, edge_index"

def get_conv(self, input_dim, _):
cgcnn = CGConv(
channels=input_dim,
Expand All @@ -48,18 +62,17 @@ def get_conv(self, input_dim, _):
bias=True,
)

input_args = "x, pos, edge_index"
conv_args = "x, edge_index"

if self.use_edge_attr:
input_args += ", edge_attr"
conv_args += ", edge_attr"

return Sequential(
input_args,
self.input_args,
[
(cgcnn, conv_args + " -> x"),
(lambda x, pos: [x, pos], "x, pos -> x, pos"),
(cgcnn, self.conv_args + " -> inv_node_feat"),
(
lambda inv_node_feat, equiv_node_feat: [
inv_node_feat,
equiv_node_feat,
],
"inv_node_feat, equiv_node_feat -> inv_node_feat, equiv_node_feat",
),
],
)

Expand Down
37 changes: 24 additions & 13 deletions hydragnn/models/DIMEStack.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,8 @@ class DIMEStack(Base):

def __init__(
self,
input_args,
conv_args,
basis_emb_size,
envelope_exponent,
int_emb_size,
Expand All @@ -60,7 +62,7 @@ def __init__(
self.edge_dim = edge_dim
self.radius = radius

super().__init__(*args, **kwargs)
super().__init__(input_args, conv_args, *args, **kwargs)

self.rbf = BesselBasisLayer(num_radial, radius, envelope_exponent)
self.sbf = SphericalBasisLayer(
Expand Down Expand Up @@ -111,28 +113,37 @@ def get_conv(self, input_dim, output_dim):

if self.use_edge_attr:
return Sequential(
"x, pos, rbf, edge_attr, sbf, i, j, idx_kj, idx_ji",
self.input_args,
[
(lin, "x -> x"),
(emb, "x, rbf, i, j, edge_attr -> x1"),
(lin, "inv_node_feat -> inv_node_feat"),
(emb, "inv_node_feat, rbf, i, j, edge_attr -> x1"),
(inter, "x1, rbf, sbf, idx_kj, idx_ji -> x2"),
(dec, "x2, rbf, i -> c"),
(lambda x, pos: [x, pos], "c, pos -> c, pos"),
(dec, "x2, rbf, i -> inv_node_feat"),
(
lambda inv_node_feat, equiv_node_feat: [
inv_node_feat,
equiv_node_feat,
],
"inv_node_feat, equiv_node_feat -> inv_node_feat, equiv_node_feat",
),
],
)
else:
return Sequential(
"x, pos, rbf, sbf, i, j, idx_kj, idx_ji",
self.input_args,
[
(lin, "x -> x"),
(emb, "x, rbf, i, j -> x1"),
(lin, "inv_node_feat -> inv_node_feat"),
(emb, "inv_node_feat, rbf, i, j -> x1"),
(inter, "x1, rbf, sbf, idx_kj, idx_ji -> x2"),
(dec, "x2, rbf, i -> c"),
(lambda x, pos: [x, pos], "c, pos -> c, pos"),
(dec, "x2, rbf, i -> inv_node_feat"),
(
lambda x, pos: [x, pos],
"inv_node_feat, equiv_node_feat -> inv_node_feat, equiv_node_feat",
),
],
)

def _conv_args(self, data):
def _embedding(self, data):
assert (
data.pos is not None
), "DimeNet requires node positions (data.pos) to be set."
Expand Down Expand Up @@ -166,7 +177,7 @@ def _conv_args(self, data):
), "Data must have edge attributes if use_edge_attributes is set."
conv_args.update({"edge_attr": data.edge_attr})

return conv_args
return data.x, data.pos, conv_args


"""
Expand Down
34 changes: 26 additions & 8 deletions hydragnn/models/EGCLStack.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@
class EGCLStack(Base):
def __init__(
self,
input_args,
conv_args,
edge_attr_dim: int,
*args,
max_neighbours: Optional[int] = None,
Expand All @@ -30,7 +32,11 @@ def __init__(
self.edge_dim = (
0 if edge_attr_dim is None else edge_attr_dim
) # Must be named edge_dim to trigger use by Base
super().__init__(*args, **kwargs)
super().__init__(input_args, conv_args, *args, **kwargs)

assert (
self.input_args == "inv_node_feat, equiv_node_feat, edge_index, edge_attr"
)
pass

def _init_conv(self):
Expand All @@ -56,21 +62,33 @@ def get_conv(self, input_dim, output_dim, last_layer=False):

if self.equivariance and not last_layer:
return Sequential(
"x, pos, edge_index, edge_attr",
self.input_args,
[
(egcl, "x, pos, edge_index, edge_attr -> x, pos"),
(
egcl,
"inv_node_feat, equiv_node_feat, edge_index, edge_attr -> inv_node_feat, equiv_node_feat",
),
],
)
else:
return Sequential(
"x, pos, edge_index, edge_attr",
self.input_args,
[
(egcl, "x, pos, edge_index, edge_attr -> x"),
(lambda x, pos: [x, pos], "x, pos -> x, pos"),
(
egcl,
"inv_node_feat, equiv_node_feat, edge_index, edge_attr -> inv_node_feat",
),
(
lambda inv_node_feat, equiv_node_feat: [
inv_node_feat,
equiv_node_feat,
],
"inv_node_feat, equiv_node_feat -> inv_node_feat, equiv_node_feat",
),
],
)

def _conv_args(self, data):
def _embedding(self, data):
if self.edge_dim > 0:
conv_args = {
"edge_index": data.edge_index,
Expand All @@ -82,7 +100,7 @@ def _conv_args(self, data):
"edge_attr": None,
}

return conv_args
return data.x, data.pos, conv_args

def __str__(self):
return "EGCLStack"
Expand Down
23 changes: 12 additions & 11 deletions hydragnn/models/GATStack.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@
class GATStack(Base):
def __init__(
self,
input_args,
conv_args,
heads: int,
negative_slope: float,
*args,
Expand All @@ -30,7 +32,7 @@ def __init__(
self.heads = heads
self.negative_slope = negative_slope

super().__init__(*args, **kwargs)
super().__init__(input_args, conv_args, *args, **kwargs)

def _init_conv(self):
"""Here this function overwrites _init_conv() in Base since it has different implementation
Expand Down Expand Up @@ -99,18 +101,17 @@ def get_conv(self, input_dim, output_dim, concat):
concat=concat,
)

input_args = "x, pos, edge_index"
conv_args = "x, edge_index"

if self.use_edge_attr:
input_args += ", edge_attr"
conv_args += ", edge_attr"

return Sequential(
input_args,
self.input_args,
[
(gat, conv_args + " -> x"),
(lambda x, pos: [x, pos], "x, pos -> x, pos"),
(gat, self.conv_args + " -> inv_node_feat"),
(
lambda inv_node_feat, equiv_node_feat: [
inv_node_feat,
equiv_node_feat,
],
"inv_node_feat, equiv_node_feat -> inv_node_feat, equiv_node_feat",
),
],
)

Expand Down
12 changes: 6 additions & 6 deletions hydragnn/models/GINStack.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,14 +33,14 @@ def get_conv(self, input_dim, output_dim):
train_eps=True,
)

input_args = "x, pos, edge_index"
conv_args = "x, edge_index"

return Sequential(
input_args,
self.input_args,
[
(gin, conv_args + " -> x"),
(lambda x, pos: [x, pos], "x, pos -> x, pos"),
(gin, self.conv_args + " -> inv_node_feat"),
(
lambda x, equiv_node_feat: [x, equiv_node_feat],
"inv_node_feat, equiv_node_feat -> inv_node_feat, equiv_node_feat",
),
],
)

Expand Down
Loading

0 comments on commit 8b47801

Please sign in to comment.