Skip to content

Commit

Permalink
Add callback that estimates walltime
Browse files Browse the repository at this point in the history
  • Loading branch information
charleskawczynski committed Dec 13, 2023
1 parent b26b78b commit b1acc78
Show file tree
Hide file tree
Showing 5 changed files with 100 additions and 2 deletions.
12 changes: 11 additions & 1 deletion src/cache/cache.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
struct AtmosCache{
FT <: AbstractFloat,
FTE,
WTE,
SD,
AM,
NUM,
Expand Down Expand Up @@ -28,6 +30,12 @@ struct AtmosCache{
"""Timestep of the simulation (in seconds). This is also used by callbacks and tendencies"""
dt::FT

"""Timestep of the simulation (in seconds). This is also used by callbacks and tendencies"""
t_end::FTE

"""Walltime estimate"""
walltime_estimate::WTE

"""Start date (used for insolation)."""
start_date::SD

Expand Down Expand Up @@ -92,7 +100,7 @@ end

# The model also depends on f_plane_coriolis_frequency(params)
# This is a constant Coriolis frequency that is only used if space is flat
function build_cache(Y, atmos, params, surface_setup, dt, start_date)
function build_cache(Y, atmos, params, surface_setup, dt, t_end, start_date)
FT = eltype(params)

ᶜcoord = Fields.local_geometry_field(Y.c).coordinates
Expand Down Expand Up @@ -183,6 +191,8 @@ function build_cache(Y, atmos, params, surface_setup, dt, start_date)

args = (
dt,
t_end,
WallTimeEstimate(),
start_date,
atmos,
numerics,
Expand Down
45 changes: 45 additions & 0 deletions src/callbacks/callbacks.jl
Original file line number Diff line number Diff line change
Expand Up @@ -449,6 +449,51 @@ function save_restart_func(integrator, output_dir)
return nothing
end

Base.@kwdef mutable struct WallTimeEstimate
n_calls::Int = 0
n_next::Int = 1
t_last::Float64 = -1
∑Δt_wall::Float64 = 0
end
import Dates
function print_walltime_estimate(integrator)
(; walltime_estimate, dt, t_end) = integrator.p
t = integrator.t
wte = walltime_estimate
if wte.n_calls > 1 # 1 to avoid call during initialization
n_steps = ceil(Int, t_end / dt)
t_wall_ave_per_step = wte.∑Δt_wall / n_steps
twps = t_wall_ave_per_step
percent_complete = round(t / t_end * 100; digits = 1)
n_steps_remaining = n_steps - wte.n_calls
erwt = twps * n_steps_remaining
etwt = twps * n_steps
spent = wte.∑Δt_wall
pwu =
x -> trunc_time(
string(
Dates.canonicalize(Dates.CompoundPeriod(x, Dates.Second)),
),
)
sypd =
simulated_years_per_day(EfficiencyStats((zero(t), t), wte.∑Δt_wall))
if wte.n_calls == wte.n_next
@info "Time estimates" per_step = pwu(twps) total = pwu(etwt) remaining =
pwu(erwt) spent = pwu(spent) percent_complete = percent_complete sypd =
sypd
wte.n_next *= 2 # doubling factor (to reduce log noise)
end
Δt = time() - wte.t_last
else
Δt = Float64(0)
wte.n_next = wte.n_calls + 1
end
wte.∑Δt_wall += Δt
wte.n_calls += 1
wte.t_last = time()
return nothing
end

function gc_func(integrator)
full = true # whether to do a full GC
num_pre = Base.gc_num()
Expand Down
8 changes: 7 additions & 1 deletion src/solver/type_getters.jl
Original file line number Diff line number Diff line change
Expand Up @@ -455,7 +455,12 @@ function get_callbacks(parsed_args, sim_info, atmos, params, comms_ctx)
FT = eltype(params)
(; dt, output_dir) = sim_info

callbacks = ()
callbacks = (
call_every_n_steps(
(integrator) -> print_walltime_estimate(integrator);
skip_first = true,
),
)
dt_save_to_disk = time_to_seconds(parsed_args["dt_save_to_disk"])
if !(dt_save_to_disk == Inf)
callbacks = (
Expand Down Expand Up @@ -775,6 +780,7 @@ function get_simulation(config::AtmosConfig)
params,
surface_setup,
sim_info.dt,
sim_info.t_end,
sim_info.start_date,
)
end
Expand Down
36 changes: 36 additions & 0 deletions src/utils/utilities.jl
Original file line number Diff line number Diff line change
Expand Up @@ -210,6 +210,42 @@ function prettytime(t)
return @sprintf("%.3f %s", value, units)
end

import Dates
using Dates

time_per_time(::Type{P}, ::Type{P}) where {P} = 1
function define_time_per_times(periods)
for i in eachindex(periods)
T, n = periods[i]
N = Int64(1)
for j in (i - 1):-1:firstindex(periods) # less-precise periods
Tc, nc = periods[j]
N *= nc
@eval time_per_time(::Type{$T}, ::Type{$Tc}) = $N
end
end
end

# # From Dates
define_time_per_times([
(:Week, 7),
(:Day, 24),
(:Hour, 60),
(:Minute, 60),
(:Second, 1000),
(:Millisecond, 1000),
(:Microsecond, 1000),
(:Nanosecond, 1),
])

function Dates.CompoundPeriod(x::Real, ::Type{T}) where {T <: Period}
nf = time_per_time(Nanosecond, T)
return Dates.canonicalize(Dates.CompoundPeriod(Nanosecond(ceil(x * nf))))
end

trunc_time(s::String) = count(',', s) > 1 ? join(split(s, ",")[1:2], ",") : s


function prettymemory(b)
if b < 1024
return string(b, " bytes")
Expand Down
1 change: 1 addition & 0 deletions test/coupler_compatibility.jl
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ const T2 = 290
@. sfc_setup = (surface_state,)
p_overwritten = CA.AtmosCache(
p.dt,
WallTimeEstimate(),
p.start_date,
p.atmos,
p.numerics,
Expand Down

0 comments on commit b1acc78

Please sign in to comment.