Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add function barrier to orchestrate diagnostics #2834

Merged
merged 3 commits into from
Mar 23, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 7 additions & 7 deletions perf/flame.jl
Original file line number Diff line number Diff line change
Expand Up @@ -38,19 +38,19 @@ ProfileCanvas.html_file(joinpath(output_dir, "flame.html"), results)
#####

allocs_limit = Dict()
allocs_limit["flame_perf_target"] = 278_360
allocs_limit["flame_perf_target_tracers"] = 308_336
allocs_limit["flame_perf_target_edmfx"] = 7_005_552
allocs_limit["flame_perf_target"] = 51_600
allocs_limit["flame_perf_target_tracers"] = 81_576
allocs_limit["flame_perf_target_edmfx"] = 86_608
allocs_limit["flame_perf_diagnostics"] = 10_876_900
allocs_limit["flame_perf_target_diagnostic_edmfx"] = 412_056
allocs_limit["flame_perf_target_diagnostic_edmfx"] = 86_608
allocs_limit["flame_sphere_baroclinic_wave_rhoe_equilmoist_expvdiff"] =
4_018_252_656
allocs_limit["flame_perf_target_frierson"] = 4_015_547_056
allocs_limit["flame_perf_target_threaded"] = 1_276_864
allocs_limit["flame_perf_target_callbacks"] = 398_984
allocs_limit["flame_perf_target_callbacks"] = 172_032
allocs_limit["flame_perf_gw"] = 3_268_961_856
allocs_limit["flame_perf_target_prognostic_edmfx_aquaplanet"] = 299_616
allocs_limit["flame_gpu_implicit_barowave_moist"] = 252_300
allocs_limit["flame_perf_target_prognostic_edmfx_aquaplanet"] = 73_608
allocs_limit["flame_gpu_implicit_barowave_moist"] = 199_936
# Ideally, we would like to track all the allocations, but this becomes too
# expensive there is too many of them. Here, we set the default sample rate to
# 1, but lower it to a smaller value when we expect the job to produce lots of
Expand Down
119 changes: 76 additions & 43 deletions src/diagnostics/diagnostic.jl
Original file line number Diff line number Diff line change
Expand Up @@ -661,6 +661,71 @@ function accumulate!(diag_accumulator, diag_storage, reduction_time_func)
return nothing
end

import NVTX
NVTX.@annotate function compute_callback!(
integrator,
accumulators,
storage,
diag,
counters,
compute!,
)
compute!(storage, integrator.u, integrator.p, integrator.t)

# accumulator[diag] is not defined for non-reductions
diag_accumulator = get(accumulators, diag, nothing)

accumulate!(diag_accumulator, storage, diag.reduction_time_func)
counters[diag] += 1
return nothing
end

NVTX.@annotate function output_callback!(
integrator,
accumulators,
storage,
diag,
counters,
output_dir,
)
# Move accumulated value to storage so that we can output it (for
# reductions). This provides a unified interface to pre_output_hook! and
# output, at the cost of an additional copy. If this copy turns out to be
# too expensive, we can move the if statement below.
isnothing(diag.reduction_time_func) || (storage .= accumulators[diag])

# Any operations we have to perform before writing to output?
# Here is where we would divide by N to obtain an arithmetic average
diag.pre_output_hook!(storage, counters[diag])

# Write to disk
write_field!(
diag.output_writer,
storage,
diag,
integrator.u,
integrator.p,
integrator.t,
output_dir,
)

# accumulator[diag] is not defined for non-reductions
diag_accumulator = get(accumulators, diag, nothing)

reset_accumulator!(diag_accumulator, diag.reduction_time_func)
counters[diag] = 0
return nothing
end

function orchestrate_diagnostics(integrator, diagnostics_functions)
for d in diagnostics_functions
if d.cbf.n > 0 && integrator.step % d.cbf.n == 0
d.f!(integrator)
end
end
end


"""
get_callbacks_from_diagnostics(diagnostics, storage, counters)

Expand Down Expand Up @@ -699,61 +764,29 @@ function get_callbacks_from_diagnostics(
# diagnostics that perform reductions.

callback_arrays = map(diagnostics) do diag
variable = diag.variable
compute_callback =
integrator -> begin
variable.compute!(
compute_callback!(
integrator,
accumulators,
storage[diag],
integrator.u,
integrator.p,
integrator.t,
)

# accumulator[diag] is not defined for non-reductions
diag_accumulator = get(accumulators, diag, nothing)

accumulate!(
diag_accumulator,
storage[diag],
diag.reduction_time_func,
diag,
counters,
diag.variable.compute!,
)
counters[diag] += 1
return nothing
end

output_callback =
integrator -> begin
# Move accumulated value to storage so that we can output it (for
# reductions). This provides a unified interface to pre_output_hook! and
# output, at the cost of an additional copy. If this copy turns out to be
# too expensive, we can move the if statement below.
isnothing(diag.reduction_time_func) ||
(storage[diag] .= accumulators[diag])

# Any operations we have to perform before writing to output?
# Here is where we would divide by N to obtain an arithmetic average
diag.pre_output_hook!(storage[diag], counters[diag])

# Write to disk
write_field!(
diag.output_writer,
output_callback!(
integrator,
accumulators,
storage[diag],
diag,
integrator.u,
integrator.p,
integrator.t,
counters,
output_dir,
)

# accumulator[diag] is not defined for non-reductions
diag_accumulator = get(accumulators, diag, nothing)

reset_accumulator!(diag_accumulator, diag.reduction_time_func)
counters[diag] = 0
return nothing
end

return [
[
AtmosCallback(compute_callback, EveryNSteps(diag.compute_every)),
AtmosCallback(output_callback, EveryNSteps(diag.output_every)),
]
Expand Down
9 changes: 2 additions & 7 deletions src/solver/type_getters.jl
Original file line number Diff line number Diff line change
Expand Up @@ -814,13 +814,8 @@ function get_simulation(config::AtmosConfig)
# reason, we only add one callback to the integrator, and this function takes care of
# executing the other callbacks. This single function is orchestrate_diagnostics

function orchestrate_diagnostics(integrator)
for d in diagnostics_functions
if d.cbf.n > 0 && integrator.step % d.cbf.n == 0
d.f!(integrator)
end
end
end
orchestrate_diagnostics(integrator) =
CAD.orchestrate_diagnostics(integrator, diagnostics_functions)

diagnostic_callbacks =
call_every_n_steps(orchestrate_diagnostics, skip_first = true)
Expand Down
Loading