Skip to content

Commit

Permalink
Broadcasting for semiclassical objects (#404)
Browse files Browse the repository at this point in the history

Co-authored-by: Stefan Krastanov <github.acc@krastanov.org>
Co-authored-by: Stefan Krastanov <stefan@krastanov.org>
  • Loading branch information
3 people authored Aug 11, 2024
1 parent f730f1a commit beb0f37
Show file tree
Hide file tree
Showing 6 changed files with 148 additions and 16 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ RecursiveArrayTools = "2, 3"
Reexport = "0.2, 1.0"
StochasticDiffEq = "6"
WignerSymbols = "1, 2"
julia = "1.3"
julia = "1.10"

[extras]
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
Expand Down
1 change: 1 addition & 0 deletions src/QuantumOptics.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ module QuantumOptics
using Reexport
@reexport using QuantumOpticsBase
using SparseArrays, LinearAlgebra
import RecursiveArrayTools

export
ylm,
Expand Down
110 changes: 95 additions & 15 deletions src/semiclassical.jl
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
module semiclassical

using QuantumOpticsBase
import Base: ==
import QuantumOpticsBase: IncompatibleBases
import Base: ==, isapprox, +, -, *, /
import ..timeevolution: integrate, recast!, jump, integrate_mcwf, jump_callback,
JumpRNGState, threshold, roll!, as_vector, QO_CHECKS
import LinearAlgebra: normalize, normalize!
import RecursiveArrayTools

using Random, LinearAlgebra
import OrdinaryDiffEq
Expand All @@ -31,26 +33,104 @@ mutable struct State{B,T,C}
new{B,T,C}(quantum, classical)
end
end

Base.length(state::State) = length(state.quantum) + length(state.classical)
Base.copy(state::State) = State(copy(state.quantum), copy(state.classical))
Base.eltype(state::State) = promote_type(eltype(state.quantum),eltype(state.classical))
normalize!(state::State) = (normalize!(state.quantum); state)
normalize(state::State) = State(normalize(state.quantum),copy(state.classical))

function ==(a::State, b::State)
QuantumOpticsBase.samebases(a.quantum, b.quantum) &&
length(a.classical)==length(b.classical) &&
(a.classical==b.classical) &&
(a.quantum==b.quantum)
end
State{B}(q::T, c::C) where {B,T<:QuantumState{B},C} = State(q,c)

# Standard interfaces
Base.zero(x::State) = State(zero(x.quantum), zero(x.classical))
Base.length(x::State) = length(x.quantum) + length(x.classical)
Base.axes(x::State) = (Base.OneTo(length(x)),)
Base.size(x::State) = size(x.quantum)
Base.ndims(x::Type{<:State{B,T,C}}) where {B,T<:QuantumState{B},C} = ndims(T)
Base.copy(x::State) = State(copy(x.quantum), copy(x.classical))
Base.copyto!(x::State, y::State) = (copyto!(x.quantum, y.quantum); copyto!(x.classical, y.classical); x)
Base.fill!(x::State, a) = (fill!(x.quantum, a), fill!(x.classical, a))
Base.eltype(x::State) = promote_type(eltype(x.quantum),eltype(x.classical))
Base.eltype(x::Type{<:State{B,T,C}}) where {B,T<:QuantumState{B},C} = promote_type(eltype(T), eltype(C))
Base.similar(x::State, ::Type{T} = eltype(x)) where {T} = State(similar(x.quantum, T), similar(x.classical, T))
Base.getindex(x::State, idx) = idx <= length(x.quantum) ? getindex(x.quantum, idx) : getindex(x.classical, idx-length(x.quantum))

normalize!(x::State) = (normalize!(x.quantum); x)
normalize(x::State) = State(normalize(x.quantum),copy(x.classical))
LinearAlgebra.norm(x::State) = LinearAlgebra.norm(x.quantum)

==(x::State{B}, y::State{B}) where {B} = (x.classical==y.classical) && (x.quantum==y.quantum)
==(x::State, y::State) = false

isapprox(x::State{B}, y::State{B}; kwargs...) where {B} = isapprox(x.quantum,y.quantum; kwargs...) && isapprox(x.classical,y.classical; kwargs...)
isapprox(x::State, y::State; kwargs...) = false

QuantumOpticsBase.expect(op, state::State) = expect(op, state.quantum)
QuantumOpticsBase.variance(op, state::State) = variance(op, state.quantum)
QuantumOpticsBase.ptrace(state::State, indices) = State(ptrace(state.quantum, indices), state.classical)

QuantumOpticsBase.dm(x::State) = State(dm(x.quantum), x.classical)

Base.broadcastable(x::State) = x

# Custom broadcasting style
struct StateStyle{B} <: Broadcast.BroadcastStyle end

# Style precedence rules
Broadcast.BroadcastStyle(::Type{<:State{B}}) where {B} = StateStyle{B}()
Broadcast.BroadcastStyle(::StateStyle{B1}, ::StateStyle{B2}) where {B1,B2} = throw(IncompatibleBases())
Broadcast.BroadcastStyle(::StateStyle{B}, ::Broadcast.DefaultArrayStyle{0}) where {B} = StateStyle{B}()
Broadcast.BroadcastStyle(::Broadcast.DefaultArrayStyle{0}, ::StateStyle{B}) where {B} = StateStyle{B}()

# Out-of-place broadcasting
@inline function Base.copy(bc::Broadcast.Broadcasted{<:StateStyle{B},Axes,F,Args}) where {B,Axes,F,Args<:Tuple}
bcf = Broadcast.flatten(bc)
# extract quantum object from broadcast container
qobj = find_quantum(bcf)
data_q = zeros(eltype(qobj), size(qobj)...)
Nq = length(qobj)
# allocate quantum data from broadcast container
@inbounds @simd for I in 1:Nq
data_q[I] = bcf[I]
end
# extract classical object from broadcast container
cobj = find_classical(bcf)
data_c = zeros(eltype(cobj), length(cobj))
Nc = length(cobj)
# allocate classical data from broadcast container
@inbounds @simd for I in 1:Nc
data_c[I] = bcf[I+Nq]
end
type = eval(nameof(typeof(qobj)))
return State{B}(type(basis(qobj), data_q), data_c)
end

for f [:find_quantum, :find_classical]
@eval ($f)(bc::Broadcast.Broadcasted) = ($f)(bc.args)
@eval ($f)(args::Tuple) = ($f)(($f)(args[1]), Base.tail(args))
@eval ($f)(x) = x
@eval ($f)(::Any, rest) = ($f)(rest)
end
find_quantum(x::State, rest) = x.quantum
find_classical(x::State, rest) = x.classical

# In-place broadcasting
@inline function Base.copyto!(dest::State{B}, bc::Broadcast.Broadcasted{<:StateStyle{B},Axes,F,Args}) where {B,Axes,F,Args<:Tuple}
axes(dest) == axes(bc) || throwdm(axes(dest), axes(bc))
bc′ = Base.Broadcast.preprocess(dest, bc)
# write broadcasted quantum data to dest
qobj = dest.quantum
@inbounds @simd for I in 1:length(qobj)
qobj.data[I] = bc′[I]
end
# write broadcasted classical data to dest
cobj = dest.classical
@inbounds @simd for I in 1:length(cobj)
cobj[I] = bc′[I+length(qobj)]
end
return dest
end
@inline Base.copyto!(dest::State{B1}, bc::Broadcast.Broadcasted{<:StateStyle{B2},Axes,F,Args}) where {B1,B2,Axes,F,Args<:Tuple} =
throw(IncompatibleBases())

Base.@propagate_inbounds Base.Broadcast._broadcast_getindex(x::State, i) = Base.getindex(x, i)
RecursiveArrayTools.recursive_unitless_bottom_eltype(x::State) = eltype(x)
RecursiveArrayTools.recursivecopy!(dest::State, src::State) = copyto!(dest, src)
RecursiveArrayTools.recursivecopy(x::State) = copy(x)
RecursiveArrayTools.recursivefill!(x::State, a) = fill!(x, a)

"""
semiclassical.schroedinger_dynamic(tspan, state0, fquantum, fclassical[; fout, ...])
Expand Down
1 change: 1 addition & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ names = [

"test_timeevolution_abstractdata.jl",

"test_sciml_broadcast_interfaces.jl",
"test_ForwardDiff.jl"
]

Expand Down
25 changes: 25 additions & 0 deletions test/test_sciml_broadcast_interfaces.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
using Test
using QuantumOptics
using OrdinaryDiffEq

@testset "sciml interface" begin

# semiclassical ODE problem
b = SpinBasis(1//2)
psi0 = spindown(b)
u0 = ComplexF64[0.5, 0.75]
sc = semiclassical.State(psi0, u0)
t₀, t₁ = (0.0, pi)
σx = sigmax(b)

fquantum(t, q, u) = σx + cos(u[1])*identityoperator(σx)
fclassical!(du, u, q, t) = (du[1] = sin(u[2]); du[2] = 2*u[1])
f!(dstate, state, p, t) = semiclassical.dschroedinger_dynamic!(dstate, fquantum, fclassical!, state, t)
prob = ODEProblem(f!, sc, (t₀, t₁))

sol = solve(prob, DP5(); reltol = 1.0e-8, abstol = 1.0e-10, save_everystep=false)
tout, ψt = semiclassical.schroedinger_dynamic([t₀, t₁], sc, fquantum, fclassical!; reltol = 1.0e-8, abstol = 1.0e-10)

@test sol[end] ψt[end]

end
25 changes: 25 additions & 0 deletions test/test_semiclassical.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
using Test
using QuantumOptics
using LinearAlgebra
using QuantumOpticsBase: IncompatibleBases

@testset "semiclassical" begin

Expand Down Expand Up @@ -175,4 +176,28 @@ after_jump = findlast(t-> !(t∈T), tout4)
@test ψt4[before_jump].quantum == ψ0.quantum
@test ψt4[after_jump].quantum == spindown(ba)fockstate(bf,0)

# Test broadcasting interface
b = FockBasis(10)
bn = FockBasis(20)
u0 = ComplexF64[0.7, 0.2]
psi = fockstate(b, 2)
psin = fockstate(bn, 2)
rho = dm(psi)

sc_ket = semiclassical.State(psi, u0)
sc_ketn = semiclassical.State(psin, u0)
sc_dm = semiclassical.State(rho, u0)

@test Base.size(sc_dm) == Base.size(rho)
@test (copy_sc = copy(sc_ket); Base.fill!(copy_sc, 0.0); copy_sc) == semiclassical.State(fill!(copy(psi), 0.0), fill!(copy(u0), 0.0))
@test Base.similar(sc_ket, Int) isa semiclassical.State
@test normalize!(copy(sc_ket)) == semiclassical.State(normalize!(copy(psi)), u0)
@test !(sc_ket == sc_ketn)
@test !(isapprox(sc_ket, sc_ketn))
@test sc_ket .* 1.0 == sc_ket
@test sc_dm .* 1.0 == sc_dm
@test sc_ket .+ 2.0 == semiclassical.State(psi .+ 2.0, u0 .+ 2.0)
@test sc_dm .+ 2.0 == semiclassical.State(rho .+ 2.0, u0 .+ 2.0)
@test_throws IncompatibleBases sc_ket .+ semiclassical.State(spinup(SpinBasis(10)), u0)

end # testsets

0 comments on commit beb0f37

Please sign in to comment.