Skip to content

Commit

Permalink
Merge pull request #2556 from mmesiti/multithreaded-ABM
Browse files Browse the repository at this point in the history
Added threading to ABM algorithms using @..
  • Loading branch information
ChrisRackauckas authored Dec 20, 2024
2 parents 845ee9d + c56813f commit fdc341a
Show file tree
Hide file tree
Showing 6 changed files with 242 additions and 144 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ get_fsalfirstlast(cache::ABMMutableCache, u) = (cache.fsalfirst, cache.k)
function get_fsalfirstlast(cache::ABMVariableCoefficientMutableCache, u)
(cache.fsalfirst, cache.k4)
end
@cache mutable struct AB3Cache{uType, rateType} <: ABMMutableCache
@cache mutable struct AB3Cache{uType, rateType, Thread} <: ABMMutableCache
u::uType
uprev::uType
fsalfirst::rateType
Expand All @@ -14,6 +14,7 @@ end
k::rateType
tmp::uType
step::Int
thread::Thread
end

@cache mutable struct AB3ConstantCache{rateType} <: OrdinaryDiffEqConstantCache
Expand All @@ -32,7 +33,7 @@ function alg_cache(alg::AB3, u, rate_prototype, ::Type{uEltypeNoUnits},
ralk2 = zero(rate_prototype)
k = zero(rate_prototype)
tmp = zero(u)
AB3Cache(u, uprev, fsalfirst, k2, k3, ralk2, k, tmp, 1)
AB3Cache(u, uprev, fsalfirst, k2, k3, ralk2, k, tmp, 1, alg.thread)
end

function alg_cache(alg::AB3, u, rate_prototype, ::Type{uEltypeNoUnits},
Expand All @@ -44,7 +45,7 @@ function alg_cache(alg::AB3, u, rate_prototype, ::Type{uEltypeNoUnits},
AB3ConstantCache(k2, k3, 1)
end

@cache mutable struct ABM32Cache{uType, rateType} <: ABMMutableCache
@cache mutable struct ABM32Cache{uType, rateType, Thread} <: ABMMutableCache
u::uType
uprev::uType
fsalfirst::rateType
Expand All @@ -54,6 +55,7 @@ end
k::rateType
tmp::uType
step::Int
thread::Thread
end

@cache mutable struct ABM32ConstantCache{rateType} <: OrdinaryDiffEqConstantCache
Expand All @@ -72,7 +74,7 @@ function alg_cache(alg::ABM32, u, rate_prototype, ::Type{uEltypeNoUnits},
ralk2 = zero(rate_prototype)
k = zero(rate_prototype)
tmp = zero(u)
ABM32Cache(u, uprev, fsalfirst, k2, k3, ralk2, k, tmp, 1)
ABM32Cache(u, uprev, fsalfirst, k2, k3, ralk2, k, tmp, 1, alg.thread)
end

function alg_cache(alg::ABM32, u, rate_prototype, ::Type{uEltypeNoUnits},
Expand All @@ -84,7 +86,7 @@ function alg_cache(alg::ABM32, u, rate_prototype, ::Type{uEltypeNoUnits},
ABM32ConstantCache(k2, k3, 1)
end

@cache mutable struct AB4Cache{uType, rateType} <: ABMMutableCache
@cache mutable struct AB4Cache{uType, rateType, Thread} <: ABMMutableCache
u::uType
uprev::uType
fsalfirst::rateType
Expand All @@ -98,6 +100,7 @@ end
t3::rateType
t4::rateType
step::Int
thread::Thread
end

@cache mutable struct AB4ConstantCache{rateType} <: OrdinaryDiffEqConstantCache
Expand All @@ -121,7 +124,7 @@ function alg_cache(alg::AB4, u, rate_prototype, ::Type{uEltypeNoUnits},
t2 = zero(rate_prototype)
t3 = zero(rate_prototype)
t4 = zero(rate_prototype)
AB4Cache(u, uprev, fsalfirst, k2, k3, k4, ralk2, k, tmp, t2, t3, t4, 1)
AB4Cache(u, uprev, fsalfirst, k2, k3, k4, ralk2, k, tmp, t2, t3, t4, 1, alg.thread)
end

function alg_cache(alg::AB4, u, rate_prototype, ::Type{uEltypeNoUnits},
Expand All @@ -134,7 +137,7 @@ function alg_cache(alg::AB4, u, rate_prototype, ::Type{uEltypeNoUnits},
AB4ConstantCache(k2, k3, k4, 1)
end

@cache mutable struct ABM43Cache{uType, rateType} <: ABMMutableCache
@cache mutable struct ABM43Cache{uType, rateType, Thread} <: ABMMutableCache
u::uType
uprev::uType
fsalfirst::rateType
Expand All @@ -151,6 +154,7 @@ end
t6::rateType
t7::rateType
step::Int
thread::Thread
end

@cache mutable struct ABM43ConstantCache{rateType} <: OrdinaryDiffEqConstantCache
Expand All @@ -177,7 +181,8 @@ function alg_cache(alg::ABM43, u, rate_prototype, ::Type{uEltypeNoUnits},
t5 = zero(rate_prototype)
t6 = zero(rate_prototype)
t7 = zero(rate_prototype)
ABM43Cache(u, uprev, fsalfirst, k2, k3, k4, ralk2, k, tmp, t2, t3, t4, t5, t6, t7, 1)
ABM43Cache(u, uprev, fsalfirst, k2, k3, k4, ralk2, k,
tmp, t2, t3, t4, t5, t6, t7, 1, alg.thread)
end

function alg_cache(alg::ABM43, u, rate_prototype, ::Type{uEltypeNoUnits},
Expand All @@ -190,7 +195,7 @@ function alg_cache(alg::ABM43, u, rate_prototype, ::Type{uEltypeNoUnits},
ABM43ConstantCache(k2, k3, k4, 1)
end

@cache mutable struct AB5Cache{uType, rateType} <: ABMMutableCache
@cache mutable struct AB5Cache{uType, rateType, Thread} <: ABMMutableCache
u::uType
uprev::uType
fsalfirst::rateType
Expand All @@ -204,6 +209,7 @@ end
t3::rateType
t4::rateType
step::Int
thread::Thread
end

@cache mutable struct AB5ConstantCache{rateType} <: OrdinaryDiffEqConstantCache
Expand All @@ -228,7 +234,7 @@ function alg_cache(alg::AB5, u, rate_prototype, ::Type{uEltypeNoUnits},
t2 = zero(rate_prototype)
t3 = zero(rate_prototype)
t4 = zero(rate_prototype)
AB5Cache(u, uprev, fsalfirst, k2, k3, k4, k5, k, tmp, t2, t3, t4, 1)
AB5Cache(u, uprev, fsalfirst, k2, k3, k4, k5, k, tmp, t2, t3, t4, 1, alg.thread)
end

function alg_cache(alg::AB5, u, rate_prototype, ::Type{uEltypeNoUnits},
Expand All @@ -242,7 +248,7 @@ function alg_cache(alg::AB5, u, rate_prototype, ::Type{uEltypeNoUnits},
AB5ConstantCache(k2, k3, k4, k5, 1)
end

@cache mutable struct ABM54Cache{uType, rateType} <: ABMMutableCache
@cache mutable struct ABM54Cache{uType, rateType, Thread} <: ABMMutableCache
u::uType
uprev::uType
fsalfirst::rateType
Expand All @@ -260,6 +266,7 @@ end
t7::rateType
t8::rateType
step::Int
thread::Thread
end

@cache mutable struct ABM54ConstantCache{rateType} <: OrdinaryDiffEqConstantCache
Expand Down Expand Up @@ -288,7 +295,8 @@ function alg_cache(alg::ABM54, u, rate_prototype, ::Type{uEltypeNoUnits},
t6 = zero(rate_prototype)
t7 = zero(rate_prototype)
t8 = zero(rate_prototype)
ABM54Cache(u, uprev, fsalfirst, k2, k3, k4, k5, k, tmp, t2, t3, t4, t5, t6, t7, t8, 1)
ABM54Cache(u, uprev, fsalfirst, k2, k3, k4, k5, k, tmp,
t2, t3, t4, t5, t6, t7, t8, 1, alg.thread)
end

function alg_cache(alg::ABM54, u, rate_prototype, ::Type{uEltypeNoUnits},
Expand Down Expand Up @@ -317,7 +325,7 @@ end
end

@cache mutable struct VCAB3Cache{uType, rateType, TabType, bs3Type, tArrayType, cArrayType,
uNoUnitsType, coefType, dtArrayType} <:
uNoUnitsType, coefType, dtArrayType, Thread} <:
ABMVariableCoefficientMutableCache
u::uType
uprev::uType
Expand All @@ -337,6 +345,7 @@ end
utilde::uType
tab::TabType
step::Int
thread::Thread
end

function alg_cache(alg::VCAB3, u, rate_prototype, ::Type{uEltypeNoUnits},
Expand Down Expand Up @@ -395,7 +404,7 @@ function alg_cache(alg::VCAB3, u, rate_prototype, ::Type{uEltypeNoUnits},
tmp = zero(u)
utilde = zero(u)
VCAB3Cache(u, uprev, fsalfirst, bs3cache, k4, ϕstar_nm1, dts, c, g, ϕ_n, ϕstar_n, β,
order, atmp, tmp, utilde, tab, 1)
order, atmp, tmp, utilde, tab, 1, alg.thread)
end

@cache mutable struct VCAB4ConstantCache{rk4constcache, tArrayType, rArrayType, cArrayType,
Expand All @@ -413,7 +422,7 @@ end
end

@cache mutable struct VCAB4Cache{uType, rateType, rk4cacheType, tArrayType, cArrayType,
uNoUnitsType, coefType, dtArrayType} <:
uNoUnitsType, coefType, dtArrayType, Thread} <:
ABMVariableCoefficientMutableCache
u::uType
uprev::uType
Expand All @@ -432,6 +441,7 @@ end
tmp::uType
utilde::uType
step::Int
thread::Thread
end

function alg_cache(alg::VCAB4, u, rate_prototype, ::Type{uEltypeNoUnits},
Expand Down Expand Up @@ -489,7 +499,7 @@ function alg_cache(alg::VCAB4, u, rate_prototype, ::Type{uEltypeNoUnits},
tmp = zero(u)
utilde = zero(u)
VCAB4Cache(u, uprev, fsalfirst, rk4cache, k4, ϕstar_nm1, dts, c, g, ϕ_n, ϕstar_n, β,
order, atmp, tmp, utilde, 1)
order, atmp, tmp, utilde, 1, alg.thread)
end

# VCAB5
Expand All @@ -509,7 +519,7 @@ end
end

@cache mutable struct VCAB5Cache{uType, rateType, rk4cacheType, tArrayType, cArrayType,
uNoUnitsType, coefType, dtArrayType} <:
uNoUnitsType, coefType, dtArrayType, Thread} <:
ABMVariableCoefficientMutableCache
u::uType
uprev::uType
Expand All @@ -528,6 +538,7 @@ end
tmp::uType
utilde::uType
step::Int
thread::Thread
end

function alg_cache(alg::VCAB5, u, rate_prototype, ::Type{uEltypeNoUnits},
Expand Down Expand Up @@ -585,7 +596,7 @@ function alg_cache(alg::VCAB5, u, rate_prototype, ::Type{uEltypeNoUnits},
tmp = zero(u)
utilde = zero(u)
VCAB5Cache(u, uprev, fsalfirst, rk4cache, k4, ϕstar_nm1, dts, c, g, ϕ_n, ϕstar_n, β,
order, atmp, tmp, utilde, 1)
order, atmp, tmp, utilde, 1, alg.thread)
end

# VCABM3
Expand All @@ -607,7 +618,7 @@ end

@cache mutable struct VCABM3Cache{
uType, rateType, TabType, bs3Type, tArrayType, cArrayType,
uNoUnitsType, coefType, dtArrayType} <:
uNoUnitsType, coefType, dtArrayType, Thread} <:
ABMVariableCoefficientMutableCache
u::uType
uprev::uType
Expand All @@ -628,6 +639,7 @@ end
utilde::uType
tab::TabType
step::Int
thread::Thread
end

function alg_cache(alg::VCABM3, u, rate_prototype, ::Type{uEltypeNoUnits},
Expand Down Expand Up @@ -691,7 +703,7 @@ function alg_cache(alg::VCABM3, u, rate_prototype, ::Type{uEltypeNoUnits},
tmp = zero(u)
utilde = zero(u)
VCABM3Cache(u, uprev, fsalfirst, bs3cache, k4, ϕstar_nm1, dts, c, g, ϕ_n, ϕ_np1,
ϕstar_n, β, order, atmp, tmp, utilde, tab, 1)
ϕstar_n, β, order, atmp, tmp, utilde, tab, 1, alg.thread)
end

# VCABM4
Expand All @@ -713,7 +725,7 @@ end
end

@cache mutable struct VCABM4Cache{uType, rateType, rk4cacheType, tArrayType, cArrayType,
uNoUnitsType, coefType, dtArrayType} <:
uNoUnitsType, coefType, dtArrayType, Thread} <:
ABMVariableCoefficientMutableCache
u::uType
uprev::uType
Expand All @@ -733,6 +745,7 @@ end
tmp::uType
utilde::uType
step::Int
thread::Thread
end

function alg_cache(alg::VCABM4, u, rate_prototype, ::Type{uEltypeNoUnits},
Expand Down Expand Up @@ -796,7 +809,7 @@ function alg_cache(alg::VCABM4, u, rate_prototype, ::Type{uEltypeNoUnits},
tmp = zero(u)
utilde = zero(u)
VCABM4Cache(u, uprev, fsalfirst, rk4cache, k4, ϕstar_nm1, dts, c, g, ϕ_n, ϕ_np1,
ϕstar_n, β, order, atmp, tmp, utilde, 1)
ϕstar_n, β, order, atmp, tmp, utilde, 1, alg.thread)
end

# VCABM5
Expand All @@ -818,7 +831,7 @@ end
end

@cache mutable struct VCABM5Cache{uType, rateType, rk4cacheType, tArrayType, cArrayType,
uNoUnitsType, coefType, dtArrayType} <:
uNoUnitsType, coefType, dtArrayType, Thread} <:
ABMVariableCoefficientMutableCache
u::uType
uprev::uType
Expand All @@ -838,6 +851,7 @@ end
tmp::uType
utilde::uType
step::Int
thread::Thread
end

function alg_cache(alg::VCABM5, u, rate_prototype, ::Type{uEltypeNoUnits},
Expand Down Expand Up @@ -901,7 +915,7 @@ function alg_cache(alg::VCABM5, u, rate_prototype, ::Type{uEltypeNoUnits},
tmp = zero(u)
utilde = zero(u)
VCABM5Cache(u, uprev, fsalfirst, rk4cache, k4, ϕstar_nm1, dts, c, g, ϕ_n, ϕ_np1,
ϕstar_n, β, order, atmp, tmp, utilde, 1)
ϕstar_n, β, order, atmp, tmp, utilde, 1, alg.thread)
end

# VCABM
Expand All @@ -924,7 +938,7 @@ end
end

@cache mutable struct VCABMCache{uType, rateType, dtType, tArrayType, cArrayType,
uNoUnitsType, coefType, dtArrayType} <:
uNoUnitsType, coefType, dtArrayType, Thread} <:
ABMVariableCoefficientMutableCache
u::uType
uprev::uType
Expand Down Expand Up @@ -952,6 +966,7 @@ end
atmpm2::uNoUnitsType
atmpp1::uNoUnitsType
step::Int
thread::Thread
end

function alg_cache(alg::VCABM, u, rate_prototype, ::Type{uEltypeNoUnits},
Expand Down Expand Up @@ -1023,5 +1038,5 @@ function alg_cache(alg::VCABM, u, rate_prototype, ::Type{uEltypeNoUnits},
VCABMCache(
u, uprev, fsalfirst, k4, ϕstar_nm1, dts, c, g, ϕ_n, ϕ_np1, ϕstar_n, β, order,
max_order, atmp, tmp, ξ, ξ0, utilde, utildem1, utildem2, utildep1, atmpm1,
atmpm2, atmpp1, 1)
atmpm2, atmpp1, 1, alg.thread)
end
Loading

0 comments on commit fdc341a

Please sign in to comment.