Skip to content

Commit

Permalink
Implement potentialoperator for composite bases
Browse files Browse the repository at this point in the history
  • Loading branch information
david-pl committed Apr 10, 2018
1 parent 395f707 commit 3df13c6
Show file tree
Hide file tree
Showing 2 changed files with 93 additions and 1 deletion.
52 changes: 51 additions & 1 deletion src/particle.jl
Original file line number Diff line number Diff line change
Expand Up @@ -209,14 +209,64 @@ function potentialoperator(b::MomentumBasis, V::Function)
transform(b, b_pos)*full(potentialoperator(b_pos, V))*transform(b_pos, b)
end

"""
potentialoperator(b::CompositeBasis, V(x, y, z, ...))
Operator representing a potential ``V`` in more than one dimension.
# Arguments
* `b`: Composite basis consisting purely either of `PositionBasis` or
`MomentumBasis`. Note, that calling this with a composite basis in
momentum space might consume a large amount of memory.
* `V`: Function describing the potential. ATTENTION: The number of arguments
accepted by `V` must match the spatial dimension. Furthermore, the order
of the arguments has to match that of the order of the tensor product of
bases (e.g. if `b=bx⊗by⊗bz`, then `V(x,y,z)`).
"""
function potentialoperator(b::CompositeBasis, V::Function)
if isa(b.bases[1], PositionBasis)
potentialoperator_position(b, V)
elseif isa(b.bases[1], MomentumBasis)
potentialoperator_momentum(b, V)
else
throw(IncompatibleBases())
end
end
function potentialoperator_position(b::CompositeBasis, V::Function)
for base=b.bases
@assert isa(base, PositionBasis)
end

points = [samplepoints(b1) for b1=b.bases]
dims = length.(points)
n = length(b.bases)
data = Array{Complex128}(dims...)
@inbounds for i=1:length(data)
index = ind2sub(data, i)
args = (points[j][index[j]] for j=1:n)
data[i] = V(args...)
end

diagonaloperator(b, data[:])
end
function potentialoperator_momentum(b::CompositeBasis, V::Function)
bases_pos = []
for base=b.bases
@assert isa(base, MomentumBasis)
push!(bases_pos, PositionBasis(base))
end
b_pos = tensor(bases_pos...)
transform(b, b_pos)*full(potentialoperator_position(b_pos, V))*transform(b_pos, b)
end

"""
FFTOperator
Abstract type for all implementations of FFT operators.
"""
abstract type FFTOperator <: Operator end

PlanFFT = Base.DFT.FFTW.cFFTWPlan
const PlanFFT = Base.DFT.FFTW.cFFTWPlan

"""
FFTOperators
Expand Down
42 changes: 42 additions & 0 deletions test/test_particle.jl
Original file line number Diff line number Diff line change
Expand Up @@ -340,6 +340,46 @@ difference = (full(Txp) - permutesystems(full(Txp2), [2, 1, 3])).data
difference = (full(dagger(Txp)) - permutesystems(full(Tpx2), [2, 1, 3])).data
@test isapprox(difference, zeros(difference); atol=1e-13)

# Test potentialoperator in more than 1D
N = [21, 18]
xmin = [-32.5, -10π]
xmax = [24.1, 9π]

basis_position = [PositionBasis(xmin[i], xmax[i], N[i]) for i=1:2]
basis_momentum = MomentumBasis.(basis_position)

bcomp_pos = tensor(basis_position...)
bcomp_mom = tensor(basis_momentum...)
V(x, y) = sin(x*y) + cos(x)
xsample, ysample = samplepoints.(basis_position)
V_op = diagonaloperator(bcomp_pos, [V(x, y) for y in ysample for x in xsample])
V_op2 = potentialoperator(bcomp_pos, V)
@test V_op == V_op2

basis_position = PositionBasis.(basis_momentum)
bcomp_pos = tensor(basis_position...)
Txp = transform(bcomp_pos, bcomp_mom)
Tpx = transform(bcomp_mom, bcomp_pos)
xsample, ysample = samplepoints.(basis_position)
V_op = Tpx*full(diagonaloperator(bcomp_pos, [V(x, y) for y in ysample for x in xsample]))*Txp
V_op2 = potentialoperator(bcomp_mom, V)
@test V_op == V_op2

N = [17, 12, 9]
xmin = [-32.5, -10π, -0.1]
xmax = [24.1, 9π, 22.0]

basis_position = [PositionBasis(xmin[i], xmax[i], N[i]) for i=1:3]
basis_momentum = MomentumBasis.(basis_position)

bcomp_pos = tensor(basis_position...)
bcomp_mom = tensor(basis_momentum...)
V(x, y, z) = exp(-z^2) + sin(x*y) + cos(x)
xsample, ysample, zsample = samplepoints.(basis_position)
V_op = diagonaloperator(bcomp_pos, [V(x, y, z) for z in zsample for y in ysample for x in xsample])
V_op2 = potentialoperator(bcomp_pos, V)
@test V_op == V_op2

# Test error messages
b1 = PositionBasis(-1, 1, 50)
b2 = MomentumBasis(-1, 1, 30)
Expand All @@ -364,4 +404,6 @@ bc2 = b1 ⊗ b2
@test_throws bases.IncompatibleBases transform(bc1, bc2)
@test_throws bases.IncompatibleBases transform(bc2, bc1)

@test_throws bases.IncompatibleBases potentialoperator(bc bc, V)

end # testset

0 comments on commit 3df13c6

Please sign in to comment.