Skip to content

Reduce RAM usage of nonlocal term #1088

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 11 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions src/DFTK.jl
Original file line number Diff line number Diff line change
Expand Up @@ -266,8 +266,8 @@ end
positions = [ones(3)/8, -ones(3)/8]
magnetic_moments = [2, -2]

@compile_workload begin
precompilation_workflow(lattice, atoms, positions, magnetic_moments)
end
# @compile_workload begin
# precompilation_workflow(lattice, atoms, positions, magnetic_moments)
# end
end
end # module DFTK
141 changes: 125 additions & 16 deletions src/terms/nonlocal.jl
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,119 @@
proj_coeffs
end

"""
Projector set of a single atom (independent of the atom's position),
and the structure factor for the atom position. Used inside NonlocalProjectors
such that the projector set can be reused for multiple atoms in the same atom group.
"""
struct AtomProjectors{VT <: AbstractVector, PT <: AbstractMatrix}
# nbasis
structure_factors::VT
# nbasis x nproj
projectors::PT
end

"""
Matrix-like type to represent the nonlocal projection vectors P without
allocating the full matrix.
This type extends AbstractMatrix, but it does not implement all
the required methods, only those that were shown to be needed.
In particular, random access to the matrix elements is not supported.
"""
struct NonlocalProjectors{T <: Real,
ST <: AbstractVector{Complex{T}},
PT <: AtomProjectors,
} <: AbstractMatrix{Complex{T}}
# TODO: this is a real problem wrt. thread-safety, no?
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How bad is this? DftHamiltonianBlock should handle it fine, but GenericHamiltonianBlock seems to be parallelizing over bands which will cause problems!

# nbasis
proj_scratch::ST
atoms::Vector{PT}
end
function NonlocalProjectors(atoms::Vector{<:AtomProjectors})
at = first(atoms)
T = promote_type(eltype(at.structure_factors), eltype(at.projectors))
proj_scratch = similar(at.structure_factors, T)
NonlocalProjectors(proj_scratch, atoms)
end

function Base.size(P::NonlocalProjectors)
n = length(P.proj_scratch)
m = sum(size(at.projectors, 2) for at in P.atoms)
(n, m)
end
function Base.Matrix(P::NonlocalProjectors{T}) where {T}
n, m = size(P)
out = zeros(Complex{T}, n, m)
iproj = 1
for at in P.atoms
for proj in eachcol(at.projectors)
out[:, iproj] .= at.structure_factors .* proj
iproj += 1
end
end
out

Check warning on line 197 in src/terms/nonlocal.jl

View check run for this annotation

Codecov / codecov/patch

src/terms/nonlocal.jl#L197

Added line #L197 was not covered by tests
end

function Base.show(io::IO, P::NonlocalProjectors)
print(io, "DFTK.NonlocalProjectors{")
show(io, P.atoms)
print(io, "}")

Check warning on line 203 in src/terms/nonlocal.jl

View check run for this annotation

Codecov / codecov/patch

src/terms/nonlocal.jl#L200-L203

Added lines #L200 - L203 were not covered by tests
end
function Base.show(io::IO, ::MIME"text/plain", P::NonlocalProjectors)
print(io, summary(P))

Check warning on line 206 in src/terms/nonlocal.jl

View check run for this annotation

Codecov / codecov/patch

src/terms/nonlocal.jl#L205-L206

Added lines #L205 - L206 were not covered by tests
end

# Add a level of indirection here to avoid ambiguity with the mul! method provided by Julia.
LinearAlgebra.mul!(C::AbstractVector, A::Adjoint{<:Any, <:NonlocalProjectors},
ψk::AbstractVector) = _mul!(C, A, ψk)
LinearAlgebra.mul!(C::AbstractMatrix, A::Adjoint{<:Any, <:NonlocalProjectors},
ψk::AbstractMatrix) = _mul!(C, A, ψk)

LinearAlgebra.mul!(C::AbstractVector, A::NonlocalProjectors, B::AbstractVector,
α::Number, β::Number) = _mul!(C, A, B, α, β)
LinearAlgebra.mul!(C::AbstractMatrix, A::NonlocalProjectors, B::AbstractMatrix,
α::Number, β::Number) = _mul!(C, A, B, α, β)

function _mul!(C::AbstractVecOrMat, A::Adjoint{<:Any, <:NonlocalProjectors},
ψk::AbstractVecOrMat)
if size(C, 1) != size(A, 1) || size(A, 2) != size(ψk, 1) || size(ψk, 2) != size(C, 2)
throw(DimensionMismatch(lazy"A has size $(size(A)), B has size $(size(ψk)), C has size $(size(C))"))

Check warning on line 223 in src/terms/nonlocal.jl

View check run for this annotation

Codecov / codecov/patch

src/terms/nonlocal.jl#L223

Added line #L223 was not covered by tests
end

iproj = 1
proj_scratch = A.parent.proj_scratch
for at in A.parent.atoms
for proj in eachcol(at.projectors)
proj_scratch .= at.structure_factors .* proj
@views mul!(C[iproj:iproj, :], proj_scratch', ψk)
iproj += 1
end
end
C

Check warning on line 235 in src/terms/nonlocal.jl

View check run for this annotation

Codecov / codecov/patch

src/terms/nonlocal.jl#L235

Added line #L235 was not covered by tests
end

function _mul!(C::AbstractArray, A::NonlocalProjectors, B::AbstractArray,
α::Number, β::Number)
if size(C, 1) != size(A, 1) || size(A, 2) != size(B, 1) || size(B, 2) != size(C, 2)
throw(DimensionMismatch(lazy"A has size $(size(A)), B has size $(size(B)), C has size $(size(C))"))

Check warning on line 241 in src/terms/nonlocal.jl

View check run for this annotation

Codecov / codecov/patch

src/terms/nonlocal.jl#L241

Added line #L241 was not covered by tests
end

C .*= β

iproj = 1
proj_scratch = A.proj_scratch
for at in A.atoms
for proj in eachcol(at.projectors)
# TODO: does this use BLAS?
proj_scratch .= at.structure_factors .* proj
for iband in axes(B, 2)
@views C[:, iband] .+= proj_scratch .* (α * B[iproj, iband])
end
iproj += 1
end
end
C

Check warning on line 258 in src/terms/nonlocal.jl

View check run for this annotation

Codecov / codecov/patch

src/terms/nonlocal.jl#L258

Added line #L258 was not covered by tests
end

@doc raw"""
Build projection vectors for a atoms array generated by term_nonlocal
Expand Down Expand Up @@ -171,36 +284,31 @@
psps::AbstractVector{<: NormConservingPsp},
psp_positions) where {T}
unit_cell_volume = basis.model.unit_cell_volume
n_proj = count_n_proj(psps, psp_positions)
n_G = length(G_vectors(basis, kpt))
proj_vectors = zeros(Complex{eltype(psp_positions[1][1])}, n_G, n_proj)
G_plus_k = to_cpu(Gplusk_vectors(basis, kpt))

# Compute the columns of proj_vectors = 1/√Ω \hat proj_i(k+G)
# Since the proj_i are translates of each others, \hat proj_i(k+G) decouples as
# \hat proj_i(p) = ∫ proj(r-R) e^{-ip·r} dr = e^{-ip·R} \hat proj(p).
# The first term is the structure factor, the second the form factor.
offset = 0 # offset into proj_vectors
for (psp, positions) in zip(psps, psp_positions)
atom_projectors = reduce(vcat, map(zip(psps, psp_positions)) do (psp, positions)
# Compute position-independent form factors
G_plus_k_cart = to_cpu(Gplusk_vectors_cart(basis, kpt))
form_factors = build_projector_form_factors(psp, G_plus_k_cart)
psp_form_factors = build_projector_form_factors(psp, G_plus_k_cart)
psp_form_factors ./= sqrt(unit_cell_volume)
# Offload potential values to a device (like a GPU),
# and make sure to share this allocation for all atoms in the group
psp_form_factors = to_device(basis.architecture, psp_form_factors)

# Combine with structure factors
for r in positions
map(positions) do r
# k+G in this formula can also be G, this only changes an unimportant phase factor
structure_factors = map(p -> cis2pi(-dot(p, r)), G_plus_k)
@views for iproj = 1:count_n_proj(psp)
proj_vectors[:, offset+iproj] .=
structure_factors .* form_factors[:, iproj] ./ sqrt(unit_cell_volume)
end
offset += count_n_proj(psp)
structure_factors = to_device(basis.architecture, map(p -> cis2pi(-dot(p, r)), G_plus_k))
AtomProjectors(structure_factors, psp_form_factors)
end
end
@assert offset == n_proj
end)

# Offload potential values to a device (like a GPU)
to_device(basis.architecture, proj_vectors)
NonlocalProjectors(atom_projectors)
end

"""
Expand Down Expand Up @@ -282,6 +390,7 @@
D = build_projection_coefficients(basis, psp_groups)
P = build_projection_vectors(basis, kpt, psp_groups, positions)
P_minus_q = build_projection_vectors(basis, kpt_minus_q, psp_groups, positions)
# TODO: probably needs an extra parenthesis to first compute P'ψ
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So... I noticed that Julia has custom * overloads for matrix multiplication with more than 2 operands to select the chain of operations that will minimize the total cost. Presumably it will almost always compute P_minus_q' * ψk first. But if it doesn't we are in trouble, so this should probably be changed to (P_minus_q' * ψk).

P * (D * P_minus_q' * ψk)
end

Expand Down
2 changes: 1 addition & 1 deletion src/terms/operators.jl
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ end
function apply!(Hψ, op::NonlocalOperator, ψ)
mul!(Hψ.fourier, op.P, (op.D * (op.P' * ψ.fourier)), 1, 1)
end
Matrix(op::NonlocalOperator) = op.P * op.D * op.P'
Matrix(op::NonlocalOperator) = op.P * op.D * Matrix(op.P)'

"""
Magnetic field operator A⋅(-i∇).
Expand Down
39 changes: 39 additions & 0 deletions test/PspUpf.jl
Original file line number Diff line number Diff line change
Expand Up @@ -250,3 +250,42 @@ end
end
end
end

@testitem "Test nonlocal term operations" tags=[:psp] setup=[mPspUpf] begin
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure about this test. The goal was just for me to be able to try my changes with <5s waiting time.

using DFTK
using LinearAlgebra

lattice = 5 * I(3)
positions = [zeros(3), 1/3 .* ones(3), 2/3 .* ones(3)]
for (element, psp) in mPspUpf.upf_pseudos
if sum(psp.r2_ρion) > 0 # Otherwise, it's all 0 in the UPF as a placeholder
el = ElementPsp(element, psp)
atoms = [el, el, el]
model = model_DFT(lattice, atoms, positions; functionals=LDA())
basis = PlaneWaveBasis(model; Ecut=5, kgrid=[2, 2, 2])
n_bands = 7
ψ = [DFTK.random_orbitals(basis, kpt, n_bands) for kpt in basis.kpoints]
occ = [2.0 * ones(n_bands) for _ in basis.kpoints]
ρ = DFTK.compute_density(basis, ψ, occ)

energies, ham = DFTK.energy_hamiltonian(basis, ψ, occ; ρ)
hamψ = ham * ψ

hblock = ham.blocks[1]
nonloc = hblock.nonlocal_op
nonlocal_dense = Matrix(nonloc)
ψk = ψ[1]

Hψk = zero(ψk)
DFTK.apply!((;fourier=Hψk), nonloc, (; fourier=ψk))
Hψk_dense = nonlocal_dense * ψk

Pψk = nonloc.P' * ψk
DPψk = nonloc.D * Pψk
@show norm(nonloc.P * DPψk) norm(Matrix(nonloc.P) * DPψk)
@assert @show(norm(nonloc.P * DPψk - Matrix(nonloc.P) * DPψk)) < 1e-10

@assert @show(norm(Hψk - Hψk_dense)) < 1e-10
end
end
end
Loading