-
Notifications
You must be signed in to change notification settings - Fork 7
/
Copy pathgcn_conv.py
127 lines (106 loc) · 4.69 KB
/
gcn_conv.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
import torch
from torch.nn import Parameter
from torch_scatter import scatter_add
from torch_geometric.nn.conv import MessagePassing
from torch_geometric.utils import remove_self_loops, add_self_loops
from torch_geometric.nn.inits import glorot, zeros
class GCNConv(MessagePassing):
r"""The graph convolutional operator from the `"Semi-supervised
Classfication with Graph Convolutional Networks"
<https://arxiv.org/abs/1609.02907>`_ paper
.. math::
\mathbf{X}^{\prime} = \mathbf{\hat{D}}^{-1/2} \mathbf{\hat{A}}
\mathbf{\hat{D}}^{-1/2} \mathbf{X} \mathbf{\Theta},
where :math:`\mathbf{\hat{A}} = \mathbf{A} + \mathbf{I}` denotes the
adjacency matrix with inserted self-loops and
:math:`\hat{D}_{ii} = \sum_{j=0} \hat{A}_{ij}` its diagonal degree matrix.
Args:
in_channels (int): Size of each input sample.
out_channels (int): Size of each output sample.
improved (bool, optional): If set to :obj:`True`, the layer computes
:math:`\mathbf{\hat{A}}` as :math:`\mathbf{A} + 2\mathbf{I}`.
(default: :obj:`False`)
cached (bool, optional): If set to :obj:`True`, the layer will cache
the computation of :math:`{\left(\mathbf{\hat{D}}^{-1/2}
\mathbf{\hat{A}} \mathbf{\hat{D}}^{-1/2} \right)}`.
(default: :obj:`False`)
bias (bool, optional): If set to :obj:`False`, the layer will not learn
an additive bias. (default: :obj:`True`)
edge_norm (bool, optional): whether or not to normalize adj matrix.
(default: :obj:`True`)
gfn (bool, optional): If `True`, only linear transform (1x1 conv) is
applied to every nodes. (default: :obj:`False`)
"""
def __init__(self,
in_channels,
out_channels,
improved=False,
cached=False,
bias=True,
edge_norm=True,
gfn=False):
super(GCNConv, self).__init__('add')
self.in_channels = in_channels
self.out_channels = out_channels
self.improved = improved
self.cached = cached
self.cached_result = None
self.edge_norm = edge_norm
self.gfn = gfn
self.weight = Parameter(torch.Tensor(in_channels, out_channels))
if bias:
self.bias = Parameter(torch.Tensor(out_channels))
else:
self.register_parameter('bias', None)
self.reset_parameters()
def reset_parameters(self):
glorot(self.weight)
zeros(self.bias)
self.cached_result = None
@staticmethod
def norm(edge_index, num_nodes, edge_weight, improved=False, dtype=None):
if edge_weight is None:
edge_weight = torch.ones((edge_index.size(1), ),
dtype=dtype,
device=edge_index.device)
edge_weight = edge_weight.view(-1)
assert edge_weight.size(0) == edge_index.size(1)
edge_index, edge_weight = remove_self_loops(edge_index, edge_weight)
edge_index, _ = add_self_loops(edge_index, num_nodes=num_nodes)
# Add edge_weight for loop edges.
loop_weight = torch.full((num_nodes, ),
1 if not improved else 2,
dtype=edge_weight.dtype,
device=edge_weight.device)
edge_weight = torch.cat([edge_weight, loop_weight], dim=0)
row, col = edge_index
deg = scatter_add(edge_weight, row, dim=0, dim_size=num_nodes)
deg_inv_sqrt = deg.pow(-0.5)
deg_inv_sqrt[deg_inv_sqrt == float('inf')] = 0
return edge_index, deg_inv_sqrt[row] * edge_weight * deg_inv_sqrt[col]
def forward(self, x, edge_index, edge_weight=None):
""""""
x = torch.matmul(x, self.weight)
if self.gfn:
return x
if not self.cached or self.cached_result is None:
if self.edge_norm:
edge_index, norm = GCNConv.norm(
edge_index, x.size(0), edge_weight, self.improved, x.dtype)
else:
norm = None
self.cached_result = edge_index, norm
edge_index, norm = self.cached_result
return self.propagate(edge_index, x=x, norm=norm)
def message(self, x_j, norm):
if self.edge_norm:
return norm.view(-1, 1) * x_j
else:
return x_j
def update(self, aggr_out):
if self.bias is not None:
aggr_out = aggr_out + self.bias
return aggr_out
def __repr__(self):
return '{}({}, {})'.format(self.__class__.__name__, self.in_channels,
self.out_channels)