Skip to content

Commit

Permalink
Remove use of OrdinaryDiffEq
Browse files Browse the repository at this point in the history
  • Loading branch information
charleskawczynski committed Oct 3, 2023
1 parent df6bb7f commit 7947a38
Show file tree
Hide file tree
Showing 8 changed files with 5 additions and 44 deletions.
2 changes: 1 addition & 1 deletion config/default_configs/default_config.yml
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ max_newton_iters_ode:
help: "Maximum number of Newton's method iterations (only for ODE algorithms that use Newton's method)"
value: 1
ode_algo:
help: "ODE algorithm [`ARS343` (default), `SSP333`, `IMKG343a`, `ODE.Euler`, `ODE.IMEXEuler`, `ODE.Rosenbrock23`, etc.]"
help: "ODE algorithm [`ARS343` (default), `SSP333`, `IMKG343a`, etc.]"
value: "ARS343"
krylov_rtol:
help: "Relative tolerance of the Krylov method (only for ClimaTimeSteppers.jl; only used if `use_krylov_method` is `true`)"
Expand Down
2 changes: 1 addition & 1 deletion src/callbacks/callbacks.jl
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ NVTX.@annotate function turb_conv_affect_filter!(integrator)
TC.affect_filter!(edmf, grid, state, tc_params, t)
end

# We're lying to OrdinaryDiffEq.jl, in order to avoid
# We're lying to SciMLBase.jl, in order to avoid
# paying for an additional `∑tendencies!` call, which is required
# to support supplying a continuous representation of the
# solution.
Expand Down
1 change: 0 additions & 1 deletion src/dycore_equations_deprecated/sgs_flux_tendencies.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
using LinearAlgebra
import OrdinaryDiffEq as ODE
import Logging
import TerminalLoggers
import LinearAlgebra as LA
Expand Down
1 change: 0 additions & 1 deletion src/initial_conditions/InitialConditions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,6 @@ import ..Parameters as CAP
import ..TurbulenceConvection as TC
import Thermodynamics as TD
import AtmosphericProfilesLibrary as APL
import OrdinaryDiffEq as ODE
import SciMLBase
import Dierckx

Expand Down
1 change: 0 additions & 1 deletion src/parameterized_tendencies/radiation/radiation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@

import ClimaComms
import ClimaCore: Device, DataLayouts, Geometry, Spaces, Fields, Operators
import OrdinaryDiffEq as ODE
import Insolation
import Thermodynamics as TD
import .Parameters as CAP
Expand Down
2 changes: 1 addition & 1 deletion src/prognostic_equations/implicit/schur_complement_W.jl
Original file line number Diff line number Diff line change
Expand Up @@ -226,7 +226,7 @@ Note: The matrix S = A𝕄ρ Aρ𝕄 + A𝕄𝔼 A𝔼𝕄 + A𝕄𝕄 - I is th
the large -I block in A.
=#

# Function required by OrdinaryDiffEq.jl
# Function required by ClimaTimeSteppers.jl
linsolve!(::Type{Val{:init}}, f, u0; kwargs...) = _linsolve!
_linsolve!(x, A, b, update_matrix = false; kwargs...) =
LinearAlgebra.ldiv!(x, A, b)
Expand Down
1 change: 0 additions & 1 deletion src/solver/solve.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import ClimaTimeSteppers as CTS
import OrdinaryDiffEq as ODE

struct EfficiencyStats{TS <: Tuple, WT}
tspan::TS
Expand Down
39 changes: 2 additions & 37 deletions src/solver/type_getters.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@ import ClimaAtmos.RRTMGPInterface as RRTMGPI
import ClimaAtmos as CA
import LinearAlgebra
import ClimaCore.Fields
import OrdinaryDiffEq as ODE
import ClimaTimeSteppers as CTS
import DiffEqCallbacks as DECB

Expand Down Expand Up @@ -340,29 +339,13 @@ is_explicit_CTS_algo_type(alg_or_tableau) =
is_imex_CTS_algo_type(alg_or_tableau) =
alg_or_tableau <: CTS.IMEXARKAlgorithmName

is_implicit_type(::typeof(ODE.IMEXEuler)) = true
is_implicit_type(alg_or_tableau) =
alg_or_tableau <: Union{
ODE.OrdinaryDiffEqImplicitAlgorithm,
ODE.OrdinaryDiffEqAdaptiveImplicitAlgorithm,
} || is_imex_CTS_algo_type(alg_or_tableau)

is_ordinary_diffeq_newton(::typeof(ODE.IMEXEuler)) = true
is_ordinary_diffeq_newton(alg_or_tableau) =
alg_or_tableau <: Union{
ODE.OrdinaryDiffEqNewtonAlgorithm,
ODE.OrdinaryDiffEqNewtonAdaptiveAlgorithm,
}
is_implicit_type(alg_or_tableau) = is_imex_CTS_algo_type(alg_or_tableau)

is_imex_CTS_algo(::CTS.IMEXAlgorithm) = true
is_imex_CTS_algo(::SciMLBase.AbstractODEAlgorithm) = false

is_implicit(::ODE.OrdinaryDiffEqImplicitAlgorithm) = true
is_implicit(::ODE.OrdinaryDiffEqAdaptiveImplicitAlgorithm) = true
is_implicit(ode_algo) = is_imex_CTS_algo(ode_algo)

is_rosenbrock(::ODE.Rosenbrock23) = true
is_rosenbrock(::ODE.Rosenbrock32) = true
is_rosenbrock(::SciMLBase.AbstractODEAlgorithm) = false
use_transform(ode_algo) =
!(is_imex_CTS_algo(ode_algo) || is_rosenbrock(ode_algo))
Expand Down Expand Up @@ -409,28 +392,13 @@ Returns the ode algorithm
=#
function ode_configuration(::Type{FT}, parsed_args) where {FT}
ode_name = parsed_args["ode_algo"]
alg_or_tableau = if startswith(ode_name, "ODE.")
@warn "apply_limiter flag is ignored for OrdinaryDiffEq algorithms"
getproperty(ODE, Symbol(split(ode_name, ".")[2]))
else
getproperty(CTS, Symbol(ode_name))
end
alg_or_tableau = getproperty(CTS, Symbol(ode_name))
@info "Using ODE config: `$alg_or_tableau`"

if is_explicit_CTS_algo_type(alg_or_tableau)
return CTS.ExplicitAlgorithm(alg_or_tableau())
elseif !is_implicit_type(alg_or_tableau)
return alg_or_tableau()
elseif is_ordinary_diffeq_newton(alg_or_tableau)
if parsed_args["max_newton_iters_ode"] == 1
error("OridinaryDiffEq requires at least 2 Newton iterations")
end
# κ like a relative tolerance; its default value in ODE is 0.01
nlsolve = ODE.NLNewton(;
κ = parsed_args["max_newton_iters_ode"] == 2 ? Inf : 0.01,
max_iter = parsed_args["max_newton_iters_ode"],
)
return alg_or_tableau(; linsolve = linsolve!, nlsolve)
elseif is_imex_CTS_algo_type(alg_or_tableau)
newtons_method = CTS.NewtonsMethod(;
max_iters = parsed_args["max_newton_iters_ode"],
Expand Down Expand Up @@ -477,9 +445,6 @@ function get_callbacks(parsed_args, simulation, atmos, params)
(; dt) = simulation

callbacks = ()
if startswith(parsed_args["ode_algo"], "ODE.")
callbacks = (callbacks..., call_every_n_steps(dss_callback!))
end
dt_save_to_disk = time_to_seconds(parsed_args["dt_save_to_disk"])
if !(dt_save_to_disk == Inf)
callbacks = (
Expand Down

0 comments on commit 7947a38

Please sign in to comment.