Skip to content

Commit 58fcd7d

Browse files
fix corner case in normalize_graphdata (#589)
1 parent cb2ad5f commit 58fcd7d

File tree

4 files changed

+27
-14
lines changed

4 files changed

+27
-14
lines changed

GNNGraphs/src/gnngraph.jl

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -146,11 +146,9 @@ function GNNGraph(data::D;
146146
edata = normalize_graphdata(edata, default_name = :e, n = num_edges,
147147
duplicate_if_needed = true)
148148

149-
# don't force the shape of the data when there is only one graph
150-
gdata = normalize_graphdata(gdata, default_name = :u,
151-
n = num_graphs > 1 ? num_graphs : -1)
149+
gdata = normalize_graphdata(gdata, default_name = :u, n = num_graphs, glob=true)
152150

153-
GNNGraph(graph,
151+
return GNNGraph(graph,
154152
num_nodes, num_edges, num_graphs,
155153
graph_indicator,
156154
ndata, edata, gdata)
@@ -203,7 +201,7 @@ function GNNGraph(g::GNNGraph; ndata = g.ndata, edata = g.edata, gdata = g.gdata
203201
ndata = normalize_graphdata(ndata, default_name = :x, n = g.num_nodes)
204202
edata = normalize_graphdata(edata, default_name = :e, n = g.num_edges,
205203
duplicate_if_needed = true)
206-
gdata = normalize_graphdata(gdata, default_name = :u, n = g.num_graphs)
204+
gdata = normalize_graphdata(gdata, default_name = :u, n = g.num_graphs, glob=true)
207205

208206
if !isnothing(graph_type)
209207
if graph_type == :coo

GNNGraphs/src/gnnheterograph/gnnheterograph.jl

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -144,8 +144,7 @@ function GNNHeteroGraph(data::EDict;
144144
ndata = normalize_heterographdata(ndata, default_name = :x, ns = num_nodes)
145145
edata = normalize_heterographdata(edata, default_name = :e, ns = num_edges,
146146
duplicate_if_needed = true)
147-
gdata = normalize_graphdata(gdata, default_name = :u,
148-
n = num_graphs > 1 ? num_graphs : -1)
147+
gdata = normalize_graphdata(gdata, default_name = :u, n = num_graphs, glob = true)
149148
end
150149

151150
return GNNHeteroGraph(graph,

GNNGraphs/src/utils.jl

Lines changed: 12 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -129,19 +129,24 @@ function normalize_graphdata(data; default_name::Symbol, kws...)
129129
normalize_graphdata(NamedTuple{(default_name,)}((data,)); default_name, kws...)
130130
end
131131

132-
function normalize_graphdata(data::NamedTuple; default_name, n, duplicate_if_needed = false)
132+
function normalize_graphdata(data::NamedTuple; default_name::Symbol, n::Int,
133+
duplicate_if_needed::Bool = false, glob::Bool = false)
133134
# This had to workaround two Zygote bugs with NamedTuples
134135
# https://github.com/FluxML/Zygote.jl/issues/1071
135-
# https://github.com/FluxML/Zygote.jl/issues/1072
136+
# https://github.com/FluxML/Zygote.jl/issues/1072 # TODO fixed. Can we simplify something?
137+
136138

137139
if n > 1
138140
@assert all(x -> x isa AbstractArray, data) "Non-array features provided."
139141
end
140142

141-
if n <= 1
142-
# If last array dimension is not 1, add a new dimension.
143-
# This is mostly useful to reshape global feature vectors
144-
# of size D to Dx1 matrices.
143+
if n <= 1 && glob == true
144+
@assert n == 1
145+
n = -1 # relax the case of a single graph, allowing to store arbitrary types
146+
# # # If last array dimension is not 1, add a new dimension.
147+
# # # This is mostly useful to reshape global feature vectors
148+
# # # of size D to Dx1 matrices.
149+
# TODO remove this and handle better the batching of global features
145150
unsqz_last(v::AbstractArray) = size(v)[end] != 1 ? reshape(v, size(v)..., 1) : v
146151
unsqz_last(v) = v
147152

@@ -161,7 +166,7 @@ function normalize_graphdata(data::NamedTuple; default_name, n, duplicate_if_nee
161166

162167
for x in data
163168
if x isa AbstractArray
164-
@assert size(x)[end]==n "Wrong size in last dimension for feature array, expected $n but got $(size(x)[end])."
169+
@assert size(x)[end] == n "Wrong size in last dimension for feature array, expected $n but got $(size(x)[end])."
165170
end
166171
end
167172
end

GNNGraphs/test/gnngraph.jl

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -228,6 +228,17 @@ end
228228
@test g.num_nodes == 1
229229
@test g.num_edges == 0
230230
@test g.ndata.a == [1]
231+
232+
g = GNNGraph((Int[], Int[]); ndata=(;a=[1]), edata=(;b=Int[]), num_nodes=1)
233+
@test g.num_nodes == 1
234+
@test g.num_edges == 0
235+
@test g.ndata.a == [1]
236+
@test g.edata.b == Int[]
237+
238+
g = GNNGraph(; edata=(;b=Int[]))
239+
@test g.num_nodes == 0
240+
@test g.num_edges == 0
241+
@test g.edata.b == Int[]
231242
end
232243

233244

0 commit comments

Comments
 (0)