From 34e63a0c9a2499698a1ad29d162b8057c6ceca6d Mon Sep 17 00:00:00 2001 From: Charles Kawczynski Date: Wed, 7 Feb 2024 14:54:46 -0500 Subject: [PATCH] Move diagnostics preparation into callbacks --- src/ClimaAtmos.jl | 1 + src/callbacks/callbacks.jl | 272 ++++++++++++++++++++++++ src/callbacks/get_callbacks.jl | 94 +++++++++ src/diagnostics/diagnostic.jl | 4 +- src/diagnostics/hdf5_writer.jl | 4 +- src/diagnostics/netcdf_writer.jl | 8 +- src/solver/type_getters.jl | 347 +------------------------------ 7 files changed, 386 insertions(+), 344 deletions(-) create mode 100644 src/callbacks/get_callbacks.jl diff --git a/src/ClimaAtmos.jl b/src/ClimaAtmos.jl index ff3bd638586..a98f629d8fa 100644 --- a/src/ClimaAtmos.jl +++ b/src/ClimaAtmos.jl @@ -108,6 +108,7 @@ include(joinpath("prognostic_equations", "dss.jl")) include(joinpath("prognostic_equations", "limited_tendencies.jl")) include(joinpath("callbacks", "callbacks.jl")) +include(joinpath("callbacks", "get_callbacks.jl")) include(joinpath("diagnostics", "Diagnostics.jl")) import .Diagnostics as CAD diff --git a/src/callbacks/callbacks.jl b/src/callbacks/callbacks.jl index 3963489b2ea..9b4ab25faba 100644 --- a/src/callbacks/callbacks.jl +++ b/src/callbacks/callbacks.jl @@ -347,3 +347,275 @@ function reset_graceful_exit(output_dir) ispath(output_dir) || mkpath(output_dir) open(io -> print(io, 0), file, "w") end + +function get_diagnostics(parsed_args, atmos_model, spaces) + + # We either get the diagnostics section in the YAML file, or we return an empty list + # (which will result in an empty list being created by the map below) + yaml_diagnostics = get(parsed_args, "diagnostics", []) + + # ALLOWED_REDUCTIONS is the collection of reductions we support. The keys are the + # strings that have to be provided in the YAML file. The values are tuples with the + # function that has to be passed to reduction_time_func and the one that has to passed + # to pre_output_hook! + + # We make "nothing" a string so that we can accept also the word "nothing", in addition + # to the absence of the value + # + # NOTE: Everything has to be lowercase in ALLOWED_REDUCTIONS (so that we can match + # "max" and "Max") + ALLOWED_REDUCTIONS = Dict( + "nothing" => (nothing, nothing), # nothing is: just dump the variable + "max" => (max, nothing), + "min" => (min, nothing), + "average" => ((+), CAD.average_pre_output_hook!), + ) + + hdf5_writer = CAD.HDF5Writer() + + if !isnothing(parsed_args["netcdf_interpolation_num_points"]) + num_netcdf_points = + tuple(parsed_args["netcdf_interpolation_num_points"]...) + else + # TODO: Once https://github.com/CliMA/ClimaCore.jl/pull/1567 is merged, + # dispatch over the Grid type + num_netcdf_points = (180, 90, 50) + end + + netcdf_writer = CAD.NetCDFWriter(; + spaces, + num_points = num_netcdf_points, + interpolate_z_over_msl = parsed_args["netcdf_interpolate_z_over_msl"], + disable_vertical_interpolation = parsed_args["netcdf_output_at_levels"], + ) + writers = (hdf5_writer, netcdf_writer) + + # The default writer is HDF5 + ALLOWED_WRITERS = Dict( + "nothing" => netcdf_writer, + "h5" => hdf5_writer, + "hdf5" => hdf5_writer, + "nc" => netcdf_writer, + "netcdf" => netcdf_writer, + ) + + diagnostics_ragged = map(yaml_diagnostics) do yaml_diag + short_names = yaml_diag["short_name"] + output_name = get(yaml_diag, "output_name", nothing) + + if short_names isa Vector + isnothing(output_name) || error( + "Diagnostics: cannot have multiple short_names while specifying output_name", + ) + else + short_names = [short_names] + end + + ret_value = map(short_names) do short_name + # Return "nothing" if "reduction_time" is not in the YAML block + # + # We also normalize everything to lowercase, so that can accept "max" but + # also "Max" + reduction_time_yaml = + lowercase(get(yaml_diag, "reduction_time", "nothing")) + + if !haskey(ALLOWED_REDUCTIONS, reduction_time_yaml) + error("reduction $reduction_time_yaml not implemented") + else + reduction_time_func, pre_output_hook! = + ALLOWED_REDUCTIONS[reduction_time_yaml] + end + + writer_ext = lowercase(get(yaml_diag, "writer", "nothing")) + + if !haskey(ALLOWED_WRITERS, writer_ext) + error("writer $writer_ext not implemented") + else + writer = ALLOWED_WRITERS[writer_ext] + end + + haskey(yaml_diag, "period") || + error("period keyword required for diagnostics") + + period_seconds = time_to_seconds(yaml_diag["period"]) + + if isnothing(output_name) + output_short_name = CAD.descriptive_short_name( + CAD.get_diagnostic_variable(short_name), + period_seconds, + reduction_time_func, + pre_output_hook!, + ) + end + + if isnothing(reduction_time_func) + compute_every = period_seconds + else + compute_every = :timestep + end + + return CAD.ScheduledDiagnosticTime( + variable = CAD.get_diagnostic_variable(short_name), + output_every = period_seconds, + compute_every = compute_every, + reduction_time_func = reduction_time_func, + pre_output_hook! = pre_output_hook!, + output_writer = writer, + output_short_name = output_short_name, + ) + end + return ret_value + end + + # Flatten the array of arrays of diagnostics + diagnostics = vcat(diagnostics_ragged...) + + if parsed_args["output_default_diagnostics"] + t_end = time_to_seconds(parsed_args["t_end"]) + return [ + CAD.default_diagnostics( + atmos_model, + t_end; + output_writer = netcdf_writer, + )..., + diagnostics..., + ], + writers + else + return collect(diagnostics), writers + end +end + +function get_diagnostics_cb( + Y, + p, + t_start, + config, + sim_info, + atmos, + params, + spaces, +) + (; comms_ctx, parsed_args) = config + + s = @timed_str begin + diagnostics, writers = get_diagnostics(parsed_args, atmos, spaces) + end + @info "initializing diagnostics: $s" + + if length(diagnostics) > 0 + @info "Computing diagnostics:" + else + return (; + diagnostic_callbacks = (), + diagnostics_functions = (), + writers = (), + ) + end + + for writer in writers + writer_str = nameof(typeof(writer)) + diags_with_writer = + filter((x) -> getproperty(x, :output_writer) == writer, diagnostics) + diags_outputs = [ + getproperty(diag, :output_short_name) for diag in diags_with_writer + ] + @info "$writer_str: $diags_outputs" + end + + # First, we convert all the ScheduledDiagnosticTime into ScheduledDiagnosticIteration, + # ensuring that there is consistency in the timestep and the periods and translating + # those periods that depended on the timestep + diagnostics_iterations = + [CAD.ScheduledDiagnosticIterations(d, sim_info.dt) for d in diagnostics] + + # For diagnostics that perform reductions, the storage is used for the values computed + # at each call. Reductions also save the accumulated value in diagnostic_accumulators. + diagnostic_storage = Dict() + diagnostic_accumulators = Dict() + diagnostic_counters = Dict() + + s = @timed_str begin + diagnostics_functions = CAD.get_callbacks_from_diagnostics( + diagnostics_iterations, + diagnostic_storage, + diagnostic_accumulators, + diagnostic_counters, + sim_info.output_dir, + ) + end + @info "Prepared diagnostic callbacks: $s" + + function orchestrate_diagnostics(integrator) + for d in diagnostics_functions + if d.cbf.n > 0 && integrator.step % d.cbf.n == 0 + d.f!(integrator) + end + end + end + + # It would be nice to just pass the callbacks to the integrator. However, this leads to + # a significant increase in compile time for reasons that are not known. For this + # reason, we only add one callback to the integrator, and this function takes care of + # executing the other callbacks. This single function is orchestrate_diagnostics + + diagnostic_callbacks = + (call_every_n_steps(orchestrate_diagnostics, skip_first = true),) + + s = @timed_str begin + for diag in diagnostics_iterations + variable = diag.variable + try + # The first time we call compute! we use its return value. All + # the subsequent times (in the callbacks), we will write the + # result in place + diagnostic_storage[diag] = + variable.compute!(nothing, Y, p, t_start) + diagnostic_counters[diag] = 1 + # If it is not a reduction, call the output writer as well + if isnothing(diag.reduction_time_func) + CAD.write_field!( + diag.output_writer, + diagnostic_storage[diag], + diag, + Y, + p, + t_start, + sim_info.output_dir, + ) + else + # Add to the accumulator + + # We use similar + .= instead of copy because CUDA 5.2 does + # not supported nested wrappers with view(reshape(view)) + # objects. See discussion in + # https://github.com/CliMA/ClimaAtmos.jl/pull/2579 and + # https://github.com/JuliaGPU/Adapt.jl/issues/21 + diagnostic_accumulators[diag] = + similar(diagnostic_storage[diag]) + diagnostic_accumulators[diag] .= + diagnostic_storage[diag] + end + catch e + error("Could not compute diagnostic $(variable.long_name): $e") + end + end + end + @info "Init diagnostics: $s" + + if parsed_args["warn_allocations_diagnostics"] + for diag in diagnostics_iterations + # We write over the storage space we have already prepared (and filled) before + allocs = @allocated diag.variable.compute!( + diagnostic_storage[diag], + Y, + p, + t_start, + ) + if allocs > 10 * 1024 + @warn "Diagnostics $(diag.output_short_name) allocates $allocs bytes" + end + end + end + return (; diagnostic_callbacks, diagnostics_functions, writers) +end diff --git a/src/callbacks/get_callbacks.jl b/src/callbacks/get_callbacks.jl new file mode 100644 index 00000000000..817c04f9d17 --- /dev/null +++ b/src/callbacks/get_callbacks.jl @@ -0,0 +1,94 @@ +function get_callbacks(Y, p, t_start, config, sim_info, atmos, params, spaces) + (; comms_ctx, parsed_args) = config + FT = eltype(params) + (; dt, output_dir) = sim_info + + callbacks = () + callbacks = ( + callbacks..., + call_every_n_steps( + terminate!; + skip_first = true, + condition = (u, t, integrator) -> + maybe_graceful_exit(integrator), + ), + ) + + dt_save_state_to_disk = + time_to_seconds(parsed_args["dt_save_state_to_disk"]) + if !(dt_save_state_to_disk == Inf) + callbacks = ( + callbacks..., + call_every_dt( + (integrator) -> + save_state_to_disk_func(integrator, output_dir), + dt_save_state_to_disk; + skip_first = sim_info.restart, + ), + ) + end + + if is_distributed(comms_ctx) + callbacks = ( + callbacks..., + call_every_n_steps( + gc_func, + parse(Int, get(ENV, "CLIMAATMOS_GC_NSTEPS", "1000")), + skip_first = true, + ), + ) + end + + if parsed_args["check_conservation"] + callbacks = ( + callbacks..., + call_every_n_steps( + flux_accumulation!; + skip_first = true, + call_at_end = true, + ), + ) + end + + # get_diagnostics_cb returns an empty tuple if diagnostics are empty + (; diagnostic_callbacks, diagnostics_functions, writers) = + get_diagnostics_cb( + Y, + p, + t_start, + config, + sim_info, + atmos, + params, + spaces, + ) + callbacks = (callbacks..., diagnostic_callbacks...) + + if atmos.radiation_mode isa RRTMGPI.AbstractRRTMGPMode + # TODO: better if-else criteria? + dt_rad = if parsed_args["config"] == "column" + dt + else + FT(time_to_seconds(parsed_args["dt_rad"])) + end + callbacks = + (callbacks..., call_every_dt(rrtmgp_model_callback!, dt_rad)) + end + + dt_cf = FT(time_to_seconds(parsed_args["dt_cloud_fraction"])) + callbacks = + (callbacks..., call_every_dt(cloud_fraction_model_callback!, dt_cf)) + + if parsed_args["log_progress"] && !sim_info.restart + @info "Progress logging enabled." + callbacks = ( + callbacks..., + call_every_n_steps( + (integrator) -> print_walltime_estimate(integrator); + skip_first = true, + ), + ) + end + + return (; callbacks, diagnostics_functions, writers) +end diff --git a/src/diagnostics/diagnostic.jl b/src/diagnostics/diagnostic.jl index 619241b56fd..ff3fe9601ac 100644 --- a/src/diagnostics/diagnostic.jl +++ b/src/diagnostics/diagnostic.jl @@ -733,7 +733,9 @@ function get_callbacks_from_diagnostics( diag.output_writer, storage[diag], diag, - integrator, + integrator.u, + integrator.p, + integrator.t, output_dir, ) diff --git a/src/diagnostics/hdf5_writer.jl b/src/diagnostics/hdf5_writer.jl index 2d128d8aa2b..be1136348f0 100644 --- a/src/diagnostics/hdf5_writer.jl +++ b/src/diagnostics/hdf5_writer.jl @@ -41,7 +41,9 @@ function write_field!( writer::HDF5Writer, field, diagnostic, - integrator, + u, + p, + t, output_dir, ) var = diagnostic.variable diff --git a/src/diagnostics/netcdf_writer.jl b/src/diagnostics/netcdf_writer.jl index 695ce1edd5c..ebb94f16964 100644 --- a/src/diagnostics/netcdf_writer.jl +++ b/src/diagnostics/netcdf_writer.jl @@ -559,7 +559,9 @@ function write_field!( writer::NetCDFWriter, field, diagnostic, - integrator, + u, + p, + t, output_dir, ) @@ -667,7 +669,7 @@ function write_field!( v.attrib["long_name"] = diagnostic.output_long_name v.attrib["units"] = var.units v.attrib["comments"] = var.comments - v.attrib["start_date"] = string(integrator.p.start_date) + v.attrib["start_date"] = string(p.start_date) temporal_size = 0 end @@ -676,7 +678,7 @@ function write_field!( # position ever if we are writing the file for the first time) time_index = temporal_size + 1 - nc["time"][time_index] = integrator.t + nc["time"][time_index] = t # TODO: It would be nice to find a cleaner way to do this if length(dim_names) == 3 diff --git a/src/solver/type_getters.jl b/src/solver/type_getters.jl index 3ac5e8a694f..ea9570765a5 100644 --- a/src/solver/type_getters.jl +++ b/src/solver/type_getters.jl @@ -436,85 +436,6 @@ thermo_state_type(::NonEquilMoistModel, ::Type{FT}) where {FT} = TD.PhaseNonEquil{FT} -function get_callbacks(parsed_args, sim_info, atmos, params, comms_ctx) - FT = eltype(params) - (; dt, output_dir) = sim_info - - callbacks = () - if parsed_args["log_progress"] && !sim_info.restart - @info "Progress logging enabled." - callbacks = ( - callbacks..., - call_every_n_steps( - (integrator) -> print_walltime_estimate(integrator); - skip_first = true, - ), - ) - end - callbacks = ( - callbacks..., - call_every_n_steps( - terminate!; - skip_first = true, - condition = (u, t, integrator) -> - maybe_graceful_exit(integrator), - ), - ) - - dt_save_state_to_disk = - time_to_seconds(parsed_args["dt_save_state_to_disk"]) - if !(dt_save_state_to_disk == Inf) - callbacks = ( - callbacks..., - call_every_dt( - (integrator) -> - save_state_to_disk_func(integrator, output_dir), - dt_save_state_to_disk; - skip_first = sim_info.restart, - ), - ) - end - - if is_distributed(comms_ctx) - callbacks = ( - callbacks..., - call_every_n_steps( - gc_func, - parse(Int, get(ENV, "CLIMAATMOS_GC_NSTEPS", "1000")), - skip_first = true, - ), - ) - end - - if parsed_args["check_conservation"] - callbacks = ( - callbacks..., - call_every_n_steps( - flux_accumulation!; - skip_first = true, - call_at_end = true, - ), - ) - end - - if atmos.radiation_mode isa RRTMGPI.AbstractRRTMGPMode - # TODO: better if-else criteria? - dt_rad = if parsed_args["config"] == "column" - dt - else - FT(time_to_seconds(parsed_args["dt_rad"])) - end - callbacks = - (callbacks..., call_every_dt(rrtmgp_model_callback!, dt_rad)) - end - - dt_cf = FT(time_to_seconds(parsed_args["dt_cloud_fraction"])) - callbacks = - (callbacks..., call_every_dt(cloud_fraction_model_callback!, dt_cf)) - - return callbacks -end - function get_sim_info(config::AtmosConfig) (; parsed_args) = config FT = eltype(config) @@ -548,144 +469,6 @@ function get_sim_info(config::AtmosConfig) return sim end -function get_diagnostics(parsed_args, atmos_model, spaces) - - # We either get the diagnostics section in the YAML file, or we return an empty list - # (which will result in an empty list being created by the map below) - yaml_diagnostics = get(parsed_args, "diagnostics", []) - - # ALLOWED_REDUCTIONS is the collection of reductions we support. The keys are the - # strings that have to be provided in the YAML file. The values are tuples with the - # function that has to be passed to reduction_time_func and the one that has to passed - # to pre_output_hook! - - # We make "nothing" a string so that we can accept also the word "nothing", in addition - # to the absence of the value - # - # NOTE: Everything has to be lowercase in ALLOWED_REDUCTIONS (so that we can match - # "max" and "Max") - ALLOWED_REDUCTIONS = Dict( - "nothing" => (nothing, nothing), # nothing is: just dump the variable - "max" => (max, nothing), - "min" => (min, nothing), - "average" => ((+), CAD.average_pre_output_hook!), - ) - - hdf5_writer = CAD.HDF5Writer() - - if !isnothing(parsed_args["netcdf_interpolation_num_points"]) - num_netcdf_points = - tuple(parsed_args["netcdf_interpolation_num_points"]...) - else - # TODO: Once https://github.com/CliMA/ClimaCore.jl/pull/1567 is merged, - # dispatch over the Grid type - num_netcdf_points = (180, 90, 50) - end - - netcdf_writer = CAD.NetCDFWriter(; - spaces, - num_points = num_netcdf_points, - interpolate_z_over_msl = parsed_args["netcdf_interpolate_z_over_msl"], - disable_vertical_interpolation = parsed_args["netcdf_output_at_levels"], - ) - writers = (hdf5_writer, netcdf_writer) - - # The default writer is HDF5 - ALLOWED_WRITERS = Dict( - "nothing" => netcdf_writer, - "h5" => hdf5_writer, - "hdf5" => hdf5_writer, - "nc" => netcdf_writer, - "netcdf" => netcdf_writer, - ) - - diagnostics_ragged = map(yaml_diagnostics) do yaml_diag - short_names = yaml_diag["short_name"] - output_name = get(yaml_diag, "output_name", nothing) - - if short_names isa Vector - isnothing(output_name) || error( - "Diagnostics: cannot have multiple short_names while specifying output_name", - ) - else - short_names = [short_names] - end - - ret_value = map(short_names) do short_name - # Return "nothing" if "reduction_time" is not in the YAML block - # - # We also normalize everything to lowercase, so that can accept "max" but - # also "Max" - reduction_time_yaml = - lowercase(get(yaml_diag, "reduction_time", "nothing")) - - if !haskey(ALLOWED_REDUCTIONS, reduction_time_yaml) - error("reduction $reduction_time_yaml not implemented") - else - reduction_time_func, pre_output_hook! = - ALLOWED_REDUCTIONS[reduction_time_yaml] - end - - writer_ext = lowercase(get(yaml_diag, "writer", "nothing")) - - if !haskey(ALLOWED_WRITERS, writer_ext) - error("writer $writer_ext not implemented") - else - writer = ALLOWED_WRITERS[writer_ext] - end - - haskey(yaml_diag, "period") || - error("period keyword required for diagnostics") - - period_seconds = time_to_seconds(yaml_diag["period"]) - - if isnothing(output_name) - output_short_name = CAD.descriptive_short_name( - CAD.get_diagnostic_variable(short_name), - period_seconds, - reduction_time_func, - pre_output_hook!, - ) - end - - if isnothing(reduction_time_func) - compute_every = period_seconds - else - compute_every = :timestep - end - - return CAD.ScheduledDiagnosticTime( - variable = CAD.get_diagnostic_variable(short_name), - output_every = period_seconds, - compute_every = compute_every, - reduction_time_func = reduction_time_func, - pre_output_hook! = pre_output_hook!, - output_writer = writer, - output_short_name = output_short_name, - ) - end - return ret_value - end - - # Flatten the array of arrays of diagnostics - diagnostics = vcat(diagnostics_ragged...) - - if parsed_args["output_default_diagnostics"] - t_end = time_to_seconds(parsed_args["t_end"]) - return [ - CAD.default_diagnostics( - atmos_model, - t_end; - output_writer = netcdf_writer, - )..., - diagnostics..., - ], - writers - else - return collect(diagnostics), writers - end -end - function args_integrator(parsed_args, Y, p, tspan, ode_algo, callback) (; atmos, dt) = p dt_save_to_sol = time_to_seconds(parsed_args["dt_save_to_sol"]) @@ -825,84 +608,28 @@ function get_simulation(config::AtmosConfig) @info "ode_configuration: $s" s = @timed_str begin - callback = get_callbacks( - config.parsed_args, + (; callbacks, diagnostics_functions, writers) = get_callbacks( + Y, + p, + t_start, + config, sim_info, atmos, params, - config.comms_ctx, + spaces, ) end @info "get_callbacks: $s" - # Initialize diagnostics - s = @timed_str begin - diagnostics, writers = - get_diagnostics(config.parsed_args, atmos, spaces) - end - @info "initializing diagnostics: $s" - - length(diagnostics) > 0 && @info "Computing diagnostics:" - - for writer in writers - writer_str = nameof(typeof(writer)) - diags_with_writer = - filter((x) -> getproperty(x, :output_writer) == writer, diagnostics) - diags_outputs = [ - getproperty(diag, :output_short_name) for diag in diags_with_writer - ] - @info "$writer_str: $diags_outputs" - end - - # First, we convert all the ScheduledDiagnosticTime into ScheduledDiagnosticIteration, - # ensuring that there is consistency in the timestep and the periods and translating - # those periods that depended on the timestep - diagnostics_iterations = - [CAD.ScheduledDiagnosticIterations(d, sim_info.dt) for d in diagnostics] - - # For diagnostics that perform reductions, the storage is used for the values computed - # at each call. Reductions also save the accumulated value in diagnostic_accumulators. - diagnostic_storage = Dict() - diagnostic_accumulators = Dict() - diagnostic_counters = Dict() - - s = @timed_str begin - diagnostics_functions = CAD.get_callbacks_from_diagnostics( - diagnostics_iterations, - diagnostic_storage, - diagnostic_accumulators, - diagnostic_counters, - sim_info.output_dir, - ) - end - @info "Prepared diagnostic callbacks: $s" - - # It would be nice to just pass the callbacks to the integrator. However, this leads to - # a significant increase in compile time for reasons that are not known. For this - # reason, we only add one callback to the integrator, and this function takes care of - # executing the other callbacks. This single function is orchestrate_diagnostics - - function orchestrate_diagnostics(integrator) - for d in diagnostics_functions - if d.cbf.n > 0 && integrator.step % d.cbf.n == 0 - d.f!(integrator) - end - end - end - - diagnostic_callbacks = - call_every_n_steps(orchestrate_diagnostics, skip_first = true) - # The generic constructor for SciMLBase.CallbackSet has to split callbacks into discrete # and continuous. This is not hard, but can introduce significant latency. However, all # the callbacks in ClimaAtmos are discrete_callbacks, so we directly pass this # information to the constructor continuous_callbacks = tuple() - discrete_callbacks = (callback..., diagnostic_callbacks) + # Initialize diagnostics s = @timed_str begin - all_callbacks = - SciMLBase.CallbackSet(continuous_callbacks, discrete_callbacks) + all_callbacks = SciMLBase.CallbackSet(continuous_callbacks, callbacks) end @info "Prepared SciMLBase.CallbackSet callbacks: $s" steps_cycle_non_diag = n_steps_per_cycle_per_cb(all_callbacks, sim_info.dt) @@ -931,64 +658,6 @@ function get_simulation(config::AtmosConfig) @info "init integrator: $s" reset_graceful_exit(sim_info.output_dir) - s = @timed_str begin - for diag in diagnostics_iterations - variable = diag.variable - try - # The first time we call compute! we use its return value. All - # the subsequent times (in the callbacks), we will write the - # result in place - diagnostic_storage[diag] = variable.compute!( - nothing, - integrator.u, - integrator.p, - integrator.t, - ) - diagnostic_counters[diag] = 1 - # If it is not a reduction, call the output writer as well - if isnothing(diag.reduction_time_func) - CAD.write_field!( - diag.output_writer, - diagnostic_storage[diag], - diag, - integrator, - sim_info.output_dir, - ) - else - # Add to the accumulator - - # We use similar + .= instead of copy because CUDA 5.2 does - # not supported nested wrappers with view(reshape(view)) - # objects. See discussion in - # https://github.com/CliMA/ClimaAtmos.jl/pull/2579 and - # https://github.com/JuliaGPU/Adapt.jl/issues/21 - diagnostic_accumulators[diag] = - similar(diagnostic_storage[diag]) - diagnostic_accumulators[diag] .= - diagnostic_storage[diag] - end - catch e - error("Could not compute diagnostic $(variable.long_name): $e") - end - end - end - @info "Init diagnostics: $s" - - if config.parsed_args["warn_allocations_diagnostics"] - for diag in diagnostics_iterations - # We write over the storage space we have already prepared (and filled) before - allocs = @allocated diag.variable.compute!( - diagnostic_storage[diag], - integrator.u, - integrator.p, - integrator.t, - ) - if allocs > 10 * 1024 - @warn "Diagnostics $(diag.output_short_name) allocates $allocs bytes" - end - end - end - return AtmosSimulation( sim_info.job_id, sim_info.output_dir,