Skip to content

Commit

Permalink
fix e2e test
Browse files Browse the repository at this point in the history
  • Loading branch information
nefrathenrici committed Jun 20, 2024
1 parent b917ef7 commit 0039c9b
Showing 1 changed file with 12 additions and 19 deletions.
31 changes: 12 additions & 19 deletions calibration/test/e2e_test.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ import EnsembleKalmanProcesses as EKP
import Statistics: var, mean
using Test

experiment_dir = joinpath(pkgdir(CA), "test", "calibration")
experiment_dir = joinpath(pkgdir(CA), "calibration", "test")
model_interface = joinpath(pkgdir(CA), "calibration", "model_interface.jl")
include(model_interface)

Expand All @@ -34,8 +34,9 @@ end

function process_member_data(simdir::SimDir)
isempty(simdir.vars) && return NaN
rsut = get(simdir; short_name = "rsut", reduction = "average", period = "30d")
return slice(Cl(rsut); time = 30).data
rsut =
get(simdir; short_name = "rsut", reduction = "average", period = "30d")
return slice(average_xy(rsut); time = 30).data
end

# Generate observations
Expand All @@ -56,7 +57,7 @@ noise = 0.1 * I
n_iterations = 3
ensemble_size = 10
prior = CAL.get_prior(joinpath(experiment_dir, "prior.toml"))
output_dir = joinpath("output", "single_column_held_suarez_rhoe_equilmoist")
output_dir = "calibration_end_to_end_test"
experiment_config = CAL.ExperimentConfig(;
n_iterations,
ensemble_size,
Expand Down Expand Up @@ -85,18 +86,19 @@ for i in 0:(n_iterations - 1)
end
G_ensemble = CAL.observation_map(i)
CAL.save_G_ensemble(experiment_config, i, G_ensemble)
eki = CAL.update_ensemble(experiment_config, i)
global eki = CAL.update_ensemble(experiment_config, i)
end

@testset "Pure Julia Calibration" begin
@testset "Pure Julia Calibration" begin
minimal_eki_test(eki)
end

backend = CAL.get_backend()
if backend == CAL.CaltechHPC
slurm_kwargs = CAL.kwargs(time = 5)
slurm_eki = CAL.calibrate(backend, experiment_config; slurm_kwargs)
@testset "Caltech HPC Calibration" begin
slurm_eki =
CAL.calibrate(backend, experiment_config; slurm_kwargs, model_interface)
@testset "Caltech HPC Calibration" begin
minimal_eki_test(slurm_eki)
end

Expand All @@ -120,11 +122,7 @@ end
# Add plots to help debug
function scatter_plot(eki::EKP.EnsembleKalmanProcess)
f = CairoMakie.Figure(resolution = (800, 600))
ax = CairoMakie.Axis(
f[1, 1],
ylabel = "Parameter Value",
xlabel = "G",
)
ax = CairoMakie.Axis(f[1, 1], ylabel = "Parameter Value", xlabel = "G")

g = vec.(EKP.get_g(eki; return_array = true))
phi = map(x -> abs.(x), vec.((EKP.get_ϕ(prior, eki))))
Expand All @@ -142,11 +140,7 @@ scatter_plot(eki)

function phi_versus_iter_plot(eki)
f = CairoMakie.Figure(resolution = (800, 600))
ax = CairoMakie.Axis(
f[1, 1],
ylabel = "Phi",
xlabel = "Iteration",
)
ax = CairoMakie.Axis(f[1, 1], ylabel = "Phi", xlabel = "Iteration")
phi = EKP.get_ϕ(prior, eki)
for (i, phi_) in enumerate(phi)
CairoMakie.scatter!(ax, fill(i, length(phi_)), vec(phi_))
Expand All @@ -160,4 +154,3 @@ function phi_versus_iter_plot(eki)
end

phi_versus_iter_plot(eki)

0 comments on commit 0039c9b

Please sign in to comment.