Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix JuliaBackend calibrate constructor #116

Merged
merged 1 commit into from
Oct 25, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "ClimaCalibrate"
uuid = "4347a170-ebd6-470c-89d3-5c705c0cacc2"
authors = ["Climate Modeling Alliance"]
version = "0.0.4"
version = "0.0.5"

[deps]
Distributed = "8ba89e20-285c-5b6f-9357-94700520ee1b"
Expand Down
27 changes: 20 additions & 7 deletions src/backends.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
using Distributed

import EnsembleKalmanProcesses as EKP

export get_backend, calibrate, model_run

abstract type AbstractBackend end
Expand Down Expand Up @@ -82,10 +84,15 @@
::Type{JuliaBackend},
config::ExperimentConfig;
reruns = 0,
ekp = nothing,
ekp_kwargs...,
)
(; n_iterations, output_dir, ensemble_size) = config
eki = initialize(config; ekp_kwargs...)
ekp = if ekp isa EKP.EnsembleKalmanProcess
initialize(ekp, prior, output_dir)

Check warning on line 92 in src/backends.jl

View check run for this annotation

Codecov / codecov/patch

src/backends.jl#L91-L92

Added lines #L91 - L92 were not covered by tests
else
initialize(config; ekp_kwargs...)

Check warning on line 94 in src/backends.jl

View check run for this annotation

Codecov / codecov/patch

src/backends.jl#L94

Added line #L94 was not covered by tests
end
on_error(e::InterruptException) = rethrow(e)
on_error(e) =
@error "Single ensemble member has errored. See stacktrace" exception =
Expand All @@ -101,14 +108,15 @@
terminate = update_ensemble(config, i)
!isnothing(terminate) && break
iter_path = path_to_iteration(output_dir, i + 1)
eki = JLD2.load_object(joinpath(iter_path, "eki_file.jld2"))
ekp = JLD2.load_object(joinpath(iter_path, "eki_file.jld2"))

Check warning on line 111 in src/backends.jl

View check run for this annotation

Codecov / codecov/patch

src/backends.jl#L111

Added line #L111 was not covered by tests
end
return eki
return ekp

Check warning on line 113 in src/backends.jl

View check run for this annotation

Codecov / codecov/patch

src/backends.jl#L113

Added line #L113 was not covered by tests
end

"""
calibrate(::Type{AbstractBackend}, config::ExperimentConfig; kwargs...)
calibrate(::Type{AbstractBackend}, experiment_dir; kwargs...)
calibrate(::Type{AbstractBackend}, ekp::EnsembleKalmanProcess, experiment_dir; kwargs...)

Run a full calibration, scheduling the forward model runs on Caltech's HPC cluster.

Expand Down Expand Up @@ -160,13 +168,18 @@
),
verbose = false,
reruns = 1,
ekp = nothing,
hpc_kwargs,
ekp_kwargs...,
)
(; n_iterations, output_dir, ensemble_size) = config
(; n_iterations, output_dir, prior, ensemble_size) = config

Check warning on line 175 in src/backends.jl

View check run for this annotation

Codecov / codecov/patch

src/backends.jl#L175

Added line #L175 was not covered by tests
@info "Initializing calibration" n_iterations ensemble_size output_dir

eki = initialize(config; ekp_kwargs...)
ekp = if ekp isa EKP.EnsembleKalmanProcess
initialize(ekp, prior, output_dir)

Check warning on line 179 in src/backends.jl

View check run for this annotation

Codecov / codecov/patch

src/backends.jl#L178-L179

Added lines #L178 - L179 were not covered by tests
else
initialize(config; ekp_kwargs...)

Check warning on line 181 in src/backends.jl

View check run for this annotation

Codecov / codecov/patch

src/backends.jl#L181

Added line #L181 was not covered by tests
end
module_load_str = module_load_string(b)
for i in 0:(n_iterations - 1)
@info "Iteration $i"
Expand Down Expand Up @@ -201,9 +214,9 @@
terminate = update_ensemble(config, i)
!isnothing(terminate) && break
iter_path = path_to_iteration(output_dir, i + 1)
eki = JLD2.load_object(joinpath(iter_path, "eki_file.jld2"))
ekp = JLD2.load_object(joinpath(iter_path, "eki_file.jld2"))

Check warning on line 217 in src/backends.jl

View check run for this annotation

Codecov / codecov/patch

src/backends.jl#L217

Added line #L217 was not covered by tests
end
return eki
return ekp

Check warning on line 219 in src/backends.jl

View check run for this annotation

Codecov / codecov/patch

src/backends.jl#L219

Added line #L219 was not covered by tests
end

# Dispatch on backend type to unify `calibrate` for all HPCBackends
Expand Down
3 changes: 2 additions & 1 deletion src/ekp_interface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -268,10 +268,11 @@
initial_ensemble =
EKP.construct_initial_ensemble(rng_ekp, prior, ensemble_size)

ekp_str_kwargs = Dict([string(k) => v for (k, v) in ekp_kwargs])

Check warning on line 271 in src/ekp_interface.jl

View check run for this annotation

Codecov / codecov/patch

src/ekp_interface.jl#L271

Added line #L271 was not covered by tests
eki_constructor =
(args...) -> EKP.EnsembleKalmanProcess(
args...,
Dict(EKP.default_options_dict(EKP.Inversion())..., ekp_kwargs...);
merge(EKP.default_options_dict(EKP.Inversion()), ekp_str_kwargs);
rng = rng_ekp,
)

Expand Down
4 changes: 3 additions & 1 deletion test/Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
[deps]
Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595"
CalibrateEmulateSample = "95e48a1f-0bec-4818-9538-3db4340308e3"
ClimaCalibrate = "4347a170-ebd6-470c-89d3-5c705c0cacc2"
ClimaParams = "5c42b081-d73a-476f-9059-fd94b934656c"
Conda = "8f4d0f93-b110-5947-807f-2305c1781a2d"
Expand All @@ -12,3 +11,6 @@ Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

[compat]
EnsembleKalmanProcesses = "2"
27 changes: 23 additions & 4 deletions test/ekp_interface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,19 @@ config = CAL.ExperimentConfig(
output_dir,
)

eki = CAL.initialize(config)

user_initial_ensemble =
EKP.construct_initial_ensemble(rng_ekp, prior, ensemble_size)
user_constructed_eki = EKP.EnsembleKalmanProcess(
user_initial_ensemble,
observations,
noise,
EKP.Inversion(),
EKP.default_options_dict(EKP.Inversion());
rng = rng_ekp,
)

eki = CAL.initialize(config; rng_seed)
eki_with_kwargs = CAL.initialize(
config;
scheduler = EKP.MutableScheduler(2),
Expand All @@ -46,6 +58,14 @@ eki_with_kwargs = CAL.initialize(
@test eki_with_kwargs.accelerator isa EKP.NesterovAccelerator
end

@testset "Test that a user-constructed EKP obj is same as initialized one" begin
for prop in propertynames(eki)
prop in [:u, :accelerator, :localizer] && continue
@test getproperty(eki, prop) == getproperty(user_constructed_eki, prop)
end
@test eki.u[1].stored_data == user_constructed_eki.u[1].stored_data
end

override_file = joinpath(
config.output_dir,
"iteration_000",
Expand Down Expand Up @@ -80,11 +100,10 @@ end
joinpath(output_dir, "iteration_000", "member_001", "parameters.toml")
td = CP.create_toml_dict(FT; override_file)
params = CP.get_parameter_values(td, param_names)
@test params.one == 2.513110562120818
@test params.two == 4.614950047803855
@test params.one == 3.1313341622997677
@test params.two == 5.063035177034372
end


@testset "Environment variables" begin
@test_throws ErrorException(
"Experiment dir not found in environment. Ensure that env variable \"CALIBRATION_EXPERIMENT_DIR\" is set.",
Expand Down
2 changes: 1 addition & 1 deletion test/pure_julia_e2e.jl
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ ekp = calibrate(JuliaBackend, experiment_config)
parameter_values =
[EKP.get_ϕ_mean(prior, ekp, it) for it in 1:(n_iterations + 1)]
@test parameter_values[1][1] ≈ 8.507 rtol = 0.01
@test parameter_values[end][1] ≈ 19.0124 rtol = 0.01
@test parameter_values[end][1] ≈ 11.852161842745355 rtol = 0.01
end

rm(output_dir; recursive = true)
3 changes: 2 additions & 1 deletion test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ using Test

include("ekp_interface.jl")
include("model_interface.jl")
include("emulate_sample.jl")
# Disabled since we use EKP 2.0 in testing, CES is still incompatible with EKP 2.0
# include("emulate_sample.jl")
include("pure_julia_e2e.jl")
include("aqua.jl")
Loading