Skip to content

Commit

Permalink
Merge pull request #2539 from Shreyas-Ekanathan/master
Browse files Browse the repository at this point in the history
Speed up Radau
  • Loading branch information
ChrisRackauckas authored Nov 19, 2024
2 parents 34a49c1 + b6b86de commit 52bcf16
Show file tree
Hide file tree
Showing 6 changed files with 100 additions and 67 deletions.
8 changes: 4 additions & 4 deletions lib/OrdinaryDiffEqFIRK/src/algorithms.jl
Original file line number Diff line number Diff line change
Expand Up @@ -163,13 +163,13 @@ struct AdaptiveRadau{CS, AD, F, P, FDT, ST, CJ, Tol, C1, C2, StepLimiter} <:
new_W_γdt_cutoff::C2
controller::Symbol
step_limiter!::StepLimiter
min_stages::Int
max_stages::Int
min_order::Int
max_order::Int
end

function AdaptiveRadau(; chunk_size = Val{0}(), autodiff = Val{true}(),
standardtag = Val{true}(), concrete_jac = nothing,
diff_type = Val{:forward}, min_stages = 3, max_stages = 7,
diff_type = Val{:forward}, min_order = 5, max_order = 13,
linsolve = nothing, precs = DEFAULT_PRECS,
extrapolant = :dense, fast_convergence_cutoff = 1 // 5,
new_W_γdt_cutoff = 1 // 5,
Expand All @@ -187,6 +187,6 @@ function AdaptiveRadau(; chunk_size = Val{0}(), autodiff = Val{true}(),
fast_convergence_cutoff,
new_W_γdt_cutoff,
controller,
step_limiter!, min_stages, max_stages)
step_limiter!, min_order, max_order)
end

9 changes: 6 additions & 3 deletions lib/OrdinaryDiffEqFIRK/src/controllers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -22,12 +22,14 @@ function step_accept_controller!(integrator, controller::PredictiveController, a
cache.step = step + 1
hist_iter = hist_iter * 0.8 + iter * 0.2
cache.hist_iter = hist_iter
max_stages = (alg.max_order - 1) ÷ 4 * 2 + 1
min_stages = (alg.min_order - 1) ÷ 4 * 2 + 1
if (step > 10)
if (hist_iter < 2.6 && num_stages < alg.max_stages)
if (hist_iter < 2.6 && num_stages <= max_stages)
cache.num_stages += 2
cache.step = 1
cache.hist_iter = iter
elseif ((hist_iter > 8 || cache.status == VerySlowConvergence || cache.status == Divergence) && num_stages > alg.min_stages)
elseif ((hist_iter > 8 || cache.status == VerySlowConvergence || cache.status == Divergence) && num_stages >= min_stages)
cache.num_stages -= 2
cache.step = 1
cache.hist_iter = iter
Expand All @@ -44,8 +46,9 @@ function step_reject_controller!(integrator, controller::PredictiveController, a
cache.step = step + 1
hist_iter = hist_iter * 0.8 + iter * 0.2
cache.hist_iter = hist_iter
min_stages = (alg.min_order - 1) ÷ 4 * 2 + 1
if (step > 10)
if ((hist_iter > 8 || cache.status == VerySlowConvergence || cache.status == Divergence) && num_stages > alg.min_stages)
if ((hist_iter > 8 || cache.status == VerySlowConvergence || cache.status == Divergence) && num_stages >= min_stages)
cache.num_stages -= 2
cache.step = 1
cache.hist_iter = iter
Expand Down
50 changes: 40 additions & 10 deletions lib/OrdinaryDiffEqFIRK/src/firk_caches.jl
Original file line number Diff line number Diff line change
Expand Up @@ -362,6 +362,10 @@ mutable struct RadauIIA9Cache{uType, cuType, uNoUnitsType, rateType, JType, W1Ty
tmp4::uType
tmp5::uType
tmp6::uType
tmp7::uType
tmp8::uType
tmp9::uType
tmp10::uType
atmp::uNoUnitsType
jac_config::JC
linsolve1::F1
Expand Down Expand Up @@ -440,6 +444,10 @@ function alg_cache(alg::RadauIIA9, u, rate_prototype, ::Type{uEltypeNoUnits},
tmp4 = zero(u)
tmp5 = zero(u)
tmp6 = zero(u)
tmp7 = zero(u)
tmp8 = zero(u)
tmp9 = zero(u)
tmp10 = zero(u)
atmp = similar(u, uEltypeNoUnits)
recursivefill!(atmp, false)
jac_config = build_jac_config(alg, f, uf, du1, uprev, u, tmp, dw1)
Expand Down Expand Up @@ -469,7 +477,7 @@ function alg_cache(alg::RadauIIA9, u, rate_prototype, ::Type{uEltypeNoUnits},
du1, fsalfirst, k, k2, k3, k4, k5, fw1, fw2, fw3, fw4, fw5,
J, W1, W2, W3,
uf, tab, κ, one(uToltype), 10000,
tmp, tmp2, tmp3, tmp4, tmp5, tmp6, atmp, jac_config,
tmp, tmp2, tmp3, tmp4, tmp5, tmp6, tmp7, tmp8, tmp9, tmp10, atmp, jac_config,
linsolve1, linsolve2, linsolve3, rtol, atol, dt, dt,
Convergence, alg.step_limiter!)
end
Expand Down Expand Up @@ -497,17 +505,26 @@ function alg_cache(alg::AdaptiveRadau, u, rate_prototype, ::Type{uEltypeNoUnits}
::Val{false}) where {uEltypeNoUnits, uBottomEltypeNoUnits, tTypeNoUnits}
uf = UDerivativeWrapper(f, t, p)
uToltype = constvalue(uBottomEltypeNoUnits)
num_stages = alg.min_stages
max = alg.max_stages

max_order = alg.max_order
min_order = alg.min_order
max = (max_order - 1) ÷ 4 * 2 + 1
min = (min_order - 1) ÷ 4 * 2 + 1
if (alg.min_order < 5)
error("min_order choice $min_order below 5 is not compatible with the algorithm")
elseif (max < min)
error("max_order $max_order is below min_order $min_order")
end
num_stages = min

tabs = [BigRadauIIA5Tableau(uToltype, constvalue(tTypeNoUnits)), BigRadauIIA9Tableau(uToltype, constvalue(tTypeNoUnits)), BigRadauIIA13Tableau(uToltype, constvalue(tTypeNoUnits))]

i = 9
while i <= alg.max_stages
while i <= max
push!(tabs, adaptiveRadauTableau(uToltype, constvalue(tTypeNoUnits), i))
i += 2
end
cont = Vector{typeof(u)}(undef, max)
for i in 1: max
for i in 1:max
cont[i] = zero(u)
end

Expand All @@ -525,6 +542,8 @@ mutable struct AdaptiveRadauCache{uType, cuType, tType, uNoUnitsType, rateType,
z::Vector{uType}
w::Vector{uType}
c_prime::Vector{tType}
αdt::Vector{tType}
βdt::Vector{tType}
dw1::uType
ubuff::uType
dw2::Vector{cuType}
Expand Down Expand Up @@ -568,8 +587,16 @@ function alg_cache(alg::AdaptiveRadau, u, rate_prototype, ::Type{uEltypeNoUnits}
uf = UJacobianWrapper(f, t, p)
uToltype = constvalue(uBottomEltypeNoUnits)

max = alg.max_stages
num_stages = alg.min_stages
max_order = alg.max_order
min_order = alg.min_order
max = (max_order - 1) ÷ 4 * 2 + 1
min = (min_order - 1) ÷ 4 * 2 + 1
if (alg.min_order < 5)
error("min_order choice $min_order below 5 is not compatible with the algorithm")
elseif (max < min)
error("max_order $max_order is below min_order $min_order")
end
num_stages = min

tabs = [BigRadauIIA5Tableau(uToltype, constvalue(tTypeNoUnits)), BigRadauIIA9Tableau(uToltype, constvalue(tTypeNoUnits)), BigRadauIIA13Tableau(uToltype, constvalue(tTypeNoUnits))]
i = 9
Expand All @@ -583,9 +610,12 @@ function alg_cache(alg::AdaptiveRadau, u, rate_prototype, ::Type{uEltypeNoUnits}
z = Vector{typeof(u)}(undef, max)
w = Vector{typeof(u)}(undef, max)
for i in 1 : max
z[i] = w[i] = zero(u)
z[i] = zero(u)
w[i] = zero(u)
end

αdt = [zero(t) for i in 1:max]
βdt = [zero(t) for i in 1:max]
c_prime = Vector{typeof(t)}(undef, max) #time stepping
for i in 1 : max
c_prime[i] = zero(t)
Expand Down Expand Up @@ -641,7 +671,7 @@ function alg_cache(alg::AdaptiveRadau, u, rate_prototype, ::Type{uEltypeNoUnits}
atol = reltol isa Number ? reltol : zero(reltol)

AdaptiveRadauCache(u, uprev,
z, w, c_prime, dw1, ubuff, dw2, cubuff, dw, cont, derivatives,
z, w, c_prime, αdt, βdt, dw1, ubuff, dw2, cubuff, dw, cont, derivatives,
du1, fsalfirst, ks, k, fw,
J, W1, W2,
uf, tabs, κ, one(uToltype), 10000, tmp,
Expand Down
88 changes: 44 additions & 44 deletions lib/OrdinaryDiffEqFIRK/src/firk_perform_step.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1032,7 +1032,7 @@ end
@unpack dw1, ubuff, dw23, dw45, cubuff1, cubuff2 = cache
@unpack k, k2, k3, k4, k5, fw1, fw2, fw3, fw4, fw5 = cache
@unpack J, W1, W2, W3 = cache
@unpack tmp, tmp2, tmp3, tmp4, tmp5, tmp6, atmp, jac_config, linsolve1, linsolve2, linsolve3, rtol, atol, step_limiter! = cache
@unpack tmp, tmp2, tmp3, tmp4, tmp5, tmp6, tmp7, tmp8, tmp9, tmp10, atmp, jac_config, linsolve1, linsolve2, linsolve3, rtol, atol, step_limiter! = cache
@unpack internalnorm, abstol, reltol, adaptive = integrator.opts
alg = unwrap_alg(integrator, true)
@unpack maxiters = alg
Expand Down Expand Up @@ -1087,30 +1087,30 @@ end
c2′ = c2 * c5′
c3′ = c3 * c5′
c4′ = c4 * c5′
z1 = @.. c1′ * (cont1 +
@.. z1 = c1′ * (cont1 +
(c1′-c4m1) * (cont2 +
(c1′ - c3m1) * (cont3 +
(c1′ - c2m1) * (cont4 + (c1′ - c1m1) * cont5))))
z2 = @.. c2′ * (cont1 +
@.. z2 = c2′ * (cont1 +
(c2′-c4m1) * (cont2 +
(c2′ - c3m1) * (cont3 +
(c2′ - c2m1) * (cont4 + (c2′ - c1m1) * cont5))))
z3 = @.. c3′ * (cont1 +
@.. z3 = c3′ * (cont1 +
(c3′-c4m1) * (cont2 +
(c3′ - c3m1) * (cont3 +
(c3′ - c2m1) * (cont4 + (c3′ - c1m1) * cont5))))
z4 = @.. c4′ * (cont1 +
@.. z4 = c4′ * (cont1 +
(c4′-c4m1) * (cont2 +
(c4′ - c3m1) * (cont3 +
(c4′ - c2m1) * (cont4 + (c4′ - c1m1) * cont5))))
z5 = @.. c5′ * (cont1 +
@.. z5 = c5′ * (cont1 +
(c5′-c4m1) * (cont2 +
(c5′ - c3m1) * (cont3 + (c5′ - c2m1) * (cont4 + (c5′ - c1m1) * cont5))))
w1 = @.. broadcast=false TI11*z1+TI12*z2+TI13*z3+TI14*z4+TI15*z5
w2 = @.. broadcast=false TI21*z1+TI22*z2+TI23*z3+TI24*z4+TI25*z5
w3 = @.. broadcast=false TI31*z1+TI32*z2+TI33*z3+TI34*z4+TI35*z5
w4 = @.. broadcast=false TI41*z1+TI42*z2+TI43*z3+TI44*z4+TI45*z5
w5 = @.. broadcast=false TI51*z1+TI52*z2+TI53*z3+TI54*z4+TI55*z5
@.. w1 = TI11*z1+TI12*z2+TI13*z3+TI14*z4+TI15*z5
@.. w2 = TI21*z1+TI22*z2+TI23*z3+TI24*z4+TI25*z5
@.. w3 = TI31*z1+TI32*z2+TI33*z3+TI34*z4+TI35*z5
@.. w4 = TI41*z1+TI42*z2+TI43*z3+TI44*z4+TI45*z5
@.. w5 = TI51*z1+TI52*z2+TI53*z3+TI54*z4+TI55*z5
end

# Newton iteration
Expand Down Expand Up @@ -1328,21 +1328,21 @@ end
if integrator.EEst <= oneunit(integrator.EEst)
cache.dtprev = dt
if alg.extrapolant != :constant
cache.cont1 = @.. (z4 - z5) / c4m1 # first derivative on [c4, 1]
tmp1 = @.. (z3 - z4) / c3mc4 # first derivative on [c3, c4]
cache.cont2 = @.. (tmp1 - cache.cont1) / c3m1 # second derivative on [c3, 1]
tmp2 = @.. (z2 - z3) / c2mc3 # first derivative on [c2, c3]
tmp3 = @.. (tmp2 - tmp1) / c2mc4 # second derivative on [c2, c4]
cache.cont3 = @.. (tmp3 - cache.cont2) / c2m1 # third derivative on [c2, 1]
tmp4 = @.. (z1 - z2) / c1mc2 # first derivative on [c1, c2]
tmp5 = @.. (tmp4 - tmp2) / c1mc3 # second derivative on [c1, c3]
tmp6 = @.. (tmp5 - tmp3) / c1mc4 # third derivative on [c1, c4]
cache.cont4 = @.. (tmp6 - cache.cont3) / c1m1 #fourth derivative on [c1, 1]
tmp7 = @.. z1 / c1 #first derivative on [0, c1]
tmp8 = @.. (tmp4 - tmp7) / c2 #second derivative on [0, c2]
tmp9 = @.. (tmp5 - tmp8) / c3 #third derivative on [0, c3]
tmp10 = @.. (tmp6 - tmp9) / c4 #fourth derivative on [0,c4]
cache.cont5 = @.. cache.cont4 - tmp10 #fifth derivative on [0,1]
@.. cache.cont1 = (z4 - z5) / c4m1 # first derivative on [c4, 1]
@.. tmp = (z3 - z4) / c3mc4 # first derivative on [c3, c4]
@.. cache.cont2 = (tmp - cache.cont1) / c3m1 # second derivative on [c3, 1]
@.. tmp2 = (z2 - z3) / c2mc3 # first derivative on [c2, c3]
@.. tmp3 = (tmp2 - tmp) / c2mc4 # second derivative on [c2, c4]
@.. cache.cont3 = (tmp3 - cache.cont2) / c2m1 # third derivative on [c2, 1]
@.. tmp4 = (z1 - z2) / c1mc2 # first derivative on [c1, c2]
@.. tmp5 = (tmp4 - tmp2) / c1mc3 # second derivative on [c1, c3]
@.. tmp6 = (tmp5 - tmp3) / c1mc4 # third derivative on [c1, c4]
@.. cache.cont4 = (tmp6 - cache.cont3) / c1m1 #fourth derivative on [c1, 1]
@.. tmp7 = z1 / c1 #first derivative on [0, c1]
@.. tmp8 = (tmp4 - tmp7) / c2 #second derivative on [0, c2]
@.. tmp9 = (tmp5 - tmp8) / c3 #third derivative on [0, c3]
@.. tmp10 = (tmp6 - tmp9) / c4 #fourth derivative on [0,c4]
@.. cache.cont5 = cache.cont4 - tmp10 #fifth derivative on [0,1]
end
end

Expand Down Expand Up @@ -1437,7 +1437,7 @@ end
for i in 1 : num_stages
z[i] = f(uprev + z[i], p, t + c[i] * dt)
end
OrdinaryDiffEqCore.increment_nf!(integrator.stats, 5)
OrdinaryDiffEqCore.increment_nf!(integrator.stats, num_stages)

#fw = TI * ff
fw = Vector{typeof(u)}(undef, num_stages)
Expand Down Expand Up @@ -1598,7 +1598,7 @@ end
@unpack num_stages, tabs = cache
tab = tabs[(num_stages - 1) ÷ 2]
@unpack T, TI, γ, α, β, c, e = tab
@unpack κ, cont, derivatives, z, w, c_prime = cache
@unpack κ, cont, derivatives, z, w, c_prime, αdt, βdt= cache
@unpack dw1, ubuff, dw2, cubuff, dw = cache
@unpack ks, k, fw, J, W1, W2 = cache
@unpack tmp, atmp, jac_config, linsolve1, linsolve2, rtol, atol, step_limiter! = cache
Expand All @@ -1608,13 +1608,18 @@ end
mass_matrix = integrator.f.mass_matrix

# precalculations
γdt, αdt, βdt = γ / dt, α ./ dt, β ./ dt
γdt = γ / dt
for i in 1 : (num_stages - 1) ÷ 2
αdt[i] = α[i]/dt
βdt[i] = β[i]/dt
end

(new_jac = do_newJ(integrator, alg, cache, repeat_step)) &&
(calc_J!(J, integrator, cache); cache.W_γdt = dt)
if (new_W = do_newW(integrator, alg, new_jac, cache.W_γdt))
@inbounds for II in CartesianIndices(J)
W1[II] = -γdt * mass_matrix[Tuple(II)...] + J[II]
for i in 1 :(num_stages - 1) ÷ 2
for i in 1 : (num_stages - 1) ÷ 2
W2[i][II] = -(αdt[i] + βdt[i] * im) * mass_matrix[Tuple(II)...] + J[II]
end
end
Expand Down Expand Up @@ -1668,7 +1673,7 @@ end
@.. tmp = uprev + z[i]
f(ks[i], tmp, p, t + c[i] * dt)
end
OrdinaryDiffEqCore.increment_nf!(integrator.stats, 5)
OrdinaryDiffEqCore.increment_nf!(integrator.stats, num_stages)

#mul!(fw, TI, ks)
for i in 1:num_stages
Expand All @@ -1695,15 +1700,12 @@ end
@.. ubuff = fw[1] - γdt * Mw[1]
needfactor = iter == 1 && new_W

linsolve1 = cache.linsolve1
if needfactor
linres = dolinsolve(integrator, linsolve1; A = W1, b = _vec(ubuff), linu = _vec(dw1))
cache.linsolve1 = dolinsolve(integrator, linsolve1; A = W1, b = _vec(ubuff), linu = _vec(dw1)).cache
else
linres = dolinsolve(integrator, linsolve1; A = nothing, b = _vec(ubuff), linu = _vec(dw1))
cache.linsolve1 = dolinsolve(integrator, linsolve1; A = nothing, b = _vec(ubuff), linu = _vec(dw1)).cache
end

cache.linsolve1 = linres.cache

for i in 1 :(num_stages - 1) ÷ 2
@.. cubuff[i]=complex(
fw[2 * i] - αdt[i] * Mw[2 * i] + βdt[i] * Mw[2 * i + 1], fw[2 * i + 1] - βdt[i] * Mw[2 * i] - αdt[i] * Mw[2 * i + 1])
Expand Down Expand Up @@ -1750,12 +1752,12 @@ end
# transform `w` to `z`
#mul!(z, T, w)
for i in 1:num_stages - 1
z[i] = zero(u)
@.. z[i] = zero(u)
for j in 1:num_stages
@.. z[i] += T[i,j] * w[j]
end
end
z[num_stages] = T[num_stages, 1] * w[1]
@.. z[num_stages] = T[num_stages, 1] * w[1]
i = 2
while i < num_stages
@.. z[num_stages] += w[i]
Expand Down Expand Up @@ -1796,9 +1798,8 @@ end
@.. broadcast=false ubuff=integrator.fsalfirst + tmp

if alg.smooth_est
linres = dolinsolve(integrator, linres.cache; b = _vec(ubuff),
linu = _vec(utilde))
cache.linsolve1 = linres.cache
cache.linsolve1 = dolinsolve(integrator, linsolve1; b = _vec(ubuff),
linu = _vec(utilde)).cache
integrator.stats.nsolve += 1
end

Expand All @@ -1816,9 +1817,8 @@ end
@.. broadcast=false ubuff=fsallast + tmp

if alg.smooth_est
linres = dolinsolve(integrator, linres.cache; b = _vec(ubuff),
linu = _vec(utilde))
cache.linsolve1 = linres.cache
cache.linsolve1 = dolinsolve(integrator, linsolve1; b = _vec(ubuff),
linu = _vec(utilde)).cache
integrator.stats.nsolve += 1
end

Expand Down
6 changes: 3 additions & 3 deletions lib/OrdinaryDiffEqFIRK/test/ode_firk_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,10 @@ sim21 = test_convergence(1 ./ 2 .^ (2.5:-1:0.5), prob_ode_2Dlinear, RadauIIA9())
prob_ode_linear_big = remake(prob_ode_linear, u0 = big.(prob_ode_linear.u0), tspan = big.(prob_ode_linear.tspan))
prob_ode_2Dlinear_big = remake(prob_ode_2Dlinear, u0 = big.(prob_ode_2Dlinear.u0), tspan = big.(prob_ode_2Dlinear.tspan))

for i in [3, 5, 7], prob in [prob_ode_linear_big, prob_ode_2Dlinear_big]
for i in [5, 9, 13], prob in [prob_ode_linear_big, prob_ode_2Dlinear_big]
dts = 1 ./ 2 .^ (4.25:-1:0.25)
sim21 = test_convergence(dts, prob, AdaptiveRadau(min_stages = i, max_stages = i))
@test sim21.𝒪est[:final] (2 * i - 1) atol=testTol
sim21 = test_convergence(dts, prob, AdaptiveRadau(min_order = i, max_order = i))
@test sim21.𝒪est[:final] i atol=testTol
end

# test adaptivity
Expand Down
6 changes: 3 additions & 3 deletions lib/OrdinaryDiffEqFIRKGenerator/test/ode_firk_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@ testTol = 0.5
prob_ode_linear_big = remake(prob_ode_linear, u0 = big.(prob_ode_linear.u0), tspan = big.(prob_ode_linear.tspan))
prob_ode_2Dlinear_big = remake(prob_ode_2Dlinear, u0 = big.(prob_ode_2Dlinear.u0), tspan = big.(prob_ode_2Dlinear.tspan))

for i in [9], prob in [prob_ode_linear_big, prob_ode_2Dlinear_big]
for i in [17, 21], prob in [prob_ode_linear_big, prob_ode_2Dlinear_big]
dts = 1 ./ 2 .^ (4.25:-1:0.25)
sim21 = test_convergence(dts, prob, AdaptiveRadau(min_stages = i, max_stages = i))
@test sim21.𝒪est[:final] (2 * i - 1) atol=testTol
sim21 = test_convergence(dts, prob, AdaptiveRadau(min_order = i, max_order = i))
@test sim21.𝒪est[:final] i atol=testTol
end

0 comments on commit 52bcf16

Please sign in to comment.