Skip to content

Commit

Permalink
Add ekp kwarg to pass EKP struct into calibrate
Browse files Browse the repository at this point in the history
  • Loading branch information
nefrathenrici committed Oct 25, 2024
1 parent 5dd782a commit 05e0c2f
Show file tree
Hide file tree
Showing 7 changed files with 57 additions and 16 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "ClimaCalibrate"
uuid = "4347a170-ebd6-470c-89d3-5c705c0cacc2"
authors = ["Climate Modeling Alliance"]
version = "0.0.4"
version = "0.0.5"

[deps]
Distributed = "8ba89e20-285c-5b6f-9357-94700520ee1b"
Expand Down
27 changes: 20 additions & 7 deletions src/backends.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
using Distributed

import EnsembleKalmanProcesses as EKP

export get_backend, calibrate, model_run

abstract type AbstractBackend end
Expand Down Expand Up @@ -82,10 +84,15 @@ function calibrate(
::Type{JuliaBackend},
config::ExperimentConfig;
reruns = 0,
ekp = nothing,
ekp_kwargs...,
)
(; n_iterations, output_dir, ensemble_size) = config
eki = initialize(config; ekp_kwargs...)
ekp = if ekp isa EKP.EnsembleKalmanProcess
initialize(ekp, prior, output_dir)
else
initialize(config; ekp_kwargs...)
end
on_error(e::InterruptException) = rethrow(e)
on_error(e) =
@error "Single ensemble member has errored. See stacktrace" exception =
Expand All @@ -101,14 +108,15 @@ function calibrate(
terminate = update_ensemble(config, i)
!isnothing(terminate) && break
iter_path = path_to_iteration(output_dir, i + 1)
eki = JLD2.load_object(joinpath(iter_path, "eki_file.jld2"))
ekp = JLD2.load_object(joinpath(iter_path, "eki_file.jld2"))
end
return eki
return ekp
end

"""
calibrate(::Type{AbstractBackend}, config::ExperimentConfig; kwargs...)
calibrate(::Type{AbstractBackend}, experiment_dir; kwargs...)
calibrate(::Type{AbstractBackend}, ekp::EnsembleKalmanProcess, experiment_dir; kwargs...)
Run a full calibration, scheduling the forward model runs on Caltech's HPC cluster.
Expand Down Expand Up @@ -160,13 +168,18 @@ function calibrate(
),
verbose = false,
reruns = 1,
ekp = nothing,
hpc_kwargs,
ekp_kwargs...,
)
(; n_iterations, output_dir, ensemble_size) = config
(; n_iterations, output_dir, prior, ensemble_size) = config
@info "Initializing calibration" n_iterations ensemble_size output_dir

eki = initialize(config; ekp_kwargs...)
ekp = if ekp isa EKP.EnsembleKalmanProcess
initialize(ekp, prior, output_dir)
else
initialize(config; ekp_kwargs...)
end
module_load_str = module_load_string(b)
for i in 0:(n_iterations - 1)
@info "Iteration $i"
Expand Down Expand Up @@ -201,9 +214,9 @@ function calibrate(
terminate = update_ensemble(config, i)
!isnothing(terminate) && break
iter_path = path_to_iteration(output_dir, i + 1)
eki = JLD2.load_object(joinpath(iter_path, "eki_file.jld2"))
ekp = JLD2.load_object(joinpath(iter_path, "eki_file.jld2"))
end
return eki
return ekp
end

# Dispatch on backend type to unify `calibrate` for all HPCBackends
Expand Down
5 changes: 4 additions & 1 deletion src/ekp_interface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -263,15 +263,18 @@ function _initialize(
rng_seed,
ekp_kwargs...,
)
@show Dict(ekp_kwargs)
@show
Random.seed!(rng_seed)
rng_ekp = Random.MersenneTwister(rng_seed)
initial_ensemble =
EKP.construct_initial_ensemble(rng_ekp, prior, ensemble_size)

ekp_str_kwargs = Dict([string(k) => v for (k, v) in ekp_kwargs])
eki_constructor =
(args...) -> EKP.EnsembleKalmanProcess(
args...,
Dict(EKP.default_options_dict(EKP.Inversion())..., ekp_kwargs...);
merge(EKP.default_options_dict(EKP.Inversion()), ekp_str_kwargs);
rng = rng_ekp,
)

Expand Down
4 changes: 3 additions & 1 deletion test/Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
[deps]
Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595"
CalibrateEmulateSample = "95e48a1f-0bec-4818-9538-3db4340308e3"
ClimaCalibrate = "4347a170-ebd6-470c-89d3-5c705c0cacc2"
ClimaParams = "5c42b081-d73a-476f-9059-fd94b934656c"
Conda = "8f4d0f93-b110-5947-807f-2305c1781a2d"
Expand All @@ -12,3 +11,6 @@ Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

[compat]
EnsembleKalmanProcesses = "2"
30 changes: 26 additions & 4 deletions test/ekp_interface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,19 @@ config = CAL.ExperimentConfig(
output_dir,
)

eki = CAL.initialize(config)

user_initial_ensemble =
EKP.construct_initial_ensemble(rng_ekp, prior, ensemble_size)
user_constructed_eki = EKP.EnsembleKalmanProcess(
user_initial_ensemble,
observations,
noise,
EKP.Inversion(),
EKP.default_options_dict(EKP.Inversion());
rng = rng_ekp,
)

eki = CAL.initialize(config; rng_seed)
eki_with_kwargs = CAL.initialize(
config;
scheduler = EKP.MutableScheduler(2),
Expand All @@ -46,6 +58,17 @@ eki_with_kwargs = CAL.initialize(
@test eki_with_kwargs.accelerator isa EKP.NesterovAccelerator
end

@testset "Test that a user-constructed EKP obj is same as initialized one" begin
for prop in propertynames(eki)
if prop in [:u, :accelerator, :localizer]
continue
end
@show prop
@test getproperty(eki, prop) == getproperty(user_constructed_eki, prop)
end
@test eki.u[1].stored_data == user_constructed_eki.u[1].stored_data
end

override_file = joinpath(
config.output_dir,
"iteration_000",
Expand Down Expand Up @@ -80,11 +103,10 @@ end
joinpath(output_dir, "iteration_000", "member_001", "parameters.toml")
td = CP.create_toml_dict(FT; override_file)
params = CP.get_parameter_values(td, param_names)
@test params.one == 2.513110562120818
@test params.two == 4.614950047803855
@test params.one == 3.1313341622997677
@test params.two == 5.063035177034372
end


@testset "Environment variables" begin
@test_throws ErrorException(
"Experiment dir not found in environment. Ensure that env variable \"CALIBRATION_EXPERIMENT_DIR\" is set.",
Expand Down
2 changes: 1 addition & 1 deletion test/pure_julia_e2e.jl
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ ekp = calibrate(JuliaBackend, experiment_config)
parameter_values =
[EKP.get_ϕ_mean(prior, ekp, it) for it in 1:(n_iterations + 1)]
@test parameter_values[1][1] 8.507 rtol = 0.01
@test parameter_values[end][1] 19.0124 rtol = 0.01
@test parameter_values[end][1] 11.852161842745355 rtol = 0.01
end

rm(output_dir; recursive = true)
3 changes: 2 additions & 1 deletion test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ using Test

include("ekp_interface.jl")
include("model_interface.jl")
include("emulate_sample.jl")
# Disabled since we use EKP 2.0 in testing, CES is still incompatible with EKP 2.0
# include("emulate_sample.jl")
include("pure_julia_e2e.jl")
include("aqua.jl")

0 comments on commit 05e0c2f

Please sign in to comment.