Skip to content

Commit

Permalink
add LSPE
Browse files Browse the repository at this point in the history
  • Loading branch information
yuehhua committed Jun 21, 2022
1 parent 2d2f20d commit 4bc6507
Show file tree
Hide file tree
Showing 2 changed files with 84 additions and 0 deletions.
1 change: 1 addition & 0 deletions src/GeometricFlux.jl
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,7 @@ include("models.jl")

include("sampling.jl")
include("embedding/node2vec.jl")
include("layers/positional.jl")

using .Datasets

Expand Down
83 changes: 83 additions & 0 deletions src/layers/positional.jl
Original file line number Diff line number Diff line change
Expand Up @@ -78,3 +78,86 @@ output_dim(l::EEquivGraphPE) = size(l.nn.weight, 1)

positional_encode(wg::WithGraph{<:EEquivGraphPE}, args...) = wg(args...)
positional_encode(l::EEquivGraphPE, args...) = l(args...)

"""
LSPE(mp)
Learnable structural positional encoding layer.
# Arguments
- `mp`: message-passing layer.
"""
struct LSPE{A<:MessagePassing,F,P} <: AbstractPE
layer::A
f_p::F
pe::P
end

function LSPE(mp::MessagePassing, pe_dim::Int; init=glorot_uniform, init_pe=random_walk_pe)
# f_p = Dense(; init=init)
pe = init_pe(A, pe_dim)
return LSPE(mp, f_p, pe)
end

positional_encode(::LSPE, p_i, p_j, e_ij) = p_j

function message(l::LSPE, h_i, h_j, e_ij, p_i, p_j)
x_i = isnothing(h_i) ? nothing : vcat(h_i, p_i)
x_j = isnothing(h_j) ? nothing : vcat(h_j, p_j)
return message(l.layer, x_i, x_j, e_ij)
end

function Base.show(io::IO, l::LSPE)
print(io, "LSPE(", l.layer, ")")
end


"""
random_walk_pe(A, k)
Returns positional encoding (PE) of size `(k, N)` where N is node number.
PE is generated by `k`-step random walk over given graph.
# Arguments
- `A`: Adjacency matrix of a graph.
- `k::Int`: First dimension of PE.
"""
function random_walk_pe(A::AbstractMatrix, k::Int)
N = size(A, 1)
@assert k N "k must less or equal to number of nodes"
inv_D = GraphSignals.degree_matrix(A, Float32, inverse=true)

RW = similar(A, size(A)..., k)
RW[:, :, 1] .= A * inv_D
for i in 2:k
RW[:, :, i] .= RW[:, :, i-1] * RW[:, :, 1]
end

pe = similar(RW, k, N)
for i in 1:N
pe[:, i] .= RW[i, i, :]
end

return pe
end

"""
laplacian_pe(A, k)
Returns positional encoding (PE) of size `(k, N)` where `N` is node number.
PE is generated from eigenvectors of a graph Laplacian truncated by `k`.
# Arguments
- `A`: Adjacency matrix of a graph.
- `k::Int`: First dimension of PE.
"""
function laplacian_pe(A::AbstractMatrix, k::Int)
N = size(A, 1)
@assert k N "k must less or equal to number of nodes"
L = GraphSignals.normalized_laplacian(A)
U = eigvecs(L)
return U[1:k, :]
end

0 comments on commit 4bc6507

Please sign in to comment.