Skip to content

Commit b85f098

Browse files
Feature: TGCN should support non linear_activations (#596)
* Adding TGCN with parameters for custom activations functions in hidden and gate activations * Only uploading Gate activation and uploading Lux TGCN custom non linear functions * Adding gate_activation=relu for test GNNChain in TGCN and fixing documentation * try to fix doctest * Chaging f(A,x) activation non linear function of TGCN and the corresponding tests * leave sigmoid as before * leave docs of GNNLux as before * fix documentation and test * Solving recomendations
1 parent 58fcd7d commit b85f098

File tree

4 files changed

+72
-13
lines changed

4 files changed

+72
-13
lines changed

GNNLux/src/layers/temporalconv.jl

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -33,9 +33,11 @@ LuxCore.apply(m::GNNContainerLayer, g, x, ps, st) = m(g, x, ps, st)
3333
init_state::Function
3434
end
3535

36-
function TGCNCell(ch::Pair{Int, Int}; use_bias = true, init_weight = glorot_uniform, init_state = zeros32, init_bias = zeros32, add_self_loops = false, use_edge_weight = true)
36+
function TGCNCell(ch::Pair{Int, Int}; use_bias = true, init_weight = glorot_uniform,
37+
init_state = zeros32, init_bias = zeros32, add_self_loops = false,
38+
use_edge_weight = true, act = sigmoid)
3739
in_dims, out_dims = ch
38-
conv = GCNConv(ch, sigmoid; init_weight, init_bias, use_bias, add_self_loops, use_edge_weight)
40+
conv = GCNConv(ch, act; init_weight, init_bias, use_bias, add_self_loops, use_edge_weight)
3941
gru = Lux.GRUCell(out_dims => out_dims; use_bias, init_weight = (init_weight, init_weight, init_weight), init_bias = (init_bias, init_bias, init_bias), init_state = init_state)
4042
return TGCNCell(in_dims, out_dims, conv, gru, init_state)
4143
end
@@ -57,7 +59,7 @@ function Base.show(io::IO, tgcn::TGCNCell)
5759
end
5860

5961
"""
60-
TGCN(in => out; use_bias = true, init_weight = glorot_uniform, init_state = zeros32, init_bias = zeros32, add_self_loops = false, use_edge_weight = true)
62+
TGCN(in => out; use_bias = true, init_weight = glorot_uniform, init_state = zeros32, init_bias = zeros32, add_self_loops = false, use_edge_weight = true, act = sigmoid)
6163
6264
Temporal Graph Convolutional Network (T-GCN) recurrent layer from the paper [T-GCN: A Temporal Graph Convolutional Network for Traffic Prediction](https://arxiv.org/pdf/1811.05320.pdf).
6365
@@ -76,7 +78,7 @@ Performs a layer of GCNConv to model spatial dependencies, followed by a Gated R
7678
If `add_self_loops=true` the new weights will be set to 1.
7779
This option is ignored if the `edge_weight` is explicitly provided in the forward pass.
7880
Default `false`.
79-
81+
- `act`: Activation function used in the GCNConv layer. Default `sigmoid`.
8082
8183
8284
# Examples
@@ -91,9 +93,12 @@ rng = Random.default_rng()
9193
g = rand_graph(rng, 5, 10)
9294
x = rand(rng, Float32, 2, 5)
9395
94-
# create TGCN layer
96+
# create TGCN layer
9597
tgcn = TGCN(2 => 6)
9698
99+
# create TGCN layer with custom activation
100+
tgcn_relu = TGCN(2 => 6, act = relu)
101+
97102
# setup layer
98103
ps, st = LuxCore.setup(rng, tgcn)
99104

GNNLux/test/layers/temporalconv.jl

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,11 +10,25 @@
1010
tx = [x for _ in 1:5]
1111

1212
@testset "TGCN" begin
13+
# Test with default activation (sigmoid)
1314
l = TGCN(3=>3)
1415
ps = LuxCore.initialparameters(rng, l)
1516
st = LuxCore.initialstates(rng, l)
17+
y1, _ = l(g, x, ps, st)
1618
loss = (x, ps) -> sum(first(l(g, x, ps, st)))
1719
test_gradients(loss, x, ps; atol=1.0f-2, rtol=1.0f-2, skip_backends=[AutoReverseDiff(), AutoTracker(), AutoForwardDiff(), AutoEnzyme()])
20+
21+
# Test with custom activation (relu)
22+
l_relu = TGCN(3=>3, act = relu)
23+
ps_relu = LuxCore.initialparameters(rng, l_relu)
24+
st_relu = LuxCore.initialstates(rng, l_relu)
25+
y2, _ = l_relu(g, x, ps_relu, st_relu)
26+
27+
# Outputs should be different with different activation functions
28+
@test !isapprox(y1, y2, rtol=1.0f-2)
29+
30+
loss_relu = (x, ps) -> sum(first(l_relu(g, x, ps, st_relu)))
31+
test_gradients(loss_relu, x, ps_relu; atol=1.0f-2, rtol=1.0f-2, skip_backends=[AutoReverseDiff(), AutoTracker(), AutoForwardDiff(), AutoEnzyme()])
1832
end
1933

2034
@testset "A3TGCN" begin

GraphNeuralNetworks/src/layers/temporalconv.jl

Lines changed: 18 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -758,7 +758,7 @@ EvolveGCNO(args...; kws...) = GNNRecurrence(EvolveGCNOCell(args...; kws...))
758758

759759

760760
"""
761-
TGCNCell(in => out; kws...)
761+
TGCNCell(in => out, act = relu, kws...)
762762
763763
Recurrent graph convolutional cell from the paper
764764
[T-GCN: A Temporal Graph Convolutional
@@ -824,12 +824,14 @@ end
824824

825825
Flux.@layer :noexpand TGCNCell
826826

827-
function TGCNCell((in, out)::Pair{Int, Int}; kws...)
828-
conv_z = GNNChain(GCNConv(in => out, relu; kws...), GCNConv(out => out; kws...))
827+
function TGCNCell((in, out)::Pair{Int, Int};
828+
act = relu,
829+
kws...)
830+
conv_z = GNNChain(GCNConv(in => out, act; kws...), GCNConv(out => out; kws...))
829831
dense_z = Dense(2*out => out, sigmoid)
830-
conv_r = GNNChain(GCNConv(in => out, relu; kws...), GCNConv(out => out; kws...))
832+
conv_r = GNNChain(GCNConv(in => out, act; kws...), GCNConv(out => out; kws...))
831833
dense_r = Dense(2*out => out, sigmoid)
832-
conv_h = GNNChain(GCNConv(in => out, relu; kws...), GCNConv(out => out; kws...))
834+
conv_h = GNNChain(GCNConv(in => out, act; kws...), GCNConv(out => out; kws...))
833835
dense_h = Dense(2*out => out, tanh)
834836
return TGCNCell(in, out, conv_z, dense_z, conv_r, dense_r, conv_h, dense_h)
835837
end
@@ -868,6 +870,8 @@ See [`GNNRecurrence`](@ref) for more details.
868870
# Examples
869871
870872
```jldoctest
873+
julia> using Flux # Ensure activation functions are available
874+
871875
julia> num_nodes, num_edges = 5, 10;
872876
873877
julia> d_in, d_out = 2, 3;
@@ -876,9 +880,14 @@ julia> timesteps = 5;
876880
877881
julia> g = rand_graph(num_nodes, num_edges);
878882
879-
julia> x = rand(Float32, d_in, timesteps, num_nodes);
883+
julia> x = rand(Float32, d_in, timesteps, g.num_nodes);
884+
885+
julia> layer = TGCN(d_in => d_out) # Default activation (relu)
886+
GNNRecurrence(
887+
TGCNCell(2 => 3), # 126 parameters
888+
) # Total: 18 arrays, 126 parameters, 1.469 KiB.
880889
881-
julia> layer = TGCN(d_in => d_out)
890+
julia> layer_tanh = TGCN(d_in => d_out, act = tanh) # Custom activation
882891
GNNRecurrence(
883892
TGCNCell(2 => 3), # 126 parameters
884893
) # Total: 18 arrays, 126 parameters, 1.469 KiB.
@@ -889,5 +898,6 @@ julia> size(y) # (d_out, timesteps, num_nodes)
889898
(3, 5, 5)
890899
```
891900
"""
892-
TGCN(args...; kws...) = GNNRecurrence(TGCNCell(args...; kws...))
901+
TGCN(args...; kws...) =
902+
GNNRecurrence(TGCNCell(args...; kws...))
893903

GraphNeuralNetworks/test/layers/temporalconv.jl

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,8 @@ end
2525

2626
@testitem "TGCNCell" setup=[TemporalConvTestModule, TestModule] begin
2727
using .TemporalConvTestModule, .TestModule
28+
29+
# Test with default activation function
2830
cell = GraphNeuralNetworks.TGCNCell(in_channel => out_channel)
2931
y, h = cell(g, g.x)
3032
@test y === h
@@ -33,10 +35,25 @@ end
3335
test_gradients(cell, g, g.x, loss=cell_loss, rtol=RTOL_HIGH)
3436
# with initial state
3537
test_gradients(cell, g, g.x, h, loss=cell_loss, rtol=RTOL_HIGH)
38+
39+
# Test with custom activation function
40+
custom_activation = tanh
41+
cell_custom = GraphNeuralNetworks.TGCNCell(in_channel => out_channel, act = custom_activation)
42+
y_custom, h_custom = cell_custom(g, g.x)
43+
@test y_custom === h_custom
44+
@test size(h_custom) == (out_channel, g.num_nodes)
45+
# Test that outputs differ when using different activation functions
46+
@test !isapprox(y, y_custom, rtol=RTOL_HIGH)
47+
# with no initial state
48+
test_gradients(cell_custom, g, g.x, loss=cell_loss, rtol=RTOL_HIGH)
49+
# with initial state
50+
test_gradients(cell_custom, g, g.x, h_custom, loss=cell_loss, rtol=RTOL_HIGH)
3651
end
3752

3853
@testitem "TGCN" setup=[TemporalConvTestModule, TestModule] begin
3954
using .TemporalConvTestModule, .TestModule
55+
56+
# Test with default activation function
4057
layer = TGCN(in_channel => out_channel)
4158
x = rand(Float32, in_channel, timesteps, g.num_nodes)
4259
state0 = rand(Float32, out_channel, g.num_nodes)
@@ -48,6 +65,19 @@ end
4865
# with initial state
4966
test_gradients(layer, g, x, state0, rtol = RTOL_HIGH)
5067

68+
# Test with custom activation function
69+
custom_activation = tanh
70+
layer_custom = TGCN(in_channel => out_channel, act = custom_activation)
71+
y_custom = layer_custom(g, x)
72+
@test layer_custom isa GNNRecurrence
73+
@test size(y_custom) == (out_channel, timesteps, g.num_nodes)
74+
# Test that outputs differ when using different activation functions
75+
@test !isapprox(y, y_custom, rtol = RTOL_HIGH)
76+
# with no initial state
77+
test_gradients(layer_custom, g, x, rtol = RTOL_HIGH)
78+
# with initial state
79+
test_gradients(layer_custom, g, x, state0, rtol = RTOL_HIGH)
80+
5181
# interplay with GNNChain
5282
model = GNNChain(TGCN(in_channel => out_channel), Dense(out_channel, 1))
5383
y = model(g, x)

0 commit comments

Comments
 (0)