diff --git a/src/diagnostics/diagnostic.jl b/src/diagnostics/diagnostic.jl index cf589e37363..020b14e84ef 100644 --- a/src/diagnostics/diagnostic.jl +++ b/src/diagnostics/diagnostic.jl @@ -661,6 +661,61 @@ function accumulate!(diag_accumulator, diag_storage, reduction_time_func) return nothing end +NVTX.@annotate function compute_callback!( + integrator, + accumulators, + storage, + diag, + counters, + compute!, +) + compute!(storage, integrator.u, integrator.p, integrator.t) + + # accumulator[diag] is not defined for non-reductions + diag_accumulator = get(accumulators, diag, nothing) + + accumulate!(diag_accumulator, storage, diag.reduction_time_func) + counters[diag] += 1 + return nothing +end + +NVTX.@annotate function output_callback!( + integrator, + accumulators, + storage, + diag, + counters, + output_dir, +) + # Move accumulated value to storage so that we can output it (for + # reductions). This provides a unified interface to pre_output_hook! and + # output, at the cost of an additional copy. If this copy turns out to be + # too expensive, we can move the if statement below. + isnothing(diag.reduction_time_func) || (storage .= accumulators[diag]) + + # Any operations we have to perform before writing to output? + # Here is where we would divide by N to obtain an arithmetic average + diag.pre_output_hook!(storage, counters[diag]) + + # Write to disk + write_field!( + diag.output_writer, + storage, + diag, + integrator.u, + integrator.p, + integrator.t, + output_dir, + ) + + # accumulator[diag] is not defined for non-reductions + diag_accumulator = get(accumulators, diag, nothing) + + reset_accumulator!(diag_accumulator, diag.reduction_time_func) + counters[diag] = 0 + return nothing +end + """ get_callbacks_from_diagnostics(diagnostics, storage, counters) @@ -699,61 +754,29 @@ function get_callbacks_from_diagnostics( # diagnostics that perform reductions. callback_arrays = map(diagnostics) do diag - variable = diag.variable compute_callback = integrator -> begin - variable.compute!( + compute_callback!( + integrator, + accumulators, storage[diag], - integrator.u, - integrator.p, - integrator.t, - ) - - # accumulator[diag] is not defined for non-reductions - diag_accumulator = get(accumulators, diag, nothing) - - accumulate!( - diag_accumulator, - storage[diag], - diag.reduction_time_func, + diag, + counters, + diag.variable.compute!, ) - counters[diag] += 1 - return nothing end - output_callback = integrator -> begin - # Move accumulated value to storage so that we can output it (for - # reductions). This provides a unified interface to pre_output_hook! and - # output, at the cost of an additional copy. If this copy turns out to be - # too expensive, we can move the if statement below. - isnothing(diag.reduction_time_func) || - (storage[diag] .= accumulators[diag]) - - # Any operations we have to perform before writing to output? - # Here is where we would divide by N to obtain an arithmetic average - diag.pre_output_hook!(storage[diag], counters[diag]) - - # Write to disk - write_field!( - diag.output_writer, + output_callback!( + integrator, + accumulators, storage[diag], diag, - integrator.u, - integrator.p, - integrator.t, + counters, output_dir, ) - - # accumulator[diag] is not defined for non-reductions - diag_accumulator = get(accumulators, diag, nothing) - - reset_accumulator!(diag_accumulator, diag.reduction_time_func) - counters[diag] = 0 - return nothing end - - return [ + [ AtmosCallback(compute_callback, EveryNSteps(diag.compute_every)), AtmosCallback(output_callback, EveryNSteps(diag.output_every)), ]