Skip to content

Commit

Permalink
add LSPE
Browse files Browse the repository at this point in the history
  • Loading branch information
yuehhua committed Jul 10, 2022
1 parent dc67832 commit f813b78
Show file tree
Hide file tree
Showing 5 changed files with 199 additions and 6 deletions.
9 changes: 9 additions & 0 deletions docs/bibliography.bib
Original file line number Diff line number Diff line change
Expand Up @@ -197,3 +197,12 @@ @inproceedings{Satorras2021
url = {http://arxiv.org/abs/2102.09844},
year = {2021},
}

@article{Dwivedi2021,
abstract = {Graph neural networks (GNNs) have become the standard learning architectures for graphs. GNNs have been applied to numerous domains ranging from quantum chemistry, recommender systems to knowledge graphs and natural language processing. A major issue with arbitrary graphs is the absence of canonical positional information of nodes, which decreases the representation power of GNNs to distinguish e.g. isomorphic nodes and other graph symmetries. An approach to tackle this issue is to introduce Positional Encoding (PE) of nodes, and inject it into the input layer, like in Transformers. Possible graph PE are Laplacian eigenvectors. In this work, we propose to decouple structural and positional representations to make easy for the network to learn these two essential properties. We introduce a novel generic architecture which we call LSPE (Learnable Structural and Positional Encodings). We investigate several sparse and fully-connected (Transformer-like) GNNs, and observe a performance increase for molecular datasets, from 2.87% up to 64.14% when considering learnable PE for both GNN classes.},
author = {Vijay Prakash Dwivedi and Anh Tuan Luu and Thomas Laurent and Yoshua Bengio and Xavier Bresson},
month = {10},
title = {Graph Neural Networks with Learnable Structural and Positional Representations},
url = {http://arxiv.org/abs/2110.07875},
year = {2021},
}
28 changes: 26 additions & 2 deletions docs/src/manual/positional.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,17 @@
# Positional Encoding Layers
# Positional Encoding

## ``E(n)``-equivariant Positional Encoding Layer
## Positional Encoding Methods

```@docs
AbstractPositionalEncoding
RandomWalkPE
LaplacianPE
positional_encode
```

## Positional Encoding Layers

### ``E(n)``-equivariant Positional Encoding Layer

It employs message-passing scheme and can be defined by following functions:

Expand All @@ -17,3 +28,16 @@ EEquivGraphPE
Reference: [Satorras2021](@cite)

---

### Learnable Structural Positional Encoding layer

(WIP)

```@docs
LSPE
```

Reference: [Dwivedi2021](@cite)

---

5 changes: 4 additions & 1 deletion src/GeometricFlux.jl
Original file line number Diff line number Diff line change
Expand Up @@ -32,9 +32,12 @@ export
MessagePassing,

# layers/positional
AbstractPE,
AbstractPositionalEncoding,
RandomWalkPE,
LaplacianPE,
positional_encode,
EEquivGraphPE,
LSPE,

# layers/graph_conv
GCNConv,
Expand Down
150 changes: 147 additions & 3 deletions src/layers/positional.jl
Original file line number Diff line number Diff line change
@@ -1,11 +1,81 @@
"""
AbstractPE
AbstractPositionalEncoding
Abstract type of positional encoding for GNN.
"""
abstract type AbstractPE end
abstract type AbstractPositionalEncoding end

"""
RandomWalkPE{K}
Concrete type of positional encoding from random walk method.
See also [`positional_encode`](@ref) for generating positional encoding.
"""
struct RandomWalkPE{K} <: AbstractPositionalEncoding end

"""
LaplacianPE{K}
Concrete type of positional encoding from graph Laplacian method.
See also [`positional_encode`](@ref) for generating positional encoding.
"""
struct LaplacianPE{K} <: AbstractPositionalEncoding end

"""
positional_encode(RandomWalkPE{K}, A)
Returns positional encoding (PE) of size `(K, N)` where N is node number.
PE is generated by `K`-step random walk over given graph.
# Arguments
- `K::Int`: First dimension of PE.
- `A`: Adjacency matrix of a graph.
See also [`RandomWalkPE`](@ref) for random walk method.
"""
function positional_encode(::Type{RandomWalkPE{K}}, A::AbstractMatrix) where {K}
N = size(A, 1)
@assert K N "K=$K must less or equal to number of nodes ($N)"
inv_D = GraphSignals.degree_matrix(A, Float32, inverse=true)

RW = similar(A, size(A)..., K)
RW[:, :, 1] .= A * inv_D
for i in 2:K
RW[:, :, i] .= RW[:, :, i-1] * RW[:, :, 1]
end

pe = similar(RW, K, N)
for i in 1:N
pe[:, i] .= RW[i, i, :]
end

return pe
end

"""
positional_encode(LaplacianPE{K}, A)
Returns positional encoding (PE) of size `(K, N)` where `N` is node number.
PE is generated from eigenvectors of a graph Laplacian truncated by `K`.
# Arguments
- `K::Int`: First dimension of PE.
- `A`: Adjacency matrix of a graph.
See also [`LaplacianPE`](@ref) for graph Laplacian method.
"""
function positional_encode(::Type{LaplacianPE{K}}, A::AbstractMatrix) where {K}
N = size(A, 1)
@assert K N "K=$K must less or equal to number of nodes ($N)"
L = GraphSignals.normalized_laplacian(A)
U = eigvecs(L)
return U[1:K, :]
end

positional_encode(l::AbstractPE, args...) = throw(ErrorException("positional_encode function for $l is not implemented."))

"""
EEquivGraphPE(in_dim=>out_dim; init=glorot_uniform, bias=true)
Expand Down Expand Up @@ -78,3 +148,77 @@ output_dim(l::EEquivGraphPE) = size(l.nn.weight, 1)

positional_encode(wg::WithGraph{<:EEquivGraphPE}, args...) = wg(args...)
positional_encode(l::EEquivGraphPE, args...) = l(args...)

"""
LSPE(fg, f_h, f_e, f_p, k; pe_method=RandomWalkPE)
Learnable structural positional encoding layer. `LSPE` layer can be seen as a GNN layer
warpped in `WithGraph`.
# Arguments
- `fg::FeaturedGraph`: A given graoh for positional encoding.
- `f_h::MessagePassing`: Neural network layer for node update.
- `f_e`: Neural network layer for edge update.
- `f_p`: Neural network layer for positional encoding.
- `k::Int`: Dimension of positional encoding.
- `pe_method`: Initializer for positional encoding.
"""
struct LSPE{H<:MessagePassing,E,F,P} <: AbstractPositionalEncoding
f_h::H
f_e::E
f_p::F
pe::P
end

function LSPE(fg::AbstractFeaturedGraph, f_h::MessagePassing, f_e, f_p, k::Int;
pe_method=RandomWalkPE)
A = GraphSignals.adjacency_matrix(fg)
return LSPE(f_h, f_e, f_p, positional_encode(pe_method{k}, A))
end

# For variable graph
function (l::LSPE)(fg::AbstractFeaturedGraph)
X = node_feature(fg)
E = edge_feature(fg)
GraphSignals.check_num_nodes(fg, X)
GraphSignals.check_num_edges(fg, E)
E, V = propagate(l, graph(fg), E, X)
return ConcreteFeaturedGraph(fg, nf=V, ef=E)
end

# For static graph
function (l::LSPE)(el::NamedTuple, X::AbstractArray, E::AbstractArray)
GraphSignals.check_num_nodes(el.N, X)
GraphSignals.check_num_edges(el.E, E)
E, V = propagate(l, graph(fg), E, X)
return V, E
end

update_vertex(l::LSPE, el::NamedTuple, X, E::AbstractArray) = l.f_h(el, X, E)
update_vertex(l::LSPE, el::NamedTuple, X, E::Nothing) = l.f_h(el, X)

update_edge(l::LSPE, h_i, h_j, e_ij) = l.f_e(e_ij)

positional_encode(l::LSPE, p_i, p_j, e_ij) = l.f_p(p_i)

propagate(l::LSPE, sg::SparseGraph, E, V) = propagate(l, to_namedtuple(sg), E, V)

function propagate(l::LSPE, el::NamedTuple, E, V)
e_ij = _gather(E, el.es)
h_i = _gather(V, el.xs)
h_j = _gather(V, el.nbrs)
p_i = _gather(l.pe, el.xs)
p_j = _gather(l.pe, el.nbrs)

V = update_vertex(l, el, vcat(V, l.pe), E)
E = update_edge(l, h_i, h_j, e_ij)
l.pe = positional_encode(l, p_i, p_j, e_ij)
return E, V
end

function Base.show(io::IO, l::LSPE)
print(io, "LSPE(node_layer=", l.f_h)
print(io, ", edge_layer=", l.f_e)
print(io, ", positional_encode=", l.f_p, ")")
end
13 changes: 13 additions & 0 deletions test/layers/positional.jl
Original file line number Diff line number Diff line change
Expand Up @@ -38,4 +38,17 @@
@test length(g.grads) == 2
end
end

@testset "LSPE" begin
K = 3
f_h = GraphConv(in_channel=>out_channel)
f_e = Dense(in_channel, out_channel)
f_p = Dense(in_channel, out_channel)
l = LSPE(fg, f_h, f_e, f_p, K)

# nf = rand(T, out_channel, N)
# fg = FeaturedGraph(adj, nf=nf)
# fg_ = l(fg)
# @test size(node_feature(fg_)) == (out_channel, N)
end
end

0 comments on commit f813b78

Please sign in to comment.