@@ -14,15 +14,14 @@ def test_gatv2_conv():
14
14
assert conv .__repr__ () == 'GATv2Conv(8, 32, heads=2)'
15
15
out = conv (x1 , edge_index )
16
16
assert out .size () == (4 , 64 )
17
- assert torch .allclose (conv (x1 , edge_index , size = ( 4 , 4 ) ), out )
17
+ assert torch .allclose (conv (x1 , edge_index ), out )
18
18
assert torch .allclose (conv (x1 , adj .t ()), out , atol = 1e-6 )
19
19
20
- t = '(Tensor, Tensor, OptTensor, Size, NoneType) -> Tensor'
20
+ t = '(Tensor, Tensor, OptTensor, NoneType) -> Tensor'
21
21
jit = torch .jit .script (conv .jittable (t ))
22
22
assert torch .allclose (jit (x1 , edge_index ), out )
23
- assert torch .allclose (jit (x1 , edge_index , size = (4 , 4 )), out )
24
23
25
- t = '(Tensor, SparseTensor, OptTensor, Size, NoneType) -> Tensor'
24
+ t = '(Tensor, SparseTensor, OptTensor, NoneType) -> Tensor'
26
25
jit = torch .jit .script (conv .jittable (t ))
27
26
assert torch .allclose (jit (x1 , adj .t ()), out , atol = 1e-6 )
28
27
@@ -39,7 +38,7 @@ def test_gatv2_conv():
39
38
assert result [1 ].sizes () == [4 , 4 , 2 ] and result [1 ].nnz () == 7
40
39
assert conv ._alpha is None
41
40
42
- t = ('(Tensor, Tensor, OptTensor, Size, bool) -> '
41
+ t = ('(Tensor, Tensor, OptTensor, bool) -> '
43
42
'Tuple[Tensor, Tuple[Tensor, Tensor]]' )
44
43
jit = torch .jit .script (conv .jittable (t ))
45
44
result = jit (x1 , edge_index , return_attention_weights = True )
@@ -49,7 +48,7 @@ def test_gatv2_conv():
49
48
assert result [1 ][1 ].min () >= 0 and result [1 ][1 ].max () <= 1
50
49
assert conv ._alpha is None
51
50
52
- t = ('(Tensor, SparseTensor, OptTensor, Size, bool) -> '
51
+ t = ('(Tensor, SparseTensor, OptTensor, bool) -> '
53
52
'Tuple[Tensor, SparseTensor]' )
54
53
jit = torch .jit .script (conv .jittable (t ))
55
54
result = jit (x1 , adj .t (), return_attention_weights = True )
@@ -60,15 +59,14 @@ def test_gatv2_conv():
60
59
adj = adj .sparse_resize ((4 , 2 ))
61
60
out1 = conv ((x1 , x2 ), edge_index )
62
61
assert out1 .size () == (2 , 64 )
63
- assert torch .allclose (conv ((x1 , x2 ), edge_index , size = ( 4 , 2 ) ), out1 )
62
+ assert torch .allclose (conv ((x1 , x2 ), edge_index ), out1 )
64
63
assert torch .allclose (conv ((x1 , x2 ), adj .t ()), out1 , atol = 1e-6 )
65
64
66
- t = '(OptPairTensor, Tensor, OptTensor, Size, NoneType) -> Tensor'
65
+ t = '(OptPairTensor, Tensor, OptTensor, NoneType) -> Tensor'
67
66
jit = torch .jit .script (conv .jittable (t ))
68
67
assert torch .allclose (jit ((x1 , x2 ), edge_index ), out1 )
69
- assert torch .allclose (jit ((x1 , x2 ), edge_index , size = (4 , 2 )), out1 )
70
68
71
- t = '(OptPairTensor, SparseTensor, OptTensor, Size, NoneType) -> Tensor'
69
+ t = '(OptPairTensor, SparseTensor, OptTensor, NoneType) -> Tensor'
72
70
jit = torch .jit .script (conv .jittable (t ))
73
71
assert torch .allclose (jit ((x1 , x2 ), adj .t ()), out1 , atol = 1e-6 )
74
72
0 commit comments