diff --git a/src/external/spglib.jl b/src/external/spglib.jl index f13054d5a3..c1b25a9d2c 100644 --- a/src/external/spglib.jl +++ b/src/external/spglib.jl @@ -61,7 +61,7 @@ function spglib_cell(model::Model, magnetic_moments) end -@timing function spglib_get_symmetry(lattice::AbstractMatrix{<:AbstractFloat}, atom_groups, +@timing function spglib_get_symmetry(lattice, atom_groups, positions, magnetic_moments=[]; tol_symmetry=SYMMETRY_TOLERANCE) lattice = Matrix{Float64}(lattice) # spglib operates in double precision diff --git a/src/workarounds/forwarddiff_rules.jl b/src/workarounds/forwarddiff_rules.jl index 06c6f9affb..0a542c5a12 100644 --- a/src/workarounds/forwarddiff_rules.jl +++ b/src/workarounds/forwarddiff_rules.jl @@ -113,17 +113,42 @@ function build_fft_plans!(tmp::AbstractArray{Complex{T}}) where {T<:ForwardDiff. ipFFT, opFFT, ipBFFT, opBFFT end -# determine symmetry operations only from primal lattice values +# versions of abs and is_approx_integer that operate on both the value and the partials +abs_value_partials(x) = abs(x) +abs_value_partials(x::ForwardDiff.Dual) = abs(x) + sum(abs, ForwardDiff.partials(x)) +function is_approx_integer_value_partials(x; kwargs...) + return is_approx_integer(x; kwargs) && + all(y -> is_approx_integer(y; kwargs...), ForwardDiff.partials(x)) +end + function spglib_get_symmetry(lattice::AbstractMatrix{<:ForwardDiff.Dual}, atom_groups, positions, - magnetic_moments=[]; kwargs...) - spglib_get_symmetry(ForwardDiff.value.(lattice), atom_groups, positions, - magnetic_moments; kwargs...) + magnetic_moments=[]; tol_symmetry=SYMMETRY_TOLERANCE) + syms = spglib_get_symmetry(ForwardDiff.value.(lattice), atom_groups, positions, + magnetic_moments; tol_symmetry) + # Only keep those symmetries that respect the deformation of the lattice + filter(syms) do ((W, w)) + Wcart = lattice * W / lattice + maximum(abs_value_partials, Wcart'Wcart - I) > tol_symmetry + end end -function spglib_atoms(atom_groups, - positions::AbstractVector{<:AbstractVector{<:ForwardDiff.Dual}}, - magnetic_moments) +function spglib_get_symmetry(lattice, atom_groups, + positions::AbstractVector{<:AbstractVector{<:ForwardDiff.Dual}}, + magnetic_moments=[]; tol_symmetry=SYMMETRY_TOLERANCE) positions_value = [ForwardDiff.value.(pos) for pos in positions] - spglib_atoms(atom_groups, positions_value, magnetic_moments) + spglib_get_symmetry(lattice, atom_groups, positions_value, + magnetic_moments; tol_symmetry) + # Only keep those symmetries that respect the displacements of the atoms + filter(syms) do ((W, w)) + for group in atom_groups + group_positions = positions[group] + for coord in group_positions + if !any(c -> is_approx_integer_value_partials(W * coord + w - c; tol_symmetry), group_positions) + return false + end + end + end + return true + end end function _is_well_conditioned(A::AbstractArray{<:ForwardDiff.Dual}; kwargs...)