-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add emulate + sample functionality, docs, and tests
- Loading branch information
1 parent
598507d
commit a87aef9
Showing
12 changed files
with
218 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,3 +1,5 @@ | ||
[deps] | ||
CalibrateEmulateSample = "95e48a1f-0bec-4818-9538-3db4340308e3" | ||
Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4" | ||
DocumenterCitations = "daee34ce-89f3-4625-b898-19384cb65244" | ||
EnsembleKalmanProcesses = "aa8a2aa5-91d8-4396-bcef-d4f2ec43552d" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,64 @@ | ||
# Emulate and Sample | ||
Once you have run a successful calibration, we can fit an emulator to the resulting input/output pairs. | ||
|
||
First, import the necessary packages: | ||
```julia | ||
using CalibrateEmulateSample.Emulators | ||
using CalibrateEmulateSample.MarkovChainMonteCarlo | ||
|
||
import EnsembleKalmanProcesses as EKP | ||
using EnsembleKalmanProcesses.ParameterDistributions | ||
using EnsembleKalmanProcesses.TOMLInterface | ||
|
||
import JLD2 | ||
import CalibrateAtmos as CAL | ||
``` | ||
|
||
Next, load in the data, EKP object, and prior distribution. These values are taken | ||
from the perfect model experiment with experiment ID `sphere_held_suarez_rhoe_equilmoist`. | ||
```julia | ||
y_obs = [261.5493] | ||
y_noise_cov = [0.02619;;] | ||
ekp = JLD2.load_object( | ||
joinpath( | ||
pkgdir(CAL), | ||
"docs", | ||
"src", | ||
"assets", | ||
"eki_file_for_emulate_example.jld2", | ||
), | ||
) | ||
init_params = [EKP.get_u_final(ekp)[1]] | ||
|
||
prior_path = joinpath( | ||
pkgdir(CAL), | ||
"experiments", | ||
"sphere_held_suarez_rhoe_equilmoist", | ||
"prior.toml", | ||
) | ||
|
||
prior = CAL.get_prior(prior_path) | ||
``` | ||
Get the input-output pairs which will be used to train the emulator. | ||
The inputs are the parameter values, and the outputs are the result of the observation map. | ||
In thise case, the outputs are the average air temperature at roughly 500 meters. | ||
```julia | ||
input_output_pairs = CAL.get_input_output_pairs(ekp) | ||
``` | ||
Next, create the Gaussian Process-based emulator and Markov chain. | ||
The samples from the chain can be used in future predictive model runs with the same configuration. | ||
The posterior distribution can be saved to a JLD2 file using `save_posterior`. Samples can be extracted from the posterior using ClimaParams. | ||
```julia | ||
emulator = CAL.gp_emulator(input_output_pairs, y_noise_cov) | ||
(; mcmc, chain) = CAL.sample(emulator, y_obs, prior, init_params) | ||
constrained_posterior = CAL.save_posterior(mcmc, chain; filename = "samples.jld2") | ||
``` | ||
|
||
Finally, you can plot the prior and posterior distributions to see results: | ||
```julia | ||
using Plots | ||
plot(prior) | ||
posterior = get_posterior(mcmc, chain) | ||
plot!(posterior) | ||
vline!([65.0]) | ||
``` |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,87 @@ | ||
import CalibrateEmulateSample as CES | ||
using CalibrateEmulateSample.Emulators | ||
using CalibrateEmulateSample.MarkovChainMonteCarlo | ||
|
||
import EnsembleKalmanProcesses as EKP | ||
using EnsembleKalmanProcesses.ParameterDistributions | ||
using EnsembleKalmanProcesses.TOMLInterface | ||
|
||
import JLD2 | ||
|
||
""" | ||
get_input_output_pairs(ekp) | ||
Helper function for getting the input/output pairs from an EKP object. | ||
""" | ||
function get_input_output_pairs(ekp; N_iterations = nothing) | ||
N_iterations = isnothing(N_iterations) ? length(ekp.g) : N_iterations | ||
input_output_pairs = CES.Utilities.get_training_points(ekp, N_iterations) | ||
return input_output_pairs | ||
end | ||
|
||
""" | ||
gp_emulator(input_output_pairs, obs_noise_cov) | ||
Constructs a gaussian process emulator from the given input/output pairs and noise. | ||
""" | ||
function gp_emulator(input_output_pairs, obs_noise_cov) | ||
gppackage = GPJL() | ||
gauss_proc = GaussianProcess(gppackage, noise_learn = false) | ||
emulator = Emulator(gauss_proc, input_output_pairs; obs_noise_cov) | ||
optimize_hyperparameters!(emulator) | ||
return emulator | ||
end | ||
|
||
""" | ||
sample( | ||
emulator, | ||
y_obs, | ||
prior, | ||
init_params; | ||
n_samples = 100_000, | ||
init_stepsize = 0.1, | ||
discard_initial = 0 | ||
) | ||
Constructs a MarkovChainMonteCarlo object, optimizes its stepsize, and takes | ||
`n_samples` number of samples. | ||
The initial stepsize can be specified by `init_stepsize`, | ||
and the number of initial samples to discard can be set by `discard_initial`. | ||
Returns both the MCMC object and the samples in a NamedTuple. | ||
""" | ||
function sample( | ||
emulator, | ||
y_obs, | ||
prior, | ||
init_params; | ||
n_samples = 100_000, | ||
init_stepsize = 0.1, | ||
discard_initial = 0, | ||
) | ||
mcmc = MCMCWrapper(RWMHSampling(), y_obs, prior, emulator; init_params) | ||
new_step = optimize_stepsize(mcmc; init_stepsize, N = 2000, discard_initial) | ||
chain = MarkovChainMonteCarlo.sample( | ||
mcmc, | ||
n_samples; | ||
stepsize = new_step, | ||
discard_initial = 0, | ||
) | ||
return (; mcmc, chain) | ||
end | ||
|
||
""" | ||
save_posterior(mcmc, chain; filename = "samples.jld2") | ||
Given an MCMC object, a Markov chain of samples, and a prior distribution, | ||
constructs the posterior distribution and saves it to `filename`. | ||
Returns the samples in constrained (physical) parameter space. | ||
""" | ||
function save_posterior(mcmc, chain; filename = "samples.jld2") | ||
posterior = MarkovChainMonteCarlo.get_posterior(mcmc, chain) | ||
constrained_posterior = transform_unconstrained_to_constrained( | ||
posterior, | ||
MarkovChainMonteCarlo.get_distribution(posterior), | ||
) | ||
JLD2.save_object(filename, posterior) | ||
return constrained_posterior | ||
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,9 +1,13 @@ | ||
[deps] | ||
CLIMAParameters = "6eacf6c3-8458-43b9-ae03-caf5306d3d53" | ||
CalibrateAtmos = "4347a170-ebd6-470c-89d3-5c705c0cacc2" | ||
CalibrateEmulateSample = "95e48a1f-0bec-4818-9538-3db4340308e3" | ||
ClimaCore = "d414da3d-4745-48bb-8d80-42e94e092884" | ||
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" | ||
EnsembleKalmanProcesses = "aa8a2aa5-91d8-4396-bcef-d4f2ec43552d" | ||
JLD2 = "033835bb-8acc-5ee8-8aae-3f567f8a3819" | ||
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" | ||
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" | ||
SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f" | ||
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" | ||
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,4 +1,8 @@ | ||
using Test | ||
|
||
import Random | ||
Random.seed!(1234) | ||
|
||
include("test_init.jl") | ||
include("test_atmos_config.jl") | ||
include("test_emulate_sample.jl") |
Binary file not shown.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,4 @@ | ||
["equator_pole_temperature_gradient_wet"] | ||
prior = "Parameterized(Normal(4.779568,0.31223328))" | ||
constraint = "[bounded_below(0)]" | ||
type = "float" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,36 @@ | ||
import JLD2 | ||
import Statistics: mean | ||
|
||
using CalibrateEmulateSample.Emulators | ||
using CalibrateEmulateSample.MarkovChainMonteCarlo | ||
import EnsembleKalmanProcesses as EKP | ||
using EnsembleKalmanProcesses.ParameterDistributions | ||
using EnsembleKalmanProcesses.TOMLInterface | ||
|
||
import CalibrateAtmos as CAL | ||
|
||
y_obs = [261.5493] | ||
y_noise_cov = [0.02619;;] | ||
ekp = JLD2.load_object(joinpath("test_case_inputs", "eki_test.jld2")) | ||
init_params = [EKP.get_u_final(ekp)[1]] | ||
|
||
prior_path = joinpath("test_case_inputs", "sphere_hs_rhoe.toml") | ||
|
||
prior = CAL.get_prior(prior_path) | ||
|
||
input_output_pairs = CAL.get_input_output_pairs(ekp) | ||
|
||
@test input_output_pairs.inputs.stored_data == | ||
hcat([ekp.u[i].stored_data for i in 1:(length(ekp.u) - 1)]...) | ||
@test input_output_pairs.outputs.stored_data == | ||
hcat([ekp.g[i].stored_data for i in 1:length(ekp.g)]...) | ||
|
||
emulator = CAL.gp_emulator(input_output_pairs, y_noise_cov) | ||
|
||
|
||
(; mcmc, chain) = CAL.sample(emulator, y_obs, prior, init_params) | ||
@test mean(chain.value[1:100000]) ≈ 4.19035299 rtol = 0.0001 | ||
|
||
constrained_posterior = CAL.save_posterior(mcmc, chain) | ||
@test mean(constrained_posterior["equator_pole_temperature_gradient_wet"]) ≈ | ||
66.046965013381 rtol = 0.0001 |