Skip to content

Commit

Permalink
Add simple gpu example
Browse files Browse the repository at this point in the history
Fixes
  • Loading branch information
charleskawczynski committed May 25, 2023
1 parent e38923e commit 4ddf0c6
Show file tree
Hide file tree
Showing 6 changed files with 34 additions and 7 deletions.
18 changes: 18 additions & 0 deletions .buildkite/pipeline.yml
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@ env:
agents:
config: cpu
queue: central
slurm_mem: 8G
modules: julia/1.8.5 ucx/1.13.1_cuda-11.2 cuda/11.2 openmpi/4.1.5_cuda-11.2 hdf5/1.12.2-ompi415

steps:
- label: "init :computer:"
Expand All @@ -28,6 +30,10 @@ steps:
- "julia --project -e 'using Pkg; Pkg.precompile()'"
- "julia --project -e 'using Pkg; Pkg.status()'"

- echo "--- Configure CUDA"
# force the initialization of the CUDA runtime as it is lazily loaded by default
- "julia --project -e 'using CUDA; CUDA.precompile_runtime()'"

- echo "--- Instantiate test"
- "julia --project=test -e 'using Pkg; Pkg.develop(path = \".\")'"
- "julia --project=test -e 'using Pkg; Pkg.instantiate(;verbose=true)'"
Expand Down Expand Up @@ -61,6 +67,18 @@ steps:
- label: ":computer: Ensure mse tables are reset when necessary"
command: "julia --color=yes --project=examples regression_tests/test_reset.jl"

- group: "GPU"
steps:
- label: "Simple GPU test"
key: "gpu_example"
command:
- "julia --project -e 'using CUDA; CUDA.versioninfo()'"
- "julia --color=yes --project=examples examples/hybrid/driver.jl --enable_threading false --ode_algo SSP33ShuOsher --t_end 1mins --dt 1secs --dt_save_to_sol Inf --dt_save_to_disk Inf --job_id gpu_example --device CUDADevice"
artifact_paths: "gpu_example/*"
soft_fail: true
agents:
slurm_gres: "gpu:1"

- group: "Gravity wave"
steps:

Expand Down
2 changes: 1 addition & 1 deletion examples/hybrid/driver.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ if !(@isdefined parsed_args)
end

if !(@isdefined comms_ctx) # Coupler compatibility
const comms_ctx = CA.get_comms_context(ClimaComms.CPUDevice())
const comms_ctx = CA.get_comms_context(parsed_args)
end

const FT = parsed_args["FLOAT_TYPE"] == "Float64" ? Float64 : Float32
Expand Down
6 changes: 2 additions & 4 deletions src/staggered_nonhydrostatic_model.jl
Original file line number Diff line number Diff line change
Expand Up @@ -90,10 +90,8 @@ function default_cache(
apply_limiter ? Limiters.QuasiMonotoneLimiter(similar(Y.c, FT)) :
nothing

net_energy_flux_toa = [sum(similar(Y.f, Geometry.WVector{FT})) * 0]
net_energy_flux_toa[] = Geometry.WVector(FT(0))
net_energy_flux_sfc = [sum(similar(Y.f, Geometry.WVector{FT})) * 0]
net_energy_flux_sfc[] = Geometry.WVector(FT(0))
net_energy_flux_toa = [Geometry.WVector(FT(0))]
net_energy_flux_sfc = [Geometry.WVector(FT(0))]

default_cache = (;
simulation,
Expand Down
4 changes: 4 additions & 0 deletions src/utils/cli_options.jl
Original file line number Diff line number Diff line change
Expand Up @@ -339,6 +339,10 @@ function argparse_settings()
"--orographic_gravity_wave"
help = "Orographic drag on horizontal mean flow [`nothing` (default), `gfdl_restart`, `raw_topo`]"
arg_type = String
"--device"
help = "Device type to use [`CPUDevice` (default), `CUDADevice`]"
arg_type = String
default = "CPUDevice"
"--perf_summary"
help = "Flag for collecting performance summary information"
arg_type = Bool
Expand Down
4 changes: 3 additions & 1 deletion src/utils/common_spaces.jl
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,9 @@ function make_hybrid_spaces(
z_mesh = Meshes.IntervalMesh(z_domain, z_stretch; nelems = z_elem)
@info "z heights" z_mesh.faces
if surface_warp == nothing
z_topology = Topologies.IntervalTopology(z_mesh)
device = ClimaComms.device(h_space)
comms_ctx = ClimaComms.SingletonCommsContext(device)
z_topology = Topologies.IntervalTopology(comms_ctx, z_mesh)
z_space = Spaces.CenterFiniteDifferenceSpace(z_topology)
center_space = Spaces.ExtrudedFiniteDifferenceSpace(h_space, z_space)
face_space = Spaces.FaceExtrudedFiniteDifferenceSpace(center_space)
Expand Down
7 changes: 6 additions & 1 deletion src/utils/type_getters.jl
Original file line number Diff line number Diff line change
Expand Up @@ -626,7 +626,12 @@ function args_integrator(parsed_args, Y, p, tspan, ode_algo, callback)
end

import ClimaComms, Logging, NVTX
function get_comms_context(device = ClimaComms.CPUDevice())
function get_comms_context(parsed_args)
device = if parsed_args["device"] == "CUDADevice"
ClimaComms.CUDADevice()
else
ClimaComms.CPUDevice()
end
comms_ctx = ClimaComms.context(device)
ClimaComms.init(comms_ctx)
if ClimaComms.iamroot(comms_ctx)
Expand Down

0 comments on commit 4ddf0c6

Please sign in to comment.