Skip to content

Commit

Permalink
Add callback that estimates walltime
Browse files Browse the repository at this point in the history
Increase allocation limits
  • Loading branch information
charleskawczynski committed Dec 14, 2023
1 parent 06a3da7 commit e570207
Show file tree
Hide file tree
Showing 5 changed files with 118 additions and 2 deletions.
12 changes: 11 additions & 1 deletion src/cache/cache.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
struct AtmosCache{
FT <: AbstractFloat,
FTE,
WTE,
SD,
AM,
NUM,
Expand Down Expand Up @@ -28,6 +30,12 @@ struct AtmosCache{
"""Timestep of the simulation (in seconds). This is also used by callbacks and tendencies"""
dt::FT

"""Timestep of the simulation (in seconds). This is also used by callbacks and tendencies"""
t_end::FTE

"""Walltime estimate"""
walltime_estimate::WTE

"""Start date (used for insolation)."""
start_date::SD

Expand Down Expand Up @@ -93,7 +101,7 @@ end

# The model also depends on f_plane_coriolis_frequency(params)
# This is a constant Coriolis frequency that is only used if space is flat
function build_cache(Y, atmos, params, surface_setup, dt, start_date)
function build_cache(Y, atmos, params, surface_setup, dt, t_end, start_date)
FT = eltype(params)

ᶜcoord = Fields.local_geometry_field(Y.c).coordinates
Expand Down Expand Up @@ -184,6 +192,8 @@ function build_cache(Y, atmos, params, surface_setup, dt, start_date)

args = (
dt,
t_end,
WallTimeEstimate(),
start_date,
atmos,
numerics,
Expand Down
57 changes: 57 additions & 0 deletions src/callbacks/callbacks.jl
Original file line number Diff line number Diff line change
Expand Up @@ -434,6 +434,63 @@ function save_restart_func(integrator, output_dir)
return nothing
end

Base.@kwdef mutable struct WallTimeEstimate
n_calls::Int = 0
n_next::Int = 1
t_last::Float64 = -1
∑Δt_wall::Float64 = 0
end
import Dates
function print_walltime_estimate(integrator)
(; walltime_estimate, dt, t_end) = integrator.p
wte = walltime_estimate
# 1st call initializes `t_last` (during integrator init)
# 2nd call computes `Δt` for the first time (but also includes compilation), so we skip
cond = wte.n_calls > 1
if cond
# We need to account for skipping cost of `Δt` when `n_calls == 1`:
factor = wte.n_calls == 2 ? 2 : 1
Δt = factor * (time() - wte.t_last)
else
Δt = Float64(0)
wte.n_next = wte.n_calls + 1
end
wte.∑Δt_wall += Δt
wte.t_last = time()

if wte.n_calls == wte.n_next && cond
t = integrator.t
n_steps_total = ceil(Int, t_end / dt)
n_steps = ceil(Int, t / dt)
t_wall_ave_per_step = wte.∑Δt_wall / n_steps
twps = t_wall_ave_per_step
_twps = time_and_units(t_wall_ave_per_step)
percent_complete = round(t / t_end * 100; digits = 1)
n_steps_remaining = n_steps_total - n_steps
erwt = twps * n_steps_remaining
_erwt = time_and_units(erwt)
etwt = time_and_units(twps * n_steps_total)
spent = time_and_units(wte.∑Δt_wall)
t_sim = time_and_units(Float64(t))
sypd = round(
simulated_years_per_day(
EfficiencyStats((zero(t), t), wte.∑Δt_wall),
);
digits = 3,
)
estimated_finish_date =
Dates.now() + compound_period(erwt, Dates.Second)
@info "Progress" simulation_time = t_sim n_steps_completed = n_steps wall_time_per_step =
_twps wall_time_total = etwt wall_time_remaining = erwt wall_time_spent =
spent percent_complete = "$percent_complete%" sypd = sypd date_now =
Dates.now() estimated_finish_date = estimated_finish_date
wte.n_next *= 2 # doubling factor (to reduce log noise)
end
wte.n_calls += 1

return nothing
end

function gc_func(integrator)
num_pre = Base.gc_num()
alloc_since_last = (num_pre.allocd + num_pre.deferred_alloc) / 2^20
Expand Down
8 changes: 7 additions & 1 deletion src/solver/type_getters.jl
Original file line number Diff line number Diff line change
Expand Up @@ -455,7 +455,12 @@ function get_callbacks(parsed_args, sim_info, atmos, params, comms_ctx)
FT = eltype(params)
(; dt, output_dir) = sim_info

callbacks = ()
callbacks = (
call_every_n_steps(
(integrator) -> print_walltime_estimate(integrator);
skip_first = true,
),
)
dt_save_to_disk = time_to_seconds(parsed_args["dt_save_to_disk"])
if !(dt_save_to_disk == Inf)
callbacks = (
Expand Down Expand Up @@ -779,6 +784,7 @@ function get_simulation(config::AtmosConfig)
params,
surface_setup,
sim_info.dt,
sim_info.t_end,
sim_info.start_date,
)
end
Expand Down
42 changes: 42 additions & 0 deletions src/utils/utilities.jl
Original file line number Diff line number Diff line change
Expand Up @@ -210,6 +210,48 @@ function prettytime(t)
return @sprintf("%.3f %s", value, units)
end

import Dates
using Dates

time_per_time(::Type{P}, ::Type{P}) where {P} = 1
function define_time_per_times(periods)
for i in eachindex(periods)
T, n = periods[i]
N = Int64(1)
for j in (i - 1):-1:firstindex(periods) # less-precise periods
Tc, nc = periods[j]
N *= nc
@eval time_per_time(::Type{$T}, ::Type{$Tc}) = $N
end
end
end

# # From Dates
define_time_per_times([
(:Week, 7),
(:Day, 24),
(:Hour, 60),
(:Minute, 60),
(:Second, 1000),
(:Millisecond, 1000),
(:Microsecond, 1000),
(:Nanosecond, 1),
])

function time_and_units(x)::String
return trunc_time(
string(Dates.canonicalize(compound_period(x, Dates.Second))),
)
end

function compound_period(x::Real, ::Type{T}) where {T <: Period}
nf = time_per_time(Nanosecond, T)
return Dates.canonicalize(Dates.CompoundPeriod(Nanosecond(ceil(x * nf))))
end

trunc_time(s::String) = count(',', s) > 1 ? join(split(s, ",")[1:2], ",") : s


function prettymemory(b)
if b < 1024
return string(b, " bytes")
Expand Down
1 change: 1 addition & 0 deletions test/coupler_compatibility.jl
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ const T2 = 290
@. sfc_setup = (surface_state,)
p_overwritten = CA.AtmosCache(
p.dt,
WallTimeEstimate(),
p.start_date,
p.atmos,
p.numerics,
Expand Down

0 comments on commit e570207

Please sign in to comment.