Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use class-resolver for readouts and enable choosing MaxReadout #68

Open
wants to merge 6 commits into
base: master
Choose a base branch
from
Open
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Cleanup
cthoyt committed Feb 8, 2022
commit 50adc0215c98b40840ed7d8383cba1960ba48791
3 changes: 2 additions & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
@@ -7,4 +7,5 @@ matplotlib
tqdm
networkx
ninja
jinja2
jinja2
class-resolver
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
@@ -39,6 +39,7 @@
"networkx",
"ninja",
"jinja2",
"class-resolver",
],
python_requires=">=3.7,<3.9",
classifiers=[
2 changes: 1 addition & 1 deletion torchdrug/layers/__init__.py
Original file line number Diff line number Diff line change
@@ -23,7 +23,7 @@
"MessagePassingBase", "GraphConv", "GraphAttentionConv", "RelationalGraphConv", "GraphIsomorphismConv",
"NeuralFingerprintConv", "ContinuousFilterConv", "MessagePassing", "ChebyshevConv",
"DiffPool", "MinCutPool",
"MeanReadout", "SumReadout", "MaxReadout", "Softmax", "Set2Set", "Sort",
"MeanReadout", "SumReadout", "MaxReadout", "Softmax", "Set2Set", "Sort", "readout_resolver", "Readout",
"ConditionalFlow",
"NodeSampler", "EdgeSampler",
"distribution", "functional",
13 changes: 5 additions & 8 deletions torchdrug/models/chebnet.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
from collections.abc import Sequence

import torch
from class_resolver import Hint
from torch import nn

from torchdrug import core, layers
from torchdrug.core import Registry as R
from torchdrug.layers import Readout, readout_resolver


@R.register("models.ChebNet")
@@ -25,11 +27,11 @@ class ChebyshevConvolutionalNetwork(nn.Module, core.Configurable):
batch_norm (bool, optional): apply batch normalization or not
activation (str or function, optional): activation function
concat_hidden (bool, optional): concat hidden representations from all layers as output
readout (str, optional): readout function. Available functions are ``sum`` and ``mean``.
readout: readout function. Available functions are ``sum`` and ``mean``.
"""

def __init__(self, input_dim, hidden_dims, edge_input_dim=None, k=1, short_cut=False, batch_norm=False,
activation="relu", concat_hidden=False, readout="sum"):
activation="relu", concat_hidden=False, readout: Hint[Readout] = "sum"):
super(ChebyshevConvolutionalNetwork, self).__init__()

if not isinstance(hidden_dims, Sequence):
@@ -45,12 +47,7 @@ def __init__(self, input_dim, hidden_dims, edge_input_dim=None, k=1, short_cut=F
self.layers.append(layers.ChebyshevConv(self.dims[i], self.dims[i + 1], edge_input_dim, k,
batch_norm, activation))

if readout == "sum":
self.readout = layers.SumReadout()
elif readout == "mean":
self.readout = layers.MeanReadout()
else:
raise ValueError("Unknown readout `%s`" % readout)
self.readout = readout_resolver.make(readout)

def forward(self, graph, input, all_loss=None, metric=None):
"""
2 changes: 2 additions & 0 deletions torchdrug/models/gat.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
from collections.abc import Sequence

import torch
from class_resolver import Hint
from torch import nn

from torchdrug import core, layers
from torchdrug.core import Registry as R
from torchdrug.layers import Readout, readout_resolver


@R.register("models.GAT")
15 changes: 4 additions & 11 deletions torchdrug/models/gcn.py
Original file line number Diff line number Diff line change
@@ -5,8 +5,8 @@
from torch import nn

from torchdrug import core, layers
from torchdrug.layers import readout_resolver, Readout
from torchdrug.core import Registry as R
from torchdrug.layers import Readout, readout_resolver


@R.register("models.GCN")
@@ -99,11 +99,11 @@ class RelationalGraphConvolutionalNetwork(nn.Module, core.Configurable):
batch_norm (bool, optional): apply batch normalization or not
activation (str or function, optional): activation function
concat_hidden (bool, optional): concat hidden representations from all layers as output
readout (str, optional): readout function. Available functions are ``sum``, ``mean``, and ``max``.
readout: readout function. Available functions are ``sum``, ``mean``, and ``max``.
"""

def __init__(self, input_dim, hidden_dims, num_relation, edge_input_dim=None, short_cut=False, batch_norm=False,
activation="relu", concat_hidden=False, readout="sum"):
activation="relu", concat_hidden=False, readout: Hint[Readout] = "sum"):
super(RelationalGraphConvolutionalNetwork, self).__init__()

if not isinstance(hidden_dims, Sequence):
@@ -120,14 +120,7 @@ def __init__(self, input_dim, hidden_dims, num_relation, edge_input_dim=None, sh
self.layers.append(layers.RelationalGraphConv(self.dims[i], self.dims[i + 1], num_relation, edge_input_dim,
batch_norm, activation))

if readout == "sum":
self.readout = layers.SumReadout()
elif readout == "mean":
self.readout = layers.MeanReadout()
elif readout == "max":
self.readout = layers.MaxReadout()
else:
raise ValueError("Unknown readout `%s`" % readout)
self.readout = readout_resolver.make(readout)

def forward(self, graph, input, all_loss=None, metric=None):
"""
6 changes: 4 additions & 2 deletions torchdrug/models/gin.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
from collections.abc import Sequence

import torch
from class_resolver import Hint
from torch import nn

from torchdrug import core, layers
from torchdrug.core import Registry as R
from torchdrug.layers import Readout, readout_resolver


@R.register("models.GIN")
@@ -26,12 +28,12 @@ class GraphIsomorphismNetwork(nn.Module, core.Configurable):
batch_norm (bool, optional): apply batch normalization or not
activation (str or function, optional): activation function
concat_hidden (bool, optional): concat hidden representations from all layers as output
readout (str, optional): readout function. Available functions are ``sum``, ``mean``, and ``max``.
readout: readout function. Available functions are ``sum``, ``mean``, and ``max``.
"""

def __init__(self, input_dim=None, hidden_dims=None, edge_input_dim=None, num_mlp_layer=2, eps=0, learn_eps=False,
short_cut=False, batch_norm=False, activation="relu", concat_hidden=False,
readout="sum"):
readout: Hint[Readout] = "sum"):
super(GraphIsomorphismNetwork, self).__init__()

if not isinstance(hidden_dims, Sequence):
6 changes: 4 additions & 2 deletions torchdrug/models/neuralfp.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
from collections.abc import Sequence

import torch
from class_resolver import Hint
from torch import nn
from torch.nn import functional as F

from torchdrug import core, layers
from torchdrug.core import Registry as R
from torchdrug.layers import Readout, readout_resolver


@R.register("models.NeuralFP")
@@ -25,11 +27,11 @@ class NeuralFingerprint(nn.Module, core.Configurable):
batch_norm (bool, optional): apply batch normalization or not
activation (str or function, optional): activation function
concat_hidden (bool, optional): concat hidden representations from all layers as output
readout (str, optional): readout function. Available functions are ``sum`` and ``mean``.
readout: readout function. Available functions are ``sum``, ``mean``, and ``max``.
"""

def __init__(self, input_dim, output_dim, hidden_dims, edge_input_dim=None, short_cut=False, batch_norm=False,
activation="relu", concat_hidden=False, readout="sum"):
activation="relu", concat_hidden=False, readout: Hint[Readout] = "sum"):
super(NeuralFingerprint, self).__init__()

if not isinstance(hidden_dims, Sequence):
3 changes: 3 additions & 0 deletions torchdrug/models/schnet.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
from collections.abc import Sequence

import torch
from class_resolver import Hint
from torch import nn

from torchdrug import core, layers
from torchdrug.core import Registry as R
from torchdrug.layers import Readout, readout_resolver


@R.register("models.SchNet")
@@ -25,6 +27,7 @@ class SchNet(nn.Module, core.Configurable):
batch_norm (bool, optional): apply batch normalization or not
activation (str or function, optional): activation function
concat_hidden (bool, optional): concat hidden representations from all layers as output
readout: readout function. Available functions are ``sum``, ``mean``, and ``max``.
"""

def __init__(self, input_dim, hidden_dims, edge_input_dim=None, cutoff=5, num_gaussian=100, short_cut=True,
14 changes: 6 additions & 8 deletions torchdrug/tasks/pretrain.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
import copy

import torch
from class_resolver import Hint
from torch import nn
from torch.nn import functional as F
from torch_scatter import scatter_max, scatter_min

from torchdrug import core, tasks, layers
from torchdrug.data import constant
from torchdrug.layers import functional
from torchdrug.layers import functional, readout_resolver, Readout
from torchdrug.core import Registry as R


@@ -169,9 +170,10 @@ class ContextPrediction(tasks.Task, core.Configurable):
r2 (int, optional): outer radius for context graphs
readout (nn.Module, optional): readout function over context anchor nodes
num_negative (int, optional): number of negative samples per positive sample
readout: readout function. Available functions are ``sum``, ``mean``, and ``max``.
"""

def __init__(self, model, context_model=None, k=5, r1=4, r2=7, readout="mean", num_negative=1):
def __init__(self, model, context_model=None, k=5, r1=4, r2=7, readout: Hint[Readout] = "mean", num_negative=1):
super(ContextPrediction, self).__init__()
self.model = model
self.k = k
@@ -184,12 +186,8 @@ def __init__(self, model, context_model=None, k=5, r1=4, r2=7, readout="mean", n
self.context_model = copy.deepcopy(model)
else:
self.context_model = context_model
if readout == "sum":
self.readout = layers.SumReadout()
elif readout == "mean":
self.readout = layers.MeanReadout()
else:
raise ValueError("Unknown readout `%s`" % readout)

self.readout = readout_resolver.make(readout)

def substruct_and_context(self, graph):
center_index = (torch.rand(len(graph), device=self.device) * graph.num_nodes).long()