-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add initial worker interface, needs a lot of cleanup
- Loading branch information
1 parent
81dbec4
commit 3f784bf
Showing
5 changed files
with
180 additions
and
2 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,123 @@ | ||
using Distributed, ClusterManagers | ||
import EnsembleKalmanProcesses as EKP | ||
|
||
#= | ||
`srun -J julia-4014777 -n 10 -D /home/nefrathe/clima/ClimaCalibrate.jl | ||
--cpus-per-task=1 | ||
-t 00:20:00 -o /home/nefrathe/clima/ClimaCalibrate.jl/./julia-4014777-17333586001-%4t.out | ||
/clima/software/julia/julia-1.11.0/bin/julia | ||
--project=/home/nefrathe/clima/ClimaCalibrate.jl/Project.toml | ||
--worker=U3knisg2TJufcrbJ` | ||
=# | ||
|
||
|
||
# srun_proc = open(srun_cmd) | ||
|
||
# worker_cookie() = begin Distributed.init_multi(); cluster_cookie() end | ||
# worker_arg() = `--worker=$(worker_cookie())` | ||
|
||
# srun_cmd = `srun -J $jobname -n $np -D $exehome $(srunargs) $exename $exeflags $(worker_arg())` | ||
|
||
default_worker_pool() = WorkerPool(workers()) | ||
|
||
function run_iteration(iter, ensemble_size, output_dir; worker_pool = default_worker_pool(), failure_rate = 0.5) | ||
# Create a channel to collect results | ||
results = Channel{Any}(ensemble_size) | ||
@sync begin | ||
for m in 1:(ensemble_size) | ||
@async begin | ||
# Get a worker from the pool | ||
worker = take!(worker_pool) | ||
try | ||
model_config = set_up_forward_model(m, iter, config) | ||
result = remotecall_fetch( | ||
run_forward_model, | ||
worker, | ||
model_config, | ||
) | ||
put!(results, (m, result)) | ||
catch e | ||
@error "Error running member $m" exception = e | ||
put!(results, (m, e)) | ||
finally | ||
# Always return worker to pool | ||
put!(worker_pool, worker) | ||
end | ||
end | ||
end | ||
end | ||
|
||
# Collect all results | ||
ensemble_results = Dict{Int, Any}() | ||
for _ in 1:(ensemble_size) | ||
m, result = take!(results) | ||
if result isa Exception | ||
@error "Member $m failed" error = result | ||
else | ||
ensemble_results[m] = result | ||
end | ||
end | ||
results = values(ensemble_results) | ||
iter_failure_rate = sum(isa.(results, Exception)) / ensemble_size | ||
if iter_failure_rate > failure_rate | ||
error("Ensemble for iter $iter had a $(iter_failure_rate * 100)% failure rate") | ||
end | ||
end | ||
|
||
function worker_calibrate(ensemble_size, n_iterations, observations, noise, prior, output_dir; failure_rate = 0.5, worker_pool = default_worker_pool(), ekp_kwargs...) | ||
initialize( | ||
ensemble_size, | ||
observations, | ||
noise, | ||
prior, | ||
output_dir; | ||
rng_seed = 1234, | ||
ekp_kwargs..., | ||
) | ||
for iter in 0:(n_iterations) | ||
(; time) = @timed run_iteration(iter, ensemble_size, output_dir; worker_pool, failure_rate) | ||
@info "Iteration $iter time: $time" | ||
# Process results | ||
G_ensemble = observation_map(iter) | ||
save_G_ensemble(output_dir, iter, G_ensemble) | ||
update_ensemble(output_dir, iter, prior) | ||
iter_path = path_to_iteration(output_dir, iter) | ||
end | ||
return JLD2.load_object(path_to_iteration(output_dir, n_iterations)) | ||
end | ||
|
||
function worker_calibrate(ekp::EKP.EnsembleKalmanProcess, ensemble_size,n_iterations, observations, noise, prior, output_dir; failure_rate = 0.5, worker_pool = default_worker_pool(), ekp_kwargs...) | ||
initialize( | ||
ekp, prior, output_dir | ||
; | ||
rng_seed = 1234, | ||
) | ||
for iter in 0:(n_iterations) | ||
(; time) = @timed run_iteration(iter, ensemble_size, output_dir; worker_pool, failure_rate) | ||
@info "Iteration $iter time: $time" | ||
# Process results | ||
G_ensemble = observation_map(iter) | ||
save_G_ensemble(output_dir, iter, G_ensemble) | ||
update_ensemble(output_dir, iter, prior) | ||
iter_path = path_to_iteration(output_dir, iter) | ||
end | ||
return JLD2.load_object(path_to_iteration(output_dir, n_iterations)) | ||
end | ||
|
||
|
||
function worker_calibrate(config; worker_pool = default_worker_pool()) | ||
(; ensemble_size, observations, noise, prior, output_dir) = config | ||
return worker_calibrate(ensemble_size, observations, noise, prior, output_dir; worker_pool) | ||
end | ||
|
||
function slurm_worker_pool(nprocs::Int; slurm_kwargs...) | ||
return WorkerPool(addprocs( | ||
SlurmManager(nprocs); | ||
t = "01:00:00", cpus_per_task = 1, | ||
exeflags = "--project=$(Base.active_project())", | ||
slurm_kwargs..., | ||
)) | ||
end | ||
|
||
# gpus_per_task=1 | ||
worker_pool = default_worker_pool() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,53 @@ | ||
# Tests for SurfaceFluxes example calibration on HPC, used in buildkite testing | ||
# To run, open the REPL: julia --project=experiments/surface_fluxes_perfect_model test/hpc_backend_e2e.jl | ||
|
||
using Pkg | ||
Pkg.instantiate(; verbose = true) | ||
|
||
import ClimaCalibrate: | ||
get_backend, | ||
HPCBackend, | ||
JuliaBackend, | ||
calibrate, | ||
get_prior, | ||
kwargs, | ||
ExperimentConfig, | ||
DerechoBackend | ||
using Test | ||
import EnsembleKalmanProcesses: get_ϕ_mean_final, get_g_mean_final | ||
|
||
experiment_dir = dirname(Base.active_project()) | ||
model_interface = joinpath(experiment_dir, "model_interface.jl") | ||
|
||
# Generate observational data and include observational map | ||
include(joinpath(experiment_dir, "generate_data.jl")) | ||
include(joinpath(experiment_dir, "observation_map.jl")) | ||
include(model_interface) | ||
|
||
prior = get_prior(joinpath(experiment_dir, "prior.toml")) | ||
|
||
function test_sf_calibration_output(eki, prior) | ||
@testset "End to end test using file config (surface fluxes perfect model)" begin | ||
parameter_values = get_ϕ_mean_final(prior, eki) | ||
test_parameter_values = [4.778584250117946, 3.7295665619234697] | ||
@test all( | ||
isapprox.(parameter_values, test_parameter_values; rtol = 1e-3), | ||
) | ||
|
||
forward_model_output = get_g_mean_final(eki) | ||
test_model_output = [0.05228473730385304] | ||
@test all( | ||
isapprox.(forward_model_output, test_model_output; rtol = 1e-3), | ||
) | ||
end | ||
end | ||
|
||
@everywhere | ||
eki = worker_calibrate(config; model_interface, hpc_kwargs, verbose = true) | ||
test_sf_calibration_output(eki, prior) | ||
|
||
# Pure Julia calibration, this should run anywhere | ||
eki = calibrate(JuliaBackend, experiment_dir) | ||
test_sf_calibration_output(eki, prior) | ||
|
||
include(joinpath(experiment_dir, "postprocessing.jl")) |