diff --git a/.buildkite/pipeline.yml b/.buildkite/pipeline.yml index bb423fcef8..55c4079c43 100755 --- a/.buildkite/pipeline.yml +++ b/.buildkite/pipeline.yml @@ -13,6 +13,8 @@ env: agents: config: cpu queue: central + slurm_mem: 8G + modules: julia/1.8.5 ucx/1.13.1_cuda-11.2 cuda/11.2 openmpi/4.1.5_cuda-11.2 hdf5/1.12.2-ompi415 steps: - label: "init :computer:" @@ -28,6 +30,10 @@ steps: - "julia --project -e 'using Pkg; Pkg.precompile()'" - "julia --project -e 'using Pkg; Pkg.status()'" + - echo "--- Configure CUDA" + # force the initialization of the CUDA runtime as it is lazily loaded by default + - "julia --project -e 'using CUDA; CUDA.precompile_runtime()'" + - echo "--- Instantiate test" - "julia --project=test -e 'using Pkg; Pkg.develop(path = \".\")'" - "julia --project=test -e 'using Pkg; Pkg.instantiate(;verbose=true)'" @@ -61,6 +67,18 @@ steps: - label: ":computer: Ensure mse tables are reset when necessary" command: "julia --color=yes --project=examples regression_tests/test_reset.jl" + - group: "GPU" + steps: + - label: "Simple GPU test" + key: "gpu_example" + command: + - "julia --project -e 'using CUDA; CUDA.versioninfo()'" + - "julia --color=yes --project=examples examples/hybrid/driver.jl --enable_threading false --ode_algo SSP33ShuOsher --t_end 1mins --dt 1secs --dt_save_to_sol Inf --dt_save_to_disk Inf --job_id gpu_example --device CUDADevice" + artifact_paths: "gpu_example/*" + soft_fail: true + agents: + slurm_gres: "gpu:1" + - group: "Gravity wave" steps: diff --git a/examples/hybrid/driver.jl b/examples/hybrid/driver.jl index 840038ab49..e333cd8a6b 100644 --- a/examples/hybrid/driver.jl +++ b/examples/hybrid/driver.jl @@ -7,7 +7,7 @@ if !(@isdefined parsed_args) end if !(@isdefined comms_ctx) # Coupler compatibility - const comms_ctx = CA.get_comms_context(ClimaComms.CPUDevice()) + const comms_ctx = CA.get_comms_context(parsed_args) end const FT = parsed_args["FLOAT_TYPE"] == "Float64" ? Float64 : Float32 diff --git a/src/staggered_nonhydrostatic_model.jl b/src/staggered_nonhydrostatic_model.jl index 3ff3867d6d..af4c6f2b49 100644 --- a/src/staggered_nonhydrostatic_model.jl +++ b/src/staggered_nonhydrostatic_model.jl @@ -90,10 +90,8 @@ function default_cache( apply_limiter ? Limiters.QuasiMonotoneLimiter(similar(Y.c, FT)) : nothing - net_energy_flux_toa = [sum(similar(Y.f, Geometry.WVector{FT})) * 0] - net_energy_flux_toa[] = Geometry.WVector(FT(0)) - net_energy_flux_sfc = [sum(similar(Y.f, Geometry.WVector{FT})) * 0] - net_energy_flux_sfc[] = Geometry.WVector(FT(0)) + net_energy_flux_toa = [Geometry.WVector(FT(0))] + net_energy_flux_sfc = [Geometry.WVector(FT(0))] default_cache = (; simulation, diff --git a/src/utils/cli_options.jl b/src/utils/cli_options.jl index 32a27cbc93..de41989d66 100644 --- a/src/utils/cli_options.jl +++ b/src/utils/cli_options.jl @@ -339,6 +339,10 @@ function argparse_settings() "--orographic_gravity_wave" help = "Orographic drag on horizontal mean flow [`nothing` (default), `gfdl_restart`, `raw_topo`]" arg_type = String + "--device" + help = "Device type to use [`CPUDevice` (default), `CUDADevice`]" + arg_type = String + default = "CPUDevice" "--perf_summary" help = "Flag for collecting performance summary information" arg_type = Bool diff --git a/src/utils/common_spaces.jl b/src/utils/common_spaces.jl index e4fab35beb..80ff66692f 100644 --- a/src/utils/common_spaces.jl +++ b/src/utils/common_spaces.jl @@ -88,7 +88,9 @@ function make_hybrid_spaces( z_mesh = Meshes.IntervalMesh(z_domain, z_stretch; nelems = z_elem) @info "z heights" z_mesh.faces if surface_warp == nothing - z_topology = Topologies.IntervalTopology(z_mesh) + device = ClimaComms.device(h_space) + comms_ctx = ClimaComms.SingletonCommsContext(device) + z_topology = Topologies.IntervalTopology(comms_ctx, z_mesh) z_space = Spaces.CenterFiniteDifferenceSpace(z_topology) center_space = Spaces.ExtrudedFiniteDifferenceSpace(h_space, z_space) face_space = Spaces.FaceExtrudedFiniteDifferenceSpace(center_space) diff --git a/src/utils/type_getters.jl b/src/utils/type_getters.jl index 3da59bb881..cf5dc6e67b 100644 --- a/src/utils/type_getters.jl +++ b/src/utils/type_getters.jl @@ -626,7 +626,12 @@ function args_integrator(parsed_args, Y, p, tspan, ode_algo, callback) end import ClimaComms, Logging, NVTX -function get_comms_context(device = ClimaComms.CPUDevice()) +function get_comms_context(parsed_args) + device = if parsed_args["device"] == "CUDADevice" + ClimaComms.CUDADevice() + else + ClimaComms.CPUDevice() + end comms_ctx = ClimaComms.context(device) ClimaComms.init(comms_ctx) if ClimaComms.iamroot(comms_ctx)