Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement the OrdinaryDiffEq interface for Kets #16

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,15 @@ version = "v0.2.7"
FFTW = "7a1cc6ca-52ef-59f5-83cd-3a7055c09341"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
RecursiveArrayTools = "731186ca-8d62-57ce-b412-fbd966d074cd"
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"

[compat]
julia = "1.3"
FFTW = "1.2"
Adapt = "1, 2"
RecursiveArrayTools = "2.11"

[extras]
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
Expand Down
47 changes: 26 additions & 21 deletions src/operators_dense.jl
Original file line number Diff line number Diff line change
Expand Up @@ -331,40 +331,45 @@ struct OperatorStyle{BL<:Basis,BR<:Basis} <: DataOperatorStyle{BL,BR} end
Broadcast.BroadcastStyle(::Type{<:Operator{BL,BR}}) where {BL<:Basis,BR<:Basis} = OperatorStyle{BL,BR}()
Broadcast.BroadcastStyle(::OperatorStyle{B1,B2}, ::OperatorStyle{B3,B4}) where {B1<:Basis,B2<:Basis,B3<:Basis,B4<:Basis} = throw(IncompatibleBases())

# Broadcast with scalars (of use in ODE solvers checking for tolerances, e.g. `.* reltol .+ abstol`)
Broadcast.BroadcastStyle(::T, ::Broadcast.DefaultArrayStyle{0}) where {Bl<:Basis, Br<:Basis, T<:OperatorStyle{Bl,Br}} = T()

# Out-of-place broadcasting
@inline function Base.copy(bc::Broadcast.Broadcasted{Style,Axes,F,Args}) where {BL<:Basis,BR<:Basis,Style<:OperatorStyle{BL,BR},Axes,F,Args<:Tuple}
bcf = Broadcast.flatten(bc)
bl,br = find_basis(bcf.args)
bc_ = Broadcasted_restrict_f(bcf.f, bcf.args, axes(bcf))
return Operator{BL,BR}(bl, br, copy(bc_))
T = find_dType(bcf)
data = zeros(T, length(bl), length(br))
@inbounds @simd for I in eachindex(bcf)
data[I] = bcf[I]
end
return Operator{BL,BR}(bl, br, data)
end
find_basis(a::DataOperator, rest) = (a.basis_l, a.basis_r)

const BasicMathFunc = Union{typeof(+),typeof(-),typeof(*)}
function Broadcasted_restrict_f(f::BasicMathFunc, args::Tuple{Vararg{<:DataOperator}}, axes)
args_ = Tuple(a.data for a=args)
return Broadcast.Broadcasted(f, args_, axes)
end
function Broadcasted_restrict_f(f, args::Tuple{Vararg{<:DataOperator}}, axes)
throw(error("Cannot broadcast function `$f` on type `$(eltype(args))`"))
end
find_basis(a::DataOperator, rest) = (a.basis_l, a.basis_r)
find_dType(a::DataOperator, rest) = eltype(a)
Base.getindex(a::DataOperator, idx) = getindex(a.data, idx)
Base.iterate(a::DataOperator) = iterate(a.data)
Base.iterate(a::DataOperator, idx) = iterate(a.data, idx)

# In-place broadcasting
@inline function Base.copyto!(dest::DataOperator{BL,BR}, bc::Broadcast.Broadcasted{Style,Axes,F,Args}) where {BL<:Basis,BR<:Basis,Style<:DataOperatorStyle{BL,BR},Axes,F,Args}
axes(dest) == axes(bc) || Base.Broadcast.throwdm(axes(dest), axes(bc))
# Performance optimization: broadcast!(identity, dest, A) is equivalent to copyto!(dest, A) if indices match
if bc.f === identity && isa(bc.args, Tuple{<:DataOperator{BL,BR}}) # only a single input argument to broadcast!
A = bc.args[1]
if axes(dest) == axes(A)
return copyto!(dest, A)
end
bc′ = Base.Broadcast.preprocess(dest, bc)
dest′ = dest.data
@inbounds @simd for I in eachindex(bc′)
dest′[I] = bc′[I]
end
# Get the underlying data fields of operators and broadcast them as arrays
bcf = Broadcast.flatten(bc)
bc_ = Broadcasted_restrict_f(bcf.f, bcf.args, axes(bcf))
copyto!(dest.data, bc_)
return dest
end
@inline Base.copyto!(A::DataOperator{BL,BR},B::DataOperator{BL,BR}) where {BL<:Basis,BR<:Basis} = (copyto!(A.data,B.data); A)
@inline Base.copyto!(dest::DataOperator{BL,BR}, bc::Broadcast.Broadcasted{Style,Axes,F,Args}) where {BL<:Basis,BR<:Basis,Style<:DataOperatorStyle,Axes,F,Args} =
throw(IncompatibleBases())

# A few more standard interfaces: These do not necessarily make sense for a StateVector, but enable transparent use of DifferentialEquations.jl
Base.eltype(::Type{Operator{Bl,Br,A}}) where {Bl,Br,N,A<:AbstractMatrix{N}} = N # ODE init
Base.any(f::Function, ρ::Operator; kwargs...) = any(f, ρ.data; kwargs...) # ODE nan checks
Base.all(f::Function, ρ::Operator; kwargs...) = all(f, ρ.data; kwargs...)
Broadcast.similar(ρ::Operator, t) = typeof(ρ)(ρ.basis_l, ρ.basis_r, copy(ρ.data))
using RecursiveArrayTools
RecursiveArrayTools.recursivecopy!(dst::Operator{Bl,Br,A},src::Operator{Bl,Br,A}) where {Bl,Br,A} = copy!(dst.data,src.data) # ODE in-place equations
86 changes: 45 additions & 41 deletions src/states.jl
Original file line number Diff line number Diff line change
Expand Up @@ -209,72 +209,76 @@ Broadcast.BroadcastStyle(::Type{<:Bra{B}}) where {B<:Basis} = BraStyle{B}()
Broadcast.BroadcastStyle(::KetStyle{B1}, ::KetStyle{B2}) where {B1<:Basis,B2<:Basis} = throw(IncompatibleBases())
Broadcast.BroadcastStyle(::BraStyle{B1}, ::BraStyle{B2}) where {B1<:Basis,B2<:Basis} = throw(IncompatibleBases())

# Broadcast with scalars (of use in ODE solvers checking for tolerances, e.g. `.* reltol .+ abstol`)
Broadcast.BroadcastStyle(::T, ::Broadcast.DefaultArrayStyle{0}) where {B<:Basis, T<:KetStyle{B}} = T()
Broadcast.BroadcastStyle(::T, ::Broadcast.DefaultArrayStyle{0}) where {B<:Basis, T<:BraStyle{B}} = T()

# Out-of-place broadcasting
@inline function Base.copy(bc::Broadcast.Broadcasted{Style,Axes,F,Args}) where {B<:Basis,Style<:KetStyle{B},Axes,F,Args<:Tuple}
bcf = Broadcast.flatten(bc)
bc_ = Broadcasted_restrict_f(bcf.f, bcf.args, axes(bcf))
b = find_basis(bcf)
return Ket{B}(b, copy(bc_))
T = find_dType(bcf)
data = zeros(T, length(b))
@inbounds @simd for I in eachindex(bcf)
data[I] = bcf[I]
end
return Ket{B}(b, data)
end
@inline function Base.copy(bc::Broadcast.Broadcasted{Style,Axes,F,Args}) where {B<:Basis,Style<:BraStyle{B},Axes,F,Args<:Tuple}
bcf = Broadcast.flatten(bc)
bc_ = Broadcasted_restrict_f(bcf.f, bcf.args, axes(bcf))
b = find_basis(bcf)
return Bra{B}(b, copy(bc_))
end
find_basis(bc::Broadcast.Broadcasted) = find_basis(bc.args)
find_basis(args::Tuple) = find_basis(find_basis(args[1]), Base.tail(args))
find_basis(x) = x
find_basis(a::StateVector, rest) = a.basis
find_basis(::Any, rest) = find_basis(rest)

const BasicMathFunc = Union{typeof(+),typeof(-),typeof(*)}
function Broadcasted_restrict_f(f::BasicMathFunc, args::Tuple{Vararg{<:T}}, axes) where T<:StateVector
args_ = Tuple(a.data for a=args)
return Broadcast.Broadcasted(f, args_, axes)
T = find_dType(bcf)
data = zeros(T, length(b))
@inbounds @simd for I in eachindex(bcf)
data[I] = bcf[I]
end
return Bra{B}(b, data)
end
function Broadcasted_restrict_f(f, args::Tuple{Vararg{<:T}}, axes) where T<:StateVector
throw(error("Cannot broadcast function `$f` on type `$T`"))
for f ∈ [:find_basis,:find_dType]
@eval ($f)(bc::Broadcast.Broadcasted) = ($f)(bc.args)
@eval ($f)(args::Tuple) = ($f)(($f)(args[1]), Base.tail(args))
@eval ($f)(x) = x
@eval ($f)(::Any, rest) = ($f)(rest)
end

find_basis(a::StateVector, rest) = a.basis
find_dType(a::StateVector, rest) = eltype(a)
Base.getindex(st::StateVector, idx) = getindex(st.data, idx)

# In-place broadcasting for Kets
@inline function Base.copyto!(dest::Ket{B}, bc::Broadcast.Broadcasted{Style,Axes,F,Args}) where {B<:Basis,Style<:KetStyle{B},Axes,F,Args}
axes(dest) == axes(bc) || Base.Broadcast.throwdm(axes(dest), axes(bc))
# Performance optimization: broadcast!(identity, dest, A) is equivalent to copyto!(dest, A) if indices match
if bc.f === identity && isa(bc.args, Tuple{<:Ket{B}}) # only a single input argument to broadcast!
A = bc.args[1]
if axes(dest) == axes(A)
return copyto!(dest, A)
end
axes(dest) == axes(bc) || throwdm(axes(dest), axes(bc))
bc′ = Base.Broadcast.preprocess(dest, bc)
dest′ = dest.data
@inbounds @simd for I in eachindex(bc′)
dest′[I] = bc′[I]
end
# Get the underlying data fields of kets and broadcast them as arrays
bcf = Broadcast.flatten(bc)
args_ = Tuple(a.data for a=bcf.args)
bc_ = Broadcast.Broadcasted(bcf.f, args_, axes(bcf))
copyto!(dest.data, bc_)
return dest
end
@inline Base.copyto!(dest::Ket{B1}, bc::Broadcast.Broadcasted{Style,Axes,F,Args}) where {B1<:Basis,B2<:Basis,Style<:KetStyle{B2},Axes,F,Args} =
throw(IncompatibleBases())

# In-place broadcasting for Bras
@inline function Base.copyto!(dest::Bra{B}, bc::Broadcast.Broadcasted{Style,Axes,F,Args}) where {B<:Basis,Style<:BraStyle{B},Axes,F,Args}
axes(dest) == axes(bc) || Base.Broadcast.throwdm(axes(dest), axes(bc))
# Performance optimization: broadcast!(identity, dest, A) is equivalent to copyto!(dest, A) if indices match
if bc.f === identity && isa(bc.args, Tuple{<:Bra{B}}) # only a single input argument to broadcast!
A = bc.args[1]
if axes(dest) == axes(A)
return copyto!(dest, A)
end
axes(dest) == axes(bc) || throwdm(axes(dest), axes(bc))
bc′ = Base.Broadcast.preprocess(dest, bc)
dest′ = dest.data
@inbounds @simd for I in eachindex(bc′)
dest′[I] = bc′[I]
end
# Get the underlying data fields of bras and broadcast them as arrays
bcf = Broadcast.flatten(bc)
bc_ = Broadcasted_restrict_f(bcf.f, bcf.args, axes(bcf))
copyto!(dest.data, bc_)
return dest
end
@inline Base.copyto!(dest::Bra{B1}, bc::Broadcast.Broadcasted{Style,Axes,F,Args}) where {B1<:Basis,B2<:Basis,Style<:BraStyle{B2},Axes,F,Args} =
throw(IncompatibleBases())

@inline Base.copyto!(A::T,B::T) where T<:StateVector = (copyto!(A.data,B.data); A)

# A few more standard interfaces: These do not necessarily make sense for a StateVector, but enable transparent use of DifferentialEquations.jl
Base.eltype(::Type{Ket{B,A}}) where {B,N,A<:AbstractVector{N}} = N # ODE init
Base.eltype(::Type{Bra{B,A}}) where {B,N,A<:AbstractVector{N}} = N
Base.zero(k::StateVector) = typeof(k)(k.basis, zero(k.data)) # ODE init
Base.any(f::Function, x::StateVector; kwargs...) = any(f, x.data; kwargs...) # ODE nan checks
Base.all(f::Function, x::StateVector; kwargs...) = all(f, x.data; kwargs...)
Broadcast.similar(k::StateVector, t) = typeof(k)(k.basis, copy(k.data))
using RecursiveArrayTools
RecursiveArrayTools.recursivecopy!(dst::Ket{B,A},src::Ket{B,A}) where {B,A} = copy!(dst.data,src.data) # ODE in-place equations
RecursiveArrayTools.recursivecopy!(dst::Bra{B,A},src::Bra{B,A}) where {B,A} = copy!(dst.data,src.data)
2 changes: 1 addition & 1 deletion src/superoperators.jl
Original file line number Diff line number Diff line change
Expand Up @@ -232,7 +232,7 @@ end
# end
find_basis(a::SuperOperator, rest) = (a.basis_l, a.basis_r)

const BasicMathFunc = Union{typeof(+),typeof(-),typeof(*)}
const BasicMathFunc = Union{typeof(+),typeof(-),typeof(*),typeof(/)}
function Broadcasted_restrict_f(f::BasicMathFunc, args::Tuple{Vararg{<:SuperOperator}}, axes)
args_ = Tuple(a.data for a=args)
return Broadcast.Broadcasted(f, args_, axes)
Expand Down
1 change: 0 additions & 1 deletion test/test_abstractdata.jl
Original file line number Diff line number Diff line change
Expand Up @@ -340,7 +340,6 @@ op1 .= op1_ .+ 3 * op1_
bf = FockBasis(3)
op3 = randtestoperator(bf)
@test_throws QuantumOpticsBase.IncompatibleBases op1 .+ op3
@test_throws ErrorException cos.(op1)

####################
# Test lazy tensor #
Expand Down
6 changes: 5 additions & 1 deletion test/test_operators_dense.jl
Original file line number Diff line number Diff line change
Expand Up @@ -360,6 +360,10 @@ op1 .= op1_ .+ 3 * op1_
bf = FockBasis(3)
op3 = randoperator(bf)
@test_throws QuantumOpticsBase.IncompatibleBases op1 .+ op3
@test_throws ErrorException cos.(op1)
@test op3 * 2 == op3 .+ op3
z = zero(op3)
z .= op3 .* 3
@test z == op3 .* 2 .+ op3
@test_broken all(z .== op3 .* 2 .+ op3)

end # testset
1 change: 0 additions & 1 deletion test/test_operators_sparse.jl
Original file line number Diff line number Diff line change
Expand Up @@ -386,6 +386,5 @@ op3 = sprandop(FockBasis(1),FockBasis(2))
op_ = copy(op1)
op_ .+= op1
@test op_ == 2*op1
@test_throws ErrorException cos.(op_)

end # testset
15 changes: 13 additions & 2 deletions test/test_states.jl
Original file line number Diff line number Diff line change
Expand Up @@ -161,7 +161,18 @@ psi_ .+= psi123
bra_ = copy(bra123)
bra_ .= 3*bra123
@test bra_ == 3*dagger(psi123)
@test_throws ErrorException cos.(psi_)
@test_throws ErrorException cos.(bra_)
@test bra_ .* 2 == bra_ .+ bra_
@test bra_ * 2 == bra_ .+ bra_
z = zero(bra_)
z .= bra_ .* 2
@test_broken all(z .== bra_ .+ bra_)
@test z == bra_ .+ bra_
ket_ = bra_'
@test ket_ .* 2 == ket_ .+ ket_
@test ket_ * 2 == ket_ .+ ket_
z = zero(ket_)
z .= ket_ .* 2
@test_broken all(z .== ket_.+ ket_)
@test z == ket_ .+ ket_

end # testset