diff --git a/Project.toml b/Project.toml index bbadcc44..fc805cc5 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "ClimaCalibrate" uuid = "4347a170-ebd6-470c-89d3-5c705c0cacc2" authors = ["Climate Modeling Alliance"] -version = "0.0.4" +version = "0.0.5" [deps] Distributed = "8ba89e20-285c-5b6f-9357-94700520ee1b" diff --git a/src/backends.jl b/src/backends.jl index 1d01d19b..fc1c90cc 100644 --- a/src/backends.jl +++ b/src/backends.jl @@ -1,5 +1,7 @@ using Distributed +import EnsembleKalmanProcesses as EKP + export get_backend, calibrate, model_run abstract type AbstractBackend end @@ -82,10 +84,15 @@ function calibrate( ::Type{JuliaBackend}, config::ExperimentConfig; reruns = 0, + ekp = nothing, ekp_kwargs..., ) (; n_iterations, output_dir, ensemble_size) = config - eki = initialize(config; ekp_kwargs...) + ekp = if ekp isa EKP.EnsembleKalmanProcess + initialize(ekp, prior, output_dir) + else + initialize(config; ekp_kwargs...) + end on_error(e::InterruptException) = rethrow(e) on_error(e) = @error "Single ensemble member has errored. See stacktrace" exception = @@ -101,14 +108,15 @@ function calibrate( terminate = update_ensemble(config, i) !isnothing(terminate) && break iter_path = path_to_iteration(output_dir, i + 1) - eki = JLD2.load_object(joinpath(iter_path, "eki_file.jld2")) + ekp = JLD2.load_object(joinpath(iter_path, "eki_file.jld2")) end - return eki + return ekp end """ calibrate(::Type{AbstractBackend}, config::ExperimentConfig; kwargs...) calibrate(::Type{AbstractBackend}, experiment_dir; kwargs...) + calibrate(::Type{AbstractBackend}, ekp::EnsembleKalmanProcess, experiment_dir; kwargs...) Run a full calibration, scheduling the forward model runs on Caltech's HPC cluster. @@ -160,13 +168,18 @@ function calibrate( ), verbose = false, reruns = 1, + ekp = nothing, hpc_kwargs, ekp_kwargs..., ) - (; n_iterations, output_dir, ensemble_size) = config + (; n_iterations, output_dir, prior, ensemble_size) = config @info "Initializing calibration" n_iterations ensemble_size output_dir - eki = initialize(config; ekp_kwargs...) + ekp = if ekp isa EKP.EnsembleKalmanProcess + initialize(ekp, prior, output_dir) + else + initialize(config; ekp_kwargs...) + end module_load_str = module_load_string(b) for i in 0:(n_iterations - 1) @info "Iteration $i" @@ -201,9 +214,9 @@ function calibrate( terminate = update_ensemble(config, i) !isnothing(terminate) && break iter_path = path_to_iteration(output_dir, i + 1) - eki = JLD2.load_object(joinpath(iter_path, "eki_file.jld2")) + ekp = JLD2.load_object(joinpath(iter_path, "eki_file.jld2")) end - return eki + return ekp end # Dispatch on backend type to unify `calibrate` for all HPCBackends diff --git a/src/ekp_interface.jl b/src/ekp_interface.jl index 889ab714..a9625893 100644 --- a/src/ekp_interface.jl +++ b/src/ekp_interface.jl @@ -268,10 +268,11 @@ function _initialize( initial_ensemble = EKP.construct_initial_ensemble(rng_ekp, prior, ensemble_size) + ekp_str_kwargs = Dict([string(k) => v for (k, v) in ekp_kwargs]) eki_constructor = (args...) -> EKP.EnsembleKalmanProcess( args..., - Dict(EKP.default_options_dict(EKP.Inversion())..., ekp_kwargs...); + merge(EKP.default_options_dict(EKP.Inversion()), ekp_str_kwargs); rng = rng_ekp, ) diff --git a/test/Project.toml b/test/Project.toml index 0b7abcc4..535d6ec8 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -1,6 +1,5 @@ [deps] Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595" -CalibrateEmulateSample = "95e48a1f-0bec-4818-9538-3db4340308e3" ClimaCalibrate = "4347a170-ebd6-470c-89d3-5c705c0cacc2" ClimaParams = "5c42b081-d73a-476f-9059-fd94b934656c" Conda = "8f4d0f93-b110-5947-807f-2305c1781a2d" @@ -12,3 +11,6 @@ Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f" Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" + +[compat] +EnsembleKalmanProcesses = "2" diff --git a/test/ekp_interface.jl b/test/ekp_interface.jl index 9c95785d..088696b7 100644 --- a/test/ekp_interface.jl +++ b/test/ekp_interface.jl @@ -31,7 +31,19 @@ config = CAL.ExperimentConfig( output_dir, ) -eki = CAL.initialize(config) + +user_initial_ensemble = + EKP.construct_initial_ensemble(rng_ekp, prior, ensemble_size) +user_constructed_eki = EKP.EnsembleKalmanProcess( + user_initial_ensemble, + observations, + noise, + EKP.Inversion(), + EKP.default_options_dict(EKP.Inversion()); + rng = rng_ekp, +) + +eki = CAL.initialize(config; rng_seed) eki_with_kwargs = CAL.initialize( config; scheduler = EKP.MutableScheduler(2), @@ -46,6 +58,14 @@ eki_with_kwargs = CAL.initialize( @test eki_with_kwargs.accelerator isa EKP.NesterovAccelerator end +@testset "Test that a user-constructed EKP obj is same as initialized one" begin + for prop in propertynames(eki) + prop in [:u, :accelerator, :localizer] && continue + @test getproperty(eki, prop) == getproperty(user_constructed_eki, prop) + end + @test eki.u[1].stored_data == user_constructed_eki.u[1].stored_data +end + override_file = joinpath( config.output_dir, "iteration_000", @@ -80,11 +100,10 @@ end joinpath(output_dir, "iteration_000", "member_001", "parameters.toml") td = CP.create_toml_dict(FT; override_file) params = CP.get_parameter_values(td, param_names) - @test params.one == 2.513110562120818 - @test params.two == 4.614950047803855 + @test params.one == 3.1313341622997677 + @test params.two == 5.063035177034372 end - @testset "Environment variables" begin @test_throws ErrorException( "Experiment dir not found in environment. Ensure that env variable \"CALIBRATION_EXPERIMENT_DIR\" is set.", diff --git a/test/pure_julia_e2e.jl b/test/pure_julia_e2e.jl index 1e330392..f2698a10 100644 --- a/test/pure_julia_e2e.jl +++ b/test/pure_julia_e2e.jl @@ -77,7 +77,7 @@ ekp = calibrate(JuliaBackend, experiment_config) parameter_values = [EKP.get_ϕ_mean(prior, ekp, it) for it in 1:(n_iterations + 1)] @test parameter_values[1][1] ≈ 8.507 rtol = 0.01 - @test parameter_values[end][1] ≈ 19.0124 rtol = 0.01 + @test parameter_values[end][1] ≈ 11.852161842745355 rtol = 0.01 end rm(output_dir; recursive = true) diff --git a/test/runtests.jl b/test/runtests.jl index abc513c5..1be7ad68 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -2,6 +2,7 @@ using Test include("ekp_interface.jl") include("model_interface.jl") -include("emulate_sample.jl") +# Disabled since we use EKP 2.0 in testing, CES is still incompatible with EKP 2.0 +# include("emulate_sample.jl") include("pure_julia_e2e.jl") include("aqua.jl")