From c3621cadbcd2a17918fd22ca7ed0b004e2e17857 Mon Sep 17 00:00:00 2001 From: Stefan Krastanov Date: Sat, 3 Apr 2021 17:33:36 -0400 Subject: [PATCH] Implement the OrdinaryDiffEq interface for Kets MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The following now works ```julia using QuantumOptics using DifferentialEquations ℋ = SpinBasis(1//2) σx = sigmax(ℋ) ↓ = s = spindown(ℋ) schrod(ψ,p,t) = im * σx * ψ t₀, t₁ = (0.0, pi) Δt = 0.1 prob = ODEProblem(schrod, ↓, (t₀, t₁)) sol = solve(prob,Tsit5()) ``` It works for Bras as well. It works for in-place operations, however there are spurrious allocations due to inefficient broadcasting that ruin the performance. --- src/operators_dense.jl | 2 +- src/states.jl | 26 +++++++++++++++++++++++--- src/superoperators.jl | 2 +- 3 files changed, 25 insertions(+), 5 deletions(-) diff --git a/src/operators_dense.jl b/src/operators_dense.jl index 80a6bee2..de077558 100644 --- a/src/operators_dense.jl +++ b/src/operators_dense.jl @@ -340,7 +340,7 @@ Broadcast.BroadcastStyle(::OperatorStyle{B1,B2}, ::OperatorStyle{B3,B4}) where { end find_basis(a::DataOperator, rest) = (a.basis_l, a.basis_r) -const BasicMathFunc = Union{typeof(+),typeof(-),typeof(*)} +const BasicMathFunc = Union{typeof(+),typeof(-),typeof(*),typeof(/)} function Broadcasted_restrict_f(f::BasicMathFunc, args::Tuple{Vararg{<:DataOperator}}, axes) args_ = Tuple(a.data for a=args) return Broadcast.Broadcasted(f, args_, axes) diff --git a/src/states.jl b/src/states.jl index df604f95..e1a23d07 100644 --- a/src/states.jl +++ b/src/states.jl @@ -228,7 +228,7 @@ find_basis(x) = x find_basis(a::StateVector, rest) = a.basis find_basis(::Any, rest) = find_basis(rest) -const BasicMathFunc = Union{typeof(+),typeof(-),typeof(*)} +const BasicMathFunc = Union{typeof(+),typeof(-),typeof(*),typeof(/)} # `:/` was added for use with scalars in the DifferentialEquations interface function Broadcasted_restrict_f(f::BasicMathFunc, args::Tuple{Vararg{<:T}}, axes) where T<:StateVector args_ = Tuple(a.data for a=args) return Broadcast.Broadcasted(f, args_, axes) @@ -237,6 +237,15 @@ function Broadcasted_restrict_f(f, args::Tuple{Vararg{<:T}}, axes) where T<:Stat throw(error("Cannot broadcast function `$f` on type `$T`")) end +# Broadcast with scalars (of use in ODE solvers checking for tolerances, e.g. `.* reltol .+ abstol`) +Broadcast.BroadcastStyle(::T, ::Broadcast.DefaultArrayStyle{0}) where {B<:Basis, T<:QuantumOpticsBase.KetStyle{B}} = T() +Broadcast.BroadcastStyle(::T, ::Broadcast.DefaultArrayStyle{0}) where {B<:Basis, T<:QuantumOpticsBase.BraStyle{B}} = T() +getdata(arg::StateVector) = arg.data +getdata(arg) = arg +function Broadcasted_restrict_f(f, args, axes) + args_ = Tuple(getdata(a) for a=args) + return Broadcast.Broadcasted(f, args_, axes) +end # In-place broadcasting for Kets @inline function Base.copyto!(dest::Ket{B}, bc::Broadcast.Broadcasted{Style,Axes,F,Args}) where {B<:Basis,Style<:KetStyle{B},Axes,F,Args} @@ -250,8 +259,8 @@ end end # Get the underlying data fields of kets and broadcast them as arrays bcf = Broadcast.flatten(bc) - args_ = Tuple(a.data for a=bcf.args) - bc_ = Broadcast.Broadcasted(bcf.f, args_, axes(bcf)) + args_ = Tuple(getdata(a) for a=bcf.args) + bc_ = Broadcasted_restrict_f(bcf.f, bcf.args, axes(bcf)) copyto!(dest.data, bc_) return dest end @@ -278,3 +287,14 @@ end throw(IncompatibleBases()) @inline Base.copyto!(A::T,B::T) where T<:StateVector = (copyto!(A.data,B.data); A) + +# A few more standard interfaces: These do not necessarily make sense for a StateVector, but enable transparent use of DifferentialEquations.jl +Base.eltype(::Type{Ket{B,A}}) where {B,N,A<:AbstractVector{N}} = N # ODE init +Base.eltype(::Type{Bra{B,A}}) where {B,N,A<:AbstractVector{N}} = N +Base.zero(k::StateVector) = typeof(k)(k.basis, zero(k.data)) # ODE init +Base.any(f::Function, x::StateVector; kwargs...) = any(f, x.data; kwargs...) # ODE nan checks +Base.all(f::Function, x::StateVector; kwargs...) = all(f, x.data; kwargs...) +Broadcast.similar(k::StateVector, t) = typeof(k)(k.basis, copy(k.data)) +using RecursiveArrayTools +RecursiveArrayTools.recursivecopy!(dst::Ket{B,A},src::Ket{B,A}) where {B,A} = copy!(dst.data,src.data) # ODE in-place equations +RecursiveArrayTools.recursivecopy!(dst::Bra{B,A},src::Bra{B,A}) where {B,A} = copy!(dst.data,src.data) \ No newline at end of file diff --git a/src/superoperators.jl b/src/superoperators.jl index 0ef51fac..4518b527 100644 --- a/src/superoperators.jl +++ b/src/superoperators.jl @@ -232,7 +232,7 @@ end # end find_basis(a::SuperOperator, rest) = (a.basis_l, a.basis_r) -const BasicMathFunc = Union{typeof(+),typeof(-),typeof(*)} +const BasicMathFunc = Union{typeof(+),typeof(-),typeof(*),typeof(/)} function Broadcasted_restrict_f(f::BasicMathFunc, args::Tuple{Vararg{<:SuperOperator}}, axes) args_ = Tuple(a.data for a=args) return Broadcast.Broadcasted(f, args_, axes)