From 404070217e7fdbe91de92060bf1137d4e3b18b5c Mon Sep 17 00:00:00 2001 From: Augustin Bussy Date: Fri, 22 Nov 2024 18:01:23 +0100 Subject: [PATCH] Speed up AtomicLocal forces (#1024) Co-authored-by: Michael F. Herbst --- src/terms/local.jl | 41 +++++++++++++++++++++----------------- test/external/atomsbase.jl | 16 +++++++-------- 2 files changed, 31 insertions(+), 26 deletions(-) diff --git a/src/terms/local.jl b/src/terms/local.jl index 67ac994c35..7a79b7d440 100644 --- a/src/terms/local.jl +++ b/src/terms/local.jl @@ -70,7 +70,21 @@ function (external::ExternalFromFourier)(basis::PlaneWaveBasis{T}) where {T} TermExternal(irfft(basis, pot_fourier)) end - +# Returns the form factors at unique values of |G + q| (in Cartesian coordinates). +# Uses a hash map for O(1) lookup. +function atomic_local_form_factors(basis::PlaneWaveBasis{T}, Gqs_cart::AbstractArray) where{T} + form_factors = IdDict{Tuple{Int,T},T}() # IdDict for Dual compatibility + for G in Gqs_cart + p = norm(G) + for (igroup, group) in enumerate(basis.model.atom_groups) + if !haskey(form_factors, (igroup, p)) + element = basis.model.atoms[first(group)] + form_factors[(igroup, p)] = local_potential_fourier(element, p) + end + end + end + form_factors +end ## Atomic local potential @@ -93,19 +107,7 @@ function compute_local_potential(basis::PlaneWaveBasis{T}; positions=basis.model # TODO Bring Gqs_cart on the CPU for compatibility with the pseudopotentials which # are not isbits ... might be able to solve this by restructuring the loop - # Pre-compute the form factors at unique values of |G| to speed up - # the potential Fourier transform (by a lot). Using a hash map gives O(1) - # lookup. - form_factors = IdDict{Tuple{Int,T},T}() # IdDict for Dual compatibility - for G in Gqs_cart - p = norm(G) - for (igroup, group) in enumerate(model.atom_groups) - if !haskey(form_factors, (igroup, p)) - element = model.atoms[first(group)] - form_factors[(igroup, p)] = local_potential_fourier(element, p) - end - end - end + form_factors = atomic_local_form_factors(basis, Gqs_cart) Gqs = [G + q for G in to_cpu(G_vectors(basis))] # TODO Again for GPU compatibility pot_fourier = map(enumerate(Gqs)) do (iG, G) @@ -138,17 +140,20 @@ end ρ_fourier = fft(basis, total_density(ρ)) real_ifSreal = S <: Real ? real : identity + # TODO: Right now, forces are not GPU compatible. Refer to compute_local_potential + # comments when working on this + Gqs_cart = [model.recip_lattice * (G + q) for G in G_vectors(basis)] + form_factors = atomic_local_form_factors(basis, Gqs_cart) + # energy = sum of form_factor(G) * struct_factor(G) * rho(G) # where struct_factor(G) = e^{-i G·r} forces = [zero(Vec3{S}) for _ = 1:length(model.positions)] - for group in model.atom_groups + for (igroup, group) in enumerate(model.atom_groups) element = model.atoms[first(group)] - form_factors = [complex(S)(local_potential_fourier(element, norm(recip_lattice * (G + q)))) - for G in G_vectors(basis)] for idx in group r = model.positions[idx] forces[idx] = -real_ifSreal(sum(conj(ρ_fourier[iG]) - * form_factors[iG] + * form_factors[(igroup, norm(Gqs_cart[iG]))] * cis2pi(-dot(G + q, r)) * (-2T(π)) * (G + q) * im / sqrt(model.unit_cell_volume) diff --git a/test/external/atomsbase.jl b/test/external/atomsbase.jl index eb1cb72ebd..570e37add2 100644 --- a/test/external/atomsbase.jl +++ b/test/external/atomsbase.jl @@ -25,8 +25,8 @@ @test system[:, :magnetic_moment] == magnetic_moments parsed = DFTK.parse_system(system) - @test parsed.lattice ≈ lattice atol=5e-13 - @test parsed.positions ≈ positions atol=5e-13 + @test parsed.lattice ≈ lattice atol=1e-12 + @test parsed.positions ≈ positions atol=1e-12 for i = 1:4 @test iszero(parsed.magnetic_moments[i][1:2]) @test parsed.magnetic_moments[i][3] == magnetic_moments[i] @@ -118,8 +118,8 @@ end system = periodic_system(atoms, lattice; fractional=true) let model = Model(system) - @test model.lattice ≈ pos_lattice atol=5e-13 - @test model.positions ≈ pos_units atol=5e-13 + @test model.lattice ≈ pos_lattice atol=1e-12 + @test model.positions ≈ pos_units atol=1e-12 @test model.spin_polarization == :none @test length(model.atoms) == 4 @@ -139,8 +139,8 @@ end @test system[4, :pseudopotential] == "hgh/pbe/c-q4.hgh" parsed = DFTK.parse_system(system) - @test parsed.lattice ≈ pos_lattice atol=5e-13 - @test parsed.positions ≈ pos_units atol=5e-13 + @test parsed.lattice ≈ pos_lattice atol=1e-12 + @test parsed.positions ≈ pos_units atol=1e-12 @test isempty(parsed.magnetic_moments) @test length(parsed.atoms) == 4 @@ -159,8 +159,8 @@ end @test system[4, :pseudopotential] == "hgh/lda/c-q4.hgh" model = Model(system) - @test model.lattice ≈ pos_lattice atol=5e-13 - @test model.positions ≈ pos_units atol=5e-13 + @test model.lattice ≈ pos_lattice atol=1e-12 + @test model.positions ≈ pos_units atol=1e-12 @test model.spin_polarization == :none @test length(model.atoms) == 4