Skip to content

Commit

Permalink
nicer experiment config
Browse files Browse the repository at this point in the history
  • Loading branch information
nefrathenrici committed Apr 2, 2024
1 parent eb3e144 commit 155be7c
Showing 1 changed file with 34 additions and 40 deletions.
74 changes: 34 additions & 40 deletions src/ekp_interface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,17 +10,20 @@ import ClimaComms
export ExperimentConfig

struct ExperimentConfig
id
n_iterations
ensemble_size
observations
noise
prior
output_dir
id::AbstractString
n_iterations::Integer
ensemble_size::Integer
observations::Any
noise::Any
prior::Any
output_dir::Any
end

"""
ExperimentConfig(experiment_id;
ExperimentConfig(experiment_id)
ExperimentConfig(
experiment_id,
n_iterations,
ensemble_size,
observations,
Expand All @@ -29,38 +32,38 @@ end
output_dir,
)
ExperimentConfig constructor. If an individual keyword argument is not given,
default is obtained from the YAML at `get_ekp_yaml(experiment_id)`.
ExperimentConfig stores the configuration for a specific experiment.
If only the `experiment_id` is passed, the config will be loaded from
`experiments/experiment_id/ekp_config.yml`
"""
function ExperimentConfig(experiment_id; kwargs...)
# config_dict is a Dict read from a YAML file
config_dict = get_ekp_yaml(experiment_id)
function ExperimentConfig(experiment_id)
config_yaml = joinpath("experiments", experiment_id, "ekp_config.yml")
config_dict = isfile(config_yaml) ? Dict() : YAML.load_file(config_yaml)

required_fields = ["n_iterations",
"ensemble_size",
"prior_path",
"observations",
"noise"]
@assert issubset(required_fields, keys(config_dict))

default_output =
haskey(ENV, "CI") ? experiment_id : joinpath("output", experiment_id)
config_dict["output_dir"] = get(config_dict, "output_dir", default_output)
output_dir = get(config_dict, "output_dir", default_output)

for key in ["observations", "noise"]
if haskey(config_dict, key)
config_dict[key] = JLD2.load_object(config_dict[key])
end
end

if haskey(config_dict, "prior_path")
config_dict["prior"] = get_prior(config_dict["prior_path"])
end
observations = JLD2.load_object(config_dict["observations"])
noise = JLD2.load_object(config_dict["noise"])

config_kwargs = Dict(Symbol(key) => config_dict[key] for key in keys(config_dict))
merged_kwargs = merge(config_kwargs, kwargs)
prior = get_prior(config_dict["prior_path"])

return ExperimentConfig(
experiment_id,
merged_kwargs[:n_iterations],
merged_kwargs[:ensemble_size],
merged_kwargs[:observations],
merged_kwargs[:noise],
merged_kwargs[:prior],
merged_kwargs[:output_dir],
config_dict["n_iterations"],
config_dict["ensemble_size"],
observations,
noise,
prior,
output_dir,
)
end

Expand Down Expand Up @@ -90,15 +93,6 @@ function get_prior(param_dict::AbstractDict; names = nothing)
return prior
end

"""
get_ekp_yaml(experiment_id)
Load the EKP configuration for a given `experiment_id`. If no file is found, return an empty Dict()
"""
function get_ekp_yaml(experiment_id)
config_yaml = joinpath("experiments", experiment_id, "ekp_config.yml")
return isfile(config_yaml) ? Dict() : YAML.load_file(config_yaml)
end
"""
save_G_ensemble(experiment_id, iteration, G_ensemble)
Expand Down

0 comments on commit 155be7c

Please sign in to comment.