Skip to content

Commit 498ea18

Browse files
committed
Refactoring
1 parent a80412d commit 498ea18

7 files changed

+18
-22
lines changed

README.md

+8-6
Original file line numberDiff line numberDiff line change
@@ -3,33 +3,35 @@
33
This is the official implementation of the paper: [On the Bottleneck of Graph Neural Networks and its Practical Implications](https://openreview.net/pdf?id=i80OPhOCVH2) (ICLR'2021).
44

55
By [Uri Alon](http://urialon.cswp.cs.technion.ac.il/) and [Eran Yahav](http://www.cs.technion.ac.il/~yahave/).
6-
See also the [[video]](https://youtu.be/vrLsEwzZTCQ) and the [[slides]](https://urialon.cswp.cs.technion.ac.il/wp-content/uploads/sites/83/2020/07/bottleneck_slides.pdf).
6+
See also the [[video]](https://youtu.be/vrLsEwzZTCQ), [[poster]](https://urialon.cswp.cs.technion.ac.il/wp-content/uploads/sites/83/2021/03/bottleneck_poster.pdf) and [[slides]](https://urialon.cswp.cs.technion.ac.il/wp-content/uploads/sites/83/2020/07/bottleneck_slides.pdf).
77

88
this repository is divided into three sub-projects:
99

1010
1. The subdirectory `tf-gnn-samples` is a clone of
1111
[https://github.com/microsoft/tf-gnn-samples](https://github.com/microsoft/tf-gnn-samples) by Brockschmidt (ICML'2020).
12+
This project can be used to reproduce the QM9 and VarMisuse experiments of Section 4.2 and 4.2 in the paper.
1213
This sub-project depends on TensorFlow 1.13.
1314
The instructions for our clone are the same as their original code, except that reproducing our experiments
1415
(the QM9 dataset and VarMisuse) can be done by running the
1516
script `tf-gnn-samples/run_qm9_benchs_fa.py` or `tf-gnn-samples/run_varmisuse_benchs_fa.py` instead of their original scripts.
1617
For additional dependencies and instructions, see their original README:
1718
[https://github.com/microsoft/tf-gnn-samples/blob/master/README.md](https://github.com/microsoft/tf-gnn-samples/blob/master/README.md).
1819
The main modification that we performed is using a Fully-Adjacent layer as the last
19-
GNN layer and we describe in our paper (Section 4).
20+
GNN layer and we describe in our paper.
2021
2. The subdirectory `gnn-comparison` is a clone of [https://github.com/diningphil/gnn-comparison](https://github.com/diningphil/gnn-comparison)
21-
by Errica et al. (ICLR'2020). This sub-project depends on PyTorch 1.4 and Pytorch-Geometric.
22+
by Errica et al. (ICLR'2020).
23+
This project can be used to reproduce the biological experiments (Section 4.3, the ENZYMES and NCI1 datasets).
24+
This sub-project depends on PyTorch 1.4 and Pytorch-Geometric.
2225
For additional dependencies and instructions, see their original README:
2326
[https://github.com/diningphil/gnn-comparison/blob/master/README.md](https://github.com/diningphil/gnn-comparison/blob/master/README.md).
24-
The instructions for our clone are the same, except that we added an additional flag to every `config_*.yml` file, called `last_layer_fully_adjacent`,
25-
which is set to `True` by default, and reproduces our experiments (Section 4.3, the ENZYMES and NCI1 datasets).
27+
The instructions for our clone are the same, except that we added an additional flag to every `config_*.yml` file, called `last_layer_fa`,
28+
which is set to `True` by default, and reproduces our experiments.
2629
The main modification that we performed is using a Fully-Adjacent layer as the last
2730
GNN layer.
2831
3. The main directory (in which this file resides) can be used to reproduce the experiments of
2932
Section 4.1 in the paper, for the "Tree-NeighborsMatch" problem. The rest of this README file includes the
3033
instructions for this main directory.
3134
This repository can be used to reproduce the experiments of
32-
Section 4.1 in the paper, for the "Tree-NeighborsMatch" problem.
3335

3436
This project was designed to be useful in experimenting with new GNN architectures and new solutions for the over-squashing problem.
3537

common.py

+3-9
Original file line numberDiff line numberDiff line change
@@ -7,13 +7,7 @@
77

88

99
class Task(Enum):
10-
PARITY = auto()
11-
LEAF_PARITY = auto()
12-
NOISY_LEAF_PARITY = auto()
13-
NOISY_SUM = auto()
14-
DICTIONARY = auto()
15-
NOISY_DICTIONARY = auto()
16-
DUMMY = auto()
10+
NEIGHBORS_MATCH = auto()
1711

1812
@staticmethod
1913
def from_string(s):
@@ -23,7 +17,7 @@ def from_string(s):
2317
raise ValueError()
2418

2519
def get_dataset(self, depth, train_fraction):
26-
if self is Task.DICTIONARY:
20+
if self is Task.NEIGHBORS_MATCH:
2721
dataset = DictionaryLookupDataset(depth)
2822
else:
2923
dataset = None
@@ -55,7 +49,7 @@ def get_layer(self, in_dim, out_dim):
5549
return GINConv(nn.Sequential(nn.Linear(in_dim, out_dim), nn.BatchNorm1d(out_dim), nn.ReLU(),
5650
nn.Linear(out_dim, out_dim), nn.BatchNorm1d(out_dim), nn.ReLU()))
5751
elif self is GNN_TYPE.GAT:
58-
# 4-heads, although the paper by Velickovic et al. had used 6-8 heads.
52+
# 4-heads, although the paper by Velickovic et al. had used 6-8 heads.
5953
# The output will be the concatenation of the heads, yielding a vector of size out_dim
6054
num_heads = 4
6155
return GATConv(in_dim, out_dim // num_heads, heads=num_heads)

main.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66

77
if __name__ == '__main__':
88
parser = ArgumentParser()
9-
parser.add_argument("--task", dest="task", default=Task.PARITY, type=Task.from_string, choices=list(Task),
9+
parser.add_argument("--task", dest="task", default=Task.NEIGHBORS_MATCH, type=Task.from_string, choices=list(Task),
1010
required=False)
1111
parser.add_argument("--type", dest="type", default=GNN_TYPE.GCN, type=GNN_TYPE.from_string, choices=list(GNN_TYPE),
1212
required=False)
@@ -18,7 +18,7 @@
1818
parser.add_argument("--eval_every", dest="eval_every", default=100, type=int, required=False)
1919
parser.add_argument("--batch_size", dest="batch_size", default=1024, type=int, required=False)
2020
parser.add_argument("--accum_grad", dest="accum_grad", default=1, type=int, required=False)
21-
parser.add_argument("--stop", dest="stop", default=STOP.TRAIN, type=STOP.from_string, choices=list(STOP),
21+
parser.add_argument("--stop", dest="stop", default=STOP.TRAIN, type=STOP.from_string, choices=list(STOP),
2222
required=False)
2323
parser.add_argument("--patience", dest="patience", default=20, type=int, required=False)
2424
parser.add_argument("--loader_workers", dest="loader_workers", default=0, type=int, required=False)
@@ -33,7 +33,7 @@
3333

3434

3535
def get_fake_args(
36-
task=Task.PARITY,
36+
task=Task.NEIGHBORS_MATCH,
3737
type=GNN_TYPE.GCN,
3838
dim=32,
3939
depth=3,

run-gat-2-8.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ def __init__(self, train_acc, test_acc, epoch):
2424

2525
if __name__ == '__main__':
2626

27-
task = Task.DICTIONARY
27+
task = Task.NEIGHBORS_MATCH
2828
gnn_type = GNN_TYPE.GAT
2929
stopping_criterion = STOP.TRAIN
3030
min_depth = 2

run-gcn-2-8.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ def __init__(self, train_acc, test_acc, epoch):
2424

2525
if __name__ == '__main__':
2626

27-
task = Task.DICTIONARY
27+
task = Task.NEIGHBORS_MATCH
2828
gnn_type = GNN_TYPE.GCN
2929
stopping_criterion = STOP.TRAIN
3030
min_depth = 2

run-ggnn-2-8.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ def __init__(self, train_acc, test_acc, epoch):
2424

2525
if __name__ == '__main__':
2626

27-
task = Task.DICTIONARY
27+
task = Task.NEIGHBORS_MATCH
2828
gnn_type = GNN_TYPE.GGNN
2929
stopping_criterion = STOP.TRAIN
3030
min_depth = 2

run-gin-2-8.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ def __init__(self, train_acc, test_acc, epoch):
2424

2525
if __name__ == '__main__':
2626

27-
task = Task.DICTIONARY
27+
task = Task.NEIGHBORS_MATCH
2828
gnn_type = GNN_TYPE.GIN
2929
stopping_criterion = STOP.TRAIN
3030
min_depth = 2

0 commit comments

Comments
 (0)