Skip to content

Commit

Permalink
clean up radau tableau generation
Browse files Browse the repository at this point in the history
  • Loading branch information
oscardssmith committed Nov 15, 2024
1 parent 26a5ad6 commit 2f0e93f
Show file tree
Hide file tree
Showing 3 changed files with 68 additions and 362 deletions.
11 changes: 6 additions & 5 deletions lib/OrdinaryDiffEqFIRK/src/firk_caches.jl
Original file line number Diff line number Diff line change
Expand Up @@ -497,13 +497,14 @@ function alg_cache(alg::AdaptiveRadau, u, rate_prototype, ::Type{uEltypeNoUnits}
::Val{false}) where {uEltypeNoUnits, uBottomEltypeNoUnits, tTypeNoUnits}
uf = UDerivativeWrapper(f, t, p)
uToltype = constvalue(uBottomEltypeNoUnits)
tTolType = constvalue(tTypeNoUnits)
num_stages = alg.min_stages
max = alg.max_stages
tabs = [BigRadauIIA5Tableau(uToltype, constvalue(tTypeNoUnits)), BigRadauIIA9Tableau(uToltype, constvalue(tTypeNoUnits)), BigRadauIIA13Tableau(uToltype, constvalue(tTypeNoUnits))]
tabs = [RadauIIATableau(uToltype, tTolType, 3), RadauIIATableau(uToltype, tTolType, 5), RadauIIATableau(uToltype, tTolType, 7)]

i = 9
while i <= alg.max_stages
push!(tabs, adaptiveRadauTableau(uToltype, constvalue(tTypeNoUnits), i))
push!(tabs, RadauIIATableau(uToltype, tTolType, i))
i += 2
end
cont = Vector{typeof(u)}(undef, max)
Expand Down Expand Up @@ -609,7 +610,7 @@ function alg_cache(alg::AdaptiveRadau, u, rate_prototype, ::Type{uEltypeNoUnits}
fsalfirst = zero(rate_prototype)
fw = [zero(rate_prototype) for i in 1 : max]
ks = [zero(rate_prototype) for i in 1 : max]

k = ks[1]

J, W1 = build_J_W(alg, u, uprev, p, t, dt, f, uEltypeNoUnits, Val(true))
Expand Down Expand Up @@ -641,7 +642,7 @@ function alg_cache(alg::AdaptiveRadau, u, rate_prototype, ::Type{uEltypeNoUnits}
atol = reltol isa Number ? reltol : zero(reltol)

AdaptiveRadauCache(u, uprev,
z, w, c_prime, dw1, ubuff, dw2, cubuff, dw, cont, derivatives,
z, w, c_prime, dw1, ubuff, dw2, cubuff, dw, cont, derivatives,
du1, fsalfirst, ks, k, fw,
J, W1, W2,
uf, tabs, κ, one(uToltype), 10000, tmp,
Expand Down
24 changes: 12 additions & 12 deletions lib/OrdinaryDiffEqFIRK/src/firk_perform_step.jl
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@ end

@muladd function perform_step!(integrator, cache::RadauIIA3ConstantCache)
@unpack t, dt, uprev, u, f, p = integrator
@unpack T11, T12, T21, T22, TI11, TI12, TI21, TI22 = cache.tab
@unpack T11, T12, T21, TI12, TI21, TI22 = cache.tab
@unpack c1, c2, α, β, e1, e2 = cache.tab
@unpack κ, cont1, cont2 = cache
@unpack internalnorm, abstol, reltol, adaptive = integrator.opts
Expand Down Expand Up @@ -153,7 +153,7 @@ end
ff2 = f(uprev + z2, p, t + c2 * dt)
OrdinaryDiffEqCore.increment_nf!(integrator.stats, 2)

fw1 = @. TI11 * ff1 + TI12 * ff2
fw1 = @. TI12 * ff2 #TI11 = 0
fw2 = @. TI21 * ff1 + TI22 * ff2

if mass_matrix isa UniformScaling
Expand Down Expand Up @@ -193,7 +193,7 @@ end

# transform `w` to `z`
z1 = @. T11 * w1 + T12 * w2
z2 = @. T21 * w1 + T22 * w2
z2 = @. T21 * w1 # T22 = 0

# check stopping criterion
iter > 1 &&= θ / (1 - θ))
Expand Down Expand Up @@ -226,7 +226,7 @@ end

@muladd function perform_step!(integrator, cache::RadauIIA3Cache, repeat_step = false)
@unpack t, dt, uprev, u, f, p, fsallast, fsalfirst = integrator
@unpack T11, T12, T21, T22, TI11, TI12, TI21, TI22 = cache.tab
@unpack T11, T12, T21, TI12, TI21, TI22 = cache.tab
@unpack c1, c2, α, β, e1, e2 = cache.tab
@unpack κ, cont1, cont2 = cache
@unpack z1, z2, w1, w2,
Expand Down Expand Up @@ -273,7 +273,7 @@ end
f(k2, tmp, p, t + c2 * dt)
OrdinaryDiffEqCore.increment_nf!(integrator.stats, 2)

@. fw1 = TI11 * fsallast + TI12 * k2
@. fw1 = TI12 * k2 # TI11=0
@. fw2 = TI21 * fsallast + TI22 * k2

if mass_matrix === I
Expand Down Expand Up @@ -332,7 +332,7 @@ end

# transform `w` to `z`
@. z1 = T11 * w1 + T12 * w2
@. z2 = T21 * w1 + T22 * w2
@. z2 = T21 * w1 #T22 = 0

# check stopping criterion
iter > 1 &&= θ / (1 - θ))
Expand Down Expand Up @@ -1493,7 +1493,7 @@ end
break
end
end

for i in 1 : num_stages
w[i] = @.. w[i] - z[i]
end
Expand All @@ -1513,7 +1513,7 @@ end
i += 2
end


# check stopping criterion
iter > 1 &&= θ / (1 - θ))
if η * ndw < κ && (iter > 1 || iszero(ndw) || !iszero(integrator.success_iter))
Expand All @@ -1534,7 +1534,7 @@ end
cache.iter = iter

u = @.. uprev + z[num_stages]

if adaptive
tmp = 0
for i in 1 : num_stages
Expand Down Expand Up @@ -1638,7 +1638,7 @@ end
@.. z[i] = cont[num_stages] * (c_prime[i] - c[1] + 1) + cont[num_stages - 1]
j = num_stages - 2
while j > 0
@.. z[i] *= (c_prime[i] - c[num_stages - j] + 1)
@.. z[i] *= (c_prime[i] - c[num_stages - j] + 1)
@.. z[i] += cont[j]
j = j - 1
end
Expand Down Expand Up @@ -1682,7 +1682,7 @@ end
Mw = w
elseif mass_matrix isa UniformScaling
for i in 1 : num_stages
mul!(z[i], mass_matrix.λ, w[i])
mul!(z[i], mass_matrix.λ, w[i])
end
Mw = z
else
Expand Down Expand Up @@ -1784,7 +1784,7 @@ end
@.. broadcast=false u=uprev + z[num_stages]

step_limiter!(u, integrator, p, t + dt)

if adaptive
utilde = w[2]
@.. tmp = 0
Expand Down
Loading

0 comments on commit 2f0e93f

Please sign in to comment.