Skip to content

Commit

Permalink
Merge branch 'main' into gb/concurrent_group
Browse files Browse the repository at this point in the history
  • Loading branch information
Sbozzolo authored Mar 19, 2024
2 parents e807825 + a554cd0 commit 9340574
Show file tree
Hide file tree
Showing 11 changed files with 91 additions and 39 deletions.
92 changes: 53 additions & 39 deletions experiments/AMIP/coupler_driver.jl
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,7 @@ then `ClimaComms` automatically selects the device from which this code is calle

using ClimaComms
comms_ctx = get_comms_context(parsed_args)
const pid, nprocs = ClimaComms.init(comms_ctx)
ClimaComms.init(comms_ctx)

#=
### I/O Directory Setup
Expand Down Expand Up @@ -524,6 +524,20 @@ update_firstdayofmonth!_cb =
MonthlyCallback(dt = FT(1), func = update_firstdayofmonth!, ref_date = [dates.date1[1]], active = true)
callbacks = (; checkpoint = checkpoint_cb, update_firstdayofmonth! = update_firstdayofmonth!_cb)

#=
## Initialize turbulent fluxes
Decide on the type of turbulent flux partition (see `FluxCalculator` documentation for more details).
=#
turbulent_fluxes = nothing
if config_dict["turb_flux_partition"] == "PartitionedStateFluxes"
turbulent_fluxes = PartitionedStateFluxes()
elseif config_dict["turb_flux_partition"] == "CombinedStateFluxes"
turbulent_fluxes = CombinedStateFluxes()
else
error("turb_flux_partition must be either PartitionedStateFluxes or CombinedStateFluxes")
end

#=
## Initialize Coupled Simulation
Expand All @@ -549,6 +563,8 @@ cs = CoupledSimulation{FT}(
diagnostics,
callbacks,
dir_paths,
turbulent_fluxes,
thermo_params,
);

#=
Expand All @@ -573,41 +589,31 @@ depend on initial conditions of other component models than those in which the v
The concrete steps for proper initialization are:
=#

# 1.decide on the type of turbulent flux partition (see `FluxCalculator` documentation for more details)
turbulent_fluxes = nothing
if config_dict["turb_flux_partition"] == "PartitionedStateFluxes"
turbulent_fluxes = PartitionedStateFluxes()
elseif config_dict["turb_flux_partition"] == "CombinedStateFluxes"
turbulent_fluxes = CombinedStateFluxes()
else
error("turb_flux_partition must be either PartitionedStateFluxes or CombinedStateFluxes")
end

# 2.coupler updates surface model area fractions
# 1.coupler updates surface model area fractions
update_surface_fractions!(cs)

# 3.surface density (`ρ_sfc`): calculated by the coupler by adiabatically extrapolating atmospheric thermal state to the surface.
# 2.surface density (`ρ_sfc`): calculated by the coupler by adiabatically extrapolating atmospheric thermal state to the surface.
# For this, we need to import surface and atmospheric fields. The model sims are then updated with the new surface density.
import_combined_surface_fields!(cs.fields, cs.model_sims, cs.boundary_space, turbulent_fluxes)
import_atmos_fields!(cs.fields, cs.model_sims, cs.boundary_space, turbulent_fluxes)
update_model_sims!(cs.model_sims, cs.fields, turbulent_fluxes)
import_combined_surface_fields!(cs.fields, cs.model_sims, cs.boundary_space, cs.turbulent_fluxes)
import_atmos_fields!(cs.fields, cs.model_sims, cs.boundary_space, cs.turbulent_fluxes)
update_model_sims!(cs.model_sims, cs.fields, cs.turbulent_fluxes)

# 4.surface vapor specific humidity (`q_sfc`): step surface models with the new surface density to calculate their respective `q_sfc` internally
# 3.surface vapor specific humidity (`q_sfc`): step surface models with the new surface density to calculate their respective `q_sfc` internally
## TODO: the q_sfc calculation follows the design of the bucket q_sfc, but it would be neater to abstract this from step! (#331)
step!(land_sim, Δt_cpl)
step!(ocean_sim, Δt_cpl)
step!(ice_sim, Δt_cpl)

# 5.turbulent fluxes: now we have all information needed for calculating the initial turbulent surface fluxes using the combined state
# 4.turbulent fluxes: now we have all information needed for calculating the initial turbulent surface fluxes using the combined state
# or the partitioned state method
if turbulent_fluxes isa CombinedStateFluxes
if cs.turbulent_fluxes isa CombinedStateFluxes
## import the new surface properties into the coupler (note the atmos state was also imported in step 3.)
import_combined_surface_fields!(cs.fields, cs.model_sims, cs.boundary_space, turbulent_fluxes) # i.e. T_sfc, albedo, z0, beta, q_sfc
import_combined_surface_fields!(cs.fields, cs.model_sims, cs.boundary_space, cs.turbulent_fluxes) # i.e. T_sfc, albedo, z0, beta, q_sfc
## calculate turbulent fluxes inside the atmos cache based on the combined surface state in each grid box
combined_turbulent_fluxes!(cs.model_sims, cs.fields, turbulent_fluxes) # this updates the atmos thermo state, sfc_ts
elseif turbulent_fluxes isa PartitionedStateFluxes
combined_turbulent_fluxes!(cs.model_sims, cs.fields, cs.turbulent_fluxes) # this updates the atmos thermo state, sfc_ts
elseif cs.turbulent_fluxes isa PartitionedStateFluxes
## calculate turbulent fluxes in surface models and save the weighted average in coupler fields
partitioned_turbulent_fluxes!(cs.model_sims, cs.fields, cs.boundary_space, MoninObukhovScheme(), thermo_params)
partitioned_turbulent_fluxes!(cs.model_sims, cs.fields, cs.boundary_space, MoninObukhovScheme(), cs.thermo_params)

## update atmos sfc_conditions for surface temperature
## TODO: this is hard coded and needs to be simplified (req. CA modification) (#479)
Expand All @@ -616,15 +622,15 @@ elseif turbulent_fluxes isa PartitionedStateFluxes
atmos_sim.integrator.p.precomputed.sfc_conditions .= new_p.precomputed.sfc_conditions
end

# 6.reinitialize models + radiative flux: prognostic states and time are set to their initial conditions. For atmos, this also triggers the callbacks and sets a nonzero radiation flux (given the new sfc_conditions)
# 5.reinitialize models + radiative flux: prognostic states and time are set to their initial conditions. For atmos, this also triggers the callbacks and sets a nonzero radiation flux (given the new sfc_conditions)
reinit_model_sims!(cs.model_sims)

# 7.update all fluxes: coupler re-imports updated atmos fluxes (radiative fluxes for both `turbulent_fluxes` types
# 6.update all fluxes: coupler re-imports updated atmos fluxes (radiative fluxes for both `turbulent_fluxes` types
# and also turbulent fluxes if `turbulent_fluxes isa CombinedStateFluxes`,
# and sends them to the surface component models. If `turbulent_fluxes isa PartitionedStateFluxes`
# atmos receives the turbulent fluxes from the coupler.
import_atmos_fields!(cs.fields, cs.model_sims, cs.boundary_space, turbulent_fluxes)
update_model_sims!(cs.model_sims, cs.fields, turbulent_fluxes)
import_atmos_fields!(cs.fields, cs.model_sims, cs.boundary_space, cs.turbulent_fluxes)
update_model_sims!(cs.model_sims, cs.fields, cs.turbulent_fluxes)

#=
## Coupling Loop
Expand All @@ -634,13 +640,12 @@ Note that we want to implement this in a dispatchable function to allow for othe
=#

function solve_coupler!(cs)
ClimaComms.iamroot(comms_ctx) ? @info("Starting coupling loop") : nothing

(; model_sims, Δt_cpl, tspan) = cs
(; model_sims, Δt_cpl, tspan, comms_ctx) = cs
(; atmos_sim, land_sim, ocean_sim, ice_sim) = model_sims

ClimaComms.iamroot(comms_ctx) ? @info("Starting coupling loop") : nothing
## step in time
walltime = @elapsed for t in ((tspan[1] + Δt_cpl):Δt_cpl:tspan[end])
walltime = @elapsed for t in ((tspan[begin] + Δt_cpl):Δt_cpl:tspan[end])

cs.dates.date[1] = current_date(cs, t)

Expand All @@ -662,7 +667,10 @@ function solve_coupler!(cs)
update_midmonth_data!(cs.dates.date[1], cs.mode.SIC_info)
end
SIC_current =
get_ice_fraction.(interpolate_midmonth_to_daily(cs.dates.date[1], cs.mode.SIC_info), mono_surface)
get_ice_fraction.(
interpolate_midmonth_to_daily(cs.dates.date[1], cs.mode.SIC_info),
cs.mode.SIC_info.mono,
)
update_field!(ice_sim, Val(:area_fraction), SIC_current)

if cs.dates.date[1] >= next_date_in_file(cs.mode.CO2_info)
Expand All @@ -686,26 +694,32 @@ function solve_coupler!(cs)
## run component models sequentially for one coupling timestep (Δt_cpl)
ClimaComms.barrier(comms_ctx)
update_surface_fractions!(cs)
update_model_sims!(cs.model_sims, cs.fields, turbulent_fluxes)
update_model_sims!(cs.model_sims, cs.fields, cs.turbulent_fluxes)

## step sims
step_model_sims!(cs.model_sims, t)

## exchange combined fields and (if specified) calculate fluxes using combined states
import_combined_surface_fields!(cs.fields, cs.model_sims, cs.boundary_space, turbulent_fluxes) # i.e. T_sfc, surface_albedo, z0, beta
if turbulent_fluxes isa CombinedStateFluxes
combined_turbulent_fluxes!(cs.model_sims, cs.fields, turbulent_fluxes) # this updates the surface thermo state, sfc_ts, in ClimaAtmos (but also unnecessarily calculates fluxes)
elseif turbulent_fluxes isa PartitionedStateFluxes
import_combined_surface_fields!(cs.fields, cs.model_sims, cs.boundary_space, cs.turbulent_fluxes) # i.e. T_sfc, surface_albedo, z0, beta
if cs.turbulent_fluxes isa CombinedStateFluxes
combined_turbulent_fluxes!(cs.model_sims, cs.fields, cs.turbulent_fluxes) # this updates the surface thermo state, sfc_ts, in ClimaAtmos (but also unnecessarily calculates fluxes)
elseif cs.turbulent_fluxes isa PartitionedStateFluxes
## calculate turbulent fluxes in surfaces and save the weighted average in coupler fields
partitioned_turbulent_fluxes!(cs.model_sims, cs.fields, cs.boundary_space, MoninObukhovScheme(), thermo_params)
partitioned_turbulent_fluxes!(
cs.model_sims,
cs.fields,
cs.boundary_space,
MoninObukhovScheme(),
cs.thermo_params,
)

## update atmos sfc_conditions for surface temperature - TODO: this needs to be simplified (need CA modification)
new_p = get_new_cache(atmos_sim, cs.fields)
CA.SurfaceConditions.update_surface_conditions!(atmos_sim.integrator.u, new_p, atmos_sim.integrator.t) # to set T_sfc (but SF calculation not necessary - CA modification)
atmos_sim.integrator.p.precomputed.sfc_conditions .= new_p.precomputed.sfc_conditions
end

import_atmos_fields!(cs.fields, cs.model_sims, cs.boundary_space, turbulent_fluxes) # radiative and/or turbulent
import_atmos_fields!(cs.fields, cs.model_sims, cs.boundary_space, cs.turbulent_fluxes) # radiative and/or turbulent

## callback to update the fist day of month if needed (for BCReader)
trigger_callback!(cs, cs.callbacks.update_firstdayofmonth!)
Expand Down
3 changes: 3 additions & 0 deletions src/BCReader.jl
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ Stores information specific to each boundary condition from a file and each vari
- segment_idx0::Vector{Int} # `segment_idx` of the file data that is closest to date0
- segment_length::Vector{Int} # length of each month segment (used in the daily interpolation)
- interpolate_daily::Bool # switch to trigger daily interpolation
- mono::Bool # flag for monotone remapping of input data
"""
struct BCFileInfo{FT <: Real, B, X, S, V, D, C, O, M, VI}
bcfile_dir::B
Expand All @@ -49,6 +50,7 @@ struct BCFileInfo{FT <: Real, B, X, S, V, D, C, O, M, VI}
segment_idx0::VI
segment_length::VI
interpolate_daily::Bool
mono::Bool
end

BCFileInfo{FT}(args...) where {FT} = BCFileInfo{FT, typeof.(args[1:9])...}(args...)
Expand Down Expand Up @@ -164,6 +166,7 @@ function bcfile_info_init(
segment_idx0,
segment_length,
interpolate_daily,
mono,
)
end

Expand Down
4 changes: 4 additions & 0 deletions src/Interfacer.jl
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,8 @@ struct CoupledSimulation{
TD <: Tuple,
NTC <: NamedTuple,
NTP <: NamedTuple,
TF,
TP,
}
comms_ctx::X
dates::D
Expand All @@ -65,6 +67,8 @@ struct CoupledSimulation{
diagnostics::TD
callbacks::NTC
dirs::NTP
turbulent_fluxes::TF
thermo_params::TP
end

CoupledSimulation{FT}(args...) where {FT} = CoupledSimulation{FT, typeof.(args[1:end])...}(args...)
Expand Down
5 changes: 5 additions & 0 deletions test/bcreader_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ for FT in (Float32, Float64)
segment_idx0, # segment_idx0
Int[], # segment_length
false, # interpolate_daily
false, # mono
)

idx = segment_idx0[1]
Expand Down Expand Up @@ -113,6 +114,7 @@ for FT in (Float32, Float64)
segment_idx0, # segment_idx0
segment_length, # segment_length
interpolate_daily, # interpolate_daily
false, # mono
)
@test BCReader.interpolate_midmonth_to_daily(date0, bcf_info_interp) == ones(boundary_space_t) .* FT(0.5)

Expand All @@ -132,6 +134,7 @@ for FT in (Float32, Float64)
segment_idx0, # segment_idx0
segment_length, # segment_length
interpolate_daily, # interpolate_daily
false, # mono
)
@test BCReader.interpolate_midmonth_to_daily(date0, bcf_info_no_interp) == monthly_fields[1]
end
Expand Down Expand Up @@ -196,6 +199,8 @@ for FT in (Float32, Float64)
(), # diagnostics
(;), # callbacks
(;), # dirs
nothing, # turbulent_fluxes
nothing, # thermo_params
)

# step in time
Expand Down
4 changes: 4 additions & 0 deletions test/conservation_checker_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,8 @@ for FT in (Float32, Float64)
(), # diagnostics
(;), # callbacks
(;), # dirs
nothing, # turbulent_fluxes
nothing, # thermo_params
)

# set non-zero radiation and precipitation
Expand Down Expand Up @@ -178,6 +180,8 @@ for FT in (Float32, Float64)
(), # diagnostics
(;), # callbacks
(;), # dirs
nothing, # turbulent_fluxes
nothing, # thermo_params
)

tot_energy, tot_water = check_conservation!(cs)
Expand Down
2 changes: 2 additions & 0 deletions test/debug/debug_amip_plots.jl
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,8 @@ plot_field_names(sim::SurfaceStub) = (:stub_field,)
(), # diagnostics
(;), # callbacks
(;), # dirs
nothing, # turbulent_fluxes
nothing, # thermo_params
)

output_plots = "test_debug"
Expand Down
6 changes: 6 additions & 0 deletions test/diagnostics_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,8 @@ for FT in (Float32, Float64)
(dg_2d,),
(;), # callbacks
(;), # dirs
nothing, # turbulent_fluxes
nothing, # thermo_params
)
accumulate_diagnostics!(cs)
@test cs.diagnostics[1].field_vector[1] == expected_results[c_i]
Expand Down Expand Up @@ -89,6 +91,8 @@ for FT in (Float32, Float64)
(dg_2d,), # diagnostics
(;), # callbacks
(;), # dirs
nothing, # turbulent_fluxes
nothing, # thermo_params
)
save_diagnostics(cs, cs.diagnostics[1])
file = filter(x -> endswith(x, ".hdf5"), readdir(test_dir))
Expand Down Expand Up @@ -129,6 +133,8 @@ for FT in (Float32, Float64)
(dg_2d,),
(;), # callbacks
(;), # dirs
nothing, # turbulent_fluxes
nothing, # thermo_params
)
accumulate_diagnostics!(cs)
@test cs.diagnostics[1].field_vector[1] == expected_results[c_i][1]
Expand Down
2 changes: 2 additions & 0 deletions test/interfacer_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,8 @@ for FT in (Float32, Float64)
(), # diagnostics
(;), # callbacks
(;), # dirs
nothing, # turbulent_fluxes
nothing, # thermo_params
)

@test float_type(cs) == FT
Expand Down
2 changes: 2 additions & 0 deletions test/mpi_tests/bcreader_mpi_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,8 @@ end
(), # diagnostics
(;), # callbacks
(;), # dirs
nothing, # turbulent_fluxes
nothing, # thermo_params
)

ClimaComms.barrier(comms_ctx)
Expand Down
2 changes: 2 additions & 0 deletions test/regridder_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,8 @@ for FT in (Float32, Float64)
(), # diagnostics
(;), # callbacks
(;), # dirs
nothing, # turbulent_fluxes
nothing, # thermo_params
)

Regridder.update_surface_fractions!(cs)
Expand Down
8 changes: 8 additions & 0 deletions test/time_manager_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,8 @@ for FT in (Float32, Float64)
(), # diagnostics
(;), # callbacks
(;), # dirs
nothing, # turbulent_fluxes
nothing, # thermo_params
)

for t in ((tspan[1] + Δt_cpl):Δt_cpl:tspan[end])
Expand Down Expand Up @@ -70,6 +72,8 @@ end
(), # diagnostics
(;), # callbacks
(;), # dirs
nothing, # turbulent_fluxes
nothing, # thermo_params
)
@test TimeManager.trigger_callback(cs, TimeManager.Monthly()) == true
end
Expand Down Expand Up @@ -110,6 +114,8 @@ end
monthly_counter = monthly_counter,
), # callbacks
(;), # dirs
nothing, # turbulent_fluxes
nothing, # thermo_params
)

TimeManager.trigger_callback!(cs, cs.callbacks.twhohourly_inactive)
Expand Down Expand Up @@ -172,6 +178,8 @@ end
(), # diagnostics
(;), # callbacks
(;), # dirs
nothing, # turbulent_fluxes
nothing, # thermo_params
)

TimeManager.update_firstdayofmonth!(cs, nothing)
Expand Down

0 comments on commit 9340574

Please sign in to comment.