Skip to content

Commit

Permalink
Fix bug in MCWF that prevented saving before/after jumps (#208)
Browse files Browse the repository at this point in the history
* Fix bug in MCWF that prevented saving before/after jumps

* Add test for displays
  • Loading branch information
david-pl authored Mar 20, 2018
1 parent b7fbf03 commit efe570c
Show file tree
Hide file tree
Showing 3 changed files with 112 additions and 34 deletions.
1 change: 1 addition & 0 deletions REQUIRE
Original file line number Diff line number Diff line change
Expand Up @@ -3,3 +3,4 @@ Compat 0.52.0
OrdinaryDiffEq 3.1.0
DiffEqCallbacks 1.0
StochasticDiffEq 3.0.0
RecursiveArrayTools
140 changes: 108 additions & 32 deletions src/mcwf.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,10 @@ using ...operators_dense, ...operators_sparse
using ..timeevolution
using ...operators_lazysum, ...operators_lazytensor, ...operators_lazyproduct
import OrdinaryDiffEq
# TODO: Remove imports
import DiffEqCallbacks, RecursiveArrayTools.copyat_or_push!
import ..recast!
Base.@pure pure_inference(fout,T) = Core.Inference.return_type(fout, T)

const DecayRates = Union{Vector{Float64}, Matrix{Float64}, Void}

Expand All @@ -26,7 +30,7 @@ function mcwf_h(tspan, psi0::Ket, H::Operator, J::Vector;
check_mcwf(psi0, H, J, Jdagger, rates)
f(t, psi, dpsi) = dmcwf_h(psi, H, J, Jdagger, dpsi, tmp, rates)
j(rng, t, psi, psi_new) = jump(rng, t, psi, J, psi_new, rates)
return integrate_mcwf(f, j, tspan, psi0, seed; fout=fout,
return integrate_mcwf(f, j, tspan, psi0, seed, fout;
display_beforeevent=display_beforeevent,
display_afterevent=display_afterevent,
kwargs...)
Expand All @@ -50,7 +54,7 @@ function mcwf_nh(tspan, psi0::Ket, Hnh::Operator, J::Vector;
check_mcwf(psi0, Hnh, J, J, nothing)
f(t, psi, dpsi) = dmcwf_nh(psi, Hnh, dpsi)
j(rng, t, psi, psi_new) = jump(rng, t, psi, J, psi_new, nothing)
return integrate_mcwf(f, j, tspan, psi0, seed; fout=fout,
return integrate_mcwf(f, j, tspan, psi0, seed, fout;
display_beforeevent=display_beforeevent,
display_afterevent=display_afterevent,
kwargs...)
Expand Down Expand Up @@ -101,8 +105,8 @@ function mcwf(tspan, psi0::Ket, H::Operator, J::Vector;
tmp = copy(psi0)
dmcwf_h_(t, psi, dpsi) = dmcwf_h(psi, H, J, Jdagger, dpsi, tmp, rates)
j_h(rng, t, psi, psi_new) = jump(rng, t, psi, J, psi_new, rates)
return integrate_mcwf(dmcwf_h_, j_h, tspan, psi0, seed;
fout=fout,
return integrate_mcwf(dmcwf_h_, j_h, tspan, psi0, seed,
fout;
display_beforeevent=display_beforeevent,
display_afterevent=display_afterevent,
kwargs...)
Expand All @@ -119,8 +123,8 @@ function mcwf(tspan, psi0::Ket, H::Operator, J::Vector;
end
dmcwf_nh_(t, psi, dpsi) = dmcwf_nh(psi, Hnh, dpsi)
j_nh(rng, t, psi, psi_new) = jump(rng, t, psi, J, psi_new, rates)
return integrate_mcwf(dmcwf_nh_, j_nh, tspan, psi0, seed;
fout=fout,
return integrate_mcwf(dmcwf_nh_, j_nh, tspan, psi0, seed,
fout;
display_beforeevent=display_beforeevent,
display_afterevent=display_afterevent,
kwargs...)
Expand Down Expand Up @@ -160,8 +164,8 @@ function mcwf_dynamic(tspan, psi0::Ket, f::Function;
tmp = copy(psi0)
dmcwf_(t, psi, dpsi) = dmcwf_h_dynamic(t, psi, f, rates, dpsi, tmp)
j_(rng, t, psi, psi_new) = jump_dynamic(rng, t, psi, f, psi_new, rates)
integrate_mcwf(dmcwf_, j_, tspan, psi0, seed;
fout=fout,
integrate_mcwf(dmcwf_, j_, tspan, psi0, seed,
fout;
display_beforeevent=display_beforeevent,
display_afterevent=display_afterevent,
kwargs...)
Expand All @@ -180,8 +184,8 @@ function mcwf_nh_dynamic(tspan, psi0::Ket, f::Function;
kwargs...)
dmcwf_(t, psi, dpsi) = dmcwf_nh_dynamic(t, psi, f, dpsi)
j_(rng, t, psi, psi_new) = jump_dynamic(rng, t, psi, f, psi_new, rates)
integrate_mcwf(dmcwf_, j_, tspan, psi0, seed;
fout=fout,
integrate_mcwf(dmcwf_, j_, tspan, psi0, seed,
fout;
display_beforeevent=display_beforeevent,
display_afterevent=display_afterevent,
kwargs...)
Expand Down Expand Up @@ -243,39 +247,111 @@ Integrate a single Monte Carlo wave function trajectory.
* `kwargs`: Further arguments are passed on to the ode solver.
"""
function integrate_mcwf(dmcwf::Function, jumpfun::Function, tspan,
psi0::Ket, seed; fout=nothing,
psi0::Ket, seed, fout::Function;
display_beforeevent=false, display_afterevent=false,
#TODO: Remove kwargs
save_everystep=false, callback=nothing,
alg=OrdinaryDiffEq.DP5(),
kwargs...)

tmp = copy(psi0)
as_ket(x::Vector{Complex128}) = Ket(psi0.basis, x)
as_vector(psi::Ket) = psi.data
rng = MersenneTwister(convert(UInt, seed))
jumpnorm = Ref(rand(rng))
djumpnorm(x::Vector{Complex128}, t, integrator) = norm(as_ket(x))^2 - (1-jumpnorm[])
function dojump(integrator)
x = integrator.u
t = integrator.t
jumpfun(rng, t, as_ket(x), tmp)
x .= tmp.data
jumpnorm[] = rand(rng)
end
cb = OrdinaryDiffEq.ContinuousCallback(djumpnorm,dojump,
save_positions = (display_beforeevent,display_afterevent))

function fout_(t, x::Ket)
if fout==nothing
psi = copy(x)
psi /= norm(psi)
return psi
else
return fout(t, x)

if !display_beforeevent && !display_afterevent
function dojump(integrator)
x = integrator.u
t = integrator.t
jumpfun(rng, t, as_ket(x), tmp)
x .= tmp.data
jumpnorm[] = rand(rng)
end
cb = OrdinaryDiffEq.ContinuousCallback(djumpnorm,dojump,
save_positions = (display_beforeevent,display_afterevent))


return timeevolution.integrate(float(tspan), dmcwf, as_vector(psi0),
copy(psi0), copy(psi0), fout;
callback = cb,
kwargs...)
else
# Temporary workaround until proper tooling for saving
# TODO: Replace by proper call to timeevolution.integrate
function fout_(x::Vector{Complex128}, t::Float64, integrator)
recast!(x, state)
fout(t, state)
end

state = copy(psi0)
dstate = copy(psi0)
out_type = pure_inference(fout, Tuple{eltype(tspan),typeof(state)})
out = DiffEqCallbacks.SavedValues(Float64,out_type)
scb = DiffEqCallbacks.SavingCallback(fout_,out,saveat=tspan,
save_everystep=save_everystep,
save_start = false)

function dojump_display(integrator)
x = integrator.u
t = integrator.t

affect! = scb.affect!
if display_beforeevent
affect!.saveiter += 1
copyat_or_push!(affect!.saved_values.t, affect!.saveiter, integrator.t)
copyat_or_push!(affect!.saved_values.saveval, affect!.saveiter,
affect!.save_func(integrator.u, integrator.t, integrator),Val{false})
end

jumpfun(rng, t, as_ket(x), tmp)

if display_afterevent
affect!.saveiter += 1
copyat_or_push!(affect!.saved_values.t, affect!.saveiter, integrator.t)
copyat_or_push!(affect!.saved_values.saveval, affect!.saveiter,
affect!.save_func(integrator.u, integrator.t, integrator),Val{false})
end

x .= tmp.data
jumpnorm[] = rand(rng)
end

cb = OrdinaryDiffEq.ContinuousCallback(djumpnorm,dojump_display,
save_positions = (false,false))
full_cb = OrdinaryDiffEq.CallbackSet(callback,cb,scb)

function df_(dx::Vector{Complex128}, x::Vector{Complex128}, p, t)
recast!(x, state)
recast!(dx, dstate)
dmcwf(t, state, dstate)
recast!(dstate, dx)
end

prob = OrdinaryDiffEq.ODEProblem{true}(df_, as_vector(psi0),(tspan[1],tspan[end]))

sol = OrdinaryDiffEq.solve(
prob,
alg;
reltol = 1.0e-6,
abstol = 1.0e-8,
save_everystep = false, save_start = false,
save_end = false,
callback=full_cb, kwargs...)
return out.t, out.saveval
end
end

timeevolution.integrate(float(tspan), dmcwf, as_vector(psi0),
copy(psi0), copy(psi0), fout_;
callback = cb,
kwargs...)
function integrate_mcwf(dmcwf::Function, jumpfun::Function, tspan,
psi0::Ket, seed, fout::Void;
kwargs...)
function fout_(t, x)
psi = copy(x)
psi /= norm(psi)
return psi
end
integrate_mcwf(dmcwf, jumpfun, tspan, psi0, seed, fout_; kwargs...)
end

"""
Expand Down
5 changes: 3 additions & 2 deletions test/test_timeevolution_mcwf.jl
Original file line number Diff line number Diff line change
Expand Up @@ -70,8 +70,9 @@ timeevolution.mcwf(T, Ψ₀, H, J; seed=UInt(2), reltol=1e-6, fout=fout)
tout, Ψt = timeevolution.mcwf(T, Ψ₀, Hlazy, J; seed=UInt(1), reltol=1e-6)
@test norm(Ψt[end] - Ψ) < 1e-5

tout, Ψt = timeevolution.mcwf(T, Ψ₀, H, Jlazy; seed=UInt(1), reltol=1e-6)
@test norm(Ψt[end] - Ψ) < 1e-5
tout, Ψt2 = timeevolution.mcwf(T, Ψ₀, H, Jlazy; seed=UInt(1), reltol=1e-6, display_beforeevent=true, display_afterevent=true)
@test norm(Ψt2[end] - Ψ) < 1e-5
@test length(Ψt2) > length(Ψt)

tout, Ψt = timeevolution.mcwf(T, Ψ₀, H, Jlazy./[sqrt(γ), sqrt(κ)]; seed=UInt(1), rates=[γ, κ], reltol=1e-6)
@test norm(Ψt[end] - Ψ) < 1e-5
Expand Down

0 comments on commit efe570c

Please sign in to comment.