Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Speed up Radau #2539

Merged
merged 8 commits into from
Nov 19, 2024
8 changes: 4 additions & 4 deletions lib/OrdinaryDiffEqFIRK/src/algorithms.jl
Original file line number Diff line number Diff line change
Expand Up @@ -163,13 +163,13 @@ struct AdaptiveRadau{CS, AD, F, P, FDT, ST, CJ, Tol, C1, C2, StepLimiter} <:
new_W_γdt_cutoff::C2
controller::Symbol
step_limiter!::StepLimiter
min_stages::Int
max_stages::Int
min_order::Int
max_order::Int
end

function AdaptiveRadau(; chunk_size = Val{0}(), autodiff = Val{true}(),
standardtag = Val{true}(), concrete_jac = nothing,
diff_type = Val{:forward}, min_stages = 3, max_stages = 7,
diff_type = Val{:forward}, min_order = 5, max_order = 13,
linsolve = nothing, precs = DEFAULT_PRECS,
extrapolant = :dense, fast_convergence_cutoff = 1 // 5,
new_W_γdt_cutoff = 1 // 5,
Expand All @@ -187,6 +187,6 @@ function AdaptiveRadau(; chunk_size = Val{0}(), autodiff = Val{true}(),
fast_convergence_cutoff,
new_W_γdt_cutoff,
controller,
step_limiter!, min_stages, max_stages)
step_limiter!, min_order, max_order)
end

9 changes: 6 additions & 3 deletions lib/OrdinaryDiffEqFIRK/src/controllers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -22,12 +22,14 @@ function step_accept_controller!(integrator, controller::PredictiveController, a
cache.step = step + 1
hist_iter = hist_iter * 0.8 + iter * 0.2
cache.hist_iter = hist_iter
max_stages = (alg.max_order - 1) ÷ 4 * 2 + 1
min_stages = (alg.min_order - 1) ÷ 4 * 2 + 1
if (step > 10)
if (hist_iter < 2.6 && num_stages < alg.max_stages)
if (hist_iter < 2.6 && num_stages <= max_stages)
cache.num_stages += 2
cache.step = 1
cache.hist_iter = iter
elseif ((hist_iter > 8 || cache.status == VerySlowConvergence || cache.status == Divergence) && num_stages > alg.min_stages)
elseif ((hist_iter > 8 || cache.status == VerySlowConvergence || cache.status == Divergence) && num_stages >= min_stages)
cache.num_stages -= 2
cache.step = 1
cache.hist_iter = iter
Expand All @@ -44,8 +46,9 @@ function step_reject_controller!(integrator, controller::PredictiveController, a
cache.step = step + 1
hist_iter = hist_iter * 0.8 + iter * 0.2
cache.hist_iter = hist_iter
min_stages = (alg.min_order - 1) ÷ 4 * 2 + 1
if (step > 10)
if ((hist_iter > 8 || cache.status == VerySlowConvergence || cache.status == Divergence) && num_stages > alg.min_stages)
if ((hist_iter > 8 || cache.status == VerySlowConvergence || cache.status == Divergence) && num_stages >= min_stages)
cache.num_stages -= 2
cache.step = 1
cache.hist_iter = iter
Expand Down
50 changes: 40 additions & 10 deletions lib/OrdinaryDiffEqFIRK/src/firk_caches.jl
Original file line number Diff line number Diff line change
Expand Up @@ -362,6 +362,10 @@ mutable struct RadauIIA9Cache{uType, cuType, uNoUnitsType, rateType, JType, W1Ty
tmp4::uType
tmp5::uType
tmp6::uType
tmp7::uType
tmp8::uType
tmp9::uType
tmp10::uType
atmp::uNoUnitsType
jac_config::JC
linsolve1::F1
Expand Down Expand Up @@ -440,6 +444,10 @@ function alg_cache(alg::RadauIIA9, u, rate_prototype, ::Type{uEltypeNoUnits},
tmp4 = zero(u)
tmp5 = zero(u)
tmp6 = zero(u)
tmp7 = zero(u)
tmp8 = zero(u)
tmp9 = zero(u)
tmp10 = zero(u)
atmp = similar(u, uEltypeNoUnits)
recursivefill!(atmp, false)
jac_config = build_jac_config(alg, f, uf, du1, uprev, u, tmp, dw1)
Expand Down Expand Up @@ -469,7 +477,7 @@ function alg_cache(alg::RadauIIA9, u, rate_prototype, ::Type{uEltypeNoUnits},
du1, fsalfirst, k, k2, k3, k4, k5, fw1, fw2, fw3, fw4, fw5,
J, W1, W2, W3,
uf, tab, κ, one(uToltype), 10000,
tmp, tmp2, tmp3, tmp4, tmp5, tmp6, atmp, jac_config,
tmp, tmp2, tmp3, tmp4, tmp5, tmp6, tmp7, tmp8, tmp9, tmp10, atmp, jac_config,
linsolve1, linsolve2, linsolve3, rtol, atol, dt, dt,
Convergence, alg.step_limiter!)
end
Expand Down Expand Up @@ -497,17 +505,26 @@ function alg_cache(alg::AdaptiveRadau, u, rate_prototype, ::Type{uEltypeNoUnits}
::Val{false}) where {uEltypeNoUnits, uBottomEltypeNoUnits, tTypeNoUnits}
uf = UDerivativeWrapper(f, t, p)
uToltype = constvalue(uBottomEltypeNoUnits)
num_stages = alg.min_stages
max = alg.max_stages

max_order = alg.max_order
min_order = alg.min_order
max = (max_order - 1) ÷ 4 * 2 + 1
min = (min_order - 1) ÷ 4 * 2 + 1
if (alg.min_order < 5)
error("min_order choice $min_order below 5 is not compatible with the algorithm")
elseif (max < min)
error("max_order $max_order is below min_order $min_order")
end
num_stages = min

tabs = [BigRadauIIA5Tableau(uToltype, constvalue(tTypeNoUnits)), BigRadauIIA9Tableau(uToltype, constvalue(tTypeNoUnits)), BigRadauIIA13Tableau(uToltype, constvalue(tTypeNoUnits))]

i = 9
while i <= alg.max_stages
while i <= max
push!(tabs, adaptiveRadauTableau(uToltype, constvalue(tTypeNoUnits), i))
i += 2
end
cont = Vector{typeof(u)}(undef, max)
for i in 1: max
for i in 1:max
cont[i] = zero(u)
end

Expand All @@ -525,6 +542,8 @@ mutable struct AdaptiveRadauCache{uType, cuType, tType, uNoUnitsType, rateType,
z::Vector{uType}
w::Vector{uType}
c_prime::Vector{tType}
αdt::Vector{tType}
βdt::Vector{tType}
dw1::uType
ubuff::uType
dw2::Vector{cuType}
Expand Down Expand Up @@ -568,8 +587,16 @@ function alg_cache(alg::AdaptiveRadau, u, rate_prototype, ::Type{uEltypeNoUnits}
uf = UJacobianWrapper(f, t, p)
uToltype = constvalue(uBottomEltypeNoUnits)

max = alg.max_stages
num_stages = alg.min_stages
max_order = alg.max_order
min_order = alg.min_order
max = (max_order - 1) ÷ 4 * 2 + 1
min = (min_order - 1) ÷ 4 * 2 + 1
if (alg.min_order < 5)
error("min_order choice $min_order below 5 is not compatible with the algorithm")
elseif (max < min)
error("max_order $max_order is below min_order $min_order")
end
num_stages = min

tabs = [BigRadauIIA5Tableau(uToltype, constvalue(tTypeNoUnits)), BigRadauIIA9Tableau(uToltype, constvalue(tTypeNoUnits)), BigRadauIIA13Tableau(uToltype, constvalue(tTypeNoUnits))]
i = 9
Expand All @@ -583,9 +610,12 @@ function alg_cache(alg::AdaptiveRadau, u, rate_prototype, ::Type{uEltypeNoUnits}
z = Vector{typeof(u)}(undef, max)
w = Vector{typeof(u)}(undef, max)
for i in 1 : max
z[i] = w[i] = zero(u)
z[i] = zero(u)
w[i] = zero(u)
end

αdt = [zero(t) for i in 1:max]
βdt = [zero(t) for i in 1:max]
c_prime = Vector{typeof(t)}(undef, max) #time stepping
for i in 1 : max
c_prime[i] = zero(t)
Expand Down Expand Up @@ -641,7 +671,7 @@ function alg_cache(alg::AdaptiveRadau, u, rate_prototype, ::Type{uEltypeNoUnits}
atol = reltol isa Number ? reltol : zero(reltol)

AdaptiveRadauCache(u, uprev,
z, w, c_prime, dw1, ubuff, dw2, cubuff, dw, cont, derivatives,
z, w, c_prime, αdt, βdt, dw1, ubuff, dw2, cubuff, dw, cont, derivatives,
du1, fsalfirst, ks, k, fw,
J, W1, W2,
uf, tabs, κ, one(uToltype), 10000, tmp,
Expand Down
88 changes: 44 additions & 44 deletions lib/OrdinaryDiffEqFIRK/src/firk_perform_step.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1032,7 +1032,7 @@ end
@unpack dw1, ubuff, dw23, dw45, cubuff1, cubuff2 = cache
@unpack k, k2, k3, k4, k5, fw1, fw2, fw3, fw4, fw5 = cache
@unpack J, W1, W2, W3 = cache
@unpack tmp, tmp2, tmp3, tmp4, tmp5, tmp6, atmp, jac_config, linsolve1, linsolve2, linsolve3, rtol, atol, step_limiter! = cache
@unpack tmp, tmp2, tmp3, tmp4, tmp5, tmp6, tmp7, tmp8, tmp9, tmp10, atmp, jac_config, linsolve1, linsolve2, linsolve3, rtol, atol, step_limiter! = cache
@unpack internalnorm, abstol, reltol, adaptive = integrator.opts
alg = unwrap_alg(integrator, true)
@unpack maxiters = alg
Expand Down Expand Up @@ -1087,30 +1087,30 @@ end
c2′ = c2 * c5′
c3′ = c3 * c5′
c4′ = c4 * c5′
z1 = @.. c1′ * (cont1 +
@.. z1 = c1′ * (cont1 +
(c1′-c4m1) * (cont2 +
(c1′ - c3m1) * (cont3 +
(c1′ - c2m1) * (cont4 + (c1′ - c1m1) * cont5))))
z2 = @.. c2′ * (cont1 +
@.. z2 = c2′ * (cont1 +
(c2′-c4m1) * (cont2 +
(c2′ - c3m1) * (cont3 +
(c2′ - c2m1) * (cont4 + (c2′ - c1m1) * cont5))))
z3 = @.. c3′ * (cont1 +
@.. z3 = c3′ * (cont1 +
(c3′-c4m1) * (cont2 +
(c3′ - c3m1) * (cont3 +
(c3′ - c2m1) * (cont4 + (c3′ - c1m1) * cont5))))
z4 = @.. c4′ * (cont1 +
@.. z4 = c4′ * (cont1 +
(c4′-c4m1) * (cont2 +
(c4′ - c3m1) * (cont3 +
(c4′ - c2m1) * (cont4 + (c4′ - c1m1) * cont5))))
z5 = @.. c5′ * (cont1 +
@.. z5 = c5′ * (cont1 +
(c5′-c4m1) * (cont2 +
(c5′ - c3m1) * (cont3 + (c5′ - c2m1) * (cont4 + (c5′ - c1m1) * cont5))))
w1 = @.. broadcast=false TI11*z1+TI12*z2+TI13*z3+TI14*z4+TI15*z5
w2 = @.. broadcast=false TI21*z1+TI22*z2+TI23*z3+TI24*z4+TI25*z5
w3 = @.. broadcast=false TI31*z1+TI32*z2+TI33*z3+TI34*z4+TI35*z5
w4 = @.. broadcast=false TI41*z1+TI42*z2+TI43*z3+TI44*z4+TI45*z5
w5 = @.. broadcast=false TI51*z1+TI52*z2+TI53*z3+TI54*z4+TI55*z5
@.. w1 = TI11*z1+TI12*z2+TI13*z3+TI14*z4+TI15*z5
@.. w2 = TI21*z1+TI22*z2+TI23*z3+TI24*z4+TI25*z5
@.. w3 = TI31*z1+TI32*z2+TI33*z3+TI34*z4+TI35*z5
@.. w4 = TI41*z1+TI42*z2+TI43*z3+TI44*z4+TI45*z5
@.. w5 = TI51*z1+TI52*z2+TI53*z3+TI54*z4+TI55*z5
end

# Newton iteration
Expand Down Expand Up @@ -1328,21 +1328,21 @@ end
if integrator.EEst <= oneunit(integrator.EEst)
cache.dtprev = dt
if alg.extrapolant != :constant
cache.cont1 = @.. (z4 - z5) / c4m1 # first derivative on [c4, 1]
tmp1 = @.. (z3 - z4) / c3mc4 # first derivative on [c3, c4]
cache.cont2 = @.. (tmp1 - cache.cont1) / c3m1 # second derivative on [c3, 1]
tmp2 = @.. (z2 - z3) / c2mc3 # first derivative on [c2, c3]
tmp3 = @.. (tmp2 - tmp1) / c2mc4 # second derivative on [c2, c4]
cache.cont3 = @.. (tmp3 - cache.cont2) / c2m1 # third derivative on [c2, 1]
tmp4 = @.. (z1 - z2) / c1mc2 # first derivative on [c1, c2]
tmp5 = @.. (tmp4 - tmp2) / c1mc3 # second derivative on [c1, c3]
tmp6 = @.. (tmp5 - tmp3) / c1mc4 # third derivative on [c1, c4]
cache.cont4 = @.. (tmp6 - cache.cont3) / c1m1 #fourth derivative on [c1, 1]
tmp7 = @.. z1 / c1 #first derivative on [0, c1]
tmp8 = @.. (tmp4 - tmp7) / c2 #second derivative on [0, c2]
tmp9 = @.. (tmp5 - tmp8) / c3 #third derivative on [0, c3]
tmp10 = @.. (tmp6 - tmp9) / c4 #fourth derivative on [0,c4]
cache.cont5 = @.. cache.cont4 - tmp10 #fifth derivative on [0,1]
@.. cache.cont1 = (z4 - z5) / c4m1 # first derivative on [c4, 1]
@.. tmp = (z3 - z4) / c3mc4 # first derivative on [c3, c4]
@.. cache.cont2 = (tmp - cache.cont1) / c3m1 # second derivative on [c3, 1]
@.. tmp2 = (z2 - z3) / c2mc3 # first derivative on [c2, c3]
@.. tmp3 = (tmp2 - tmp) / c2mc4 # second derivative on [c2, c4]
@.. cache.cont3 = (tmp3 - cache.cont2) / c2m1 # third derivative on [c2, 1]
@.. tmp4 = (z1 - z2) / c1mc2 # first derivative on [c1, c2]
@.. tmp5 = (tmp4 - tmp2) / c1mc3 # second derivative on [c1, c3]
@.. tmp6 = (tmp5 - tmp3) / c1mc4 # third derivative on [c1, c4]
@.. cache.cont4 = (tmp6 - cache.cont3) / c1m1 #fourth derivative on [c1, 1]
@.. tmp7 = z1 / c1 #first derivative on [0, c1]
@.. tmp8 = (tmp4 - tmp7) / c2 #second derivative on [0, c2]
@.. tmp9 = (tmp5 - tmp8) / c3 #third derivative on [0, c3]
@.. tmp10 = (tmp6 - tmp9) / c4 #fourth derivative on [0,c4]
@.. cache.cont5 = cache.cont4 - tmp10 #fifth derivative on [0,1]
end
end

Expand Down Expand Up @@ -1437,7 +1437,7 @@ end
for i in 1 : num_stages
z[i] = f(uprev + z[i], p, t + c[i] * dt)
end
OrdinaryDiffEqCore.increment_nf!(integrator.stats, 5)
OrdinaryDiffEqCore.increment_nf!(integrator.stats, num_stages)

#fw = TI * ff
fw = Vector{typeof(u)}(undef, num_stages)
Expand Down Expand Up @@ -1598,7 +1598,7 @@ end
@unpack num_stages, tabs = cache
tab = tabs[(num_stages - 1) ÷ 2]
@unpack T, TI, γ, α, β, c, e = tab
@unpack κ, cont, derivatives, z, w, c_prime = cache
@unpack κ, cont, derivatives, z, w, c_prime, αdt, βdt= cache
@unpack dw1, ubuff, dw2, cubuff, dw = cache
@unpack ks, k, fw, J, W1, W2 = cache
@unpack tmp, atmp, jac_config, linsolve1, linsolve2, rtol, atol, step_limiter! = cache
Expand All @@ -1608,13 +1608,18 @@ end
mass_matrix = integrator.f.mass_matrix

# precalculations
γdt, αdt, βdt = γ / dt, α ./ dt, β ./ dt
γdt = γ / dt
for i in 1 : (num_stages - 1) ÷ 2
αdt[i] = α[i]/dt
βdt[i] = β[i]/dt
end

(new_jac = do_newJ(integrator, alg, cache, repeat_step)) &&
(calc_J!(J, integrator, cache); cache.W_γdt = dt)
if (new_W = do_newW(integrator, alg, new_jac, cache.W_γdt))
@inbounds for II in CartesianIndices(J)
W1[II] = -γdt * mass_matrix[Tuple(II)...] + J[II]
for i in 1 :(num_stages - 1) ÷ 2
for i in 1 : (num_stages - 1) ÷ 2
W2[i][II] = -(αdt[i] + βdt[i] * im) * mass_matrix[Tuple(II)...] + J[II]
end
end
Expand Down Expand Up @@ -1668,7 +1673,7 @@ end
@.. tmp = uprev + z[i]
f(ks[i], tmp, p, t + c[i] * dt)
end
OrdinaryDiffEqCore.increment_nf!(integrator.stats, 5)
OrdinaryDiffEqCore.increment_nf!(integrator.stats, num_stages)

#mul!(fw, TI, ks)
for i in 1:num_stages
Expand All @@ -1695,15 +1700,12 @@ end
@.. ubuff = fw[1] - γdt * Mw[1]
needfactor = iter == 1 && new_W

linsolve1 = cache.linsolve1
if needfactor
linres = dolinsolve(integrator, linsolve1; A = W1, b = _vec(ubuff), linu = _vec(dw1))
cache.linsolve1 = dolinsolve(integrator, linsolve1; A = W1, b = _vec(ubuff), linu = _vec(dw1)).cache
else
linres = dolinsolve(integrator, linsolve1; A = nothing, b = _vec(ubuff), linu = _vec(dw1))
cache.linsolve1 = dolinsolve(integrator, linsolve1; A = nothing, b = _vec(ubuff), linu = _vec(dw1)).cache
end

cache.linsolve1 = linres.cache

for i in 1 :(num_stages - 1) ÷ 2
@.. cubuff[i]=complex(
fw[2 * i] - αdt[i] * Mw[2 * i] + βdt[i] * Mw[2 * i + 1], fw[2 * i + 1] - βdt[i] * Mw[2 * i] - αdt[i] * Mw[2 * i + 1])
Expand Down Expand Up @@ -1750,12 +1752,12 @@ end
# transform `w` to `z`
#mul!(z, T, w)
for i in 1:num_stages - 1
z[i] = zero(u)
@.. z[i] = zero(u)
for j in 1:num_stages
@.. z[i] += T[i,j] * w[j]
end
end
z[num_stages] = T[num_stages, 1] * w[1]
@.. z[num_stages] = T[num_stages, 1] * w[1]
i = 2
while i < num_stages
@.. z[num_stages] += w[i]
Expand Down Expand Up @@ -1796,9 +1798,8 @@ end
@.. broadcast=false ubuff=integrator.fsalfirst + tmp

if alg.smooth_est
linres = dolinsolve(integrator, linres.cache; b = _vec(ubuff),
linu = _vec(utilde))
cache.linsolve1 = linres.cache
cache.linsolve1 = dolinsolve(integrator, linsolve1; b = _vec(ubuff),
linu = _vec(utilde)).cache
integrator.stats.nsolve += 1
end

Expand All @@ -1816,9 +1817,8 @@ end
@.. broadcast=false ubuff=fsallast + tmp

if alg.smooth_est
linres = dolinsolve(integrator, linres.cache; b = _vec(ubuff),
linu = _vec(utilde))
cache.linsolve1 = linres.cache
cache.linsolve1 = dolinsolve(integrator, linsolve1; b = _vec(ubuff),
linu = _vec(utilde)).cache
integrator.stats.nsolve += 1
end

Expand Down
6 changes: 3 additions & 3 deletions lib/OrdinaryDiffEqFIRK/test/ode_firk_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,10 @@ sim21 = test_convergence(1 ./ 2 .^ (2.5:-1:0.5), prob_ode_2Dlinear, RadauIIA9())
prob_ode_linear_big = remake(prob_ode_linear, u0 = big.(prob_ode_linear.u0), tspan = big.(prob_ode_linear.tspan))
prob_ode_2Dlinear_big = remake(prob_ode_2Dlinear, u0 = big.(prob_ode_2Dlinear.u0), tspan = big.(prob_ode_2Dlinear.tspan))

for i in [3, 5, 7], prob in [prob_ode_linear_big, prob_ode_2Dlinear_big]
for i in [5, 9, 13], prob in [prob_ode_linear_big, prob_ode_2Dlinear_big]
dts = 1 ./ 2 .^ (4.25:-1:0.25)
sim21 = test_convergence(dts, prob, AdaptiveRadau(min_stages = i, max_stages = i))
@test sim21.𝒪est[:final]≈ (2 * i - 1) atol=testTol
sim21 = test_convergence(dts, prob, AdaptiveRadau(min_order = i, max_order = i))
@test sim21.𝒪est[:final]≈ i atol=testTol
end

# test adaptivity
Expand Down
6 changes: 3 additions & 3 deletions lib/OrdinaryDiffEqFIRKGenerator/test/ode_firk_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@ testTol = 0.5
prob_ode_linear_big = remake(prob_ode_linear, u0 = big.(prob_ode_linear.u0), tspan = big.(prob_ode_linear.tspan))
prob_ode_2Dlinear_big = remake(prob_ode_2Dlinear, u0 = big.(prob_ode_2Dlinear.u0), tspan = big.(prob_ode_2Dlinear.tspan))

for i in [9], prob in [prob_ode_linear_big, prob_ode_2Dlinear_big]
for i in [17, 21], prob in [prob_ode_linear_big, prob_ode_2Dlinear_big]
dts = 1 ./ 2 .^ (4.25:-1:0.25)
sim21 = test_convergence(dts, prob, AdaptiveRadau(min_stages = i, max_stages = i))
@test sim21.𝒪est[:final]≈ (2 * i - 1) atol=testTol
sim21 = test_convergence(dts, prob, AdaptiveRadau(min_order = i, max_order = i))
@test sim21.𝒪est[:final]≈ i atol=testTol
end
Loading