Skip to content

Commit

Permalink
WIP
Browse files Browse the repository at this point in the history
  • Loading branch information
AlexisRenchon committed Sep 3, 2024
1 parent 2222299 commit 2f41961
Show file tree
Hide file tree
Showing 4 changed files with 134 additions and 0 deletions.
15 changes: 15 additions & 0 deletions experiments/calibration/experiment_config.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
# prior= prior_prognostic.toml
ensemble_size= 50
n_iterations= 5
batch_size= 1
# model_config = model_config_prognostic.yml
# output_dir = output/gcm_driven_scm
y_var_names= [g1] # calibration variables
# z_max = 4000 # [m]
# dims = 90 # num vertical levels below z_max x num variables x batch size
eki_timestep= 0.001 # timestep of eki
const_noise= 0.05 # constant noise (diagonal elements of noise cov matrix Γ)
t_start_sec = 0 # start time
t_end_sec = 60 * 60 * 24 * 7 # end time, 7 days
# g_t_start_sec = 216000.0 # start time of SCM averaging window [s] = 2.5 days
# g_t_end_sec = 259200.0 # end time of SCM averaging window [s] = 3 days (SCM length = 3 days)
2 changes: 2 additions & 0 deletions experiments/calibration/model_interface.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
# This is where we need run_fluxnet ?
# as a function that takes some config to run
3 changes: 3 additions & 0 deletions experiments/calibration/prior_prognostic.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
[g1]
prior = "constrained_gaussian(g1, 0.1, 0.03, 0, Inf)"
type = "float"
114 changes: 114 additions & 0 deletions experiments/calibration/run_calibration.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@
# note: use the .buildkite environment
import ClimaCalibrate as CAL
import ClimaLand as CL
import EnsembleKalmanProcesses as EKP
import JLD2 # JLD2 saves and loads Julia data structures in a format comprising a subset of HDF5
using LinearAlgebra

# include("helper_funcs.jl") # seems ClimaAtmos specific, don't think we need it
include("observation_map.jl") # this we probably need, but we is it not internal to ClimaCalibrate?
# include("get_les_metadata.jl") # seems ClimaAtmos specific as well, don't need

experiment_dir = dirname(Base.active_project())

include("experiment_config.jl") # why use YAML? Let's just include a Julia file

# Get prior
# ??? not sure what prior diagnostic and prior prognostic are ???
# prior = prior_prognostic.toml

# parameter we want to calibrate prior?
norm_factors_dict = Dict(
"g1" => [306.172, 8.07383], # need to update number
)

if !isdir(output_dir)
mkpath(output_dir)
end

JLD2.jldsave(
joinpath(output_dir, "norm_factors.jld2");
norm_factors_dict = norm_factors_dict,
)

ref_paths, _ = get_les_calibration_library()
obs_vec = []

# ?? Distribution of priors ??
for ref_path in ref_paths

y_obs, Σ_obs, norm_vec_obs = get_obs(
ref_path,
experiment_config["y_var_names"],
zc_model;
ti = experiment_config["y_t_start_sec"],
tf = experiment_config["y_t_end_sec"],
norm_factors_dict = norm_factors_dict,
Σ_const = const_noise,
)

push!(
obs_vec,
EKP.Observation(
Dict(
"samples" => y_obs,
"covariances" => Σ_obs,
"names" => split(ref_path, "/")[end],
),
),
)
end

series_names = [ref_paths[i] for i in 1:length(ref_paths)]

# minibatcher = sampling of observations
rfs_minibatcher =
EKP.RandomFixedSizeMinibatcher(experiment_config["batch_size"])
observations = EKP.ObservationSeries(obs_vec, rfs_minibatcher, series_names)

# What is initialize exactly?
CAL.initialize(
ensemble_size, # ? not sure what this is
observations, # The "truth data" we calibrate on
prior, # Prior distribution of the parameters we want to calibrate
output_dir;
scheduler = EKP.DefaultScheduler(eki_timestep), # ?
)

eki = nothing
hpc_kwargs = CAL.kwargs(time = 60, mem = "16G")
module_load_str = CAL.module_load_string(CAL.CaltechHPCBackend)
for iter in 0:(n_iterations - 1)
@info "Iteration $iter"
jobids = map(1:ensemble_size) do member
@info "Running ensemble member $member"
CAL.slurm_model_run(
iter,
member,
output_dir,
experiment_dir,
model_interface,
module_load_str;
hpc_kwargs,
)
end

statuses = CAL.wait_for_jobs(
jobids,
output_dir,
iter,
experiment_dir,
model_interface,
module_load_str;
hpc_kwargs,
verbose = false,
reruns = 0,
)
CAL.report_iteration_status(statuses, output_dir, iter)
@info "Completed iteration $iter, updating ensemble"
G_ensemble = CAL.observation_map(iter)
CAL.save_G_ensemble(output_dir, iter, G_ensemble)
eki = CAL.update_ensemble(output_dir, iter, prior)
end


0 comments on commit 2f41961

Please sign in to comment.