Skip to content

Commit

Permalink
Add Emulate Sample functions (#51)
Browse files Browse the repository at this point in the history
Add emulate + sample functionality, docs, and tests
  • Loading branch information
nefrathenrici authored Feb 23, 2024
1 parent 598507d commit a87aef9
Show file tree
Hide file tree
Showing 12 changed files with 218 additions and 0 deletions.
1 change: 1 addition & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ authors = ["Climate Modeling Alliance"]
version = "0.1.0"

[deps]
CalibrateEmulateSample = "95e48a1f-0bec-4818-9538-3db4340308e3"
ClimaAtmos = "b2c96348-7fb7-4fe0-8da9-78d88439e717"
ClimaComms = "3a4d1b5c-c61d-41fd-a00a-5873ba7a1b0d"
ClimaCore = "d414da3d-4745-48bb-8d80-42e94e092884"
Expand Down
2 changes: 2 additions & 0 deletions docs/Project.toml
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
[deps]
CalibrateEmulateSample = "95e48a1f-0bec-4818-9538-3db4340308e3"
Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4"
DocumenterCitations = "daee34ce-89f3-4625-b898-19384cb65244"
EnsembleKalmanProcesses = "aa8a2aa5-91d8-4396-bcef-d4f2ec43552d"
1 change: 1 addition & 0 deletions docs/make.jl
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ makedocs(
"Home" => "index.md",
"Getting Started" => "quickstart.md",
"Experiment Setup Guide" => "experiment_setup_guide.md",
"Emulate and Sample" => "emulate_sample.md",
"API" => "api.md",
],
)
Expand Down
64 changes: 64 additions & 0 deletions docs/src/emulate_sample.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
# Emulate and Sample
Once you have run a successful calibration, we can fit an emulator to the resulting input/output pairs.

First, import the necessary packages:
```julia
using CalibrateEmulateSample.Emulators
using CalibrateEmulateSample.MarkovChainMonteCarlo

import EnsembleKalmanProcesses as EKP
using EnsembleKalmanProcesses.ParameterDistributions
using EnsembleKalmanProcesses.TOMLInterface

import JLD2
import CalibrateAtmos as CAL
```

Next, load in the data, EKP object, and prior distribution. These values are taken
from the perfect model experiment with experiment ID `sphere_held_suarez_rhoe_equilmoist`.
```julia
y_obs = [261.5493]
y_noise_cov = [0.02619;;]
ekp = JLD2.load_object(
joinpath(
pkgdir(CAL),
"docs",
"src",
"assets",
"eki_file_for_emulate_example.jld2",
),
)
init_params = [EKP.get_u_final(ekp)[1]]

prior_path = joinpath(
pkgdir(CAL),
"experiments",
"sphere_held_suarez_rhoe_equilmoist",
"prior.toml",
)

prior = CAL.get_prior(prior_path)
```
Get the input-output pairs which will be used to train the emulator.
The inputs are the parameter values, and the outputs are the result of the observation map.
In thise case, the outputs are the average air temperature at roughly 500 meters.
```julia
input_output_pairs = CAL.get_input_output_pairs(ekp)
```
Next, create the Gaussian Process-based emulator and Markov chain.
The samples from the chain can be used in future predictive model runs with the same configuration.
The posterior distribution can be saved to a JLD2 file using `save_posterior`. Samples can be extracted from the posterior using ClimaParams.
```julia
emulator = CAL.gp_emulator(input_output_pairs, y_noise_cov)
(; mcmc, chain) = CAL.sample(emulator, y_obs, prior, init_params)
constrained_posterior = CAL.save_posterior(mcmc, chain; filename = "samples.jld2")
```

Finally, you can plot the prior and posterior distributions to see results:
```julia
using Plots
plot(prior)
posterior = get_posterior(mcmc, chain)
plot!(posterior)
vline!([65.0])
```
1 change: 1 addition & 0 deletions src/CalibrateAtmos.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,5 +2,6 @@ module CalibrateAtmos

include("ekp_interface.jl")
include("atmos_interface.jl")
include("emulate_sample.jl")

end # module CalibrateAtmos
14 changes: 14 additions & 0 deletions src/ekp_interface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,20 @@ Returns the path to the iteration folder within `output_dir` for the given itera
path_to_iteration(output_dir, iteration) =
joinpath(output_dir, join(["iteration", lpad(iteration, 3, "0")], "_"))

"""
get_prior(prior_path; names = nothing)
Constructs the combined prior distribution from the TOML file at the `prior_path`.
If no parameter names are passed in, all parameters in the TOML are used in the distribution.
"""
function get_prior(prior_path; names = nothing)
param_dict = TOML.parsefile(prior_path)
names = isnothing(names) ? keys(param_dict) : names
prior_vec = [get_parameter_distribution(param_dict, n) for n in names]
prior = combine_distributions(prior_vec)
return prior
end

"""
initialize(
experiment_id;
Expand Down
87 changes: 87 additions & 0 deletions src/emulate_sample.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
import CalibrateEmulateSample as CES
using CalibrateEmulateSample.Emulators
using CalibrateEmulateSample.MarkovChainMonteCarlo

import EnsembleKalmanProcesses as EKP
using EnsembleKalmanProcesses.ParameterDistributions
using EnsembleKalmanProcesses.TOMLInterface

import JLD2

"""
get_input_output_pairs(ekp)
Helper function for getting the input/output pairs from an EKP object.
"""
function get_input_output_pairs(ekp; N_iterations = nothing)
N_iterations = isnothing(N_iterations) ? length(ekp.g) : N_iterations
input_output_pairs = CES.Utilities.get_training_points(ekp, N_iterations)
return input_output_pairs
end

"""
gp_emulator(input_output_pairs, obs_noise_cov)
Constructs a gaussian process emulator from the given input/output pairs and noise.
"""
function gp_emulator(input_output_pairs, obs_noise_cov)
gppackage = GPJL()
gauss_proc = GaussianProcess(gppackage, noise_learn = false)
emulator = Emulator(gauss_proc, input_output_pairs; obs_noise_cov)
optimize_hyperparameters!(emulator)
return emulator
end

"""
sample(
emulator,
y_obs,
prior,
init_params;
n_samples = 100_000,
init_stepsize = 0.1,
discard_initial = 0
)
Constructs a MarkovChainMonteCarlo object, optimizes its stepsize, and takes
`n_samples` number of samples.
The initial stepsize can be specified by `init_stepsize`,
and the number of initial samples to discard can be set by `discard_initial`.
Returns both the MCMC object and the samples in a NamedTuple.
"""
function sample(
emulator,
y_obs,
prior,
init_params;
n_samples = 100_000,
init_stepsize = 0.1,
discard_initial = 0,
)
mcmc = MCMCWrapper(RWMHSampling(), y_obs, prior, emulator; init_params)
new_step = optimize_stepsize(mcmc; init_stepsize, N = 2000, discard_initial)
chain = MarkovChainMonteCarlo.sample(
mcmc,
n_samples;
stepsize = new_step,
discard_initial = 0,
)
return (; mcmc, chain)
end

"""
save_posterior(mcmc, chain; filename = "samples.jld2")
Given an MCMC object, a Markov chain of samples, and a prior distribution,
constructs the posterior distribution and saves it to `filename`.
Returns the samples in constrained (physical) parameter space.
"""
function save_posterior(mcmc, chain; filename = "samples.jld2")
posterior = MarkovChainMonteCarlo.get_posterior(mcmc, chain)
constrained_posterior = transform_unconstrained_to_constrained(
posterior,
MarkovChainMonteCarlo.get_distribution(posterior),
)
JLD2.save_object(filename, posterior)
return constrained_posterior
end
4 changes: 4 additions & 0 deletions test/Project.toml
Original file line number Diff line number Diff line change
@@ -1,9 +1,13 @@
[deps]
CLIMAParameters = "6eacf6c3-8458-43b9-ae03-caf5306d3d53"
CalibrateAtmos = "4347a170-ebd6-470c-89d3-5c705c0cacc2"
CalibrateEmulateSample = "95e48a1f-0bec-4818-9538-3db4340308e3"
ClimaCore = "d414da3d-4745-48bb-8d80-42e94e092884"
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
EnsembleKalmanProcesses = "aa8a2aa5-91d8-4396-bcef-d4f2ec43552d"
JLD2 = "033835bb-8acc-5ee8-8aae-3f567f8a3819"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
4 changes: 4 additions & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,8 @@
using Test

import Random
Random.seed!(1234)

include("test_init.jl")
include("test_atmos_config.jl")
include("test_emulate_sample.jl")
Binary file added test/test_case_inputs/eki_test.jld2
Binary file not shown.
4 changes: 4 additions & 0 deletions test/test_case_inputs/sphere_hs_rhoe.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
["equator_pole_temperature_gradient_wet"]
prior = "Parameterized(Normal(4.779568,0.31223328))"
constraint = "[bounded_below(0)]"
type = "float"
36 changes: 36 additions & 0 deletions test/test_emulate_sample.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
import JLD2
import Statistics: mean

using CalibrateEmulateSample.Emulators
using CalibrateEmulateSample.MarkovChainMonteCarlo
import EnsembleKalmanProcesses as EKP
using EnsembleKalmanProcesses.ParameterDistributions
using EnsembleKalmanProcesses.TOMLInterface

import CalibrateAtmos as CAL

y_obs = [261.5493]
y_noise_cov = [0.02619;;]
ekp = JLD2.load_object(joinpath("test_case_inputs", "eki_test.jld2"))
init_params = [EKP.get_u_final(ekp)[1]]

prior_path = joinpath("test_case_inputs", "sphere_hs_rhoe.toml")

prior = CAL.get_prior(prior_path)

input_output_pairs = CAL.get_input_output_pairs(ekp)

@test input_output_pairs.inputs.stored_data ==
hcat([ekp.u[i].stored_data for i in 1:(length(ekp.u) - 1)]...)
@test input_output_pairs.outputs.stored_data ==
hcat([ekp.g[i].stored_data for i in 1:length(ekp.g)]...)

emulator = CAL.gp_emulator(input_output_pairs, y_noise_cov)


(; mcmc, chain) = CAL.sample(emulator, y_obs, prior, init_params)
@test mean(chain.value[1:100000]) 4.19035299 rtol = 0.0001

constrained_posterior = CAL.save_posterior(mcmc, chain)
@test mean(constrained_posterior["equator_pole_temperature_gradient_wet"])
66.046965013381 rtol = 0.0001

0 comments on commit a87aef9

Please sign in to comment.