Skip to content

Commit

Permalink
Remove WTE from cache
Browse files Browse the repository at this point in the history
This commit removes the `WallTimeEstimate` from the cache and moves it
to an isolated place. In the process, I refactored the struct to split
reporting with updating, so that reporting can be done with any
frequency/schedule desired using the same Schedule infrastructure used
by other functions/diagnostics as well.
  • Loading branch information
Sbozzolo committed Dec 4, 2024
1 parent 33a2af5 commit 6d30ca6
Show file tree
Hide file tree
Showing 6 changed files with 58 additions and 114 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ ClimaCore = "0.14.12"
ClimaDiagnostics = "0.2.4"
ClimaParams = "0.10.12"
ClimaTimeSteppers = "0.7.33"
ClimaUtilities = "0.1.14"
ClimaUtilities = "0.1.20"
CloudMicrophysics = "0.22.3"
Dates = "1"
DiffEqBase = "6.145"
Expand Down
12 changes: 1 addition & 11 deletions src/cache/cache.jl
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
struct AtmosCache{
FT <: AbstractFloat,
FTE,
WTE,
SD,
AM,
NUM,
Expand Down Expand Up @@ -30,12 +28,6 @@ struct AtmosCache{
"""Timestep of the simulation (in seconds). This is also used by callbacks and tendencies"""
dt::FT

"""End time of the simulation (in seconds). This used by callbacks"""
t_end::FTE

"""Walltime estimate"""
walltime_estimate::WTE

"""Start date (used for insolation and for data files)."""
start_date::SD

Expand Down Expand Up @@ -106,7 +98,7 @@ end
# The model also depends on f_plane_coriolis_frequency(params)
# This is a constant Coriolis frequency that is only used if space is flat
function build_cache(Y, atmos, params, surface_setup, sim_info, aerosol_names)
(; dt, t_end, start_date, output_dir) = sim_info
(; dt, start_date, output_dir) = sim_info
FT = eltype(params)

ᶜcoord = Fields.local_geometry_field(Y.c).coordinates
Expand Down Expand Up @@ -188,8 +180,6 @@ function build_cache(Y, atmos, params, surface_setup, sim_info, aerosol_names)

args = (
dt,
t_end,
WallTimeEstimate(),
start_date,
atmos,
numerics,
Expand Down
44 changes: 44 additions & 0 deletions src/callbacks/callback_helpers.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
import SciMLBase

import ClimaDiagnostics.Schedules: AbstractSchedule

#####
##### Callback helpers
#####
Expand Down Expand Up @@ -109,3 +112,44 @@ end

n_steps_per_cycle_per_cb_diagnostic(cbs) =
[callback_frequency(cb).n for cb in cbs if callback_frequency(cb).n > 0]

import ClimaDiagnostics.Schedules: AbstractSchedule

import Dates

"""
CappedGeometricSeriesSchedule(max_steps)
True every 2^N iterations or every `max_steps`.
This is useful to have an exponential ramp up of something that saturates to a constant
frequency. (For instance, reporting something more frequently at the beginning of the
simulation, and less frequency later)
"""
struct CappedGeometricSeriesSchedule <: AbstractSchedule
"""GeometricSeriesSchedule(integrator) is true every 2^N iterations or every max_steps"""
max_steps::Int
"""Last step that this returned true"""
step_last::Base.RefValue{Int}

function CappedGeometricSeriesSchedule(max_steps; step_last = Ref(0))
return new(max_steps, step_last)
end
end

"""
CappedGeometricSeriesSchedule(integrator)
Returns true if `integrator.step >= last_step + max_steps`, or when `integrator.step` is a
power of 2. `last_step` is the last step this function was true and `max_step` is maximum
allowed interval as defined in the schedule.
"""
function (schedule::CappedGeometricSeriesSchedule)(integrator)::Bool
if isinteger(log2(integrator.step)) ||
integrator.step > schedule.step_last[] + schedule.max_steps
schedule.step_last[] = integrator.step
return true
else
return false
end
end
93 changes: 1 addition & 92 deletions src/callbacks/callbacks.jl
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ using Insolation: instantaneous_zenith_angle
import ClimaCore.Fields: ColumnField

import ClimaUtilities.TimeVaryingInputs: evaluate!
import ClimaUtilities.OnlineLogging: WallTimeInfo, report_walltime

include("callback_helpers.jl")

Expand Down Expand Up @@ -363,98 +364,6 @@ NVTX.@annotate function save_state_to_disk_func(integrator, output_dir)
return nothing
end

Base.@kwdef mutable struct WallTimeEstimate
"""Number of calls to the callback"""
n_calls::Int = 0
"""Int indicating next time the callback will print to the log"""
n_next::Int = 1
"""Wall time of previous call to update `WallTimeEstimate`"""
t_wall_last::Float64 = -1
"""Sum of elapsed walltime over calls to `step!`"""
∑Δt_wall::Float64 = 0
"""Fixed increment to increase n_next by after 5% completion"""
n_fixed_increment::Float64 = -1
end
import Dates
function print_walltime_estimate(integrator)
(; walltime_estimate, dt, t_end) = integrator.p
t_start = integrator.sol.prob.tspan[1]
wte = walltime_estimate

# Notes on `ready_to_report`
# - The very first call (when `n_calls == 0`), there's no elapsed
# times to report (and this is called during initialization,
# before `step!` has been called).
# - The second call (`n_calls == 1`) is after `step!` is called
# for the first time, but we don't want to report this since it
# includes compilation time.
# - Calls after that (`n_calls > 1`) exclude compilation and provide
# the best wall time estimates

ready_to_report = wte.n_calls > 1
if ready_to_report
# We need to account for skipping cost of `Δt_wall` when `n_calls == 1`:
factor = wte.n_calls == 2 ? 2 : 1
Δt_wall = factor * (time() - wte.t_wall_last)
else
wte.n_calls == 1 && @info "Progress: Completed first step"
Δt_wall = Float64(0)
wte.n_next = wte.n_calls + 1
end
wte.∑Δt_wall += Δt_wall
wte.t_wall_last = time()

if wte.n_calls == wte.n_next && ready_to_report
t = integrator.t
n_steps_total = ceil(Int, (t_end - t_start) / dt)
n_steps = ceil(Int, (t - t_start) / dt)
wall_time_ave_per_step = wte.∑Δt_wall / n_steps
wall_time_ave_per_step_str = time_and_units_str(wall_time_ave_per_step)
percent_complete = round((t - t_start) / t_end * 100; digits = 1)
n_steps_remaining = n_steps_total - n_steps
wall_time_remaining = wall_time_ave_per_step * n_steps_remaining
wall_time_remaining_str = time_and_units_str(wall_time_remaining)
wall_time_total =
time_and_units_str(wall_time_ave_per_step * n_steps_total)
wall_time_spent = time_and_units_str(wte.∑Δt_wall)
simulation_time = time_and_units_str(Float64(t))
es = EfficiencyStats((t_start, t), wte.∑Δt_wall)
_sypd = simulated_years_per_day(es)
_sypd_str = string(round(_sypd; digits = 3))
sypd = _sypd_str * if _sypd < 0.01
sdpd = round(_sypd * 365, digits = 3)
" (sdpd = $sdpd)"
else
""
end
estimated_finish_date =
Dates.now() + compound_period(wall_time_remaining, Dates.Second)
@info "Progress" simulation_time = simulation_time n_steps_completed =
n_steps wall_time_per_step = wall_time_ave_per_step_str wall_time_total =
wall_time_total wall_time_remaining = wall_time_remaining_str wall_time_spent =
wall_time_spent percent_complete = "$percent_complete%" sypd = sypd date_now =
Dates.now() estimated_finish_date = estimated_finish_date

# the first fixed increment is equivalent to
# doubling (which puts us at 10%), so we check
# if we're below 5%.
if percent_complete < 5
# doubling factor (to reduce log noise)
wte.n_next *= 2
else
if wte.n_fixed_increment == -1
wte.n_fixed_increment = wte.n_next
end
# increase by fixed increment after 10%
# completion to maintain logs after 50%.
wte.n_next += wte.n_fixed_increment
end
end
wte.n_calls += 1

return nothing
end

function gc_func(integrator)
num_pre = Base.gc_num()
alloc_since_last = (num_pre.allocd + num_pre.deferred_alloc) / 2^20
Expand Down
19 changes: 11 additions & 8 deletions src/callbacks/get_callbacks.jl
Original file line number Diff line number Diff line change
Expand Up @@ -218,14 +218,17 @@ function get_callbacks(config, sim_info, atmos, params, Y, p, t_start)

callbacks = ()
if parsed_args["log_progress"]
@info "Progress logging enabled."
callbacks = (
callbacks...,
call_every_n_steps(
(integrator) -> print_walltime_estimate(integrator);
skip_first = true,
),
)
@info "Progress logging enabled"
walltime_info = WallTimeInfo()
tot_steps = ceil(Int, (sim_info.t_end - t_start) / dt)
five_percent_steps = ceil(Int, 0.05 * tot_steps)
cond = let schedule = CappedGeometricSeriesSchedule(five_percent_steps)
(u, t, integrator) -> schedule(integrator)
end
affect! = let wt = walltime_info
(integrator) -> report_walltime(wt, integrator)
end
callbacks = (callbacks..., SciMLBase.DiscreteCallback(cond, affect!))
end
check_nan_every = parsed_args["check_nan_every"]
if check_nan_every > 0
Expand Down
2 changes: 0 additions & 2 deletions test/coupler_compatibility.jl
Original file line number Diff line number Diff line change
Expand Up @@ -66,8 +66,6 @@ const T2 = 290
@. sfc_setup = (surface_state,)
p_overwritten = CA.AtmosCache(
p.dt,
simulation.t_end,
CA.WallTimeEstimate(),
p.start_date,
p.atmos,
p.numerics,
Expand Down

0 comments on commit 6d30ca6

Please sign in to comment.