diff --git a/src/ekp_interface.jl b/src/ekp_interface.jl index 7b778d84..a928cd8a 100644 --- a/src/ekp_interface.jl +++ b/src/ekp_interface.jl @@ -215,6 +215,45 @@ initialize(config::ExperimentConfig; kwargs...) = initialize( kwargs..., ) +function initialize( + ensemble_size, + observations, + prior, + output_dir; + rng_seed = 1234, + ekp_kwargs..., +) + Random.seed!(rng_seed) + rng_ekp = Random.MersenneTwister(rng_seed) + + initial_ensemble = + EKP.construct_initial_ensemble(rng_ekp, prior, ensemble_size) + eki = EKP.EnsembleKalmanProcess( + initial_ensemble, + observations, + EKP.Inversion(); + rng = rng_ekp, + failure_handler_method = EKP.SampleSuccGauss(), + ekp_kwargs..., + ) + + param_dict = get_param_dict(prior) + + save_parameter_ensemble( + EKP.get_u_final(eki), # constraints applied when saving + prior, + param_dict, + output_dir, + "parameters.toml", + 0, # Initial iteration = 0 + ) + + # Save the EKI object in the 'iteration_000' folder + eki_path = joinpath(output_dir, "iteration_000", "eki_file.jld2") + JLD2.save_object(eki_path, eki) + return eki +end + function initialize( ensemble_size, observations,