Skip to content

Commit

Permalink
Ensure that client used in GZ21 is the same client as elsewhere
Browse files Browse the repository at this point in the history
  • Loading branch information
Andrew E. Shao committed Nov 29, 2023
1 parent ff8e73b commit 5c2bc72
Show file tree
Hide file tree
Showing 6 changed files with 44 additions and 30 deletions.
7 changes: 4 additions & 3 deletions src/core/MOM.F90
Original file line number Diff line number Diff line change
Expand Up @@ -2872,7 +2872,8 @@ subroutine initialize_MOM(Time, Time_init, param_file, dirs, CS, restart_CSp, &
CS%dt, CS%ADp, CS%CDp, MOM_internal_state, CS%VarMix, CS%MEKE, &
CS%thickness_diffuse_CSp, &
CS%OBC, CS%update_OBC_CSp, CS%ALE_CSp, CS%set_visc_CSp, &
CS%visc, dirs, CS%ntrunc, CS%pbv, calc_dtbt=calc_dtbt, cont_stencil=CS%cont_stencil)
CS%visc, dirs, CS%ntrunc, CS%dbcomms_CS, CS%pbv, &
calc_dtbt=calc_dtbt, cont_stencil=CS%cont_stencil)
if (CS%dtbt_reset_period > 0.0) then
CS%dtbt_reset_interval = real_to_time(CS%dtbt_reset_period)
! Set dtbt_reset_time to be the next even multiple of dtbt_reset_interval.
Expand All @@ -2890,13 +2891,13 @@ subroutine initialize_MOM(Time, Time_init, param_file, dirs, CS, restart_CSp, &
param_file, diag, CS%dyn_unsplit_RK2_CSp, &
CS%ADp, CS%CDp, MOM_internal_state, CS%OBC, &
CS%update_OBC_CSp, CS%ALE_CSp, CS%set_visc_CSp, CS%visc, dirs, &
CS%ntrunc, cont_stencil=CS%cont_stencil)
CS%ntrunc, CS%dbcomms_CS, cont_stencil=CS%cont_stencil)
else
call initialize_dyn_unsplit(CS%u, CS%v, CS%h, Time, G, GV, US, &
param_file, diag, CS%dyn_unsplit_CSp, &
CS%ADp, CS%CDp, MOM_internal_state, CS%OBC, &
CS%update_OBC_CSp, CS%ALE_CSp, CS%set_visc_CSp, CS%visc, dirs, &
CS%ntrunc, cont_stencil=CS%cont_stencil)
CS%ntrunc, CS%dbcomms_CS, cont_stencil=CS%cont_stencil)
endif

call callTree_waypoint("dynamics initialized (initialize_MOM)")
Expand Down
6 changes: 4 additions & 2 deletions src/core/MOM_dynamics_split_RK2.F90
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ module MOM_dynamics_split_RK2
use MOM_cpu_clock, only : cpu_clock_id, cpu_clock_begin, cpu_clock_end
use MOM_cpu_clock, only : CLOCK_COMPONENT, CLOCK_SUBCOMPONENT
use MOM_cpu_clock, only : CLOCK_MODULE_DRIVER, CLOCK_MODULE, CLOCK_ROUTINE
use MOM_database_comms, only : dbcomms_CS_type
use MOM_diag_mediator, only : diag_mediator_init, enable_averages
use MOM_diag_mediator, only : disable_averaging, post_data, safe_alloc_ptr
use MOM_diag_mediator, only : post_product_u, post_product_sum_u
Expand Down Expand Up @@ -1112,7 +1113,7 @@ subroutine initialize_dyn_split_RK2(u, v, h, uh, vh, eta, Time, G, GV, US, param
diag, CS, restart_CS, dt, Accel_diag, Cont_diag, MIS, &
VarMix, MEKE, thickness_diffuse_CSp, &
OBC, update_OBC_CSp, ALE_CSp, set_visc, &
visc, dirs, ntrunc, pbv, calc_dtbt, cont_stencil)
visc, dirs, ntrunc, dbcomms_CS, pbv, calc_dtbt, cont_stencil)
type(ocean_grid_type), intent(inout) :: G !< ocean grid structure
type(verticalGrid_type), intent(in) :: GV !< ocean vertical grid structure
type(unit_scale_type), intent(in) :: US !< A dimensional unit scaling type
Expand Down Expand Up @@ -1151,6 +1152,7 @@ subroutine initialize_dyn_split_RK2(u, v, h, uh, vh, eta, Time, G, GV, US, param
integer, target, intent(inout) :: ntrunc !< A target for the variable that records
!! the number of times the velocity is
!! truncated (this should be 0).
type(dbcomms_CS_type), target, intent(in) :: dbcomms_CS !< Control structure for database communication client
logical, intent(out) :: calc_dtbt !< If true, recalculate the barotropic time step
type(porous_barrier_type), intent(in) :: pbv !< porous barrier fractional cell metrics
integer, intent(out) :: cont_stencil !< The stencil for thickness
Expand Down Expand Up @@ -1276,7 +1278,7 @@ subroutine initialize_dyn_split_RK2(u, v, h, uh, vh, eta, Time, G, GV, US, param
if (use_tides) call tidal_forcing_init(Time, G, US, param_file, CS%tides_CSp)
call PressureForce_init(Time, G, GV, US, param_file, diag, CS%PressureForce_CSp, &
CS%tides_CSp)
call hor_visc_init(Time, G, GV, US, param_file, diag, CS%hor_visc, ADp=CS%ADp)
call hor_visc_init(Time, G, GV, US, param_file, diag, CS%hor_visc, dbcomms_CS, ADp=CS%ADp)
call vertvisc_init(MIS, Time, G, GV, US, param_file, diag, CS%ADp, dirs, &
ntrunc, CS%vertvisc_CSp)
CS%set_visc_CSp => set_visc
Expand Down
6 changes: 4 additions & 2 deletions src/core/MOM_dynamics_unsplit.F90
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ module MOM_dynamics_unsplit
use MOM_cpu_clock, only : cpu_clock_id, cpu_clock_begin, cpu_clock_end
use MOM_cpu_clock, only : CLOCK_COMPONENT, CLOCK_SUBCOMPONENT
use MOM_cpu_clock, only : CLOCK_MODULE_DRIVER, CLOCK_MODULE, CLOCK_ROUTINE
use MOM_database_comms, only : dbcomms_CS_type
use MOM_diag_mediator, only : diag_mediator_init, enable_averages
use MOM_diag_mediator, only : disable_averaging, post_data, safe_alloc_ptr
use MOM_diag_mediator, only : register_diag_field, register_static_field
Expand Down Expand Up @@ -567,7 +568,7 @@ end subroutine register_restarts_dyn_unsplit
subroutine initialize_dyn_unsplit(u, v, h, Time, G, GV, US, param_file, diag, CS, &
Accel_diag, Cont_diag, MIS, &
OBC, update_OBC_CSp, ALE_CSp, set_visc, &
visc, dirs, ntrunc, cont_stencil)
visc, dirs, ntrunc, dbcomms_CS, cont_stencil)
type(ocean_grid_type), intent(inout) :: G !< The ocean's grid structure.
type(verticalGrid_type), intent(in) :: GV !< The ocean's vertical grid structure.
type(unit_scale_type), intent(in) :: US !< A dimensional unit scaling type
Expand Down Expand Up @@ -610,6 +611,7 @@ subroutine initialize_dyn_unsplit(u, v, h, Time, G, GV, US, param_file, diag, CS
integer, target, intent(inout) :: ntrunc !< A target for the variable that
!! records the number of times the velocity
!! is truncated (this should be 0).
type(dbcomms_CS_type), target, intent(in) :: dbcomms_CS !< Control stracture for database communication client
integer, intent(out) :: cont_stencil !< The stencil for thickness
!! from the continuity solver.

Expand Down Expand Up @@ -667,7 +669,7 @@ subroutine initialize_dyn_unsplit(u, v, h, Time, G, GV, US, param_file, diag, CS
if (use_tides) call tidal_forcing_init(Time, G, US, param_file, CS%tides_CSp)
call PressureForce_init(Time, G, GV, US, param_file, diag, CS%PressureForce_CSp, &
CS%tides_CSp)
call hor_visc_init(Time, G, GV, US, param_file, diag, CS%hor_visc)
call hor_visc_init(Time, G, GV, US, param_file, diag, CS%hor_visc, dbcomms_CS)
call vertvisc_init(MIS, Time, G, GV, US, param_file, diag, CS%ADp, dirs, &
ntrunc, CS%vertvisc_CSp)
CS%set_visc_CSp => set_visc
Expand Down
7 changes: 5 additions & 2 deletions src/core/MOM_dynamics_unsplit_RK2.F90
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ module MOM_dynamics_unsplit_RK2
use MOM_cpu_clock, only : cpu_clock_id, cpu_clock_begin, cpu_clock_end
use MOM_cpu_clock, only : CLOCK_COMPONENT, CLOCK_SUBCOMPONENT
use MOM_cpu_clock, only : CLOCK_MODULE_DRIVER, CLOCK_MODULE, CLOCK_ROUTINE
use MOM_database_comms, only : dbcomms_CS_type
use MOM_diag_mediator, only : diag_mediator_init, enable_averages
use MOM_diag_mediator, only : disable_averaging, post_data, safe_alloc_ptr
use MOM_diag_mediator, only : register_diag_field, register_static_field
Expand Down Expand Up @@ -518,7 +519,7 @@ end subroutine register_restarts_dyn_unsplit_RK2
subroutine initialize_dyn_unsplit_RK2(u, v, h, Time, G, GV, US, param_file, diag, CS, &
Accel_diag, Cont_diag, MIS, &
OBC, update_OBC_CSp, ALE_CSp, set_visc, &
visc, dirs, ntrunc, cont_stencil)
visc, dirs, ntrunc, dbcomms_CS, cont_stencil)
type(ocean_grid_type), intent(inout) :: G !< The ocean's grid structure.
type(verticalGrid_type), intent(in) :: GV !< The ocean's vertical grid structure.
type(unit_scale_type), intent(in) :: US !< A dimensional unit scaling type
Expand Down Expand Up @@ -558,6 +559,8 @@ subroutine initialize_dyn_unsplit_RK2(u, v, h, Time, G, GV, US, param_file, diag
integer, target, intent(inout) :: ntrunc !< A target for the variable
!! that records the number of times the
!! velocity is truncated (this should be 0).
type(dbcomms_CS_type), target, intent(in) :: dbcomms_CS !< Control structure for database
!! communication client
integer, intent(out) :: cont_stencil !< The stencil for
!! thickness from the continuity solver.

Expand Down Expand Up @@ -631,7 +634,7 @@ subroutine initialize_dyn_unsplit_RK2(u, v, h, Time, G, GV, US, param_file, diag
if (use_tides) call tidal_forcing_init(Time, G, US, param_file, CS%tides_CSp)
call PressureForce_init(Time, G, GV, US, param_file, diag, CS%PressureForce_CSp, &
CS%tides_CSp)
call hor_visc_init(Time, G, GV, US, param_file, diag, CS%hor_visc)
call hor_visc_init(Time, G, GV, US, param_file, diag, CS%hor_visc, dbcomms_CS)
call vertvisc_init(MIS, Time, G, GV, US, param_file, diag, CS%ADp, dirs, &
ntrunc, CS%vertvisc_CSp)
CS%set_visc_CSp => set_visc
Expand Down
32 changes: 18 additions & 14 deletions src/parameterizations/lateral/MOM_CNN_GZ21.F90
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
module MOM_CNN_GZ21

use MOM_database_comms, only : dbcomms_CS_type, dbclient_type
use MOM_grid, only : ocean_grid_type
use MOM_verticalGrid, only : verticalGrid_type
use MOM_domains, only : clone_MOM_domain,MOM_domain_type
Expand Down Expand Up @@ -50,6 +51,7 @@ module MOM_CNN_GZ21
!> Control structure for CNN
type, public :: CNN_CS ; private
type(MOM_domain_type), pointer :: CNN_Domain => NULL() !< Domain for inputs/outputs for CNN
type(dbclient_type), pointer :: client => NULL() !< The database communication client
integer :: isdw !< The lower i-memory limit for the wide halo arrays.
integer :: iedw !< The upper i-memory limit for the wide halo arrays.
integer :: jsdw !< The lower j-memory limit for the wide halo arrays.
Expand Down Expand Up @@ -78,20 +80,22 @@ module MOM_CNN_GZ21
contains

!> Prepare CNN input variables with wide halos
subroutine CNN_init(Time,G,GV,US,param_file,diag,CS)
subroutine CNN_init(Time,G,GV,US,param_file,diag, dbcomms_CS, CS)
type(time_type), intent(in) :: Time !< The current model time.
type(ocean_grid_type), intent(in) :: G !< The ocean's grid structure.
type(verticalGrid_type), intent(in) :: GV !< The ocean's vertical grid structure
type(unit_scale_type), intent(in) :: US !< A dimensional unit scaling type
type(param_file_type), intent(in) :: param_file !< Parameter file parser structure.
type(diag_ctrl), target, intent(inout) :: diag !< Diagnostics structure.
type(dbcomms_CS_type), target, intent(in ) :: dbcomms_CS !< Control structure for database communications
type(CNN_CS), intent(inout) :: CS !< Control structure for CNN
! Local Variables
integer :: wd_halos(2) ! Varies with CNN
character(len=40) :: mdl = "MOM_CNN" ! module name

! Register fields for output from this module.
CS%diag => diag
CS%client=> dbcomms_CS%client

CS%id_CNNu = register_diag_field('ocean_model', 'CNNu', diag%axesCuL, Time, &
'Zonal Acceleration from CNN model', 'm s-2', conversion=US%L_T2_to_m_s2)
Expand All @@ -110,12 +114,12 @@ subroutine CNN_init(Time,G,GV,US,param_file,diag,CS)
CS%id_Systd = register_diag_field('ocean_model', 'Systd', diag%axesTL, Time, &
'Meridional Acceleration from CNN model standard deviation part', &
'm s-2', conversion=US%L_T2_to_m_s2)

call get_param(param_file, mdl, "CNN_BT", CS%CNN_BT, &
"If true, momentum forcing from CNN is barotropic, otherwise baroclinic (default).", &
default=.false.)
call get_param(param_file, mdl, "CNN_HALO_SIZE", CS%CNN_halo_size, &
"Halo size at each side of subdomains, depends on CNN architecture.", &
"Halo size at each side of subdomains, depends on CNN architecture.", &
units="nondim", default=10)

wd_halos(1) = CS%CNN_halo_size
Expand Down Expand Up @@ -207,10 +211,10 @@ subroutine CNN_inference(u, v, h, diffu, diffv, G, GV, FP_CS, SS_CS, CNN, python
! Combine arrays for CNN input
WH_uv = 0.0
do k=1,nztemp
do j=jsdw,jedw ; do i=isdw,iedw
do j=jsdw,jedw ; do i=isdw,iedw
WH_uv(1,i,j,k) = WH_u(i,j,k)
WH_uv(2,i,j,k) = WH_v(i,j,k)
enddo ; enddo
enddo ; enddo
enddo

! Run Python script for CNN inference
Expand All @@ -230,14 +234,14 @@ subroutine CNN_inference(u, v, h, diffu, diffv, G, GV, FP_CS, SS_CS, CNN, python
call cpu_clock_begin(CNN%id_cnn_post1)
Sx=0.0; Sy=0.0; Sxmean=0.0; Symean=0.0; Sxstd=0.0; Systd=0.0;
do k=1,nztemp
do j=js,je ; do i=is,ie
do j=js,je ; do i=is,ie
Sx(i,j,k) = Sxy(1,i,j,k)
Sy(i,j,k) = Sxy(2,i,j,k)
Sxmean(i,j,k) = Sxy(3,i,j,k)
Symean(i,j,k) = Sxy(4,i,j,k)
Sxstd(i,j,k) = Sxy(5,i,j,k)
Systd(i,j,k) = Sxy(6,i,j,k)
enddo ; enddo
enddo ; enddo
enddo
call cpu_clock_end(CNN%id_cnn_post1)

Expand All @@ -248,7 +252,7 @@ subroutine CNN_inference(u, v, h, diffu, diffv, G, GV, FP_CS, SS_CS, CNN, python
! call pass_var(Sxmean, G%Domain)
! call pass_var(Symean, G%Domain)
! call pass_var(Sxstd, G%Domain)
! call pass_var(Systd, G%Domain)
! call pass_var(Systd, G%Domain)
call create_group_pass(pass_CNN,Sx,G%Domain)
call create_group_pass(pass_CNN,Sy,G%Domain)
call create_group_pass(pass_CNN,Sxmean,G%Domain)
Expand All @@ -257,9 +261,9 @@ subroutine CNN_inference(u, v, h, diffu, diffv, G, GV, FP_CS, SS_CS, CNN, python
call create_group_pass(pass_CNN,Systd,G%Domain)
call do_group_pass(pass_CNN,G%Domain)
call cpu_clock_end(CNN%id_cnn_post2)

call cpu_clock_begin(CNN%id_cnn_post3)
fx = 0.0; fy = 0.0;
fx = 0.0; fy = 0.0;
do k=1,nz
do j=js,je ; do I=is-1,ie
if (CNN%CNN_BT) then
Expand Down Expand Up @@ -312,7 +316,7 @@ subroutine compute_energy_source(u, v, h, fx, fy, G, GV, CS)
real, dimension(SZI_(G),SZJB_(G),SZK_(GV)), &
intent(in) :: fy !< Meridional acceleration due to convergence
!! of along-coordinate stress tensor [L T-2 ~> m s-2]

real :: KE_term(SZI_(G),SZJ_(G),SZK_(GV)) ! A term in the kinetic energy budget
! [H L2 T-3 ~> m3 s-3 or W m-2]
real :: tmp(SZI_(G),SZJ_(G),SZK_(GV)) ! temporary array for integration
Expand All @@ -334,9 +338,9 @@ subroutine compute_energy_source(u, v, h, fx, fy, G, GV, CS)

is = G%isc ; ie = G%iec ; js = G%jsc ; je = G%jec ; nz = GV%ke
Isq = G%IscB ; Ieq = G%IecB ; Jsq = G%JscB ; Jeq = G%JecB

call create_group_pass(pass_KE_uv, KE_u, KE_v, G%Domain, To_North+To_East)

KE_term(:,:,:) = 0.
tmp(:,:,:) = 0.
! Calculate the KE source from Zanna-Bolton2020 [H L2 T-3 ~> m3 s-3].
Expand Down Expand Up @@ -367,5 +371,5 @@ subroutine compute_energy_source(u, v, h, fx, fy, G, GV, CS)
call post_data(CS%id_KE_CNN, KE_term, CS%diag)

end subroutine compute_energy_source

end module MOM_CNN_GZ21
16 changes: 9 additions & 7 deletions src/parameterizations/lateral/MOM_hor_visc.F90
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ module MOM_hor_visc
! This file is part of MOM6. See LICENSE.md for the license.
use MOM_checksums, only : hchksum, Bchksum, uvchksum
use MOM_coms, only : min_across_PEs
use MOM_database_comms, only : dbcomms_CS_type
use MOM_diag_mediator, only : post_data, register_diag_field, safe_alloc_ptr
use MOM_diag_mediator, only : post_product_u, post_product_sum_u
use MOM_diag_mediator, only : post_product_v, post_product_sum_v
Expand Down Expand Up @@ -188,7 +189,7 @@ module MOM_hor_visc
type(smartsim_python_interface) :: python !< Python interface object !Cheng
type(smartsim_python_interface) :: smartsim_python !< Python interface object !Cheng
type(CNN_CS) :: CNN !< Control structure for CNN !Cheng
logical :: use_hor_visc_python !< If true, use a python script to update
logical :: use_hor_visc_python !< If true, use a python script to update
!! the lateral viscous accelerations.
character(len=200) :: &
python_dir, & !< default = ".". The directory in which Python scripts are found.
Expand Down Expand Up @@ -1685,7 +1686,7 @@ subroutine horizontal_viscosity(u, v, h, diffu, diffv, MEKE, VarMix, G, GV, US,
if (CS%id_diffu_visc_rem > 0) call post_product_u(CS%id_diffu_visc_rem, diffu, ADp%visc_rem_u, G, nz, CS%diag)
if (CS%id_diffv_visc_rem > 0) call post_product_v(CS%id_diffv_visc_rem, diffv, ADp%visc_rem_v, G, nz, CS%diag)
endif

if (CS%use_hor_visc_python) call CNN_inference(u, v, h, diffu, diffv, G, GV, CS%python, CS%smartsim_python, &
CS%CNN, CS%python_bridge_lib) !Cheng

Expand All @@ -1694,7 +1695,7 @@ end subroutine horizontal_viscosity
!> Allocates space for and calculates static variables used by horizontal_viscosity().
!! hor_visc_init calculates and stores the values of a number of metric functions that
!! are used in horizontal_viscosity().
subroutine hor_visc_init(Time, G, GV, US, param_file, diag, CS, ADp)
subroutine hor_visc_init(Time, G, GV, US, param_file, diag, CS, dbcomms_CS, ADp)
type(time_type), intent(in) :: Time !< Current model time.
type(ocean_grid_type), intent(inout) :: G !< The ocean's grid structure.
type(verticalGrid_type), intent(in) :: GV !< The ocean's vertical grid structure
Expand All @@ -1703,6 +1704,7 @@ subroutine hor_visc_init(Time, G, GV, US, param_file, diag, CS, ADp)
!! parameters.
type(diag_ctrl), target, intent(inout) :: diag !< Structure to regulate diagnostic output.
type(hor_visc_CS), intent(inout) :: CS !< Horizontal viscosity control struct
type(dbcomms_CS_type), target, intent(in) :: dbcomms_CS !< Control structure of database communication client
type(accel_diag_ptrs), intent(in), optional :: ADp !< Acceleration diagnostics

real, dimension(SZIB_(G),SZJ_(G)) :: u0u, u0v
Expand Down Expand Up @@ -2401,19 +2403,19 @@ subroutine hor_visc_init(Time, G, GV, US, param_file, diag, CS, ADp)
" 'forpy': Forpy library\n"// &
" 'smartsim': smartsim library", default='smartsim')
CS%python_bridge_lib = trim(CS%python_bridge_lib)

if (CS%use_hor_visc_python) then !Cheng
select case (lowercase(CS%python_bridge_lib))
! case("forpy")
! call forpy_run_python_init(CS%python,trim(CS%python_dir),trim(CS%python_file))
case("smartsim")
call smartsim_run_python_init(CS%smartsim_python,trim(CS%python_dir),trim(CS%python_file),param_file)
call smartsim_run_python_init(CS%smartsim_python,trim(CS%python_dir),trim(CS%python_file),param_file, dbcomms_CS)
case default
call MOM_error(FATAL, "Invalid library selected for language bridging")
end select
call CNN_init(Time, G, GV, US, param_file, diag, CS%CNN)
call CNN_init(Time, G, GV, US, param_file, diag, dbcomms_CS, CS%CNN)
endif


! Register fields for output from this module.
CS%id_normstress = register_diag_field('ocean_model', 'NoSt', diag%axesTL, Time, &
Expand Down

0 comments on commit 5c2bc72

Please sign in to comment.