Skip to content

Commit

Permalink
Add function barrier to orchestrate diagnostics
Browse files Browse the repository at this point in the history
  • Loading branch information
charleskawczynski committed Mar 22, 2024
1 parent ba14896 commit 1ce7f1c
Showing 1 changed file with 67 additions and 43 deletions.
110 changes: 67 additions & 43 deletions src/diagnostics/diagnostic.jl
Original file line number Diff line number Diff line change
Expand Up @@ -661,6 +661,62 @@ function accumulate!(diag_accumulator, diag_storage, reduction_time_func)
return nothing
end

import NVTX
NVTX.@annotate function compute_callback!(
integrator,
accumulators,
storage,
diag,
counters,
compute!,
)
compute!(storage, integrator.u, integrator.p, integrator.t)

# accumulator[diag] is not defined for non-reductions
diag_accumulator = get(accumulators, diag, nothing)

accumulate!(diag_accumulator, storage, diag.reduction_time_func)
counters[diag] += 1
return nothing
end

NVTX.@annotate function output_callback!(
integrator,
accumulators,
storage,
diag,
counters,
output_dir,
)
# Move accumulated value to storage so that we can output it (for
# reductions). This provides a unified interface to pre_output_hook! and
# output, at the cost of an additional copy. If this copy turns out to be
# too expensive, we can move the if statement below.
isnothing(diag.reduction_time_func) || (storage .= accumulators[diag])

# Any operations we have to perform before writing to output?
# Here is where we would divide by N to obtain an arithmetic average
diag.pre_output_hook!(storage, counters[diag])

# Write to disk
write_field!(
diag.output_writer,
storage,
diag,
integrator.u,
integrator.p,
integrator.t,
output_dir,
)

# accumulator[diag] is not defined for non-reductions
diag_accumulator = get(accumulators, diag, nothing)

reset_accumulator!(diag_accumulator, diag.reduction_time_func)
counters[diag] = 0
return nothing
end

"""
get_callbacks_from_diagnostics(diagnostics, storage, counters)
Expand Down Expand Up @@ -699,61 +755,29 @@ function get_callbacks_from_diagnostics(
# diagnostics that perform reductions.

callback_arrays = map(diagnostics) do diag
variable = diag.variable
compute_callback =
integrator -> begin
variable.compute!(
compute_callback!(
integrator,
accumulators,
storage[diag],
integrator.u,
integrator.p,
integrator.t,
)

# accumulator[diag] is not defined for non-reductions
diag_accumulator = get(accumulators, diag, nothing)

accumulate!(
diag_accumulator,
storage[diag],
diag.reduction_time_func,
diag,
counters,
diag.variable.compute!,
)
counters[diag] += 1
return nothing
end

output_callback =
integrator -> begin
# Move accumulated value to storage so that we can output it (for
# reductions). This provides a unified interface to pre_output_hook! and
# output, at the cost of an additional copy. If this copy turns out to be
# too expensive, we can move the if statement below.
isnothing(diag.reduction_time_func) ||
(storage[diag] .= accumulators[diag])

# Any operations we have to perform before writing to output?
# Here is where we would divide by N to obtain an arithmetic average
diag.pre_output_hook!(storage[diag], counters[diag])

# Write to disk
write_field!(
diag.output_writer,
output_callback!(
integrator,
accumulators,
storage[diag],
diag,
integrator.u,
integrator.p,
integrator.t,
counters,
output_dir,
)

# accumulator[diag] is not defined for non-reductions
diag_accumulator = get(accumulators, diag, nothing)

reset_accumulator!(diag_accumulator, diag.reduction_time_func)
counters[diag] = 0
return nothing
end

return [
[
AtmosCallback(compute_callback, EveryNSteps(diag.compute_every)),
AtmosCallback(output_callback, EveryNSteps(diag.output_every)),
]
Expand Down

0 comments on commit 1ce7f1c

Please sign in to comment.