Skip to content

Commit

Permalink
Don't tuplify time-dependent ops if types are homogenous (#401)
Browse files Browse the repository at this point in the history
  • Loading branch information
amilsted committed Jul 10, 2024
1 parent 1b634ef commit 6f3b6b4
Show file tree
Hide file tree
Showing 2 changed files with 74 additions and 13 deletions.
52 changes: 39 additions & 13 deletions src/time_dependent_operators.jl
Original file line number Diff line number Diff line change
@@ -1,7 +1,21 @@
# Convert storage of heterogeneous stuff to tuples for maximal compilation
# and to avoid runtime dispatch.
_tuplify(o::TimeDependentSum) = TimeDependentSum(Tuple, o)
_tuplify(o::LazySum) = LazySum(eltype(o.factors), o.factors, (o.operators...,))
function _tuplify(o::TimeDependentSum)
if isconcretetype(eltype(o.coefficients)) && isconcretetype(eltype(o.static_op.operators))
# No need to tuplify is types are concrete.
# We will save on compile time this way.
return o
end
return TimeDependentSum(Tuple, o)
end
function _tuplify(o::LazySum)
if isconcretetype(eltype(o.factors)) && isconcretetype(eltype(o.operators))
return o
end
return LazySum(eltype(o.factors), o.factors, (o.operators...,))
end
_tuplify(o::AbstractVector{T}) where T = isconcretetype(T) ? o : (o...,)
_tuplify(o::Tuple) = o
_tuplify(o::AbstractOperator) = o

"""
Expand All @@ -23,6 +37,7 @@ function _tdopdagger(o::TimeDependentSum)
# that requires that the original operator sticks around and is always
# updated first (though this is checked).
# Copies and conjugates the coefficients from the original op.
# TODO: Make an Adjoint wrapper for TimeDependentSum instead?
o_ls = QuantumOpticsBase.static_operator(o)
facs = o_ls.factors
c1 = (t)->(@assert current_time(o) == t; conj(facs[1]))
Expand All @@ -43,13 +58,18 @@ operators.
"""
function master_h_dynamic_function(H::AbstractTimeDependentOperator, Js)
Htup = _tuplify(H)
Js_tup = ((_tuplify(J) for J in Js)...,)

Jdags_tup = _tdopdagger.(Js_tup)
function _getfunc(Hop, Jops, Jdops)
return (@inline _tdop_master_wrapper_1(t, _) = (set_time!(Hop, t), set_time!.(Jops, t), set_time!.(Jdops, t)))
Js_tup = _tuplify(map(_tuplify, Js))
Jdags_tup = map(_tdopdagger, Js_tup)

return let Hop = Htup, Jops = Js_tup, Jdops = Jdags_tup
function _tdop_master_wrapper_1(t, _)
f = Base.Fix2(set_time!, t)
foreach(f, Jops)
foreach(f, Jdops)
set_time!(Hop, t)
return Hop, Jops, Jdops
end
end
return _getfunc(Htup, Js_tup, Jdags_tup)
end

"""
Expand All @@ -64,15 +84,21 @@ where `Hnh` is represents the non-Hermitian Hamiltonian and `Js` are the
"""
function master_nh_dynamic_function(Hnh::AbstractTimeDependentOperator, Js)
Hnhtup = _tuplify(Hnh)
Js_tup = ((_tuplify(J) for J in Js)...,)
Js_tup = _tuplify(map(_tuplify, Js))

Jdags_tup = _tdopdagger.(Js_tup)
Jdags_tup = map(_tdopdagger, Js_tup)
Htdagup = _tdopdagger(Hnhtup)

function _getfunc(Hop, Hdop, Jops, Jdops)
return (@inline _tdop_master_wrapper_2(t, _) = (set_time!(Hop, t), set_time!(Hdop, t), set_time!.(Jops, t), set_time!.(Jdops, t)))
return let Hop = Hnhtup, Hdop = Htdagup, Jops = Js_tup, Jdops = Jdags_tup
function _tdop_master_wrapper_2(t, _)
f = Base.Fix2(set_time!, t)
foreach(f, Jops)
foreach(f, Jdops)
set_time!(Hop, t)
set_time!(Hdop, t)
return Hop, Hdop, Jops, Jdops
end
end
return _getfunc(Hnhtup, Htdagup, Js_tup, Jdags_tup)
end

"""
Expand Down
35 changes: 35 additions & 0 deletions test/test_timeevolution_tdops.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,13 @@
using Test
using QuantumOptics

function test_settime(op)
t = current_time(op)
set_time!(op, randn())
set_time!(op, t)
return nothing
end

@testset "time-dependent operators" begin

b = FockBasis(7)
Expand All @@ -9,8 +16,26 @@ a = destroy(b)

H0 = number(b)
Hd = (a + a')

# function and op types homogeneous
H = TimeDependentSum(cos=>H0, cos=>Hd)
Htup = timeevolution._tuplify(H)
@test Htup === H
test_settime(Htup)
@test (@allocated(test_settime)) == 0

# op types not homogeneous
H = TimeDependentSum(cos=>H0, cos=>dense(Hd), cos=>LazySum(H0), cos=>LazySum(dense(Hd)))
Htup = timeevolution._tuplify(H)
@test Htup !== H
test_settime(Htup)
@test (@allocated(test_settime)) == 0

H = TimeDependentSum(1.0=>H0, cos=>Hd)

# function types not homogeneous
@test timeevolution._tuplify(H) !== H

ts = [0.0, 0.4]
ts_half = 0.5 * ts

Expand Down Expand Up @@ -54,6 +79,8 @@ ts_out, rhos = timeevolution.master_dynamic(ts, psi0, H, Js)
ts_out2, rhos2 = timeevolution.master_dynamic(ts, psi0, fman)
@test rhos[end].data rhos2[end].data

set_time!(H, 0.0)
set_time!.(Js, 0.0)
Hnh = H - 0.5im * sum(J' * J for J in Js)

_getf = (H0, Hd, a) -> (t,_) -> (
Expand Down Expand Up @@ -90,4 +117,12 @@ allocs1 = @allocated timeevolution.master_nh_dynamic(ts, psi0, Hnh, Js)
allocs2 = @allocated timeevolution.master_nh_dynamic(ts_half, psi0, Hnh, Js)
@test allocs1 == allocs2

Jstup = (Js...,)
ts_out2, rhos2 = timeevolution.master_nh_dynamic(ts, psi0, Hnh, Jstup)
@test rhos[end].data rhos2[end].data

allocs1 = @allocated timeevolution.master_nh_dynamic(ts, psi0, Hnh, Jstup)
allocs2 = @allocated timeevolution.master_nh_dynamic(ts_half, psi0, Hnh, Jstup)
@test allocs1 == allocs2

end

0 comments on commit 6f3b6b4

Please sign in to comment.