Skip to content

Commit

Permalink
Use SciMLBase over OrdinaryDiffEq where possible
Browse files Browse the repository at this point in the history
  • Loading branch information
charleskawczynski committed Sep 7, 2023
1 parent b0d14a2 commit 21c2578
Show file tree
Hide file tree
Showing 11 changed files with 44 additions and 39 deletions.
2 changes: 1 addition & 1 deletion examples/hybrid/driver.jl
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ using Statistics: mean
import ClimaAtmos.Parameters as CAP
import Thermodynamics as TD
import ClimaComms
using OrdinaryDiffEq
using SciMLBase
using PrettyTables
using DiffEqCallbacks
using JLD2
Expand Down
6 changes: 3 additions & 3 deletions perf/benchmark.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,9 @@ integrator = CA.get_integrator(config)

(; parsed_args) = config

import OrdinaryDiffEq as ODE
import SciMLBase
import ClimaTimeSteppers as CTS
ODE.step!(integrator) # compile first
SciMLBase.step!(integrator) # compile first

(; sol, u, p, dt, t) = integrator

Expand Down Expand Up @@ -41,7 +41,7 @@ trials["implicit_tendency!"] = get_trial(implicit_fun(integrator), implicit_args
trials["remaining_tendency!"] = get_trial(remaining_fun(integrator), remaining_args(integrator), "remaining_tendency!");
trials["additional_tendency!"] = get_trial(CA.additional_tendency!, (X, u, p, t), "additional_tendency!");
trials["hyperdiffusion_tendency!"] = get_trial(CA.hyperdiffusion_tendency!, (X, u, p, t), "hyperdiffusion_tendency!");
trials["step!"] = get_trial(ODE.step!, (integrator, ), "step!");
trials["step!"] = get_trial(SciMLBase.step!, (integrator, ), "step!");
#! format: on

table_summary = OrderedCollections.OrderedDict()
Expand Down
12 changes: 6 additions & 6 deletions perf/flame.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,8 @@ integrator = CA.get_integrator(config)
# The callbacks flame graph is very expensive, so only do 2 steps.
@info "running step"

import OrdinaryDiffEq
OrdinaryDiffEq.step!(integrator) # compile first
import SciMLBase
SciMLBase.step!(integrator) # compile first
CA.call_all_callbacks!(integrator) # compile callbacks
import Profile, ProfileCanvas
(; output_dir, job_id) = integrator.p.simulation
Expand All @@ -18,7 +18,7 @@ mkpath(output_dir)

@info "collect profile"
Profile.clear()
prof = Profile.@profile OrdinaryDiffEq.step!(integrator)
prof = Profile.@profile SciMLBase.step!(integrator)
results = Profile.fetch()
Profile.clear()

Expand All @@ -32,7 +32,7 @@ ProfileCanvas.html_file(joinpath(output_dir, "flame.html"), results)
# use new allocation profiler
@info "collecting allocations"
Profile.Allocs.clear()
Profile.Allocs.@profile sample_rate = 0.01 OrdinaryDiffEq.step!(integrator)
Profile.Allocs.@profile sample_rate = 0.01 SciMLBase.step!(integrator)
results = Profile.Allocs.fetch()
Profile.Allocs.clear()
profile = ProfileCanvas.view_allocs(results)
Expand All @@ -49,8 +49,8 @@ buffer = occursin("threaded", job_id) ? 1.4 : 1


## old allocation profiler (TODO: remove this)
allocs = @allocated OrdinaryDiffEq.step!(integrator)
@timev OrdinaryDiffEq.step!(integrator)
allocs = @allocated SciMLBase.step!(integrator)
@timev SciMLBase.step!(integrator)
@info "`allocs ($job_id)`: $(allocs)"

allocs_limit = Dict()
Expand Down
6 changes: 3 additions & 3 deletions perf/jet.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,6 @@ integrator = CA.get_integrator(config)

import JET

import OrdinaryDiffEq
OrdinaryDiffEq.step!(integrator) # Make sure no errors
JET.@test_opt OrdinaryDiffEq.step!(integrator)
import SciMLBase
SciMLBase.step!(integrator) # Make sure no errors
JET.@test_opt SciMLBase.step!(integrator)
4 changes: 2 additions & 2 deletions perf/jet_report_nfailures.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ macro n_failures(ex)
)
end

import OrdinaryDiffEq
import SciMLBase
import ClimaAtmos as CA
n = Dict()
Y = integrator.u;
Expand All @@ -20,7 +20,7 @@ t = integrator.t;
Yₜ = similar(Y);
ref_Y = similar(Y);
#! format: off
n["step!"] = @n_failures OrdinaryDiffEq.step!(integrator);
n["step!"] = @n_failures SciMLBase.step!(integrator);
n["limited_tendency!"] = @n_failures CA.limited_tendency!(Yₜ, Y, p, t);
n["horizontal_advection_tendency!"] = @n_failures CA.horizontal_advection_tendency!(Yₜ, Y, p, t);
n["horizontal_tracer_advection_tendency!"] = @n_failures CA.horizontal_tracer_advection_tendency!(Yₜ, Y, p, t);
Expand Down
6 changes: 3 additions & 3 deletions perf/jet_test_nfailures.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@ integrator = CA.get_integrator(config)

import JET

import OrdinaryDiffEq
OrdinaryDiffEq.step!(integrator) # Make sure no errors
import SciMLBase
SciMLBase.step!(integrator) # Make sure no errors

# Suggested in: https://github.com/aviatesk/JET.jl/issues/455
macro n_failures(ex)
Expand All @@ -20,7 +20,7 @@ end

using Test
@testset "Test N-jet failures" begin
n = @n_failures OrdinaryDiffEq.step!(integrator)
n = @n_failures SciMLBase.step!(integrator)
# This test is intended to provide some friction when we
# add code to our tendency function that results in degraded
# inference. By increasing this counter, we acknowledge that
Expand Down
16 changes: 8 additions & 8 deletions src/callbacks/callback_helpers.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import DiffEqCallbacks
import SciMLBase
#####
##### Callback helpers
#####
Expand All @@ -7,7 +7,7 @@ function call_every_n_steps(f!, n = 1; skip_first = false, call_at_end = false)
previous_step = Ref(0)
@assert n Inf "Adding callback that never gets called!"
cb! = AtmosCallback(f!, EveryNSteps(n))
return ODE.DiscreteCallback(
return SciMLBase.DiscreteCallback(
(u, t, integrator) ->
(previous_step[] += 1) % n == 0 ||
(call_at_end && t == integrator.sol.prob.tspan[2]),
Expand All @@ -31,7 +31,7 @@ function call_every_dt(f!, dt; skip_first = false, call_at_end = false)
next_t[] = min(next_t[], t_end)
end
end
return ODE.DiscreteCallback(
return SciMLBase.DiscreteCallback(
(u, t, integrator) -> t >= next_t[],
affect!;
initialize = (cb, u, t, integrator) -> begin
Expand All @@ -56,36 +56,36 @@ function callback_from_affect(affect!)
end
error("Callback not found in $(affect!)")
end
function atmos_callbacks(cbs::ODE.CallbackSet)
function atmos_callbacks(cbs::SciMLBase.CallbackSet)
all_cbs = [cbs.continuous_callbacks..., cbs.discrete_callbacks...]
callback_objs = map(cb -> callback_from_affect(cb.affect!), all_cbs)
filter!(x -> !(x isa DiffEqCallbacks.SavedValues), callback_objs)
return callback_objs
end

n_measured_calls(integrator) = n_measured_calls(integrator.callback)
n_measured_calls(cbs::ODE.CallbackSet) =
n_measured_calls(cbs::SciMLBase.CallbackSet) =
map(x -> x.n_measured_calls, atmos_callbacks(cbs))

n_expected_calls(integrator) = n_expected_calls(
integrator.callback,
integrator.dt,
integrator.sol.prob.tspan,
)
n_expected_calls(cbs::ODE.CallbackSet, dt, tspan) =
n_expected_calls(cbs::SciMLBase.CallbackSet, dt, tspan) =
map(x -> n_expected_calls(x, dt, tspan), atmos_callbacks(cbs))

n_steps_per_cycle(integrator) =
n_steps_per_cycle(integrator.callback, integrator.dt)
function n_steps_per_cycle(cbs::ODE.CallbackSet, dt)
function n_steps_per_cycle(cbs::SciMLBase.CallbackSet, dt)
nspc = n_steps_per_cycle_per_cb(cbs, dt)
return isempty(nspc) ? 1 : lcm(nspc)
end

n_steps_per_cycle_per_cb(integrator) =
n_steps_per_cycle_per_cb(integrator.callback, integrator.dt)

function n_steps_per_cycle_per_cb(cbs::ODE.CallbackSet, dt)
function n_steps_per_cycle_per_cb(cbs::SciMLBase.CallbackSet, dt)
return map(atmos_callbacks(cbs)) do cb
cbf = callback_frequency(cb)
if cbf isa EveryΔt
Expand Down
4 changes: 2 additions & 2 deletions src/callbacks/callbacks.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ import ClimaCore.Fields
import ClimaComms
import ClimaCore as CC
import ClimaCore.Spaces
import OrdinaryDiffEq as ODE
import SciMLBase
import ClimaAtmos.Parameters as CAP
import DiffEqCallbacks as DEQ
import ClimaCore: InputOutput
Expand Down Expand Up @@ -70,7 +70,7 @@ function turb_conv_affect_filter!(integrator)
# paying for an additional `∑tendencies!` call, which is required
# to support supplying a continuous representation of the
# solution.
ODE.u_modified!(integrator, false)
SciMLBase.u_modified!(integrator, false)
end

function rrtmgp_model_callback!(integrator)
Expand Down
9 changes: 7 additions & 2 deletions src/initial_conditions/initial_conditions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -544,8 +544,13 @@ function hydrostatic_pressure_profile(;
dp_dz(p, _, z) =
-grav * TD.air_density(thermo_params, ts(p, z, T, θ, q_tot))

prob = ODE.ODEProblem(dp_dz, p_0, (FT(0), z_max))
return ODE.solve(prob, ODE.Tsit5(), reltol = 10eps(FT), abstol = 10eps(FT))
prob = SciMLBase.ODEProblem(dp_dz, p_0, (FT(0), z_max))
return SciMLBase.solve(
prob,
ODE.Tsit5(),
reltol = 10eps(FT),
abstol = 10eps(FT),
)
end

"""
Expand Down
10 changes: 5 additions & 5 deletions src/solver/solve.jl
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ walltime_in_days(es::EfficiencyStats) = es.walltime * (1 / (24 * 3600)) #=second
function timed_solve!(integrator)
walltime = @elapsed begin
s = @timed_str begin
sol = ODE.solve!(integrator)
sol = SciMLBase.solve!(integrator)
end
end
@info "solve!: $s"
Expand Down Expand Up @@ -52,7 +52,7 @@ function solve_atmos!(integrator)
@info "Running" job_id = p.simulation.job_id output_dir =
p.simulation.output_dir tspan
comms_ctx = ClimaComms.context(axes(integrator.u.c))
ODE.step!(integrator)
SciMLBase.step!(integrator)
precompile_callbacks(integrator)
GC.gc()
try
Expand Down Expand Up @@ -111,7 +111,7 @@ for the flags outlined in a table.
"""
function benchmark_step!(integrator, Y₀, n_steps = 10)
for i in 1:n_steps
ODE.step!(integrator)
SciMLBase.step!(integrator)
integrator.u .= Y₀ # temporary hack to simplify performance benchmark.
end
return nothing
Expand All @@ -131,7 +131,7 @@ into account.
function cycle!(integrator; n_cycles = 1)
n_steps = n_steps_per_cycle(integrator) * n_cycles
for i in 1:n_steps
ODE.step!(integrator)
SciMLBase.step!(integrator)
end
return nothing
end
Expand All @@ -150,7 +150,7 @@ Precompiles `step!` and all callbacks
in the `integrator`.
"""
function precompile_atmos(integrator)
B = Base.precompile(ODE.step!, (typeof(integrator),))
B = Base.precompile(SciMLBase.step!, (typeof(integrator),))
@assert B
precompile_callbacks(integrator)
return nothing
Expand Down
8 changes: 4 additions & 4 deletions src/solver/type_getters.jl
Original file line number Diff line number Diff line change
Expand Up @@ -359,19 +359,19 @@ is_ordinary_diffeq_newton(alg_or_tableau) =
}

is_imex_CTS_algo(::CTS.IMEXAlgorithm) = true
is_imex_CTS_algo(::DiffEqBase.AbstractODEAlgorithm) = false
is_imex_CTS_algo(::SciMLBase.AbstractODEAlgorithm) = false

is_implicit(::ODE.OrdinaryDiffEqImplicitAlgorithm) = true
is_implicit(::ODE.OrdinaryDiffEqAdaptiveImplicitAlgorithm) = true
is_implicit(ode_algo) = is_imex_CTS_algo(ode_algo)

is_rosenbrock(::ODE.Rosenbrock23) = true
is_rosenbrock(::ODE.Rosenbrock32) = true
is_rosenbrock(::DiffEqBase.AbstractODEAlgorithm) = false
is_rosenbrock(::SciMLBase.AbstractODEAlgorithm) = false
use_transform(ode_algo) =
!(is_imex_CTS_algo(ode_algo) || is_rosenbrock(ode_algo))

additional_integrator_kwargs(::DiffEqBase.AbstractODEAlgorithm) = (;
additional_integrator_kwargs(::SciMLBase.AbstractODEAlgorithm) = (;
adaptive = false,
progress = isinteractive(),
progress_steps = isinteractive() ? 1 : 1000,
Expand All @@ -382,7 +382,7 @@ additional_integrator_kwargs(::CTS.DistributedODEAlgorithm) = (;
# TODO: enable progress bars in ClimaTimeSteppers
)

is_cts_algo(::DiffEqBase.AbstractODEAlgorithm) = false
is_cts_algo(::SciMLBase.AbstractODEAlgorithm) = false
is_cts_algo(::CTS.DistributedODEAlgorithm) = true

jacobi_flags(::TotalEnergy) = (; ∂ᶜ𝔼ₜ∂ᶠ𝕄_mode = :no_∂ᶜp∂ᶜK)
Expand Down

0 comments on commit 21c2578

Please sign in to comment.