Skip to content

Commit

Permalink
Speed up AtomicLocal forces (#1024)
Browse files Browse the repository at this point in the history
Co-authored-by: Michael F. Herbst <info@michael-herbst.com>
  • Loading branch information
abussy and mfherbst authored Nov 22, 2024
1 parent 07f6de2 commit 4040702
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 26 deletions.
41 changes: 23 additions & 18 deletions src/terms/local.jl
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,21 @@ function (external::ExternalFromFourier)(basis::PlaneWaveBasis{T}) where {T}
TermExternal(irfft(basis, pot_fourier))
end


# Returns the form factors at unique values of |G + q| (in Cartesian coordinates).
# Uses a hash map for O(1) lookup.
function atomic_local_form_factors(basis::PlaneWaveBasis{T}, Gqs_cart::AbstractArray) where{T}
form_factors = IdDict{Tuple{Int,T},T}() # IdDict for Dual compatibility
for G in Gqs_cart
p = norm(G)
for (igroup, group) in enumerate(basis.model.atom_groups)
if !haskey(form_factors, (igroup, p))
element = basis.model.atoms[first(group)]
form_factors[(igroup, p)] = local_potential_fourier(element, p)
end
end
end
form_factors
end

## Atomic local potential

Expand All @@ -93,19 +107,7 @@ function compute_local_potential(basis::PlaneWaveBasis{T}; positions=basis.model
# TODO Bring Gqs_cart on the CPU for compatibility with the pseudopotentials which
# are not isbits ... might be able to solve this by restructuring the loop

# Pre-compute the form factors at unique values of |G| to speed up
# the potential Fourier transform (by a lot). Using a hash map gives O(1)
# lookup.
form_factors = IdDict{Tuple{Int,T},T}() # IdDict for Dual compatibility
for G in Gqs_cart
p = norm(G)
for (igroup, group) in enumerate(model.atom_groups)
if !haskey(form_factors, (igroup, p))
element = model.atoms[first(group)]
form_factors[(igroup, p)] = local_potential_fourier(element, p)
end
end
end
form_factors = atomic_local_form_factors(basis, Gqs_cart)

Gqs = [G + q for G in to_cpu(G_vectors(basis))] # TODO Again for GPU compatibility
pot_fourier = map(enumerate(Gqs)) do (iG, G)
Expand Down Expand Up @@ -138,17 +140,20 @@ end
ρ_fourier = fft(basis, total_density(ρ))
real_ifSreal = S <: Real ? real : identity

# TODO: Right now, forces are not GPU compatible. Refer to compute_local_potential
# comments when working on this
Gqs_cart = [model.recip_lattice * (G + q) for G in G_vectors(basis)]
form_factors = atomic_local_form_factors(basis, Gqs_cart)

# energy = sum of form_factor(G) * struct_factor(G) * rho(G)
# where struct_factor(G) = e^{-i G·r}
forces = [zero(Vec3{S}) for _ = 1:length(model.positions)]
for group in model.atom_groups
for (igroup, group) in enumerate(model.atom_groups)
element = model.atoms[first(group)]
form_factors = [complex(S)(local_potential_fourier(element, norm(recip_lattice * (G + q))))
for G in G_vectors(basis)]
for idx in group
r = model.positions[idx]
forces[idx] = -real_ifSreal(sum(conj(ρ_fourier[iG])
* form_factors[iG]
* form_factors[(igroup, norm(Gqs_cart[iG]))]
* cis2pi(-dot(G + q, r))
* (-2T(π)) * (G + q) * im
/ sqrt(model.unit_cell_volume)
Expand Down
16 changes: 8 additions & 8 deletions test/external/atomsbase.jl
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,8 @@
@test system[:, :magnetic_moment] == magnetic_moments

parsed = DFTK.parse_system(system)
@test parsed.lattice lattice atol=5e-13
@test parsed.positions positions atol=5e-13
@test parsed.lattice lattice atol=1e-12
@test parsed.positions positions atol=1e-12
for i = 1:4
@test iszero(parsed.magnetic_moments[i][1:2])
@test parsed.magnetic_moments[i][3] == magnetic_moments[i]
Expand Down Expand Up @@ -118,8 +118,8 @@ end
system = periodic_system(atoms, lattice; fractional=true)

let model = Model(system)
@test model.lattice pos_lattice atol=5e-13
@test model.positions pos_units atol=5e-13
@test model.lattice pos_lattice atol=1e-12
@test model.positions pos_units atol=1e-12
@test model.spin_polarization == :none

@test length(model.atoms) == 4
Expand All @@ -139,8 +139,8 @@ end
@test system[4, :pseudopotential] == "hgh/pbe/c-q4.hgh"

parsed = DFTK.parse_system(system)
@test parsed.lattice pos_lattice atol=5e-13
@test parsed.positions pos_units atol=5e-13
@test parsed.lattice pos_lattice atol=1e-12
@test parsed.positions pos_units atol=1e-12
@test isempty(parsed.magnetic_moments)

@test length(parsed.atoms) == 4
Expand All @@ -159,8 +159,8 @@ end
@test system[4, :pseudopotential] == "hgh/lda/c-q4.hgh"

model = Model(system)
@test model.lattice pos_lattice atol=5e-13
@test model.positions pos_units atol=5e-13
@test model.lattice pos_lattice atol=1e-12
@test model.positions pos_units atol=1e-12
@test model.spin_polarization == :none

@test length(model.atoms) == 4
Expand Down

0 comments on commit 4040702

Please sign in to comment.