Skip to content

Commit d47d9cd

Browse files
authored
Remove size parameter for GATv2 and HEAT (#3744)
* refactor heat_conv and test * refactor gatv2_conv and test
1 parent 95ef04f commit d47d9cd

File tree

4 files changed

+20
-25
lines changed

4 files changed

+20
-25
lines changed

test/nn/conv/test_gatv2_conv.py

+8-10
Original file line numberDiff line numberDiff line change
@@ -14,15 +14,14 @@ def test_gatv2_conv():
1414
assert conv.__repr__() == 'GATv2Conv(8, 32, heads=2)'
1515
out = conv(x1, edge_index)
1616
assert out.size() == (4, 64)
17-
assert torch.allclose(conv(x1, edge_index, size=(4, 4)), out)
17+
assert torch.allclose(conv(x1, edge_index), out)
1818
assert torch.allclose(conv(x1, adj.t()), out, atol=1e-6)
1919

20-
t = '(Tensor, Tensor, OptTensor, Size, NoneType) -> Tensor'
20+
t = '(Tensor, Tensor, OptTensor, NoneType) -> Tensor'
2121
jit = torch.jit.script(conv.jittable(t))
2222
assert torch.allclose(jit(x1, edge_index), out)
23-
assert torch.allclose(jit(x1, edge_index, size=(4, 4)), out)
2423

25-
t = '(Tensor, SparseTensor, OptTensor, Size, NoneType) -> Tensor'
24+
t = '(Tensor, SparseTensor, OptTensor, NoneType) -> Tensor'
2625
jit = torch.jit.script(conv.jittable(t))
2726
assert torch.allclose(jit(x1, adj.t()), out, atol=1e-6)
2827

@@ -39,7 +38,7 @@ def test_gatv2_conv():
3938
assert result[1].sizes() == [4, 4, 2] and result[1].nnz() == 7
4039
assert conv._alpha is None
4140

42-
t = ('(Tensor, Tensor, OptTensor, Size, bool) -> '
41+
t = ('(Tensor, Tensor, OptTensor, bool) -> '
4342
'Tuple[Tensor, Tuple[Tensor, Tensor]]')
4443
jit = torch.jit.script(conv.jittable(t))
4544
result = jit(x1, edge_index, return_attention_weights=True)
@@ -49,7 +48,7 @@ def test_gatv2_conv():
4948
assert result[1][1].min() >= 0 and result[1][1].max() <= 1
5049
assert conv._alpha is None
5150

52-
t = ('(Tensor, SparseTensor, OptTensor, Size, bool) -> '
51+
t = ('(Tensor, SparseTensor, OptTensor, bool) -> '
5352
'Tuple[Tensor, SparseTensor]')
5453
jit = torch.jit.script(conv.jittable(t))
5554
result = jit(x1, adj.t(), return_attention_weights=True)
@@ -60,15 +59,14 @@ def test_gatv2_conv():
6059
adj = adj.sparse_resize((4, 2))
6160
out1 = conv((x1, x2), edge_index)
6261
assert out1.size() == (2, 64)
63-
assert torch.allclose(conv((x1, x2), edge_index, size=(4, 2)), out1)
62+
assert torch.allclose(conv((x1, x2), edge_index), out1)
6463
assert torch.allclose(conv((x1, x2), adj.t()), out1, atol=1e-6)
6564

66-
t = '(OptPairTensor, Tensor, OptTensor, Size, NoneType) -> Tensor'
65+
t = '(OptPairTensor, Tensor, OptTensor, NoneType) -> Tensor'
6766
jit = torch.jit.script(conv.jittable(t))
6867
assert torch.allclose(jit((x1, x2), edge_index), out1)
69-
assert torch.allclose(jit((x1, x2), edge_index, size=(4, 2)), out1)
7068

71-
t = '(OptPairTensor, SparseTensor, OptTensor, Size, NoneType) -> Tensor'
69+
t = '(OptPairTensor, SparseTensor, OptTensor, NoneType) -> Tensor'
7270
jit = torch.jit.script(conv.jittable(t))
7371
assert torch.allclose(jit((x1, x2), adj.t()), out1, atol=1e-6)
7472

test/nn/conv/test_heat_conv.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ def test_heat_conv():
2020
assert out.size() == (4, 32)
2121
assert torch.allclose(conv(x, adj.t(), node_type, edge_type), out)
2222

23-
t = '(Tensor, Tensor, Tensor, Tensor, OptTensor, Size) -> Tensor'
23+
t = '(Tensor, Tensor, Tensor, Tensor, OptTensor) -> Tensor'
2424
jit = torch.jit.script(conv.jittable(t))
2525
assert torch.allclose(jit(x, edge_index, node_type, edge_type, edge_attr),
2626
out)
@@ -33,6 +33,6 @@ def test_heat_conv():
3333
assert out.size() == (4, 16)
3434
assert torch.allclose(conv(x, adj.t(), node_type, edge_type), out)
3535

36-
t = '(Tensor, SparseTensor, Tensor, Tensor, OptTensor, Size) -> Tensor'
36+
t = '(Tensor, SparseTensor, Tensor, Tensor, OptTensor) -> Tensor'
3737
jit = torch.jit.script(conv.jittable(t))
3838
assert torch.allclose(jit(x, adj.t(), node_type, edge_type), out)

torch_geometric/nn/conv/gatv2_conv.py

+7-9
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
from typing import Union, Tuple, Optional
2-
from torch_geometric.typing import (Adj, Size, OptTensor, PairTensor)
2+
from torch_geometric.typing import (Adj, OptTensor, PairTensor)
33

44
import torch
55
from torch import Tensor
@@ -163,12 +163,12 @@ def reset_parameters(self):
163163
zeros(self.bias)
164164

165165
def forward(self, x: Union[Tensor, PairTensor], edge_index: Adj,
166-
edge_attr: OptTensor = None, size: Size = None,
166+
edge_attr: OptTensor = None,
167167
return_attention_weights: bool = None):
168-
# type: (Union[Tensor, PairTensor], Tensor, OptTensor, Size, NoneType) -> Tensor # noqa
169-
# type: (Union[Tensor, PairTensor], SparseTensor, OptTensor, Size, NoneType) -> Tensor # noqa
170-
# type: (Union[Tensor, PairTensor], Tensor, OptTensor, Size, bool) -> Tuple[Tensor, Tuple[Tensor, Tensor]] # noqa
171-
# type: (Union[Tensor, PairTensor], SparseTensor, OptTensor, Size, bool) -> Tuple[Tensor, SparseTensor] # noqa
168+
# type: (Union[Tensor, PairTensor], Tensor, OptTensor, NoneType) -> Tensor # noqa
169+
# type: (Union[Tensor, PairTensor], SparseTensor, OptTensor, NoneType) -> Tensor # noqa
170+
# type: (Union[Tensor, PairTensor], Tensor, OptTensor, bool) -> Tuple[Tensor, Tuple[Tensor, Tensor]] # noqa
171+
# type: (Union[Tensor, PairTensor], SparseTensor, OptTensor, bool) -> Tuple[Tensor, SparseTensor] # noqa
172172
r"""
173173
Args:
174174
return_attention_weights (bool, optional): If set to :obj:`True`,
@@ -202,8 +202,6 @@ def forward(self, x: Union[Tensor, PairTensor], edge_index: Adj,
202202
num_nodes = x_l.size(0)
203203
if x_r is not None:
204204
num_nodes = min(num_nodes, x_r.size(0))
205-
if size is not None:
206-
num_nodes = min(size[0], size[1])
207205
edge_index, edge_attr = remove_self_loops(
208206
edge_index, edge_attr)
209207
edge_index, edge_attr = add_self_loops(
@@ -220,7 +218,7 @@ def forward(self, x: Union[Tensor, PairTensor], edge_index: Adj,
220218

221219
# propagate_type: (x: PairTensor, edge_attr: OptTensor)
222220
out = self.propagate(edge_index, x=(x_l, x_r), edge_attr=edge_attr,
223-
size=size)
221+
size=None)
224222

225223
alpha = self._alpha
226224
self._alpha = None

torch_geometric/nn/conv/heat_conv.py

+3-4
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
from typing import Optional
2-
from torch_geometric.typing import Adj, Size, OptTensor
2+
from torch_geometric.typing import Adj, OptTensor
33

44
import torch
55
from torch import Tensor
@@ -89,8 +89,7 @@ def reset_parameters(self):
8989
self.lin.reset_parameters()
9090

9191
def forward(self, x: Tensor, edge_index: Adj, node_type: Tensor,
92-
edge_type: Tensor, edge_attr: OptTensor = None,
93-
size: Size = None) -> Tensor:
92+
edge_type: Tensor, edge_attr: OptTensor = None) -> Tensor:
9493
""""""
9594
x = self.hetero_lin(x, node_type)
9695

@@ -99,7 +98,7 @@ def forward(self, x: Tensor, edge_index: Adj, node_type: Tensor,
9998

10099
# propagate_type: (x: Tensor, edge_type_emb: Tensor, edge_attr: OptTensor) # noqa
101100
out = self.propagate(edge_index, x=x, edge_type_emb=edge_type_emb,
102-
edge_attr=edge_attr, size=size)
101+
edge_attr=edge_attr, size=None)
103102

104103
if self.concat:
105104
if self.root_weight:

0 commit comments

Comments
 (0)