Skip to content

Commit

Permalink
Merge pull request #2814 from CliMA/ck/fix_prog_edmf_gpu
Browse files Browse the repository at this point in the history
Fix an inference failure in gpu prognostic edmf
  • Loading branch information
charleskawczynski authored Mar 21, 2024
2 parents e7074c0 + 515b598 commit 39aac5f
Show file tree
Hide file tree
Showing 3 changed files with 32 additions and 9 deletions.
2 changes: 1 addition & 1 deletion perf/flame.jl
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ allocs_limit["flame_perf_target_threaded"] = 1_276_864
allocs_limit["flame_perf_target_callbacks"] = 398_984
allocs_limit["flame_perf_gw"] = 3_268_961_856
allocs_limit["flame_perf_target_prognostic_edmfx_aquaplanet"] = 299_616
allocs_limit["flame_gpu_implicit_barowave_moist"] = 381_968
allocs_limit["flame_gpu_implicit_barowave_moist"] = 658_664
# Ideally, we would like to track all the allocations, but this becomes too
# expensive there is too many of them. Here, we set the default sample rate to
# 1, but lower it to a smaller value when we expect the job to produce lots of
Expand Down
3 changes: 2 additions & 1 deletion src/cache/diagnostic_edmf_precomputed_quantities.jl
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,8 @@ function set_diagnostic_edmfx_env_quantities_level!(
turbconv_model,
)
@. u³⁰_halflevel = divide_by_ρa(
ρ_level * u³_halflevel - mapreduce(*, +, ρaʲs_level, u³ʲs_halflevel),
ρ_level * u³_halflevel -
mapreduce_with_init(*, +, ρaʲs_level, u³ʲs_halflevel),
ρ_level,
ρ_level * u³_halflevel,
ρ_level,
Expand Down
36 changes: 29 additions & 7 deletions src/utils/variable_manipulations.jl
Original file line number Diff line number Diff line change
Expand Up @@ -185,39 +185,39 @@ end
Computes the total mass-flux subdomain area-weighted density, assuming that the
mass-flux subdomain states are stored in `gs.sgsʲs`.
"""
ρa⁺(gs) = mapreduce(sgsʲ -> sgsʲ.ρa, +, gs.sgsʲs)
ρa⁺(gs) = mapreduce_with_init(sgsʲ -> sgsʲ.ρa, +, gs.sgsʲs)

"""
ρah_tot⁺(sgsʲs)
Computes the total mass-flux subdomain area-weighted ρh_tot, assuming that the
mass-flux subdomain states are stored in `sgsʲs`.
"""
ρah_tot⁺(sgsʲs) = mapreduce(sgsʲ -> sgsʲ.ρa * sgsʲ.h_tot, +, sgsʲs)
ρah_tot⁺(sgsʲs) = mapreduce_with_init(sgsʲ -> sgsʲ.ρa * sgsʲ.h_tot, +, sgsʲs)

"""
ρamse⁺(sgsʲs)
Computes the total mass-flux subdomain area-weighted ρmse, assuming that the
mass-flux subdomain states are stored in `sgsʲs`.
"""
ρamse⁺(sgsʲs) = mapreduce(sgsʲ -> sgsʲ.ρa * sgsʲ.mse, +, sgsʲs)
ρamse⁺(sgsʲs) = mapreduce_with_init(sgsʲ -> sgsʲ.ρa * sgsʲ.mse, +, sgsʲs)

"""
ρaq_tot⁺(sgsʲs)
Computes the total mass-flux subdomain area-weighted ρq_tot, assuming that the
mass-flux subdomain states are stored in `sgsʲs`.
"""
ρaq_tot⁺(sgsʲs) = mapreduce(sgsʲ -> sgsʲ.ρa * sgsʲ.q_tot, +, sgsʲs)
ρaq_tot⁺(sgsʲs) = mapreduce_with_init(sgsʲ -> sgsʲ.ρa * sgsʲ.q_tot, +, sgsʲs)

"""
ρa⁰(gs)
Computes the environment area-weighted density, assuming that the mass-flux
subdomain states are stored in `gs.sgsʲs`.
"""
ρa⁰(gs) = gs.ρ - mapreduce(sgsʲ -> sgsʲ.ρa, +, gs.sgsʲs)
ρa⁰(gs) = gs.ρ - mapreduce_with_init(sgsʲ -> sgsʲ.ρa, +, gs.sgsʲs)

"""
u₃⁺(ρaʲs, u₃ʲs, ρ, u₃, turbconv_model)
Expand All @@ -229,7 +229,7 @@ are computed from the tuples of subdomain densities and velocities `ρaʲs` and
is small.
"""
u₃⁺(ρaʲs, u₃ʲs, ρ, u₃, turbconv_model) = divide_by_ρa(
mapreduce(*, +, ρaʲs, u₃ʲs),
unrolled_dotproduct(ρaʲs, u₃ʲs),
reduce(+, ρaʲs),
ρ * u₃,
ρ,
Expand All @@ -247,7 +247,7 @@ environment quantities. The division is computed using `divide_by_ρa` to avoid
issues when `a⁰` is small.
"""
u₃⁰(ρaʲs, u₃ʲs, ρ, u₃, turbconv_model) = divide_by_ρa(
ρ * u₃ - mapreduce(*, +, ρaʲs, u₃ʲs),
ρ * u₃ - unrolled_dotproduct(ρaʲs, u₃ʲs),
ρ - reduce(+, ρaʲs),
ρ * u₃,
ρ,
Expand All @@ -265,3 +265,25 @@ remove_energy_var(specific_state::NamedTuple) =
Base.structdiff(specific_state, NamedTuple{(:e_tot,)})
remove_energy_var(specific_state::Tuple) =
map(remove_energy_var, specific_state)


import ClimaCore.RecursiveApply: , , rzero, rpromote_type
function mapreduce_with_init(f, op, iter...)
r₀ = rzero(rpromote_type(typeof(f(map(first, iter)...))))
mapreduce(f, op, iter...; init = r₀)
end

# Inference fails for certain mapreduce calls inside cuda
# kernels, so let's define a recursive unrolled dot product:
promote_type_mul(n::Number, x::Geometry.AxisTensor) = typeof(x)
promote_type_mul(x::Geometry.AxisTensor, n::Number) = typeof(x)
@inline function unrolled_dotproduct(a::Tuple, b::Tuple)
r = rzero(promote_type_mul(first(a), first(b)))
unrolled_dotproduct(r, a, b)
end
@inline unrolled_dotproduct(s, ::Tuple{}, ::Tuple{}) = s
@inline unrolled_dotproduct(s, a::Tuple, b::Tuple) =
s (first(a) first(b))
unrolled_dotproduct(s, Base.tail(a), Base.tail(b))
@inline unrolled_dotproduct(s, a::Tuple{<:Any}, b::Tuple{<:Any}) =
s (first(a) first(b))

0 comments on commit 39aac5f

Please sign in to comment.