Skip to content

Commit

Permalink
Use SciMLBase over OrdinaryDiffEq where possible
Browse files Browse the repository at this point in the history
  • Loading branch information
charleskawczynski committed Sep 7, 2023
1 parent b0d14a2 commit 8128865
Show file tree
Hide file tree
Showing 21 changed files with 49 additions and 50 deletions.
3 changes: 1 addition & 2 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@ CloudMicrophysics = "6a9e3e04-43cd-43ba-94b9-e8782df3c71b"
Colors = "5ae59095-9a9b-59fe-a467-6f913c188581"
Dates = "ade2ca70-3891-5945-98fb-dc099432e06a"
Dierckx = "39dd38d3-220a-591b-8e3c-4c3a8c710a94"
DiffEqBase = "2b5f629d-d688-5b77-993f-72d75c75574e"
DiffEqCallbacks = "459566f4-90b8-5000-8ac3-15dfb0a30def"
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
DocStringExtensions = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae"
Expand All @@ -38,6 +37,7 @@ Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7"
RRTMGP = "a01a1ee8-cea4-48fc-987c-fc7878d79da1"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
RootSolvers = "7181ea78-2dcb-4de3-ab41-2b8ab5a31e74"
SciMLBase = "0bca4576-84f4-4d90-8ffe-ffa030f20462"
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"
Expand All @@ -59,7 +59,6 @@ ClimaTimeSteppers = "0.7"
CloudMicrophysics = "0.13"
Colors = "0.12"
Dierckx = "0.5"
DiffEqBase = "6"
DiffEqCallbacks = "2"
Distributions = "0.25"
DocStringExtensions = "0.8, 0.9"
Expand Down
2 changes: 1 addition & 1 deletion docs/Manifest.toml
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

julia_version = "1.9.3"
manifest_format = "2.0"
project_hash = "484a0d15ba3f1cc8d5c8863c4fe1999899c42df7"
project_hash = "813a39123d5089f49d5cde44ee40847faefe369c"

[[deps.ADTypes]]
git-tree-sha1 = "a4c8e0f8c09d4aa708289c1a5fc23e2d1970017a"
Expand Down
1 change: 0 additions & 1 deletion docs/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@ ArgParse = "c7e460c6-2fb9-53a9-8c5b-16f535851c63"
CLIMAParameters = "6eacf6c3-8458-43b9-ae03-caf5306d3d53"
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
ClimaAtmos = "b2c96348-7fb7-4fe0-8da9-78d88439e717"
DiffEqBase = "2b5f629d-d688-5b77-993f-72d75c75574e"
Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4"
DocumenterCitations = "daee34ce-89f3-4625-b898-19384cb65244"
Logging = "56ddb016-857b-54e1-b83d-db4d58db5568"
Expand Down
2 changes: 1 addition & 1 deletion examples/Manifest.toml
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

julia_version = "1.9.3"
manifest_format = "2.0"
project_hash = "bbd263a7239c44698f2ead1e0abc80d97dea3133"
project_hash = "29934b734692261a625d91c36cff961d40b7b683"

[[deps.ADTypes]]
git-tree-sha1 = "a4c8e0f8c09d4aa708289c1a5fc23e2d1970017a"
Expand Down
1 change: 0 additions & 1 deletion examples/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@ CloudMicrophysics = "6a9e3e04-43cd-43ba-94b9-e8782df3c71b"
Colors = "5ae59095-9a9b-59fe-a467-6f913c188581"
Dates = "ade2ca70-3891-5945-98fb-dc099432e06a"
Dierckx = "39dd38d3-220a-591b-8e3c-4c3a8c710a94"
DiffEqBase = "2b5f629d-d688-5b77-993f-72d75c75574e"
DiffEqCallbacks = "459566f4-90b8-5000-8ac3-15dfb0a30def"
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
DocStringExtensions = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae"
Expand Down
2 changes: 1 addition & 1 deletion examples/hybrid/driver.jl
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ using Statistics: mean
import ClimaAtmos.Parameters as CAP
import Thermodynamics as TD
import ClimaComms
using OrdinaryDiffEq
using SciMLBase
using PrettyTables
using DiffEqCallbacks
using JLD2
Expand Down
2 changes: 1 addition & 1 deletion perf/Manifest.toml
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

julia_version = "1.9.3"
manifest_format = "2.0"
project_hash = "9556cf181fd1b65b5db679b93b4882215a7f787a"
project_hash = "967da1a008eca13e2322097c9c4f6c2a0c6d67ad"

[[deps.ADTypes]]
git-tree-sha1 = "a4c8e0f8c09d4aa708289c1a5fc23e2d1970017a"
Expand Down
1 change: 0 additions & 1 deletion perf/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@ CloudMicrophysics = "6a9e3e04-43cd-43ba-94b9-e8782df3c71b"
Colors = "5ae59095-9a9b-59fe-a467-6f913c188581"
Dates = "ade2ca70-3891-5945-98fb-dc099432e06a"
Dierckx = "39dd38d3-220a-591b-8e3c-4c3a8c710a94"
DiffEqBase = "2b5f629d-d688-5b77-993f-72d75c75574e"
DiffEqCallbacks = "459566f4-90b8-5000-8ac3-15dfb0a30def"
DiffEqNoiseProcess = "77a26b50-5914-5dd7-bc55-306e6241c503"
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
Expand Down
6 changes: 3 additions & 3 deletions perf/benchmark.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,9 @@ integrator = CA.get_integrator(config)

(; parsed_args) = config

import OrdinaryDiffEq as ODE
import SciMLBase
import ClimaTimeSteppers as CTS
ODE.step!(integrator) # compile first
SciMLBase.step!(integrator) # compile first

(; sol, u, p, dt, t) = integrator

Expand Down Expand Up @@ -41,7 +41,7 @@ trials["implicit_tendency!"] = get_trial(implicit_fun(integrator), implicit_args
trials["remaining_tendency!"] = get_trial(remaining_fun(integrator), remaining_args(integrator), "remaining_tendency!");
trials["additional_tendency!"] = get_trial(CA.additional_tendency!, (X, u, p, t), "additional_tendency!");
trials["hyperdiffusion_tendency!"] = get_trial(CA.hyperdiffusion_tendency!, (X, u, p, t), "hyperdiffusion_tendency!");
trials["step!"] = get_trial(ODE.step!, (integrator, ), "step!");
trials["step!"] = get_trial(SciMLBase.step!, (integrator, ), "step!");
#! format: on

table_summary = OrderedCollections.OrderedDict()
Expand Down
12 changes: 6 additions & 6 deletions perf/flame.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,8 @@ integrator = CA.get_integrator(config)
# The callbacks flame graph is very expensive, so only do 2 steps.
@info "running step"

import OrdinaryDiffEq
OrdinaryDiffEq.step!(integrator) # compile first
import SciMLBase
SciMLBase.step!(integrator) # compile first
CA.call_all_callbacks!(integrator) # compile callbacks
import Profile, ProfileCanvas
(; output_dir, job_id) = integrator.p.simulation
Expand All @@ -18,7 +18,7 @@ mkpath(output_dir)

@info "collect profile"
Profile.clear()
prof = Profile.@profile OrdinaryDiffEq.step!(integrator)
prof = Profile.@profile SciMLBase.step!(integrator)
results = Profile.fetch()
Profile.clear()

Expand All @@ -32,7 +32,7 @@ ProfileCanvas.html_file(joinpath(output_dir, "flame.html"), results)
# use new allocation profiler
@info "collecting allocations"
Profile.Allocs.clear()
Profile.Allocs.@profile sample_rate = 0.01 OrdinaryDiffEq.step!(integrator)
Profile.Allocs.@profile sample_rate = 0.01 SciMLBase.step!(integrator)
results = Profile.Allocs.fetch()
Profile.Allocs.clear()
profile = ProfileCanvas.view_allocs(results)
Expand All @@ -49,8 +49,8 @@ buffer = occursin("threaded", job_id) ? 1.4 : 1


## old allocation profiler (TODO: remove this)
allocs = @allocated OrdinaryDiffEq.step!(integrator)
@timev OrdinaryDiffEq.step!(integrator)
allocs = @allocated SciMLBase.step!(integrator)
@timev SciMLBase.step!(integrator)
@info "`allocs ($job_id)`: $(allocs)"

allocs_limit = Dict()
Expand Down
6 changes: 3 additions & 3 deletions perf/jet.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,6 @@ integrator = CA.get_integrator(config)

import JET

import OrdinaryDiffEq
OrdinaryDiffEq.step!(integrator) # Make sure no errors
JET.@test_opt OrdinaryDiffEq.step!(integrator)
import SciMLBase
SciMLBase.step!(integrator) # Make sure no errors
JET.@test_opt SciMLBase.step!(integrator)
4 changes: 2 additions & 2 deletions perf/jet_report_nfailures.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ macro n_failures(ex)
)
end

import OrdinaryDiffEq
import SciMLBase
import ClimaAtmos as CA
n = Dict()
Y = integrator.u;
Expand All @@ -20,7 +20,7 @@ t = integrator.t;
Yₜ = similar(Y);
ref_Y = similar(Y);
#! format: off
n["step!"] = @n_failures OrdinaryDiffEq.step!(integrator);
n["step!"] = @n_failures SciMLBase.step!(integrator);
n["limited_tendency!"] = @n_failures CA.limited_tendency!(Yₜ, Y, p, t);
n["horizontal_advection_tendency!"] = @n_failures CA.horizontal_advection_tendency!(Yₜ, Y, p, t);
n["horizontal_tracer_advection_tendency!"] = @n_failures CA.horizontal_tracer_advection_tendency!(Yₜ, Y, p, t);
Expand Down
6 changes: 3 additions & 3 deletions perf/jet_test_nfailures.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@ integrator = CA.get_integrator(config)

import JET

import OrdinaryDiffEq
OrdinaryDiffEq.step!(integrator) # Make sure no errors
import SciMLBase
SciMLBase.step!(integrator) # Make sure no errors

# Suggested in: https://github.com/aviatesk/JET.jl/issues/455
macro n_failures(ex)
Expand All @@ -20,7 +20,7 @@ end

using Test
@testset "Test N-jet failures" begin
n = @n_failures OrdinaryDiffEq.step!(integrator)
n = @n_failures SciMLBase.step!(integrator)
# This test is intended to provide some friction when we
# add code to our tendency function that results in degraded
# inference. By increasing this counter, we acknowledge that
Expand Down
16 changes: 8 additions & 8 deletions src/callbacks/callback_helpers.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import DiffEqCallbacks
import SciMLBase
#####
##### Callback helpers
#####
Expand All @@ -7,7 +7,7 @@ function call_every_n_steps(f!, n = 1; skip_first = false, call_at_end = false)
previous_step = Ref(0)
@assert n Inf "Adding callback that never gets called!"
cb! = AtmosCallback(f!, EveryNSteps(n))
return ODE.DiscreteCallback(
return SciMLBase.DiscreteCallback(
(u, t, integrator) ->
(previous_step[] += 1) % n == 0 ||
(call_at_end && t == integrator.sol.prob.tspan[2]),
Expand All @@ -31,7 +31,7 @@ function call_every_dt(f!, dt; skip_first = false, call_at_end = false)
next_t[] = min(next_t[], t_end)
end
end
return ODE.DiscreteCallback(
return SciMLBase.DiscreteCallback(
(u, t, integrator) -> t >= next_t[],
affect!;
initialize = (cb, u, t, integrator) -> begin
Expand All @@ -56,36 +56,36 @@ function callback_from_affect(affect!)
end
error("Callback not found in $(affect!)")
end
function atmos_callbacks(cbs::ODE.CallbackSet)
function atmos_callbacks(cbs::SciMLBase.CallbackSet)
all_cbs = [cbs.continuous_callbacks..., cbs.discrete_callbacks...]
callback_objs = map(cb -> callback_from_affect(cb.affect!), all_cbs)
filter!(x -> !(x isa DiffEqCallbacks.SavedValues), callback_objs)
return callback_objs
end

n_measured_calls(integrator) = n_measured_calls(integrator.callback)
n_measured_calls(cbs::ODE.CallbackSet) =
n_measured_calls(cbs::SciMLBase.CallbackSet) =
map(x -> x.n_measured_calls, atmos_callbacks(cbs))

n_expected_calls(integrator) = n_expected_calls(
integrator.callback,
integrator.dt,
integrator.sol.prob.tspan,
)
n_expected_calls(cbs::ODE.CallbackSet, dt, tspan) =
n_expected_calls(cbs::SciMLBase.CallbackSet, dt, tspan) =
map(x -> n_expected_calls(x, dt, tspan), atmos_callbacks(cbs))

n_steps_per_cycle(integrator) =
n_steps_per_cycle(integrator.callback, integrator.dt)
function n_steps_per_cycle(cbs::ODE.CallbackSet, dt)
function n_steps_per_cycle(cbs::SciMLBase.CallbackSet, dt)
nspc = n_steps_per_cycle_per_cb(cbs, dt)
return isempty(nspc) ? 1 : lcm(nspc)
end

n_steps_per_cycle_per_cb(integrator) =
n_steps_per_cycle_per_cb(integrator.callback, integrator.dt)

function n_steps_per_cycle_per_cb(cbs::ODE.CallbackSet, dt)
function n_steps_per_cycle_per_cb(cbs::SciMLBase.CallbackSet, dt)
return map(atmos_callbacks(cbs)) do cb
cbf = callback_frequency(cb)
if cbf isa EveryΔt
Expand Down
4 changes: 2 additions & 2 deletions src/callbacks/callbacks.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ import ClimaCore.Fields
import ClimaComms
import ClimaCore as CC
import ClimaCore.Spaces
import OrdinaryDiffEq as ODE
import SciMLBase
import ClimaAtmos.Parameters as CAP
import DiffEqCallbacks as DEQ
import ClimaCore: InputOutput
Expand Down Expand Up @@ -70,7 +70,7 @@ function turb_conv_affect_filter!(integrator)
# paying for an additional `∑tendencies!` call, which is required
# to support supplying a continuous representation of the
# solution.
ODE.u_modified!(integrator, false)
SciMLBase.u_modified!(integrator, false)
end

function rrtmgp_model_callback!(integrator)
Expand Down
1 change: 1 addition & 0 deletions src/initial_conditions/InitialConditions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ import ..TurbulenceConvection as TC
import Thermodynamics as TD
import AtmosphericProfilesLibrary as APL
import OrdinaryDiffEq as ODE
import SciMLBase
import Dierckx

include("local_state.jl")
Expand Down
9 changes: 7 additions & 2 deletions src/initial_conditions/initial_conditions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -544,8 +544,13 @@ function hydrostatic_pressure_profile(;
dp_dz(p, _, z) =
-grav * TD.air_density(thermo_params, ts(p, z, T, θ, q_tot))

prob = ODE.ODEProblem(dp_dz, p_0, (FT(0), z_max))
return ODE.solve(prob, ODE.Tsit5(), reltol = 10eps(FT), abstol = 10eps(FT))
prob = SciMLBase.ODEProblem(dp_dz, p_0, (FT(0), z_max))
return SciMLBase.solve(
prob,
ODE.Tsit5(),
reltol = 10eps(FT),
abstol = 10eps(FT),
)
end

"""
Expand Down
10 changes: 5 additions & 5 deletions src/solver/solve.jl
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ walltime_in_days(es::EfficiencyStats) = es.walltime * (1 / (24 * 3600)) #=second
function timed_solve!(integrator)
walltime = @elapsed begin
s = @timed_str begin
sol = ODE.solve!(integrator)
sol = SciMLBase.solve!(integrator)
end
end
@info "solve!: $s"
Expand Down Expand Up @@ -52,7 +52,7 @@ function solve_atmos!(integrator)
@info "Running" job_id = p.simulation.job_id output_dir =
p.simulation.output_dir tspan
comms_ctx = ClimaComms.context(axes(integrator.u.c))
ODE.step!(integrator)
SciMLBase.step!(integrator)
precompile_callbacks(integrator)
GC.gc()
try
Expand Down Expand Up @@ -111,7 +111,7 @@ for the flags outlined in a table.
"""
function benchmark_step!(integrator, Y₀, n_steps = 10)
for i in 1:n_steps
ODE.step!(integrator)
SciMLBase.step!(integrator)
integrator.u .= Y₀ # temporary hack to simplify performance benchmark.
end
return nothing
Expand All @@ -131,7 +131,7 @@ into account.
function cycle!(integrator; n_cycles = 1)
n_steps = n_steps_per_cycle(integrator) * n_cycles
for i in 1:n_steps
ODE.step!(integrator)
SciMLBase.step!(integrator)
end
return nothing
end
Expand All @@ -150,7 +150,7 @@ Precompiles `step!` and all callbacks
in the `integrator`.
"""
function precompile_atmos(integrator)
B = Base.precompile(ODE.step!, (typeof(integrator),))
B = Base.precompile(SciMLBase.step!, (typeof(integrator),))
@assert B
precompile_callbacks(integrator)
return nothing
Expand Down
9 changes: 4 additions & 5 deletions src/solver/type_getters.jl
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
using Dates: DateTime, @dateformat_str
using NCDatasets
using Dierckx
using DiffEqBase
using ImageFiltering
using Interpolations
import ClimaCore: InputOutput, Meshes, Spaces
Expand Down Expand Up @@ -359,19 +358,19 @@ is_ordinary_diffeq_newton(alg_or_tableau) =
}

is_imex_CTS_algo(::CTS.IMEXAlgorithm) = true
is_imex_CTS_algo(::DiffEqBase.AbstractODEAlgorithm) = false
is_imex_CTS_algo(::SciMLBase.AbstractODEAlgorithm) = false

is_implicit(::ODE.OrdinaryDiffEqImplicitAlgorithm) = true
is_implicit(::ODE.OrdinaryDiffEqAdaptiveImplicitAlgorithm) = true
is_implicit(ode_algo) = is_imex_CTS_algo(ode_algo)

is_rosenbrock(::ODE.Rosenbrock23) = true
is_rosenbrock(::ODE.Rosenbrock32) = true
is_rosenbrock(::DiffEqBase.AbstractODEAlgorithm) = false
is_rosenbrock(::SciMLBase.AbstractODEAlgorithm) = false
use_transform(ode_algo) =
!(is_imex_CTS_algo(ode_algo) || is_rosenbrock(ode_algo))

additional_integrator_kwargs(::DiffEqBase.AbstractODEAlgorithm) = (;
additional_integrator_kwargs(::SciMLBase.AbstractODEAlgorithm) = (;
adaptive = false,
progress = isinteractive(),
progress_steps = isinteractive() ? 1 : 1000,
Expand All @@ -382,7 +381,7 @@ additional_integrator_kwargs(::CTS.DistributedODEAlgorithm) = (;
# TODO: enable progress bars in ClimaTimeSteppers
)

is_cts_algo(::DiffEqBase.AbstractODEAlgorithm) = false
is_cts_algo(::SciMLBase.AbstractODEAlgorithm) = false
is_cts_algo(::CTS.DistributedODEAlgorithm) = true

jacobi_flags(::TotalEnergy) = (; ∂ᶜ𝔼ₜ∂ᶠ𝕄_mode = :no_∂ᶜp∂ᶜK)
Expand Down
1 change: 0 additions & 1 deletion src/utils/utilities.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
import ClimaComms
import ClimaCore: Spaces, Topologies, Fields, Geometry
import LinearAlgebra: norm_sqr
import DiffEqBase
import JLD2

is_energy_var(symbol) = symbol in (:ρθ, :ρe_tot, :ρaθ, :ρae_tot)
Expand Down
Loading

0 comments on commit 8128865

Please sign in to comment.