Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Light refactor to improve development #305

Merged
merged 6 commits into from
Nov 5, 2024
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 36 additions & 14 deletions hydragnn/models/Base.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@
class Base(Module):
def __init__(
self,
input_args: str,
conv_args: str,
input_dim: int,
hidden_dim: int,
output_dim: list,
Expand All @@ -46,6 +48,8 @@ def __init__(
):
super().__init__()
self.device = get_device()
self.input_args = input_args
self.conv_args = conv_args
self.input_dim = input_dim
self.hidden_dim = hidden_dim
self.dropout = dropout
Expand Down Expand Up @@ -104,6 +108,10 @@ def __init__(
and self.edge_dim > 0
):
self.use_edge_attr = True
if "edge_attr" not in self.input_args:
self.input_args += ", edge_attr"
if "edge_attr" not in self.conv_args:
self.conv_args += ", edge_attr"

# Option to only train final property layers.
self.freeze_conv = freeze_conv
Expand All @@ -127,14 +135,14 @@ def _init_conv(self):
self.graph_convs.append(conv)
self.feature_layers.append(BatchNorm(self.hidden_dim))

def _conv_args(self, data):
def _embedding(self, data):
conv_args = {"edge_index": data.edge_index.to(torch.long)}
if self.use_edge_attr:
assert (
data.edge_attr is not None
), "Data must have edge attributes if use_edge_attributes is set."
conv_args.update({"edge_attr": data.edge_attr})
return conv_args
return data.x, data.pos, conv_args

def _freeze_conv(self):
for module in [self.graph_convs, self.feature_layers]:
Expand Down Expand Up @@ -301,19 +309,27 @@ def enable_conv_checkpointing(self):
self.conv_checkpointing = True

def forward(self, data):
x = data.x
pos = data.pos

### encoder part ####
conv_args = self._conv_args(data)
inv_node_feat, equiv_node_feat, conv_args = self._embedding(data)

for conv, feat_layer in zip(self.graph_convs, self.feature_layers):
if not self.conv_checkpointing:
c, pos = conv(x=x, pos=pos, **conv_args)
inv_node_feat, equiv_node_feat = conv(
inv_node_feat=inv_node_feat,
equiv_node_feat=equiv_node_feat,
**conv_args
)
else:
c, pos = checkpoint(
conv, use_reentrant=False, x=x, pos=pos, **conv_args
inv_node_feat, equiv_node_feat = checkpoint(
conv,
use_reentrant=False,
inv_node_feat=inv_node_feat,
equiv_node_feat=equiv_node_feat,
**conv_args
)
x = self.activation_function(feat_layer(c))
inv_node_feat = self.activation_function(feat_layer(inv_node_feat))

x = inv_node_feat

#### multi-head decoder part####
# shared dense layers for graph level output
Expand All @@ -333,11 +349,17 @@ def forward(self, data):
outputs_var.append(output_head[:, head_dim:] ** 2)
else:
if self.node_NN_type == "conv":
inv_node_feat = x
for conv, batch_norm in zip(headloc[0::2], headloc[1::2]):
c, pos = conv(x=x, pos=pos, **conv_args)
c = batch_norm(c)
x = self.activation_function(c)
x_node = x
inv_node_feat, equiv_node_feat = conv(
inv_node_feat=inv_node_feat,
equiv_node_feat=equiv_node_feat,
**conv_args
)
inv_node_feat = batch_norm(inv_node_feat)
inv_node_feat = self.activation_function(inv_node_feat)
x_node = inv_node_feat
x = inv_node_feat
else:
x_node = headloc(x=x, batch=data.batch)
outputs.append(x_node[:, :head_dim])
Expand Down
33 changes: 23 additions & 10 deletions hydragnn/models/CGCNNStack.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@
class CGCNNStack(Base):
def __init__(
self,
input_args,
conv_args,
edge_dim: int,
input_dim,
output_dim,
Expand All @@ -32,13 +34,25 @@ def __init__(
# also as hidden dimension (second argument of base constructor)
# We therefore pass all required args explicitly.
super().__init__(
input_args,
conv_args,
input_dim,
input_dim,
output_dim,
*args,
**kwargs,
)

if self.use_edge_attr:
assert (
self.input_args
== "inv_node_feat, equiv_node_feat, edge_index, edge_attr"
)
assert self.conv_args == "inv_node_feat, edge_index, edge_attr"
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JustinBakerMath

Why do some stacks has this assert and others not?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can add them everywhere or remove them all. I just added this for security.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JustinBakerMath @ArCho48

I am partial to removing them all since they are set just before, but could be convinced either way.

else:
assert self.input_args == "inv_node_feat, equiv_node_feat, edge_index"
assert self.conv_args == "inv_node_feat, edge_index"

def get_conv(self, input_dim, _):
cgcnn = CGConv(
channels=input_dim,
Expand All @@ -48,18 +62,17 @@ def get_conv(self, input_dim, _):
bias=True,
)

input_args = "x, pos, edge_index"
conv_args = "x, edge_index"

if self.use_edge_attr:
input_args += ", edge_attr"
conv_args += ", edge_attr"

return Sequential(
input_args,
self.input_args,
[
(cgcnn, conv_args + " -> x"),
(lambda x, pos: [x, pos], "x, pos -> x, pos"),
(cgcnn, self.conv_args + " -> inv_node_feat"),
(
lambda inv_node_feat, equiv_node_feat: [
inv_node_feat,
equiv_node_feat,
],
"inv_node_feat, equiv_node_feat -> inv_node_feat, equiv_node_feat",
),
],
)

Expand Down
37 changes: 24 additions & 13 deletions hydragnn/models/DIMEStack.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,8 @@ class DIMEStack(Base):

def __init__(
self,
input_args,
conv_args,
basis_emb_size,
envelope_exponent,
int_emb_size,
Expand All @@ -60,7 +62,7 @@ def __init__(
self.edge_dim = edge_dim
self.radius = radius

super().__init__(*args, **kwargs)
super().__init__(input_args, conv_args, *args, **kwargs)

self.rbf = BesselBasisLayer(num_radial, radius, envelope_exponent)
self.sbf = SphericalBasisLayer(
Expand Down Expand Up @@ -111,28 +113,37 @@ def get_conv(self, input_dim, output_dim):

if self.use_edge_attr:
return Sequential(
"x, pos, rbf, edge_attr, sbf, i, j, idx_kj, idx_ji",
self.input_args,
[
(lin, "x -> x"),
(emb, "x, rbf, i, j, edge_attr -> x1"),
(lin, "inv_node_feat -> inv_node_feat"),
(emb, "inv_node_feat, rbf, i, j, edge_attr -> x1"),
(inter, "x1, rbf, sbf, idx_kj, idx_ji -> x2"),
(dec, "x2, rbf, i -> c"),
(lambda x, pos: [x, pos], "c, pos -> c, pos"),
(dec, "x2, rbf, i -> inv_node_feat"),
(
lambda inv_node_feat, equiv_node_feat: [
inv_node_feat,
equiv_node_feat,
],
"inv_node_feat, equiv_node_feat -> inv_node_feat, equiv_node_feat",
),
],
)
else:
return Sequential(
"x, pos, rbf, sbf, i, j, idx_kj, idx_ji",
self.input_args,
[
(lin, "x -> x"),
(emb, "x, rbf, i, j -> x1"),
(lin, "inv_node_feat -> inv_node_feat"),
(emb, "inv_node_feat, rbf, i, j -> x1"),
(inter, "x1, rbf, sbf, idx_kj, idx_ji -> x2"),
(dec, "x2, rbf, i -> c"),
(lambda x, pos: [x, pos], "c, pos -> c, pos"),
(dec, "x2, rbf, i -> inv_node_feat"),
(
lambda x, pos: [x, pos],
"inv_node_feat, equiv_node_feat -> inv_node_feat, equiv_node_feat",
),
],
)

def _conv_args(self, data):
def _embedding(self, data):
assert (
data.pos is not None
), "DimeNet requires node positions (data.pos) to be set."
Expand Down Expand Up @@ -166,7 +177,7 @@ def _conv_args(self, data):
), "Data must have edge attributes if use_edge_attributes is set."
conv_args.update({"edge_attr": data.edge_attr})

return conv_args
return data.x, data.pos, conv_args


"""
Expand Down
34 changes: 26 additions & 8 deletions hydragnn/models/EGCLStack.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@
class EGCLStack(Base):
def __init__(
self,
input_args,
conv_args,
edge_attr_dim: int,
*args,
max_neighbours: Optional[int] = None,
Expand All @@ -30,7 +32,11 @@ def __init__(
self.edge_dim = (
0 if edge_attr_dim is None else edge_attr_dim
) # Must be named edge_dim to trigger use by Base
super().__init__(*args, **kwargs)
super().__init__(input_args, conv_args, *args, **kwargs)

assert (
self.input_args == "inv_node_feat, equiv_node_feat, edge_index, edge_attr"
)
pass

def _init_conv(self):
Expand All @@ -56,21 +62,33 @@ def get_conv(self, input_dim, output_dim, last_layer=False):

if self.equivariance and not last_layer:
return Sequential(
"x, pos, edge_index, edge_attr",
self.input_args,
[
(egcl, "x, pos, edge_index, edge_attr -> x, pos"),
(
egcl,
"inv_node_feat, equiv_node_feat, edge_index, edge_attr -> inv_node_feat, equiv_node_feat",
),
],
)
else:
return Sequential(
"x, pos, edge_index, edge_attr",
self.input_args,
[
(egcl, "x, pos, edge_index, edge_attr -> x"),
(lambda x, pos: [x, pos], "x, pos -> x, pos"),
(
egcl,
"inv_node_feat, equiv_node_feat, edge_index, edge_attr -> inv_node_feat",
),
(
lambda inv_node_feat, equiv_node_feat: [
inv_node_feat,
equiv_node_feat,
],
"inv_node_feat, equiv_node_feat -> inv_node_feat, equiv_node_feat",
),
],
)

def _conv_args(self, data):
def _embedding(self, data):
if self.edge_dim > 0:
conv_args = {
"edge_index": data.edge_index,
Expand All @@ -82,7 +100,7 @@ def _conv_args(self, data):
"edge_attr": None,
}

return conv_args
return data.x, data.pos, conv_args

def __str__(self):
return "EGCLStack"
Expand Down
23 changes: 12 additions & 11 deletions hydragnn/models/GATStack.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@
class GATStack(Base):
def __init__(
self,
input_args,
conv_args,
heads: int,
negative_slope: float,
*args,
Expand All @@ -30,7 +32,7 @@ def __init__(
self.heads = heads
self.negative_slope = negative_slope

super().__init__(*args, **kwargs)
super().__init__(input_args, conv_args, *args, **kwargs)

def _init_conv(self):
"""Here this function overwrites _init_conv() in Base since it has different implementation
Expand Down Expand Up @@ -99,18 +101,17 @@ def get_conv(self, input_dim, output_dim, concat):
concat=concat,
)

input_args = "x, pos, edge_index"
conv_args = "x, edge_index"

if self.use_edge_attr:
input_args += ", edge_attr"
conv_args += ", edge_attr"

return Sequential(
input_args,
self.input_args,
[
(gat, conv_args + " -> x"),
(lambda x, pos: [x, pos], "x, pos -> x, pos"),
(gat, self.conv_args + " -> inv_node_feat"),
(
lambda inv_node_feat, equiv_node_feat: [
inv_node_feat,
equiv_node_feat,
],
"inv_node_feat, equiv_node_feat -> inv_node_feat, equiv_node_feat",
),
],
)

Expand Down
12 changes: 6 additions & 6 deletions hydragnn/models/GINStack.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,14 +33,14 @@ def get_conv(self, input_dim, output_dim):
train_eps=True,
)

input_args = "x, pos, edge_index"
conv_args = "x, edge_index"

return Sequential(
input_args,
self.input_args,
[
(gin, conv_args + " -> x"),
(lambda x, pos: [x, pos], "x, pos -> x, pos"),
(gin, self.conv_args + " -> inv_node_feat"),
(
lambda x, equiv_node_feat: [x, equiv_node_feat],
"inv_node_feat, equiv_node_feat -> inv_node_feat, equiv_node_feat",
),
],
)

Expand Down
Loading
Loading