diff --git a/REQUIRE b/REQUIRE index 26c883d4..82a72e08 100644 --- a/REQUIRE +++ b/REQUIRE @@ -3,3 +3,4 @@ Compat 0.52.0 OrdinaryDiffEq 3.1.0 DiffEqCallbacks 1.0 StochasticDiffEq 3.0.0 +RecursiveArrayTools diff --git a/src/mcwf.jl b/src/mcwf.jl index 098fbf90..6ffd4c4f 100644 --- a/src/mcwf.jl +++ b/src/mcwf.jl @@ -7,6 +7,10 @@ using ...operators_dense, ...operators_sparse using ..timeevolution using ...operators_lazysum, ...operators_lazytensor, ...operators_lazyproduct import OrdinaryDiffEq +# TODO: Remove imports +import DiffEqCallbacks, RecursiveArrayTools.copyat_or_push! +import ..recast! +Base.@pure pure_inference(fout,T) = Core.Inference.return_type(fout, T) const DecayRates = Union{Vector{Float64}, Matrix{Float64}, Void} @@ -26,7 +30,7 @@ function mcwf_h(tspan, psi0::Ket, H::Operator, J::Vector; check_mcwf(psi0, H, J, Jdagger, rates) f(t, psi, dpsi) = dmcwf_h(psi, H, J, Jdagger, dpsi, tmp, rates) j(rng, t, psi, psi_new) = jump(rng, t, psi, J, psi_new, rates) - return integrate_mcwf(f, j, tspan, psi0, seed; fout=fout, + return integrate_mcwf(f, j, tspan, psi0, seed, fout; display_beforeevent=display_beforeevent, display_afterevent=display_afterevent, kwargs...) @@ -50,7 +54,7 @@ function mcwf_nh(tspan, psi0::Ket, Hnh::Operator, J::Vector; check_mcwf(psi0, Hnh, J, J, nothing) f(t, psi, dpsi) = dmcwf_nh(psi, Hnh, dpsi) j(rng, t, psi, psi_new) = jump(rng, t, psi, J, psi_new, nothing) - return integrate_mcwf(f, j, tspan, psi0, seed; fout=fout, + return integrate_mcwf(f, j, tspan, psi0, seed, fout; display_beforeevent=display_beforeevent, display_afterevent=display_afterevent, kwargs...) @@ -101,8 +105,8 @@ function mcwf(tspan, psi0::Ket, H::Operator, J::Vector; tmp = copy(psi0) dmcwf_h_(t, psi, dpsi) = dmcwf_h(psi, H, J, Jdagger, dpsi, tmp, rates) j_h(rng, t, psi, psi_new) = jump(rng, t, psi, J, psi_new, rates) - return integrate_mcwf(dmcwf_h_, j_h, tspan, psi0, seed; - fout=fout, + return integrate_mcwf(dmcwf_h_, j_h, tspan, psi0, seed, + fout; display_beforeevent=display_beforeevent, display_afterevent=display_afterevent, kwargs...) @@ -119,8 +123,8 @@ function mcwf(tspan, psi0::Ket, H::Operator, J::Vector; end dmcwf_nh_(t, psi, dpsi) = dmcwf_nh(psi, Hnh, dpsi) j_nh(rng, t, psi, psi_new) = jump(rng, t, psi, J, psi_new, rates) - return integrate_mcwf(dmcwf_nh_, j_nh, tspan, psi0, seed; - fout=fout, + return integrate_mcwf(dmcwf_nh_, j_nh, tspan, psi0, seed, + fout; display_beforeevent=display_beforeevent, display_afterevent=display_afterevent, kwargs...) @@ -160,8 +164,8 @@ function mcwf_dynamic(tspan, psi0::Ket, f::Function; tmp = copy(psi0) dmcwf_(t, psi, dpsi) = dmcwf_h_dynamic(t, psi, f, rates, dpsi, tmp) j_(rng, t, psi, psi_new) = jump_dynamic(rng, t, psi, f, psi_new, rates) - integrate_mcwf(dmcwf_, j_, tspan, psi0, seed; - fout=fout, + integrate_mcwf(dmcwf_, j_, tspan, psi0, seed, + fout; display_beforeevent=display_beforeevent, display_afterevent=display_afterevent, kwargs...) @@ -180,8 +184,8 @@ function mcwf_nh_dynamic(tspan, psi0::Ket, f::Function; kwargs...) dmcwf_(t, psi, dpsi) = dmcwf_nh_dynamic(t, psi, f, dpsi) j_(rng, t, psi, psi_new) = jump_dynamic(rng, t, psi, f, psi_new, rates) - integrate_mcwf(dmcwf_, j_, tspan, psi0, seed; - fout=fout, + integrate_mcwf(dmcwf_, j_, tspan, psi0, seed, + fout; display_beforeevent=display_beforeevent, display_afterevent=display_afterevent, kwargs...) @@ -243,39 +247,111 @@ Integrate a single Monte Carlo wave function trajectory. * `kwargs`: Further arguments are passed on to the ode solver. """ function integrate_mcwf(dmcwf::Function, jumpfun::Function, tspan, - psi0::Ket, seed; fout=nothing, + psi0::Ket, seed, fout::Function; display_beforeevent=false, display_afterevent=false, + #TODO: Remove kwargs + save_everystep=false, callback=nothing, + alg=OrdinaryDiffEq.DP5(), kwargs...) + tmp = copy(psi0) as_ket(x::Vector{Complex128}) = Ket(psi0.basis, x) as_vector(psi::Ket) = psi.data rng = MersenneTwister(convert(UInt, seed)) jumpnorm = Ref(rand(rng)) djumpnorm(x::Vector{Complex128}, t, integrator) = norm(as_ket(x))^2 - (1-jumpnorm[]) - function dojump(integrator) - x = integrator.u - t = integrator.t - jumpfun(rng, t, as_ket(x), tmp) - x .= tmp.data - jumpnorm[] = rand(rng) - end - cb = OrdinaryDiffEq.ContinuousCallback(djumpnorm,dojump, - save_positions = (display_beforeevent,display_afterevent)) - - function fout_(t, x::Ket) - if fout==nothing - psi = copy(x) - psi /= norm(psi) - return psi - else - return fout(t, x) + + if !display_beforeevent && !display_afterevent + function dojump(integrator) + x = integrator.u + t = integrator.t + jumpfun(rng, t, as_ket(x), tmp) + x .= tmp.data + jumpnorm[] = rand(rng) + end + cb = OrdinaryDiffEq.ContinuousCallback(djumpnorm,dojump, + save_positions = (display_beforeevent,display_afterevent)) + + + return timeevolution.integrate(float(tspan), dmcwf, as_vector(psi0), + copy(psi0), copy(psi0), fout; + callback = cb, + kwargs...) + else + # Temporary workaround until proper tooling for saving + # TODO: Replace by proper call to timeevolution.integrate + function fout_(x::Vector{Complex128}, t::Float64, integrator) + recast!(x, state) + fout(t, state) end + + state = copy(psi0) + dstate = copy(psi0) + out_type = pure_inference(fout, Tuple{eltype(tspan),typeof(state)}) + out = DiffEqCallbacks.SavedValues(Float64,out_type) + scb = DiffEqCallbacks.SavingCallback(fout_,out,saveat=tspan, + save_everystep=save_everystep, + save_start = false) + + function dojump_display(integrator) + x = integrator.u + t = integrator.t + + affect! = scb.affect! + if display_beforeevent + affect!.saveiter += 1 + copyat_or_push!(affect!.saved_values.t, affect!.saveiter, integrator.t) + copyat_or_push!(affect!.saved_values.saveval, affect!.saveiter, + affect!.save_func(integrator.u, integrator.t, integrator),Val{false}) + end + + jumpfun(rng, t, as_ket(x), tmp) + + if display_afterevent + affect!.saveiter += 1 + copyat_or_push!(affect!.saved_values.t, affect!.saveiter, integrator.t) + copyat_or_push!(affect!.saved_values.saveval, affect!.saveiter, + affect!.save_func(integrator.u, integrator.t, integrator),Val{false}) + end + + x .= tmp.data + jumpnorm[] = rand(rng) + end + + cb = OrdinaryDiffEq.ContinuousCallback(djumpnorm,dojump_display, + save_positions = (false,false)) + full_cb = OrdinaryDiffEq.CallbackSet(callback,cb,scb) + + function df_(dx::Vector{Complex128}, x::Vector{Complex128}, p, t) + recast!(x, state) + recast!(dx, dstate) + dmcwf(t, state, dstate) + recast!(dstate, dx) + end + + prob = OrdinaryDiffEq.ODEProblem{true}(df_, as_vector(psi0),(tspan[1],tspan[end])) + + sol = OrdinaryDiffEq.solve( + prob, + alg; + reltol = 1.0e-6, + abstol = 1.0e-8, + save_everystep = false, save_start = false, + save_end = false, + callback=full_cb, kwargs...) + return out.t, out.saveval end +end - timeevolution.integrate(float(tspan), dmcwf, as_vector(psi0), - copy(psi0), copy(psi0), fout_; - callback = cb, - kwargs...) +function integrate_mcwf(dmcwf::Function, jumpfun::Function, tspan, + psi0::Ket, seed, fout::Void; + kwargs...) + function fout_(t, x) + psi = copy(x) + psi /= norm(psi) + return psi + end + integrate_mcwf(dmcwf, jumpfun, tspan, psi0, seed, fout_; kwargs...) end """ diff --git a/test/test_timeevolution_mcwf.jl b/test/test_timeevolution_mcwf.jl index 0bca19f9..8e58d4b6 100644 --- a/test/test_timeevolution_mcwf.jl +++ b/test/test_timeevolution_mcwf.jl @@ -70,8 +70,9 @@ timeevolution.mcwf(T, Ψ₀, H, J; seed=UInt(2), reltol=1e-6, fout=fout) tout, Ψt = timeevolution.mcwf(T, Ψ₀, Hlazy, J; seed=UInt(1), reltol=1e-6) @test norm(Ψt[end] - Ψ) < 1e-5 -tout, Ψt = timeevolution.mcwf(T, Ψ₀, H, Jlazy; seed=UInt(1), reltol=1e-6) -@test norm(Ψt[end] - Ψ) < 1e-5 +tout, Ψt2 = timeevolution.mcwf(T, Ψ₀, H, Jlazy; seed=UInt(1), reltol=1e-6, display_beforeevent=true, display_afterevent=true) +@test norm(Ψt2[end] - Ψ) < 1e-5 +@test length(Ψt2) > length(Ψt) tout, Ψt = timeevolution.mcwf(T, Ψ₀, H, Jlazy./[sqrt(γ), sqrt(κ)]; seed=UInt(1), rates=[γ, κ], reltol=1e-6) @test norm(Ψt[end] - Ψ) < 1e-5