Skip to content

Commit

Permalink
Merge pull request #2440 from CliMA/ck/gpu_compat
Browse files Browse the repository at this point in the history
Assert atmos to not have UnionAll types
  • Loading branch information
charleskawczynski authored Dec 18, 2023
2 parents 84b1a0e + 50aae95 commit 496cd9d
Show file tree
Hide file tree
Showing 6 changed files with 41 additions and 5 deletions.
1 change: 1 addition & 0 deletions .buildkite/pipeline.yml
Original file line number Diff line number Diff line change
Expand Up @@ -751,6 +751,7 @@ steps:
artifact_paths: "target_gpu_implicit_baroclinic_wave/*"
agents:
slurm_gpus: 1
slurm_mem: 32G

- label: "GPU: GPU dry baroclinic wave - 4 gpus"
key: "target_gpu_implicit_baroclinic_wave_4process"
Expand Down
1 change: 1 addition & 0 deletions src/ClimaAtmos.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ include(joinpath("parameters", "Parameters.jl"))
import .Parameters as CAP

include(joinpath("utils", "abbreviations.jl"))
include(joinpath("utils", "gpu_compat.jl"))
include(joinpath("utils", "common_spaces.jl"))
include(joinpath("solver", "types.jl"))
include(joinpath("solver", "cli_options.jl"))
Expand Down
8 changes: 5 additions & 3 deletions src/cache/cache.jl
Original file line number Diff line number Diff line change
Expand Up @@ -137,9 +137,11 @@ function build_cache(Y, atmos, params, surface_setup, dt, t_end, start_date)
net_energy_flux_toa = [Geometry.WVector(FT(0))]
net_energy_flux_sfc = [Geometry.WVector(FT(0))]

limiter =
isnothing(atmos.numerics.limiter) ? nothing :
atmos.numerics.limiter(similar(Y.c, FT))
limiter = if isnothing(atmos.numerics.limiter)
nothing
elseif atmos.numerics.limiter isa QuasiMonotoneLimiter
Limiters.QuasiMonotoneLimiter(similar(Y.c, FT))
end

numerics = (; limiter)

Expand Down
4 changes: 2 additions & 2 deletions src/solver/type_getters.jl
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,7 @@ function get_atmos(config::AtmosConfig, params)
surface_model = get_surface_model(parsed_args),
numerics = get_numerics(parsed_args),
)
@assert !@any_reltype(atmos, (UnionAll, DataType))

@info "AtmosModel: \n$(summary(atmos))"
return atmos
Expand All @@ -96,8 +97,7 @@ function get_numerics(parsed_args)
edmfx_sgsflux_upwinding =
Val(Symbol(parsed_args["edmfx_sgsflux_upwinding"]))

limiter =
parsed_args["apply_limiter"] ? Limiters.QuasiMonotoneLimiter : nothing
limiter = parsed_args["apply_limiter"] ? CA.QuasiMonotoneLimiter() : nothing

# wrap each upwinding mode in a Val for dispatch
numerics = AtmosNumerics(;
Expand Down
2 changes: 2 additions & 0 deletions src/solver/types.jl
Original file line number Diff line number Diff line change
Expand Up @@ -270,6 +270,8 @@ abstract type AbstractTimesteppingMode end
struct Explicit <: AbstractTimesteppingMode end
struct Implicit <: AbstractTimesteppingMode end

struct QuasiMonotoneLimiter end # For dispatching to use the ClimaCore QuasiMonotoneLimiter.

Base.@kwdef struct AtmosNumerics{
EN_UP,
TR_UP,
Expand Down
30 changes: 30 additions & 0 deletions src/utils/gpu_compat.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
"""
@any_reltype(::Any, t::Tuple, warn=true)
Returns a Bool (and prints warnings) if the given
data structure has an instance of any types in `t`.
"""
function any_reltype(found, obj, name, ets, pc = (); warn = true)
for pn in propertynames(obj)
prop = getproperty(obj, pn)
pc_full = (pc..., ".", pn)
pc_string = name * string(join(pc_full))
for et in ets
if prop isa et
warn && @warn "$pc_string::$(typeof(prop)) is a DataType"
found = true
end
end
found = found || any_reltype(found, prop, name, ets, pc_full; warn)
end
return found
end
macro any_reltype(obj, ets, warn = true)
return :(any_reltype(
false,
$(esc(obj)),
$(string(obj)),
$(esc(ets));
warn = $(esc(warn)),
))
end

0 comments on commit 496cd9d

Please sign in to comment.