Skip to content

Commit

Permalink
Move diagnostics preparation into callbacks
Browse files Browse the repository at this point in the history
  • Loading branch information
charleskawczynski committed Feb 7, 2024
1 parent cd1ca39 commit 34e63a0
Show file tree
Hide file tree
Showing 7 changed files with 386 additions and 344 deletions.
1 change: 1 addition & 0 deletions src/ClimaAtmos.jl
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,7 @@ include(joinpath("prognostic_equations", "dss.jl"))
include(joinpath("prognostic_equations", "limited_tendencies.jl"))

include(joinpath("callbacks", "callbacks.jl"))
include(joinpath("callbacks", "get_callbacks.jl"))

include(joinpath("diagnostics", "Diagnostics.jl"))
import .Diagnostics as CAD
Expand Down
272 changes: 272 additions & 0 deletions src/callbacks/callbacks.jl
Original file line number Diff line number Diff line change
Expand Up @@ -347,3 +347,275 @@ function reset_graceful_exit(output_dir)
ispath(output_dir) || mkpath(output_dir)
open(io -> print(io, 0), file, "w")
end

function get_diagnostics(parsed_args, atmos_model, spaces)

# We either get the diagnostics section in the YAML file, or we return an empty list
# (which will result in an empty list being created by the map below)
yaml_diagnostics = get(parsed_args, "diagnostics", [])

# ALLOWED_REDUCTIONS is the collection of reductions we support. The keys are the
# strings that have to be provided in the YAML file. The values are tuples with the
# function that has to be passed to reduction_time_func and the one that has to passed
# to pre_output_hook!

# We make "nothing" a string so that we can accept also the word "nothing", in addition
# to the absence of the value
#
# NOTE: Everything has to be lowercase in ALLOWED_REDUCTIONS (so that we can match
# "max" and "Max")
ALLOWED_REDUCTIONS = Dict(
"nothing" => (nothing, nothing), # nothing is: just dump the variable
"max" => (max, nothing),
"min" => (min, nothing),
"average" => ((+), CAD.average_pre_output_hook!),
)

hdf5_writer = CAD.HDF5Writer()

if !isnothing(parsed_args["netcdf_interpolation_num_points"])
num_netcdf_points =
tuple(parsed_args["netcdf_interpolation_num_points"]...)
else
# TODO: Once https://github.com/CliMA/ClimaCore.jl/pull/1567 is merged,
# dispatch over the Grid type
num_netcdf_points = (180, 90, 50)
end

netcdf_writer = CAD.NetCDFWriter(;
spaces,
num_points = num_netcdf_points,
interpolate_z_over_msl = parsed_args["netcdf_interpolate_z_over_msl"],
disable_vertical_interpolation = parsed_args["netcdf_output_at_levels"],
)
writers = (hdf5_writer, netcdf_writer)

# The default writer is HDF5
ALLOWED_WRITERS = Dict(
"nothing" => netcdf_writer,
"h5" => hdf5_writer,
"hdf5" => hdf5_writer,
"nc" => netcdf_writer,
"netcdf" => netcdf_writer,
)

diagnostics_ragged = map(yaml_diagnostics) do yaml_diag
short_names = yaml_diag["short_name"]
output_name = get(yaml_diag, "output_name", nothing)

if short_names isa Vector
isnothing(output_name) || error(
"Diagnostics: cannot have multiple short_names while specifying output_name",
)
else
short_names = [short_names]
end

ret_value = map(short_names) do short_name
# Return "nothing" if "reduction_time" is not in the YAML block
#
# We also normalize everything to lowercase, so that can accept "max" but
# also "Max"
reduction_time_yaml =
lowercase(get(yaml_diag, "reduction_time", "nothing"))

if !haskey(ALLOWED_REDUCTIONS, reduction_time_yaml)
error("reduction $reduction_time_yaml not implemented")
else
reduction_time_func, pre_output_hook! =
ALLOWED_REDUCTIONS[reduction_time_yaml]
end

writer_ext = lowercase(get(yaml_diag, "writer", "nothing"))

if !haskey(ALLOWED_WRITERS, writer_ext)
error("writer $writer_ext not implemented")
else
writer = ALLOWED_WRITERS[writer_ext]
end

haskey(yaml_diag, "period") ||
error("period keyword required for diagnostics")

period_seconds = time_to_seconds(yaml_diag["period"])

if isnothing(output_name)
output_short_name = CAD.descriptive_short_name(
CAD.get_diagnostic_variable(short_name),
period_seconds,
reduction_time_func,
pre_output_hook!,
)
end

if isnothing(reduction_time_func)
compute_every = period_seconds
else
compute_every = :timestep
end

return CAD.ScheduledDiagnosticTime(
variable = CAD.get_diagnostic_variable(short_name),
output_every = period_seconds,
compute_every = compute_every,
reduction_time_func = reduction_time_func,
pre_output_hook! = pre_output_hook!,
output_writer = writer,
output_short_name = output_short_name,
)
end
return ret_value
end

# Flatten the array of arrays of diagnostics
diagnostics = vcat(diagnostics_ragged...)

if parsed_args["output_default_diagnostics"]
t_end = time_to_seconds(parsed_args["t_end"])
return [
CAD.default_diagnostics(
atmos_model,
t_end;
output_writer = netcdf_writer,
)...,
diagnostics...,
],
writers
else
return collect(diagnostics), writers
end
end

function get_diagnostics_cb(
Y,
p,
t_start,
config,
sim_info,
atmos,
params,
spaces,
)
(; comms_ctx, parsed_args) = config

s = @timed_str begin
diagnostics, writers = get_diagnostics(parsed_args, atmos, spaces)
end
@info "initializing diagnostics: $s"

if length(diagnostics) > 0
@info "Computing diagnostics:"
else
return (;
diagnostic_callbacks = (),
diagnostics_functions = (),
writers = (),
)
end

for writer in writers
writer_str = nameof(typeof(writer))
diags_with_writer =
filter((x) -> getproperty(x, :output_writer) == writer, diagnostics)
diags_outputs = [
getproperty(diag, :output_short_name) for diag in diags_with_writer
]
@info "$writer_str: $diags_outputs"
end

# First, we convert all the ScheduledDiagnosticTime into ScheduledDiagnosticIteration,
# ensuring that there is consistency in the timestep and the periods and translating
# those periods that depended on the timestep
diagnostics_iterations =
[CAD.ScheduledDiagnosticIterations(d, sim_info.dt) for d in diagnostics]

# For diagnostics that perform reductions, the storage is used for the values computed
# at each call. Reductions also save the accumulated value in diagnostic_accumulators.
diagnostic_storage = Dict()
diagnostic_accumulators = Dict()
diagnostic_counters = Dict()

s = @timed_str begin
diagnostics_functions = CAD.get_callbacks_from_diagnostics(
diagnostics_iterations,
diagnostic_storage,
diagnostic_accumulators,
diagnostic_counters,
sim_info.output_dir,
)
end
@info "Prepared diagnostic callbacks: $s"

function orchestrate_diagnostics(integrator)
for d in diagnostics_functions
if d.cbf.n > 0 && integrator.step % d.cbf.n == 0
d.f!(integrator)
end
end
end

# It would be nice to just pass the callbacks to the integrator. However, this leads to
# a significant increase in compile time for reasons that are not known. For this
# reason, we only add one callback to the integrator, and this function takes care of
# executing the other callbacks. This single function is orchestrate_diagnostics

diagnostic_callbacks =
(call_every_n_steps(orchestrate_diagnostics, skip_first = true),)

s = @timed_str begin
for diag in diagnostics_iterations
variable = diag.variable
try
# The first time we call compute! we use its return value. All
# the subsequent times (in the callbacks), we will write the
# result in place
diagnostic_storage[diag] =
variable.compute!(nothing, Y, p, t_start)
diagnostic_counters[diag] = 1
# If it is not a reduction, call the output writer as well
if isnothing(diag.reduction_time_func)
CAD.write_field!(
diag.output_writer,
diagnostic_storage[diag],
diag,
Y,
p,
t_start,
sim_info.output_dir,
)
else
# Add to the accumulator

# We use similar + .= instead of copy because CUDA 5.2 does
# not supported nested wrappers with view(reshape(view))
# objects. See discussion in
# https://github.com/CliMA/ClimaAtmos.jl/pull/2579 and
# https://github.com/JuliaGPU/Adapt.jl/issues/21
diagnostic_accumulators[diag] =
similar(diagnostic_storage[diag])
diagnostic_accumulators[diag] .=
diagnostic_storage[diag]
end
catch e
error("Could not compute diagnostic $(variable.long_name): $e")
end
end
end
@info "Init diagnostics: $s"

if parsed_args["warn_allocations_diagnostics"]
for diag in diagnostics_iterations
# We write over the storage space we have already prepared (and filled) before
allocs = @allocated diag.variable.compute!(
diagnostic_storage[diag],
Y,
p,
t_start,
)
if allocs > 10 * 1024
@warn "Diagnostics $(diag.output_short_name) allocates $allocs bytes"
end
end
end
return (; diagnostic_callbacks, diagnostics_functions, writers)
end
94 changes: 94 additions & 0 deletions src/callbacks/get_callbacks.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
function get_callbacks(Y, p, t_start, config, sim_info, atmos, params, spaces)
(; comms_ctx, parsed_args) = config
FT = eltype(params)
(; dt, output_dir) = sim_info

callbacks = ()
callbacks = (
callbacks...,
call_every_n_steps(
terminate!;
skip_first = true,
condition = (u, t, integrator) ->
maybe_graceful_exit(integrator),
),
)

dt_save_state_to_disk =
time_to_seconds(parsed_args["dt_save_state_to_disk"])
if !(dt_save_state_to_disk == Inf)
callbacks = (
callbacks...,
call_every_dt(
(integrator) ->
save_state_to_disk_func(integrator, output_dir),
dt_save_state_to_disk;
skip_first = sim_info.restart,
),
)
end

if is_distributed(comms_ctx)
callbacks = (
callbacks...,
call_every_n_steps(
gc_func,
parse(Int, get(ENV, "CLIMAATMOS_GC_NSTEPS", "1000")),
skip_first = true,
),
)
end

if parsed_args["check_conservation"]
callbacks = (
callbacks...,
call_every_n_steps(
flux_accumulation!;
skip_first = true,
call_at_end = true,
),
)
end

# get_diagnostics_cb returns an empty tuple if diagnostics are empty
(; diagnostic_callbacks, diagnostics_functions, writers) =
get_diagnostics_cb(
Y,
p,
t_start,
config,
sim_info,
atmos,
params,
spaces,
)
callbacks = (callbacks..., diagnostic_callbacks...)

if atmos.radiation_mode isa RRTMGPI.AbstractRRTMGPMode
# TODO: better if-else criteria?
dt_rad = if parsed_args["config"] == "column"
dt
else
FT(time_to_seconds(parsed_args["dt_rad"]))
end
callbacks =
(callbacks..., call_every_dt(rrtmgp_model_callback!, dt_rad))
end

dt_cf = FT(time_to_seconds(parsed_args["dt_cloud_fraction"]))
callbacks =
(callbacks..., call_every_dt(cloud_fraction_model_callback!, dt_cf))

if parsed_args["log_progress"] && !sim_info.restart
@info "Progress logging enabled."
callbacks = (
callbacks...,
call_every_n_steps(
(integrator) -> print_walltime_estimate(integrator);
skip_first = true,
),
)
end

return (; callbacks, diagnostics_functions, writers)
end
4 changes: 3 additions & 1 deletion src/diagnostics/diagnostic.jl
Original file line number Diff line number Diff line change
Expand Up @@ -733,7 +733,9 @@ function get_callbacks_from_diagnostics(
diag.output_writer,
storage[diag],
diag,
integrator,
integrator.u,
integrator.p,
integrator.t,
output_dir,
)

Expand Down
Loading

0 comments on commit 34e63a0

Please sign in to comment.