Skip to content

Commit

Permalink
add LSPE
Browse files Browse the repository at this point in the history
  • Loading branch information
yuehhua committed Jul 9, 2022
1 parent 770fd14 commit 3292a49
Show file tree
Hide file tree
Showing 3 changed files with 128 additions and 0 deletions.
1 change: 1 addition & 0 deletions src/GeometricFlux.jl
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,7 @@ include("models.jl")

include("sampling.jl")
include("embedding/node2vec.jl")
include("layers/positional.jl")

using .Datasets

Expand Down
123 changes: 123 additions & 0 deletions src/layers/positional.jl
Original file line number Diff line number Diff line change
Expand Up @@ -78,3 +78,126 @@ output_dim(l::EEquivGraphPE) = size(l.nn.weight, 1)

positional_encode(wg::WithGraph{<:EEquivGraphPE}, args...) = wg(args...)
positional_encode(l::EEquivGraphPE, args...) = l(args...)

"""
LSPE(f_h, f_e, f_p, pe_dim; init=glorot_uniform,
init_pe=random_walk_pe)
Learnable structural positional encoding layer.
# Arguments
- `f_h::MessagePassing`: Neural network layer for node update.
- `f_e`: Neural network layer for edge update.
- `f_p`: Neural network layer for positional encoding.
- `pe_dim::Int`: Dimension of positional encoding.
- `init`: Initializer for layer weights.
- `init_pe`: Initializer for positional encoding.
"""
struct LSPE{H<:MessagePassing,E,F,P} <: AbstractPE
f_h::H
f_e::E
f_p::F
pe::P
end

function LSPE(f_h::MessagePassing, f_e, f_p, pe_dim::Int; init=glorot_uniform, init_pe=random_walk_pe)
pe = init_pe(A, pe_dim)
return LSPE(f_h, f_e, f_p, pe)
end

# For variable graph
function (l::LSPE)(fg::AbstractFeaturedGraph)
X = node_feature(fg)
E = edge_feature(fg)
GraphSignals.check_num_nodes(fg, X)
GraphSignals.check_num_edges(fg, E)
E, V = propagate(l, graph(fg), E, X)
return ConcreteFeaturedGraph(fg, nf=V, ef=E)
end

# For static graph
function (l::LSPE)(el::NamedTuple, X::AbstractArray, E::AbstractArray)
GraphSignals.check_num_nodes(el.N, X)
GraphSignals.check_num_edges(el.E, E)
E, V = propagate(l, graph(fg), E, X)
return V, E
end

update_vertex(l::LSPE, el::NamedTuple, X, E::AbstractArray) = l.f_h(el, X, E)
update_vertex(l::LSPE, el::NamedTuple, X, E::Nothing) = l.f_h(el, X)

update_edge(l::LSPE, h_i, h_j, e_ij) = l.f_e(e_ij)

positional_encode(l::LSPE, p_i, p_j, e_ij) = l.f_p(p_i)

propagate(l::LSPE, sg::SparseGraph, E, V) = propagate(l, to_namedtuple(sg), E, V)

function propagate(l::LSPE, el::NamedTuple, E, V)
e_ij = _gather(E, el.es)
h_i = _gather(V, el.xs)
h_j = _gather(V, el.nbrs)
p_i = _gather(l.pe, el.xs)
p_j = _gather(l.pe, el.nbrs)

V = update_vertex(l, el, vcat(V, l.pe), E)
E = update_edge(l, h_i, h_j, e_ij)
l.pe = positional_encode(l, p_i, p_j, e_ij)
return E, V
end

function Base.show(io::IO, l::LSPE)
print(io, "LSPE(node_layer=", l.f_h)
print(io, ", edge_layer=", l.f_e)
print(io, ", positional_encode=", l.f_p, ")")
end


"""
random_walk_pe(A, k)
Returns positional encoding (PE) of size `(k, N)` where N is node number.
PE is generated by `k`-step random walk over given graph.
# Arguments
- `A`: Adjacency matrix of a graph.
- `k::Int`: First dimension of PE.
"""
function random_walk_pe(A::AbstractMatrix, k::Int)
N = size(A, 1)
@assert k N "k must less or equal to number of nodes"
inv_D = GraphSignals.degree_matrix(A, Float32, inverse=true)

RW = similar(A, size(A)..., k)
RW[:, :, 1] .= A * inv_D
for i in 2:k
RW[:, :, i] .= RW[:, :, i-1] * RW[:, :, 1]
end

pe = similar(RW, k, N)
for i in 1:N
pe[:, i] .= RW[i, i, :]
end

return pe
end

"""
laplacian_pe(A, k)
Returns positional encoding (PE) of size `(k, N)` where `N` is node number.
PE is generated from eigenvectors of a graph Laplacian truncated by `k`.
# Arguments
- `A`: Adjacency matrix of a graph.
- `k::Int`: First dimension of PE.
"""
function laplacian_pe(A::AbstractMatrix, k::Int)
N = size(A, 1)
@assert k N "k must less or equal to number of nodes"
L = GraphSignals.normalized_laplacian(A)
U = eigvecs(L)
return U[1:k, :]
end
4 changes: 4 additions & 0 deletions test/layers/positional.jl
Original file line number Diff line number Diff line change
Expand Up @@ -38,4 +38,8 @@
@test length(g.grads) == 2
end
end

@testset "LSPE" begin

end
end

0 comments on commit 3292a49

Please sign in to comment.