Skip to content

Commit

Permalink
Merge #2183
Browse files Browse the repository at this point in the history
2183: Remove dependence on OrdinaryDiffEq r=charleskawczynski a=charleskawczynski

This PR removes the dependence on OrdinaryDiffEq 💥.

I did have to add a function for computing indefinite integrals:
```julia
function column_indefinite_integral!(
    f::Function,
    ᶠintegral::Fields.ColumnField,
    x₀,
    ᶜzfield::Fields.ColumnField,
    average = (a, b) -> (a + b) / 2,
)
```
(I guess we could maybe get rid of `ᶜzfield` and extract it from `ᶠintegral`?)

Which returns a new `ColumnInterpolatableField` object:

```julia
struct ColumnInterpolatableField{F, D}
    f::F
    data::D
    function ColumnInterpolatableField(f::Fields.ColumnField)
        zdata = vec(parent(Fields.Fields.coordinate_field(f).z))
        fdata = vec(parent(f))
        `@assert` length(zdata) == length(fdata)
        data = Dierckx.Spline1D(zdata, fdata; k = 1)
        return new{typeof(f), typeof(data)}(f, data)
    end
end
(f::ColumnInterpolatableField)(z) = Spaces.undertype(axes(f.f))(f.data(z))
```

Ultimately, this replaces `ODE.solve`, which itself is not an issue, but `ODE.Tsit5()` (and every other ODE algo implementation) is (reasonably) not in SciMLBase.

Co-authored-by: Charles Kawczynski <kawczynski.charles@gmail.com>
  • Loading branch information
bors[bot] and charleskawczynski authored Oct 25, 2023
2 parents f9c1170 + f971f11 commit a037520
Show file tree
Hide file tree
Showing 10 changed files with 99 additions and 59 deletions.
2 changes: 1 addition & 1 deletion config/default_configs/default_config.yml
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ max_newton_iters_ode:
help: "Maximum number of Newton's method iterations (only for ODE algorithms that use Newton's method)"
value: 1
ode_algo:
help: "ODE algorithm [`ARS343` (default), `SSP333`, `IMKG343a`, `ODE.Euler`, `ODE.IMEXEuler`, `ODE.Rosenbrock23`, etc.]"
help: "ODE algorithm [`ARS343` (default), `SSP333`, `IMKG343a`, etc.]"
value: "ARS343"
krylov_rtol:
help: "Relative tolerance of the Krylov method (only for ClimaTimeSteppers.jl; only used if `use_krylov_method` is `true`)"
Expand Down
6 changes: 6 additions & 0 deletions docs/src/api.md
Original file line number Diff line number Diff line change
Expand Up @@ -53,3 +53,9 @@ ClimaAtmos.InitialConditions.Soares
```@docs
ClimaAtmos.ImplicitEquationJacobian
```

### Helper

```@docs
ClimaAtmos.InitialConditions.ColumnInterpolatableField
```
2 changes: 1 addition & 1 deletion src/callbacks/callbacks.jl
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ NVTX.@annotate function turb_conv_affect_filter!(integrator)
TC.affect_filter!(edmf, grid, state, tc_params, t)
end

# We're lying to OrdinaryDiffEq.jl, in order to avoid
# We're lying to SciMLBase.jl, in order to avoid
# paying for an additional `∑tendencies!` call, which is required
# to support supplying a continuous representation of the
# solution.
Expand Down
1 change: 0 additions & 1 deletion src/dycore_equations_deprecated/sgs_flux_tendencies.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import OrdinaryDiffEq as ODE
import Logging
import TerminalLoggers
import LinearAlgebra as LA
Expand Down
1 change: 0 additions & 1 deletion src/initial_conditions/InitialConditions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@ import ..Parameters as CAP
import ..TurbulenceConvection as TC
import Thermodynamics as TD
import AtmosphericProfilesLibrary as APL
import OrdinaryDiffEq as ODE
import SciMLBase
import Dierckx

Expand Down
98 changes: 84 additions & 14 deletions src/initial_conditions/initial_conditions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,73 @@ perturb_coeff(p::Geometry.LatLongZPoint{FT}) where {FT} = sind(p.long)
perturb_coeff(p::Geometry.XZPoint{FT}) where {FT} = sin(p.x)
perturb_coeff(p::Geometry.XYZPoint{FT}) where {FT} = sin(p.x)

"""
ColumnInterpolatableField(::Fields.ColumnField)
A column field object that can be interpolated
in the z-coordinate. For example:
```julia
cif = ColumnInterpolatableField(column_field)
z = 1.0
column_field_at_z = cif(z)
```
!!! warn
This function allocates and is not GPU-compatible
so please avoid using this inside `step!` only use
this for initialization.
"""
struct ColumnInterpolatableField{F, D}
f::F
data::D
function ColumnInterpolatableField(f::Fields.ColumnField)
zdata = vec(parent(Fields.Fields.coordinate_field(f).z))
fdata = vec(parent(f))
data = Dierckx.Spline1D(zdata, fdata; k = 1)
return new{typeof(f), typeof(data)}(f, data)
end
end
(f::ColumnInterpolatableField)(z) = Spaces.undertype(axes(f.f))(f.data(z))

import ClimaComms
import ClimaCore.Domains as Domains
import ClimaCore.Meshes as Meshes
import ClimaCore.Geometry as Geometry
import ClimaCore.Operators as Operators
import ClimaCore.Topologies as Topologies
import ClimaCore.Spaces as Spaces

"""
column_indefinite_integral(f, ϕ₀, zspan; nelems = 100)
The column integral, returned as
an interpolate-able field.
"""
function column_indefinite_integral(
f::Function,
ϕ₀::FT,
zspan::Tuple{FT, FT};
nelems = 100, # sets resolution for integration
) where {FT <: Real}
# --- Make a space for integration:
z_domain = Domains.IntervalDomain(
Geometry.ZPoint(first(zspan)),
Geometry.ZPoint(last(zspan));
boundary_tags = (:bottom, :top),
)
z_mesh = Meshes.IntervalMesh(z_domain; nelems)
context = ClimaComms.SingletonCommsContext()
z_topology = Topologies.IntervalTopology(context, z_mesh)
cspace = Spaces.CenterFiniteDifferenceSpace(z_topology)
fspace = Spaces.FaceFiniteDifferenceSpace(z_topology)
# ---
zc = Fields.coordinate_field(cspace)
ᶠintegral = Fields.Field(FT, fspace)
Operators.column_integral_indefinite!(f, ᶠintegral, ϕ₀)
return ColumnInterpolatableField(ᶠintegral)
end

##
## Simple Profiles
##
Expand Down Expand Up @@ -534,23 +601,26 @@ function hydrostatic_pressure_profile(;
ts(p, z, T::FunctionOrSpline, θ::FunctionOrSpline, _) =
error("Only one of T and θ can be specified")
ts(p, z, T::FunctionOrSpline, ::Nothing, ::Nothing) =
TD.PhaseDry_pT(thermo_params, p, T(z))
TD.PhaseDry_pT(thermo_params, p, oftype(p, T(z)))
ts(p, z, ::Nothing, θ::FunctionOrSpline, ::Nothing) =
TD.PhaseDry_pθ(thermo_params, p, θ(z))
TD.PhaseDry_pθ(thermo_params, p, oftype(p, θ(z)))
ts(p, z, T::FunctionOrSpline, ::Nothing, q_tot::FunctionOrSpline) =
TD.PhaseEquil_pTq(thermo_params, p, T(z), q_tot(z))
TD.PhaseEquil_pTq(
thermo_params,
p,
oftype(p, T(z)),
oftype(p, q_tot(z)),
)
ts(p, z, ::Nothing, θ::FunctionOrSpline, q_tot::FunctionOrSpline) =
TD.PhaseEquil_pθq(thermo_params, p, θ(z), q_tot(z))
dp_dz(p, _, z) =
-grav * TD.air_density(thermo_params, ts(p, z, T, θ, q_tot))

prob = SciMLBase.ODEProblem(dp_dz, p_0, (FT(0), z_max))
return SciMLBase.solve(
prob,
ODE.Tsit5(),
reltol = 10eps(FT),
abstol = 10eps(FT),
)
TD.PhaseEquil_pθq(
thermo_params,
p,
oftype(p, θ(z)),
oftype(p, q_tot(z)),
)
dp_dz(p, z) = -grav * TD.air_density(thermo_params, ts(p, z, T, θ, q_tot))

return column_indefinite_integral(dp_dz, p_0, (FT(0), FT(z_max)))
end

"""
Expand Down
1 change: 0 additions & 1 deletion src/parameterized_tendencies/radiation/radiation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@

import ClimaComms
import ClimaCore: Device, DataLayouts, Geometry, Spaces, Fields, Operators
import OrdinaryDiffEq as ODE
import Insolation
import Thermodynamics as TD
import .Parameters as CAP
Expand Down
10 changes: 5 additions & 5 deletions src/prognostic_equations/implicit/implicit_solver.jl
Original file line number Diff line number Diff line change
Expand Up @@ -167,9 +167,9 @@ function ImplicitEquationJacobian(
)
end

# We only use A, but OrdinaryDiffEq.jl and ClimaTimeSteppers.jl require us to
# pass jac_prototype and then call similar(jac_prototype) to obtain A. This is a
# workaround to avoid unnecessary allocations.
# We only use A, but ClimaTimeSteppers.jl require us to
# pass jac_prototype and then call similar(jac_prototype) to
# obtain A. This is a workaround to avoid unnecessary allocations.
Base.similar(A::ImplicitEquationJacobian) = A

# This method specifies how to solve the equation E'(Y) * ΔY = E(Y) for ΔY.
Expand Down Expand Up @@ -197,12 +197,12 @@ function ldiv!(
x .= A.temp_x
end

# This function is used by OrdinaryDiffEq.jl instead of ldiv!.
# This function is used by DiffEqBase.jl instead of ldiv!.
linsolve!(::Type{Val{:init}}, f, u0; kwargs...) = _linsolve!
_linsolve!(x, A, b, update_matrix = false; kwargs...) = ldiv!(x, A, b)

# This method specifies how to compute E'(Y), which is referred to as "Wfact" in
# OrdinaryDiffEq.jl.
# DiffEqBase.jl.
function Wfact!(A, Y, p, dtγ, t)
NVTX.@range "Wfact!" color = colorant"green" begin
# Remove unnecessary values from p to avoid allocations in bycolumn.
Expand Down
1 change: 0 additions & 1 deletion src/solver/solve.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import ClimaTimeSteppers as CTS
import OrdinaryDiffEq as ODE

struct EfficiencyStats{TS <: Tuple, WT}
tspan::TS
Expand Down
36 changes: 2 additions & 34 deletions src/solver/type_getters.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@ import ClimaAtmos.RRTMGPInterface as RRTMGPI
import ClimaAtmos as CA
import LinearAlgebra
import ClimaCore.Fields
import OrdinaryDiffEq as ODE
import ClimaTimeSteppers as CTS
import DiffEqCallbacks as DECB

Expand Down Expand Up @@ -355,29 +354,13 @@ is_explicit_CTS_algo_type(alg_or_tableau) =
is_imex_CTS_algo_type(alg_or_tableau) =
alg_or_tableau <: CTS.IMEXARKAlgorithmName

is_implicit_type(::typeof(ODE.IMEXEuler)) = true
is_implicit_type(alg_or_tableau) =
alg_or_tableau <: Union{
ODE.OrdinaryDiffEqImplicitAlgorithm,
ODE.OrdinaryDiffEqAdaptiveImplicitAlgorithm,
} || is_imex_CTS_algo_type(alg_or_tableau)

is_ordinary_diffeq_newton(::typeof(ODE.IMEXEuler)) = true
is_ordinary_diffeq_newton(alg_or_tableau) =
alg_or_tableau <: Union{
ODE.OrdinaryDiffEqNewtonAlgorithm,
ODE.OrdinaryDiffEqNewtonAdaptiveAlgorithm,
}
is_implicit_type(alg_or_tableau) = is_imex_CTS_algo_type(alg_or_tableau)

is_imex_CTS_algo(::CTS.IMEXAlgorithm) = true
is_imex_CTS_algo(::SciMLBase.AbstractODEAlgorithm) = false

is_implicit(::ODE.OrdinaryDiffEqImplicitAlgorithm) = true
is_implicit(::ODE.OrdinaryDiffEqAdaptiveImplicitAlgorithm) = true
is_implicit(ode_algo) = is_imex_CTS_algo(ode_algo)

is_rosenbrock(::ODE.Rosenbrock23) = true
is_rosenbrock(::ODE.Rosenbrock32) = true
is_rosenbrock(::SciMLBase.AbstractODEAlgorithm) = false
use_transform(ode_algo) =
!(is_imex_CTS_algo(ode_algo) || is_rosenbrock(ode_algo))
Expand Down Expand Up @@ -417,28 +400,13 @@ Returns the ode algorithm
=#
function ode_configuration(::Type{FT}, parsed_args) where {FT}
ode_name = parsed_args["ode_algo"]
alg_or_tableau = if startswith(ode_name, "ODE.")
@warn "apply_limiter flag is ignored for OrdinaryDiffEq algorithms"
getproperty(ODE, Symbol(split(ode_name, ".")[2]))
else
getproperty(CTS, Symbol(ode_name))
end
alg_or_tableau = getproperty(CTS, Symbol(ode_name))
@info "Using ODE config: `$alg_or_tableau`"

if is_explicit_CTS_algo_type(alg_or_tableau)
return CTS.ExplicitAlgorithm(alg_or_tableau())
elseif !is_implicit_type(alg_or_tableau)
return alg_or_tableau()
elseif is_ordinary_diffeq_newton(alg_or_tableau)
if parsed_args["max_newton_iters_ode"] == 1
error("OridinaryDiffEq requires at least 2 Newton iterations")
end
# κ like a relative tolerance; its default value in ODE is 0.01
nlsolve = ODE.NLNewton(;
κ = parsed_args["max_newton_iters_ode"] == 2 ? Inf : 0.01,
max_iter = parsed_args["max_newton_iters_ode"],
)
return alg_or_tableau(; linsolve = linsolve!, nlsolve)
elseif is_imex_CTS_algo_type(alg_or_tableau)
newtons_method = CTS.NewtonsMethod(;
max_iters = parsed_args["max_newton_iters_ode"],
Expand Down

0 comments on commit a037520

Please sign in to comment.