|
| 1 | +""" |
| 2 | + AbstractPE |
| 3 | +
|
| 4 | +Abstract type of positional encoding for GNN. |
| 5 | +""" |
| 6 | +abstract type AbstractPE end |
| 7 | + |
| 8 | +positional_encode(l::AbstractPE, args...) = throw(ErrorException("positional_encode function for $l is not implemented.")) |
| 9 | + |
| 10 | +""" |
| 11 | + LSPE(mp) |
| 12 | +
|
| 13 | +Learnable structural positional encoding layer. |
| 14 | +
|
| 15 | +# Arguments |
| 16 | +
|
| 17 | +- `mp`: message-passing layer. |
| 18 | +""" |
| 19 | +struct LSPE{A<:MessagePassing,F,P} <: AbstractPE |
| 20 | + layer::A |
| 21 | + f_p::F |
| 22 | + pe::P |
| 23 | +end |
| 24 | + |
| 25 | +function LSPE(mp::MessagePassing, pe_dim::Int; init=glorot_uniform, init_pe=random_walk_pe) |
| 26 | + # f_p = Dense(; init=init) |
| 27 | + pe = init_pe(A, pe_dim) |
| 28 | + return LSPE(mp, f_p, pe) |
| 29 | +end |
| 30 | + |
| 31 | +positional_encode(::LSPE, p_i, p_j, e_ij) = p_j |
| 32 | + |
| 33 | +function message(l::LSPE, h_i, h_j, e_ij, p_i, p_j) |
| 34 | + x_i = isnothing(h_i) ? nothing : vcat(h_i, p_i) |
| 35 | + x_j = isnothing(h_j) ? nothing : vcat(h_j, p_j) |
| 36 | + return message(l.layer, x_i, x_j, e_ij) |
| 37 | +end |
| 38 | + |
| 39 | +function Base.show(io::IO, l::LSPE) |
| 40 | + print(io, "LSPE(", l.layer, ")") |
| 41 | +end |
| 42 | + |
| 43 | + |
| 44 | +""" |
| 45 | + random_walk_pe(A, k) |
| 46 | +
|
| 47 | +Returns positional encoding (PE) of size `(k, N)` where N is node number. |
| 48 | +PE is generated by `k`-step random walk over given graph. |
| 49 | +
|
| 50 | +# Arguments |
| 51 | +
|
| 52 | +- `A`: Adjacency matrix of a graph. |
| 53 | +- `k::Int`: First dimension of PE. |
| 54 | +""" |
| 55 | +function random_walk_pe(A::AbstractMatrix, k::Int) |
| 56 | + N = size(A, 1) |
| 57 | + @assert k ≤ N "k must less or equal to number of nodes" |
| 58 | + inv_D = GraphSignals.degree_matrix(A, Float32, inverse=true) |
| 59 | + |
| 60 | + RW = similar(A, size(A)..., k) |
| 61 | + RW[:, :, 1] .= A * inv_D |
| 62 | + for i in 2:k |
| 63 | + RW[:, :, i] .= RW[:, :, i-1] * RW[:, :, 1] |
| 64 | + end |
| 65 | + |
| 66 | + pe = similar(RW, k, N) |
| 67 | + for i in 1:N |
| 68 | + pe[:, i] .= RW[i, i, :] |
| 69 | + end |
| 70 | + |
| 71 | + return pe |
| 72 | +end |
| 73 | + |
| 74 | +""" |
| 75 | + laplacian_pe(A, k) |
| 76 | +
|
| 77 | +Returns positional encoding (PE) of size `(k, N)` where `N` is node number. |
| 78 | +PE is generated from eigenvectors of a graph Laplacian truncated by `k`. |
| 79 | +
|
| 80 | +# Arguments |
| 81 | +
|
| 82 | +- `A`: Adjacency matrix of a graph. |
| 83 | +- `k::Int`: First dimension of PE. |
| 84 | +""" |
| 85 | +function laplacian_pe(A::AbstractMatrix, k::Int) |
| 86 | + N = size(A, 1) |
| 87 | + @assert k ≤ N "k must less or equal to number of nodes" |
| 88 | + L = GraphSignals.normalized_laplacian(A) |
| 89 | + U = eigvecs(L) |
| 90 | + return U[1:k, :] |
| 91 | +end |
0 commit comments