diff --git a/src/ekp_interface.jl b/src/ekp_interface.jl index 7025183b..c6fc9ef8 100644 --- a/src/ekp_interface.jl +++ b/src/ekp_interface.jl @@ -190,12 +190,17 @@ end """ initialize(ensemble_size, observations, noise, prior, output_dir; kwargs...) initialize(ensemble_size, observations, prior, output_dir; kwargs...) + initialize(eki::EnsembleKalmanProcess, prior, output_dir) initialize(config::ExperimentConfig; kwargs...) initialize(filepath::AbstractString; kwargs...) Initialize the EnsembleKalmanProcess object and parameter files. +Can take in an existing EnsembleKalmanProcess which will be used to generate the + initial parameter ensemble. + Noise is optional when the observation is an EKP.ObservationSeries. + Additional kwargs may be passed through to the EnsembleKalmanProcess constructor. """ initialize(filepath::AbstractString; kwargs...) = @@ -244,6 +249,11 @@ initialize( ekp_kwargs..., ) +function initialize(eki::EKP.EnsembleKalmanProcess, prior, output_dir) + save_eki_state(eki, output_dir, 0, prior) + return eki +end + function _initialize( ensemble_size, observations, diff --git a/test/ekp_interface.jl b/test/ekp_interface.jl index b90a1b93..02beb55f 100644 --- a/test/ekp_interface.jl +++ b/test/ekp_interface.jl @@ -5,6 +5,11 @@ import ClimaCalibrate as CAL import ClimaParams as CP import LinearAlgebra: I using Test +import Random + +rng_seed = 1234 +Random.seed!(rng_seed) +rng_ekp = Random.MersenneTwister(rng_seed) FT = Float64 output_dir = "test_init" @@ -56,6 +61,30 @@ params = CP.get_parameter_values(td, param_names) @test params.two == 5.408386812503563 end +@testset "Test passing an EKP struct into `initialize`" begin + LHF_target = 4.0 + ensemble_size = 5 + N_iterations = 5 + Γ = 20.0 * EKP.I + output_dir = joinpath("test", "custom_ekp") + initial_ensemble = + EKP.construct_initial_ensemble(rng_ekp, prior, ensemble_size) + ensemble_kalman_process = EKP.EnsembleKalmanProcess( + initial_ensemble, + LHF_target, + Γ, + EKP.Inversion(), + ) + CAL.initialize(ensemble_kalman_process, prior, output_dir) + override_file = + joinpath(output_dir, "iteration_000", "member_001", "parameters.toml") + td = CP.create_toml_dict(FT; override_file) + params = CP.get_parameter_values(td, param_names) + @test params.one == 4.506555276137722 + @test params.two == 5.408386812503563 +end + + @testset "Environment variables" begin @test_throws ErrorException( "Experiment dir not found in environment. Ensure that env variable \"CALIBRATION_EXPERIMENT_DIR\" is set.",