Skip to content

Commit

Permalink
Cache: move temporary quantities to scratch
Browse files Browse the repository at this point in the history
  • Loading branch information
Sbozzolo committed Oct 25, 2023
1 parent 768aeb0 commit 5b5739b
Show file tree
Hide file tree
Showing 12 changed files with 38 additions and 34 deletions.
6 changes: 5 additions & 1 deletion src/cache/cache.jl
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,11 @@ function default_cache(
net_energy_flux_sfc,
env_thermo_quad = SGSQuadrature(FT),
precomputed_quantities(Y, atmos)...,
temporary_quantities(atmos, spaces.center_space, spaces.face_space)...,
scratch = temporary_quantities(
atmos,
spaces.center_space,
spaces.face_space,
),
hyperdiffusion_cache(Y, atmos, do_dss)...,
)
set_precomputed_quantities!(Y, default_cache, FT(0))
Expand Down
8 changes: 4 additions & 4 deletions src/cache/diagnostic_edmf_precomputed_quantities.jl
Original file line number Diff line number Diff line change
Expand Up @@ -155,7 +155,7 @@ function set_diagnostic_edmf_precomputed_quantities!(Y, p, t)

@. ᶜtke⁰ = Y.c.sgs⁰.ρatke / Y.c.ρ

ᶜ∇Φ³ = p.ᶜtemp_CT3
ᶜ∇Φ³ = p.scratch.ᶜtemp_CT3
@. ᶜ∇Φ³ = CT3(ᶜgradᵥ(ᶠinterp(ᶜΦ)))
@. ᶜ∇Φ³ += CT3(gradₕ(ᶜΦ))

Expand Down Expand Up @@ -680,18 +680,18 @@ function set_diagnostic_edmf_precomputed_quantities!(Y, p, t)
# TODO: Currently the shear production only includes vertical gradients
ᶠu⁰ = p.ᶠtemp_C123
@. ᶠu⁰ = C123(ᶠinterp(Y.c.uₕ)) + C123(ᶠu³⁰)
ᶜstrain_rate = p.ᶜtemp_UVWxUVW
ᶜstrain_rate = p.scratch.ᶜtemp_UVWxUVW
compute_strain_rate_center!(ᶜstrain_rate, ᶠu⁰)
@. ᶜstrain_rate_norm = norm_sqr(ᶜstrain_rate)

ᶜprandtl_nvec = p.ᶜtemp_scalar
ᶜprandtl_nvec = p.scratch.ᶜtemp_scalar
@. ᶜprandtl_nvec = turbulent_prandtl_number(
params,
obukhov_length,
ᶜlinear_buoygrad,
ᶜstrain_rate_norm,
)
ᶜtke_exch = p.ᶜtemp_scalar_2
ᶜtke_exch = p.scratch.ᶜtemp_scalar_2
@. ᶜtke_exch = 0
# using ᶜu⁰ would be more correct, but this is more consistent with the
# TKE equation, where using ᶜu⁰ results in allocation
Expand Down
6 changes: 3 additions & 3 deletions src/cache/precomputed_quantities.jl
Original file line number Diff line number Diff line change
Expand Up @@ -302,7 +302,7 @@ NVTX.@annotate function set_precomputed_quantities!(Y, p, t)
n = n_mass_flux_subdomains(turbconv_model)
thermo_args = (thermo_params, energy_form, moisture_model)
(; ᶜspecific, ᶜu, ᶠu³, ᶜK, ᶜts, ᶜp, ᶜΦ) = p
ᶠuₕ³ = p.ᶠtemp_CT3
ᶠuₕ³ = p.scratch.ᶠtemp_CT3

@. ᶜspecific = specific_gs(Y.c)
set_ᶠuₕ³!(ᶠuₕ³, Y)
Expand Down Expand Up @@ -361,8 +361,8 @@ values of the first updraft.
function output_prognostic_sgs_quantities(Y, p, t)
(; turbconv_model) = p.atmos
thermo_params = CAP.thermodynamics_params(p.params)
(; ᶜp, ᶜρa⁰, ᶜρ⁰, ᶜΦ, ᶜtsʲs) = p
ᶠuₕ³ = p.ᶠtemp_CT3
(; ᶜρa⁰, ᶜρ⁰, ᶜtsʲs) = p
ᶠuₕ³ = p.scratch.ᶠtemp_CT3
set_ᶠuₕ³!(ᶠuₕ³, Y)
(ᶠu₃⁺, ᶜu⁺, ᶠu³⁺, ᶜK⁺) = similar.((p.ᶠu₃⁰, p.ᶜu⁰, p.ᶠu³⁰, p.ᶜK⁰))
set_sgs_ᶠu₃!(u₃⁺, ᶠu₃⁺, Y, turbconv_model)
Expand Down
6 changes: 3 additions & 3 deletions src/cache/prognostic_edmf_precomputed_quantities.jl
Original file line number Diff line number Diff line change
Expand Up @@ -246,18 +246,18 @@ function set_prognostic_edmf_precomputed_quantities_closures!(Y, p, t)
# TODO: Currently the shear production only includes vertical gradients
ᶠu⁰ = p.ᶠtemp_C123
@. ᶠu⁰ = C123(ᶠinterp(Y.c.uₕ)) + C123(ᶠu³⁰)
ᶜstrain_rate = p.ᶜtemp_UVWxUVW
ᶜstrain_rate = p.scratch.ᶜtemp_UVWxUVW
compute_strain_rate_center!(ᶜstrain_rate, ᶠu⁰)
@. ᶜstrain_rate_norm = norm_sqr(ᶜstrain_rate)

ᶜprandtl_nvec = p.ᶜtemp_scalar
ᶜprandtl_nvec = p.scratch.ᶜtemp_scalar
@. ᶜprandtl_nvec = turbulent_prandtl_number(
params,
obukhov_length,
ᶜlinear_buoygrad,
ᶜstrain_rate_norm,
)
ᶜtke_exch = p.ᶜtemp_scalar_2
ᶜtke_exch = p.scratch.ᶜtemp_scalar_2
@. ᶜtke_exch = 0
for j in 1:n
@. ᶜtke_exch +=
Expand Down
8 changes: 4 additions & 4 deletions src/prognostic_equations/advection.jl
Original file line number Diff line number Diff line change
Expand Up @@ -84,10 +84,10 @@ NVTX.@annotate function explicit_vertical_advection_tendency!(Yₜ, Y, p, t)
ᶜρa⁰ = advect_tke ? (n > 0 ? p.ᶜρa⁰ : Y.c.ρ) : nothing
ᶜρ⁰ = advect_tke ? (n > 0 ? p.ᶜρ⁰ : Y.c.ρ) : nothing
ᶜtke⁰ = advect_tke ? p.ᶜtke⁰ : nothing
ᶜa_scalar = p.ᶜtemp_scalar
ᶜω³ = p.ᶜtemp_CT3
ᶠω¹² = p.ᶠtemp_CT12
ᶠω¹²ʲs = p.ᶠtemp_CT12ʲs
ᶜa_scalar = p.scratch.ᶜtemp_scalar
ᶜω³ = p.scratch.ᶜtemp_CT3
ᶠω¹² = p.scratch.ᶠtemp_CT12
ᶠω¹²ʲs = p.scratch.ᶠtemp_CT12ʲs

if point_type <: Geometry.Abstract3DPoint
@. ᶜω³ = curlₕ(Y.c.uₕ)
Expand Down
20 changes: 10 additions & 10 deletions src/prognostic_equations/edmfx_sgs_flux.jl
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,8 @@ function edmfx_sgs_mass_flux_tendency!(

if p.atmos.edmfx_sgs_mass_flux
# energy
ᶠu³_diff_colidx = p.ᶠtemp_CT3[colidx]
ᶜh_tot_diff_colidx = ᶜq_tot_diff_colidx = p.ᶜtemp_scalar[colidx]
ᶠu³_diff_colidx = p.scratch.ᶠtemp_CT3[colidx]
ᶜh_tot_diff_colidx = ᶜq_tot_diff_colidx = p.scratch.ᶜtemp_scalar[colidx]
for j in 1:n
@. ᶠu³_diff_colidx = ᶠu³ʲs.:($$j)[colidx] - ᶠu³[colidx]
@. ᶜh_tot_diff_colidx =
Expand Down Expand Up @@ -109,8 +109,8 @@ function edmfx_sgs_mass_flux_tendency!(

if p.atmos.edmfx_sgs_mass_flux
# energy
ᶠu³_diff_colidx = p.ᶠtemp_CT3[colidx]
ᶜh_tot_diff_colidx = ᶜq_tot_diff_colidx = p.ᶜtemp_scalar[colidx]
ᶠu³_diff_colidx = p.scratch.ᶠtemp_CT3[colidx]
ᶜh_tot_diff_colidx = ᶜq_tot_diff_colidx = p.scratch.ᶜtemp_scalar[colidx]
for j in 1:n
@. ᶠu³_diff_colidx = ᶠu³ʲs.:($$j)[colidx] - ᶠu³[colidx]
@. ᶜh_tot_diff_colidx = ᶜh_totʲs.:($$j)[colidx] - ᶜh_tot[colidx]
Expand Down Expand Up @@ -169,7 +169,7 @@ function edmfx_sgs_diffusive_flux_tendency!(

if p.atmos.edmfx_sgs_diffusive_flux
# energy
ᶠρaK_h = p.ᶠtemp_scalar
ᶠρaK_h = p.scratch.ᶠtemp_scalar
@. ᶠρaK_h[colidx] = ᶠinterp(ᶜρa⁰[colidx]) * ᶠinterp(ᶜK_h[colidx])

ᶜdivᵥ_ρe_tot = Operators.DivergenceF2C(
Expand All @@ -192,9 +192,9 @@ function edmfx_sgs_diffusive_flux_tendency!(
end

# momentum
ᶠρaK_u = p.ᶠtemp_scalar
ᶠρaK_u = p.scratch.ᶠtemp_scalar
@. ᶠρaK_u[colidx] = ᶠinterp(ᶜρa⁰[colidx]) * ᶠinterp(ᶜK_u[colidx])
ᶠstrain_rate = p.ᶠtemp_UVWxUVW
ᶠstrain_rate = p.scratch.ᶠtemp_UVWxUVW
compute_strain_rate_face!(ᶠstrain_rate[colidx], ᶜu⁰[colidx])
@. Yₜ.c.uₕ[colidx] -= C12(
ᶜdivᵥ(-(2 * ᶠρaK_u[colidx] * ᶠstrain_rate[colidx])) / Y.c.ρ[colidx],
Expand Down Expand Up @@ -235,7 +235,7 @@ function edmfx_sgs_diffusive_flux_tendency!(

if p.atmos.edmfx_sgs_diffusive_flux
# energy
ᶠρaK_h = p.ᶠtemp_scalar
ᶠρaK_h = p.scratch.ᶠtemp_scalar
@. ᶠρaK_h[colidx] = ᶠinterp(Y.c.ρ[colidx]) * ᶠinterp(ᶜK_h[colidx])

ᶜdivᵥ_ρe_tot = Operators.DivergenceF2C(
Expand All @@ -259,9 +259,9 @@ function edmfx_sgs_diffusive_flux_tendency!(
end

# momentum
ᶠρaK_u = p.ᶠtemp_scalar
ᶠρaK_u = p.scratch.ᶠtemp_scalar
@. ᶠρaK_u[colidx] = ᶠinterp(Y.c.ρ[colidx]) * ᶠinterp(ᶜK_u[colidx])
ᶠstrain_rate = p.ᶠtemp_UVWxUVW
ᶠstrain_rate = p.scratch.ᶠtemp_UVWxUVW
compute_strain_rate_face!(ᶠstrain_rate[colidx], ᶜu[colidx])
@. Yₜ.c.uₕ[colidx] -= C12(
ᶜdivᵥ(-(2 * ᶠρaK_u[colidx] * ᶠstrain_rate[colidx])) / Y.c.ρ[colidx],
Expand Down
2 changes: 1 addition & 1 deletion src/prognostic_equations/edmfx_tke.jl
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ function edmfx_tke_tendency!(
(; ᶜK_u, ᶜK_h, ρatke_flux) = p
ᶠgradᵥ = Operators.GradientC2F()

ᶠρaK_u = p.ᶠtemp_scalar
ᶠρaK_u = p.scratch.ᶠtemp_scalar
if use_prognostic_tke(turbconv_model)
# turbulent transport (diffusive flux)
@. ᶠρaK_u[colidx] = ᶠinterp(Y.c.ρ[colidx]) * ᶠinterp(ᶜK_u[colidx])
Expand Down
2 changes: 1 addition & 1 deletion src/prognostic_equations/implicit/implicit_solver.jl
Original file line number Diff line number Diff line change
Expand Up @@ -217,7 +217,7 @@ function Wfact!(A, Y, p, dtγ, t)
p.ᶠgradᵥ_ᶜΦ,
p.ᶜρ_ref,
p.ᶜp_ref,
p.ᶜtemp_scalar,
p.scratch.ᶜtemp_scalar,
p.params,
p.atmos,
(energy_form isa TotalEnergy ? (; p.ᶜh_tot) : (;))...,
Expand Down
4 changes: 2 additions & 2 deletions src/prognostic_equations/implicit/implicit_tendency.jl
Original file line number Diff line number Diff line change
Expand Up @@ -65,9 +65,9 @@ function implicit_vertical_advection_tendency!(Yₜ, Y, p, t, colidx)
(; dt) = p.simulation
n = n_mass_flux_subdomains(turbconv_model)
ᶜJ = Fields.local_geometry_field(Y.c).J
(; ᶜspecific, ᶠu³, ᶜp, ᶠgradᵥ_ᶜΦ, ᶜρ_ref, ᶜp_ref, ᶜtemp_scalar) = p
(; ᶜspecific, ᶠu³, ᶜp, ᶠgradᵥ_ᶜΦ, ᶜρ_ref, ᶜp_ref) = p

ᶜ1 = ᶜtemp_scalar
ᶜ1 = p.scratch.ᶜtemp_scalar
@. ᶜ1[colidx] = one(Y.c.ρ[colidx])
vertical_transport!(
Yₜ.c.ρ[colidx],
Expand Down
6 changes: 3 additions & 3 deletions src/prognostic_equations/vertical_diffusion_boundary_layer.jl
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ function vertical_diffusion_boundary_layer_tendency!(

FT = eltype(Y)
interior_uₕ = Fields.level(Y.c.uₕ, 1)
ᶠp = ᶠρK_E = p.ᶠtemp_scalar
ᶠp = ᶠρK_E = p.scratch.ᶠtemp_scalar
@. ᶠp[colidx] = ᶠinterp(ᶜp[colidx])
ᶜΔz_surface = Fields.Δz_field(interior_uₕ)
@. ᶠρK_E[colidx] =
Expand Down Expand Up @@ -114,8 +114,8 @@ function vertical_diffusion_boundary_layer_tendency!(
@. Yₜ.c.ρe_tot[colidx] -=
ᶜdivᵥ_ρe_tot(-(ᶠρK_E[colidx] * ᶠgradᵥ(ᶜh_tot[colidx])))
end
ᶜρχₜ_diffusion = p.ᶜtemp_scalar
ρ_flux_χ = p.sfc_temp_C3
ᶜρχₜ_diffusion = p.scratch.ᶜtemp_scalar
ρ_flux_χ = p.scratch.sfc_temp_C3
for (ᶜρχₜ, ᶜχ, χ_name) in matching_subfields(Yₜ.c, ᶜspecific)
χ_name == :e_tot && continue
if χ_name == :q_tot
Expand Down
1 change: 0 additions & 1 deletion src/solver/type_getters.jl
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,6 @@ function get_numerics(parsed_args)
limiter =
parsed_args["apply_limiter"] ? Limiters.QuasiMonotoneLimiter : nothing


# wrap each upwinding mode in a Val for dispatch
numerics = AtmosNumerics(;
energy_upwinding,
Expand Down
3 changes: 2 additions & 1 deletion src/surface_conditions/surface_conditions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,8 @@ This functions needs to be called by the coupler whenever either field changes
to ensure that the simulation is properly updated.
"""
function set_surface_conditions!(p, surface_conditions, surface_ts)
(; sfc_conditions, params, atmos, ᶠtemp_scalar) = p
(; sfc_conditions, params, atmos) = p
(; ᶠtemp_scalar) = p.scratch

FT = eltype(params)
FT′ = eltype(parent(surface_conditions))
Expand Down

0 comments on commit 5b5739b

Please sign in to comment.