Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use SciMLBase over OrdinaryDiffEq where possible #2074

Merged
merged 3 commits into from
Sep 8, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@ CloudMicrophysics = "6a9e3e04-43cd-43ba-94b9-e8782df3c71b"
Colors = "5ae59095-9a9b-59fe-a467-6f913c188581"
Dates = "ade2ca70-3891-5945-98fb-dc099432e06a"
Dierckx = "39dd38d3-220a-591b-8e3c-4c3a8c710a94"
DiffEqBase = "2b5f629d-d688-5b77-993f-72d75c75574e"
DiffEqCallbacks = "459566f4-90b8-5000-8ac3-15dfb0a30def"
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
DocStringExtensions = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae"
Expand All @@ -38,6 +37,7 @@ Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7"
RRTMGP = "a01a1ee8-cea4-48fc-987c-fc7878d79da1"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
RootSolvers = "7181ea78-2dcb-4de3-ab41-2b8ab5a31e74"
SciMLBase = "0bca4576-84f4-4d90-8ffe-ffa030f20462"
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"
Expand All @@ -59,7 +59,6 @@ ClimaTimeSteppers = "0.7"
CloudMicrophysics = "0.13"
Colors = "0.12"
Dierckx = "0.5"
DiffEqBase = "6"
DiffEqCallbacks = "2"
Distributions = "0.25"
DocStringExtensions = "0.8, 0.9"
Expand All @@ -76,6 +75,7 @@ OrdinaryDiffEq = "5, 6"
Pkg = "1.8"
RRTMGP = "0.9"
RootSolvers = "0.2, 0.3, 0.4"
SciMLBase = "1"
StaticArrays = "1"
StatsBase = "0.33"
SurfaceFluxes = "0.7"
Expand Down
12 changes: 6 additions & 6 deletions docs/Manifest.toml
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

julia_version = "1.9.3"
manifest_format = "2.0"
project_hash = "484a0d15ba3f1cc8d5c8863c4fe1999899c42df7"
project_hash = "7fc686bc0a71a5f83e48b1285d03c4910b039279"

[[deps.ADTypes]]
git-tree-sha1 = "a4c8e0f8c09d4aa708289c1a5fc23e2d1970017a"
Expand Down Expand Up @@ -247,7 +247,7 @@ uuid = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
version = "1.16.0"

[[deps.ClimaAtmos]]
deps = ["ArgParse", "ArtifactWrappers", "Artifacts", "AtmosphericProfilesLibrary", "CLIMAParameters", "CUDA", "ClimaComms", "ClimaCore", "ClimaTimeSteppers", "CloudMicrophysics", "Colors", "Dates", "Dierckx", "DiffEqBase", "DiffEqCallbacks", "Distributions", "DocStringExtensions", "FastGaussQuadrature", "ImageFiltering", "Insolation", "Interpolations", "IntervalSets", "JLD2", "LambertW", "LinearAlgebra", "Logging", "NCDatasets", "NVTX", "OrdinaryDiffEq", "Pkg", "Printf", "RRTMGP", "Random", "RootSolvers", "StaticArrays", "Statistics", "StatsBase", "SurfaceFluxes", "TerminalLoggers", "Test", "Thermodynamics", "YAML"]
deps = ["ArgParse", "ArtifactWrappers", "Artifacts", "AtmosphericProfilesLibrary", "CLIMAParameters", "CUDA", "ClimaComms", "ClimaCore", "ClimaTimeSteppers", "CloudMicrophysics", "Colors", "Dates", "Dierckx", "DiffEqCallbacks", "Distributions", "DocStringExtensions", "FastGaussQuadrature", "ImageFiltering", "Insolation", "Interpolations", "IntervalSets", "JLD2", "LambertW", "LinearAlgebra", "Logging", "NCDatasets", "NVTX", "OrdinaryDiffEq", "Pkg", "Printf", "RRTMGP", "Random", "RootSolvers", "SciMLBase", "StaticArrays", "Statistics", "StatsBase", "SurfaceFluxes", "TerminalLoggers", "Test", "Thermodynamics", "YAML"]
path = ".."
uuid = "b2c96348-7fb7-4fe0-8da9-78d88439e717"
version = "0.16.0"
Expand Down Expand Up @@ -697,9 +697,9 @@ version = "0.16.16"

[[deps.HDF5_jll]]
deps = ["Artifacts", "CompilerSupportLibraries_jll", "JLLWrappers", "LLVMOpenMP_jll", "LazyArtifacts", "LibCURL_jll", "Libdl", "MPICH_jll", "MPIPreferences", "MPItrampoline_jll", "MicrosoftMPI_jll", "OpenMPI_jll", "OpenSSL_jll", "TOML", "Zlib_jll", "libaec_jll"]
git-tree-sha1 = "10c72358aaaa5cd6bc7cc39b95e6eadf92f5a336"
git-tree-sha1 = "38c8874692d48d5440d5752d6c74b0c6b0b60739"
uuid = "0234f1f7-429e-5d53-9886-15a909be8d59"
version = "1.14.2+0"
version = "1.14.2+1"

[[deps.HostCPUFeatures]]
deps = ["BitTwiddlingConvenienceFunctions", "IfElse", "Libdl", "Static"]
Expand Down Expand Up @@ -1574,9 +1574,9 @@ version = "1.9.0"

[[deps.StatsAPI]]
deps = ["LinearAlgebra"]
git-tree-sha1 = "45a7769a04a3cf80da1c1c7c60caf932e6f4c9f7"
git-tree-sha1 = "1ff449ad350c9c4cbc756624d6f8a8c3ef56d3ed"
uuid = "82ae8749-77ed-4fe6-ae5f-f523153014b0"
version = "1.6.0"
version = "1.7.0"

[[deps.StatsBase]]
deps = ["DataAPI", "DataStructures", "LinearAlgebra", "LogExpFunctions", "Missings", "Printf", "Random", "SortingAlgorithms", "SparseArrays", "Statistics", "StatsAPI"]
Expand Down
2 changes: 1 addition & 1 deletion docs/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,13 @@ ArgParse = "c7e460c6-2fb9-53a9-8c5b-16f535851c63"
CLIMAParameters = "6eacf6c3-8458-43b9-ae03-caf5306d3d53"
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
ClimaAtmos = "b2c96348-7fb7-4fe0-8da9-78d88439e717"
DiffEqBase = "2b5f629d-d688-5b77-993f-72d75c75574e"
Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4"
DocumenterCitations = "daee34ce-89f3-4625-b898-19384cb65244"
Logging = "56ddb016-857b-54e1-b83d-db4d58db5568"
OrderedCollections = "bac558e1-5e72-5ebc-8fee-abe8a469f55d"
PrettyTables = "08abe8d2-0d0c-5749-adfa-8a2ac140af0d"
RootSolvers = "7181ea78-2dcb-4de3-ab41-2b8ab5a31e74"
SciMLBase = "0bca4576-84f4-4d90-8ffe-ffa030f20462"
TOML = "fa267f1f-6049-4f14-aa54-33bafae1ed76"
TerminalLoggers = "5d786b92-1e48-4d6f-9151-6b4477ca9bed"
YAML = "ddb6d928-2868-570f-bddf-ab3f9cf99eb6"
Expand Down
8 changes: 4 additions & 4 deletions examples/Manifest.toml
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

julia_version = "1.9.3"
manifest_format = "2.0"
project_hash = "bbd263a7239c44698f2ead1e0abc80d97dea3133"
project_hash = "29934b734692261a625d91c36cff961d40b7b683"

[[deps.ADTypes]]
git-tree-sha1 = "a4c8e0f8c09d4aa708289c1a5fc23e2d1970017a"
Expand Down Expand Up @@ -286,7 +286,7 @@ uuid = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
version = "1.16.0"

[[deps.ClimaAtmos]]
deps = ["ArgParse", "ArtifactWrappers", "Artifacts", "AtmosphericProfilesLibrary", "CLIMAParameters", "CUDA", "ClimaComms", "ClimaCore", "ClimaTimeSteppers", "CloudMicrophysics", "Colors", "Dates", "Dierckx", "DiffEqBase", "DiffEqCallbacks", "Distributions", "DocStringExtensions", "FastGaussQuadrature", "ImageFiltering", "Insolation", "Interpolations", "IntervalSets", "JLD2", "LambertW", "LinearAlgebra", "Logging", "NCDatasets", "NVTX", "OrdinaryDiffEq", "Pkg", "Printf", "RRTMGP", "Random", "RootSolvers", "StaticArrays", "Statistics", "StatsBase", "SurfaceFluxes", "TerminalLoggers", "Test", "Thermodynamics", "YAML"]
deps = ["ArgParse", "ArtifactWrappers", "Artifacts", "AtmosphericProfilesLibrary", "CLIMAParameters", "CUDA", "ClimaComms", "ClimaCore", "ClimaTimeSteppers", "CloudMicrophysics", "Colors", "Dates", "Dierckx", "DiffEqCallbacks", "Distributions", "DocStringExtensions", "FastGaussQuadrature", "ImageFiltering", "Insolation", "Interpolations", "IntervalSets", "JLD2", "LambertW", "LinearAlgebra", "Logging", "NCDatasets", "NVTX", "OrdinaryDiffEq", "Pkg", "Printf", "RRTMGP", "Random", "RootSolvers", "SciMLBase", "StaticArrays", "Statistics", "StatsBase", "SurfaceFluxes", "TerminalLoggers", "Test", "Thermodynamics", "YAML"]
path = ".."
uuid = "b2c96348-7fb7-4fe0-8da9-78d88439e717"
version = "0.16.0"
Expand Down Expand Up @@ -2349,9 +2349,9 @@ version = "1.9.0"

[[deps.StatsAPI]]
deps = ["LinearAlgebra"]
git-tree-sha1 = "45a7769a04a3cf80da1c1c7c60caf932e6f4c9f7"
git-tree-sha1 = "1ff449ad350c9c4cbc756624d6f8a8c3ef56d3ed"
uuid = "82ae8749-77ed-4fe6-ae5f-f523153014b0"
version = "1.6.0"
version = "1.7.0"

[[deps.StatsBase]]
deps = ["DataAPI", "DataStructures", "LinearAlgebra", "LogExpFunctions", "Missings", "Printf", "Random", "SortingAlgorithms", "SparseArrays", "Statistics", "StatsAPI"]
Expand Down
1 change: 0 additions & 1 deletion examples/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@ CloudMicrophysics = "6a9e3e04-43cd-43ba-94b9-e8782df3c71b"
Colors = "5ae59095-9a9b-59fe-a467-6f913c188581"
Dates = "ade2ca70-3891-5945-98fb-dc099432e06a"
Dierckx = "39dd38d3-220a-591b-8e3c-4c3a8c710a94"
DiffEqBase = "2b5f629d-d688-5b77-993f-72d75c75574e"
DiffEqCallbacks = "459566f4-90b8-5000-8ac3-15dfb0a30def"
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
DocStringExtensions = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae"
Expand Down
4 changes: 2 additions & 2 deletions examples/hybrid/driver.jl
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,9 @@ using Statistics: mean
import ClimaAtmos.Parameters as CAP
import Thermodynamics as TD
import ClimaComms
using OrdinaryDiffEq
using SciMLBase
using PrettyTables
using DiffEqCallbacks
import DiffEqCallbacks as DECB
using JLD2
using NCDatasets
using ClimaTimeSteppers
Expand Down
12 changes: 6 additions & 6 deletions perf/Manifest.toml
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

julia_version = "1.9.3"
manifest_format = "2.0"
project_hash = "9556cf181fd1b65b5db679b93b4882215a7f787a"
project_hash = "967da1a008eca13e2322097c9c4f6c2a0c6d67ad"

[[deps.ADTypes]]
git-tree-sha1 = "a4c8e0f8c09d4aa708289c1a5fc23e2d1970017a"
Expand Down Expand Up @@ -297,7 +297,7 @@ uuid = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
version = "1.16.0"

[[deps.ClimaAtmos]]
deps = ["ArgParse", "ArtifactWrappers", "Artifacts", "AtmosphericProfilesLibrary", "CLIMAParameters", "CUDA", "ClimaComms", "ClimaCore", "ClimaTimeSteppers", "CloudMicrophysics", "Colors", "Dates", "Dierckx", "DiffEqBase", "DiffEqCallbacks", "Distributions", "DocStringExtensions", "FastGaussQuadrature", "ImageFiltering", "Insolation", "Interpolations", "IntervalSets", "JLD2", "LambertW", "LinearAlgebra", "Logging", "NCDatasets", "NVTX", "OrdinaryDiffEq", "Pkg", "Printf", "RRTMGP", "Random", "RootSolvers", "StaticArrays", "Statistics", "StatsBase", "SurfaceFluxes", "TerminalLoggers", "Test", "Thermodynamics", "YAML"]
deps = ["ArgParse", "ArtifactWrappers", "Artifacts", "AtmosphericProfilesLibrary", "CLIMAParameters", "CUDA", "ClimaComms", "ClimaCore", "ClimaTimeSteppers", "CloudMicrophysics", "Colors", "Dates", "Dierckx", "DiffEqCallbacks", "Distributions", "DocStringExtensions", "FastGaussQuadrature", "ImageFiltering", "Insolation", "Interpolations", "IntervalSets", "JLD2", "LambertW", "LinearAlgebra", "Logging", "NCDatasets", "NVTX", "OrdinaryDiffEq", "Pkg", "Printf", "RRTMGP", "Random", "RootSolvers", "SciMLBase", "StaticArrays", "Statistics", "StatsBase", "SurfaceFluxes", "TerminalLoggers", "Test", "Thermodynamics", "YAML"]
path = ".."
uuid = "b2c96348-7fb7-4fe0-8da9-78d88439e717"
version = "0.16.0"
Expand Down Expand Up @@ -2431,9 +2431,9 @@ version = "1.9.0"

[[deps.StatsAPI]]
deps = ["LinearAlgebra"]
git-tree-sha1 = "45a7769a04a3cf80da1c1c7c60caf932e6f4c9f7"
git-tree-sha1 = "1ff449ad350c9c4cbc756624d6f8a8c3ef56d3ed"
uuid = "82ae8749-77ed-4fe6-ae5f-f523153014b0"
version = "1.6.0"
version = "1.7.0"

[[deps.StatsBase]]
deps = ["DataAPI", "DataStructures", "LinearAlgebra", "LogExpFunctions", "Missings", "Printf", "Random", "SortingAlgorithms", "SparseArrays", "Statistics", "StatsAPI"]
Expand Down Expand Up @@ -2623,9 +2623,9 @@ version = "1.3.0"

[[deps.TypedSyntax]]
deps = ["CodeTracking", "JuliaSyntax"]
git-tree-sha1 = "34f0ab1aa1b869840cfc4e1e33074030e90ece7e"
git-tree-sha1 = "79ea8a4993ed5d341580c4044433e0259fceb4c6"
uuid = "d265eb64-f81a-44ad-a842-4247ee1503de"
version = "1.2.2"
version = "1.2.3"

[[deps.URIs]]
git-tree-sha1 = "b7a5e99f24892b6824a954199a45e9ffcc1c70f0"
Expand Down
1 change: 0 additions & 1 deletion perf/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@ CloudMicrophysics = "6a9e3e04-43cd-43ba-94b9-e8782df3c71b"
Colors = "5ae59095-9a9b-59fe-a467-6f913c188581"
Dates = "ade2ca70-3891-5945-98fb-dc099432e06a"
Dierckx = "39dd38d3-220a-591b-8e3c-4c3a8c710a94"
DiffEqBase = "2b5f629d-d688-5b77-993f-72d75c75574e"
DiffEqCallbacks = "459566f4-90b8-5000-8ac3-15dfb0a30def"
DiffEqNoiseProcess = "77a26b50-5914-5dd7-bc55-306e6241c503"
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
Expand Down
6 changes: 3 additions & 3 deletions perf/benchmark.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,9 @@ integrator = CA.get_integrator(config)

(; parsed_args) = config

import OrdinaryDiffEq as ODE
import SciMLBase
import ClimaTimeSteppers as CTS
ODE.step!(integrator) # compile first
SciMLBase.step!(integrator) # compile first

(; sol, u, p, dt, t) = integrator

Expand Down Expand Up @@ -41,7 +41,7 @@ trials["implicit_tendency!"] = get_trial(implicit_fun(integrator), implicit_args
trials["remaining_tendency!"] = get_trial(remaining_fun(integrator), remaining_args(integrator), "remaining_tendency!");
trials["additional_tendency!"] = get_trial(CA.additional_tendency!, (X, u, p, t), "additional_tendency!");
trials["hyperdiffusion_tendency!"] = get_trial(CA.hyperdiffusion_tendency!, (X, u, p, t), "hyperdiffusion_tendency!");
trials["step!"] = get_trial(ODE.step!, (integrator, ), "step!");
trials["step!"] = get_trial(SciMLBase.step!, (integrator, ), "step!");
#! format: on

table_summary = OrderedCollections.OrderedDict()
Expand Down
12 changes: 6 additions & 6 deletions perf/flame.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,8 @@ integrator = CA.get_integrator(config)
# The callbacks flame graph is very expensive, so only do 2 steps.
@info "running step"

import OrdinaryDiffEq
OrdinaryDiffEq.step!(integrator) # compile first
import SciMLBase
SciMLBase.step!(integrator) # compile first
CA.call_all_callbacks!(integrator) # compile callbacks
import Profile, ProfileCanvas
(; output_dir, job_id) = integrator.p.simulation
Expand All @@ -18,7 +18,7 @@ mkpath(output_dir)

@info "collect profile"
Profile.clear()
prof = Profile.@profile OrdinaryDiffEq.step!(integrator)
prof = Profile.@profile SciMLBase.step!(integrator)
results = Profile.fetch()
Profile.clear()

Expand All @@ -32,7 +32,7 @@ ProfileCanvas.html_file(joinpath(output_dir, "flame.html"), results)
# use new allocation profiler
@info "collecting allocations"
Profile.Allocs.clear()
Profile.Allocs.@profile sample_rate = 0.01 OrdinaryDiffEq.step!(integrator)
Profile.Allocs.@profile sample_rate = 0.01 SciMLBase.step!(integrator)
results = Profile.Allocs.fetch()
Profile.Allocs.clear()
profile = ProfileCanvas.view_allocs(results)
Expand All @@ -49,8 +49,8 @@ buffer = occursin("threaded", job_id) ? 1.4 : 1


## old allocation profiler (TODO: remove this)
allocs = @allocated OrdinaryDiffEq.step!(integrator)
@timev OrdinaryDiffEq.step!(integrator)
allocs = @allocated SciMLBase.step!(integrator)
@timev SciMLBase.step!(integrator)
@info "`allocs ($job_id)`: $(allocs)"

allocs_limit = Dict()
Expand Down
6 changes: 3 additions & 3 deletions perf/jet.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,6 @@ integrator = CA.get_integrator(config)

import JET

import OrdinaryDiffEq
OrdinaryDiffEq.step!(integrator) # Make sure no errors
JET.@test_opt OrdinaryDiffEq.step!(integrator)
import SciMLBase
SciMLBase.step!(integrator) # Make sure no errors
JET.@test_opt SciMLBase.step!(integrator)
4 changes: 2 additions & 2 deletions perf/jet_report_nfailures.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ macro n_failures(ex)
)
end

import OrdinaryDiffEq
import SciMLBase
import ClimaAtmos as CA
n = Dict()
Y = integrator.u;
Expand All @@ -20,7 +20,7 @@ t = integrator.t;
Yₜ = similar(Y);
ref_Y = similar(Y);
#! format: off
n["step!"] = @n_failures OrdinaryDiffEq.step!(integrator);
n["step!"] = @n_failures SciMLBase.step!(integrator);
n["limited_tendency!"] = @n_failures CA.limited_tendency!(Yₜ, Y, p, t);
n["horizontal_advection_tendency!"] = @n_failures CA.horizontal_advection_tendency!(Yₜ, Y, p, t);
n["horizontal_tracer_advection_tendency!"] = @n_failures CA.horizontal_tracer_advection_tendency!(Yₜ, Y, p, t);
Expand Down
8 changes: 4 additions & 4 deletions perf/jet_test_nfailures.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@ integrator = CA.get_integrator(config)

import JET

import OrdinaryDiffEq
OrdinaryDiffEq.step!(integrator) # Make sure no errors
import SciMLBase
SciMLBase.step!(integrator) # Make sure no errors

# Suggested in: https://github.com/aviatesk/JET.jl/issues/455
macro n_failures(ex)
Expand All @@ -20,13 +20,13 @@ end

using Test
@testset "Test N-jet failures" begin
n = @n_failures OrdinaryDiffEq.step!(integrator)
n = @n_failures SciMLBase.step!(integrator)
# This test is intended to provide some friction when we
# add code to our tendency function that results in degraded
# inference. By increasing this counter, we acknowledge that
# we have introduced an inference failure. We hope to drive
# this number down to 0.
n_allowed_failures = 256
n_allowed_failures = 680
@test n ≤ n_allowed_failures
if n < n_allowed_failures
@info "Please update the n-failures to $n"
Expand Down
20 changes: 10 additions & 10 deletions src/callbacks/callback_helpers.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import DiffEqCallbacks
import SciMLBase
#####
##### Callback helpers
#####
Expand All @@ -7,7 +7,7 @@ function call_every_n_steps(f!, n = 1; skip_first = false, call_at_end = false)
previous_step = Ref(0)
@assert n ≠ Inf "Adding callback that never gets called!"
cb! = AtmosCallback(f!, EveryNSteps(n))
return ODE.DiscreteCallback(
return SciMLBase.DiscreteCallback(
(u, t, integrator) ->
(previous_step[] += 1) % n == 0 ||
(call_at_end && t == integrator.sol.prob.tspan[2]),
Expand All @@ -31,7 +31,7 @@ function call_every_dt(f!, dt; skip_first = false, call_at_end = false)
next_t[] = min(next_t[], t_end)
end
end
return ODE.DiscreteCallback(
return SciMLBase.DiscreteCallback(
(u, t, integrator) -> t >= next_t[],
affect!;
initialize = (cb, u, t, integrator) -> begin
Expand All @@ -50,42 +50,42 @@ function callback_from_affect(affect!)
x = getproperty(affect!, p)
if x isa AtmosCallback
return x
elseif x isa DiffEqCallbacks.SavedValues
elseif x isa DECB.SavedValues
return x
end
end
error("Callback not found in $(affect!)")
end
function atmos_callbacks(cbs::ODE.CallbackSet)
function atmos_callbacks(cbs::SciMLBase.CallbackSet)
all_cbs = [cbs.continuous_callbacks..., cbs.discrete_callbacks...]
callback_objs = map(cb -> callback_from_affect(cb.affect!), all_cbs)
filter!(x -> !(x isa DiffEqCallbacks.SavedValues), callback_objs)
filter!(x -> !(x isa DECB.SavedValues), callback_objs)
return callback_objs
end

n_measured_calls(integrator) = n_measured_calls(integrator.callback)
n_measured_calls(cbs::ODE.CallbackSet) =
n_measured_calls(cbs::SciMLBase.CallbackSet) =
map(x -> x.n_measured_calls, atmos_callbacks(cbs))

n_expected_calls(integrator) = n_expected_calls(
integrator.callback,
integrator.dt,
integrator.sol.prob.tspan,
)
n_expected_calls(cbs::ODE.CallbackSet, dt, tspan) =
n_expected_calls(cbs::SciMLBase.CallbackSet, dt, tspan) =
map(x -> n_expected_calls(x, dt, tspan), atmos_callbacks(cbs))

n_steps_per_cycle(integrator) =
n_steps_per_cycle(integrator.callback, integrator.dt)
function n_steps_per_cycle(cbs::ODE.CallbackSet, dt)
function n_steps_per_cycle(cbs::SciMLBase.CallbackSet, dt)
nspc = n_steps_per_cycle_per_cb(cbs, dt)
return isempty(nspc) ? 1 : lcm(nspc)
end

n_steps_per_cycle_per_cb(integrator) =
n_steps_per_cycle_per_cb(integrator.callback, integrator.dt)

function n_steps_per_cycle_per_cb(cbs::ODE.CallbackSet, dt)
function n_steps_per_cycle_per_cb(cbs::SciMLBase.CallbackSet, dt)
return map(atmos_callbacks(cbs)) do cb
cbf = callback_frequency(cb)
if cbf isa EveryΔt
Expand Down
6 changes: 3 additions & 3 deletions src/callbacks/callbacks.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,9 @@ import ClimaCore.Fields
import ClimaComms
import ClimaCore as CC
import ClimaCore.Spaces
import OrdinaryDiffEq as ODE
import SciMLBase
import ClimaAtmos.Parameters as CAP
import DiffEqCallbacks as DEQ
import DiffEqCallbacks as DECB
import ClimaCore: InputOutput
import Dates
using Insolation: instantaneous_zenith_angle
Expand Down Expand Up @@ -70,7 +70,7 @@ function turb_conv_affect_filter!(integrator)
# paying for an additional `∑tendencies!` call, which is required
# to support supplying a continuous representation of the
# solution.
ODE.u_modified!(integrator, false)
SciMLBase.u_modified!(integrator, false)
end

function rrtmgp_model_callback!(integrator)
Expand Down
Loading
Loading