Skip to content

Commit

Permalink
more fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
nefrathenrici committed Dec 9, 2024
1 parent 3f784bf commit 7844f9e
Show file tree
Hide file tree
Showing 4 changed files with 167 additions and 81 deletions.
10 changes: 8 additions & 2 deletions experiments/surface_fluxes_perfect_model/Manifest.toml
Original file line number Diff line number Diff line change
Expand Up @@ -187,10 +187,10 @@ weakdeps = ["SparseArrays"]
ChainRulesCoreSparseArraysExt = "SparseArrays"

[[deps.ClimaCalibrate]]
deps = ["Distributed", "Distributions", "EnsembleKalmanProcesses", "JLD2", "Random", "TOML", "YAML"]
deps = ["ClusterManagers", "Distributed", "Distributions", "EnsembleKalmanProcesses", "JLD2", "Random", "TOML", "YAML"]
path = "../.."
uuid = "4347a170-ebd6-470c-89d3-5c705c0cacc2"
version = "0.0.4"
version = "0.0.5"

[deps.ClimaCalibrate.extensions]
CESExt = "CalibrateEmulateSample"
Expand All @@ -204,6 +204,12 @@ git-tree-sha1 = "b43ca371c435056129295445122ea87fd843b505"
uuid = "5c42b081-d73a-476f-9059-fd94b934656c"
version = "0.10.14"

[[deps.ClusterManagers]]
deps = ["Distributed", "Logging", "Pkg", "Sockets"]
git-tree-sha1 = "6a678b98d5ea4d2773e92c7ae607cf7371043684"
uuid = "34f1f09b-3a8b-5176-ab39-66d58a4d544e"
version = "0.4.6"

[[deps.CodecBzip2]]
deps = ["Bzip2_jll", "TranscodingStreams"]
git-tree-sha1 = "e7c529cc31bb85b97631b922fa2e6baf246f5905"
Expand Down
1 change: 0 additions & 1 deletion src/ekp_interface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -267,7 +267,6 @@ function _initialize(
rng_ekp = Random.MersenneTwister(rng_seed)
initial_ensemble =
EKP.construct_initial_ensemble(rng_ekp, prior, ensemble_size)
@show typeof(initial_ensemble)
ekp_str_kwargs = Dict([string(k) => v for (k, v) in ekp_kwargs])
eki_constructor =
(args...) -> EKP.EnsembleKalmanProcess(
Expand Down
29 changes: 8 additions & 21 deletions src/model_interface.jl
Original file line number Diff line number Diff line change
@@ -1,33 +1,19 @@
import EnsembleKalmanProcesses as EKP
import YAML

export set_up_forward_model, run_forward_model, observation_map
export forward_model, observation_map

"""
set_up_forward_model(member, iteration, experiment_dir::AbstractString)
set_up_forward_model(member, iteration, experiment_config::ExperimentConfig)
Set up and configure a single member's forward model. Used in conjunction with `run_forward_model`.
This function must be overriden by a component's model interface and
should set things like the parameter path and other member-specific settings.
"""
set_up_forward_model(member, iteration, experiment_dir::AbstractString) =
set_up_forward_model(member, iteration, ExperimentConfig(experiment_dir))

set_up_forward_model(member, iteration, experiment_config::ExperimentConfig) =
error("set_up_forward_model not implemented")

"""
run_forward_model(model_config)
forward_model(member, iteration)
Execute the forward model simulation with the given configuration.
This function should be overridden with model-specific implementation details.
`config` should be obtained from `set_up_forward_model`:
`run_forward_model(set_up_forward_model(member, iter, experiment_dir))`
This function must be overridden by a component's model interface and
should set things like the parameter path and other member-specific settings.
"""
run_forward_model(model_config) = error("run_forward_model not implemented")
function forward_model(member, iteration)
error("forward_model not implemented")
end

"""
observation_map(iteration)
Expand All @@ -38,3 +24,4 @@ This function must be implemented for each calibration experiment.
function observation_map(iteration)
error("observation_map not implemented")
end

208 changes: 151 additions & 57 deletions src/slurm_workers.jl
Original file line number Diff line number Diff line change
@@ -1,69 +1,40 @@
using Distributed, ClusterManagers
using Distributed
import EnsembleKalmanProcesses as EKP

#=
`srun -J julia-4014777 -n 10 -D /home/nefrathe/clima/ClimaCalibrate.jl
--cpus-per-task=1
-t 00:20:00 -o /home/nefrathe/clima/ClimaCalibrate.jl/./julia-4014777-17333586001-%4t.out
/clima/software/julia/julia-1.11.0/bin/julia
--project=/home/nefrathe/clima/ClimaCalibrate.jl/Project.toml
--worker=U3knisg2TJufcrbJ`
=#


# srun_proc = open(srun_cmd)

# worker_cookie() = begin Distributed.init_multi(); cluster_cookie() end
# worker_arg() = `--worker=$(worker_cookie())`

# srun_cmd = `srun -J $jobname -n $np -D $exehome $(srunargs) $exename $exeflags $(worker_arg())`

export worker_calibrate, add_slurm_workers
default_worker_pool() = WorkerPool(workers())

function run_iteration(iter, ensemble_size, output_dir; worker_pool = default_worker_pool(), failure_rate = 0.5)
# Create a channel to collect results
results = Channel{Any}(ensemble_size)
nfailures = 0
@sync begin
for m in 1:(ensemble_size)
@async begin
# Get a worker from the pool
worker = take!(worker_pool)
@info "Running particle $m on worker $worker"
try
model_config = set_up_forward_model(m, iter, config)
result = remotecall_fetch(
run_forward_model,
worker,
model_config,
)
put!(results, (m, result))
remotecall( forward_model, worker, m, iter)
catch e
@error "Error running member $m" exception = e
put!(results, (m, e))
nfailures += 1
finally
# Always return worker to pool
put!(worker_pool, worker)
end
end
end
end

# Collect all results
ensemble_results = Dict{Int, Any}()
for _ in 1:(ensemble_size)
m, result = take!(results)
if result isa Exception
@error "Member $m failed" error = result
else
ensemble_results[m] = result
end
end
results = values(ensemble_results)
iter_failure_rate = sum(isa.(results, Exception)) / ensemble_size
iter_failure_rate = nfailures / ensemble_size
if iter_failure_rate > failure_rate
error("Ensemble for iter $iter had a $(iter_failure_rate * 100)% failure rate")
end
end

function worker_calibrate(config; failure_rate = 0.5, worker_pool = default_worker_pool(), ekp_kwargs...)
(; ensemble_size, n_iterations, observations, noise, prior, output_dir) = config
return worker_calibrate(ensemble_size, n_iterations, observations, noise, prior, output_dir; failure_rate, worker_pool, ekp_kwargs...)
end

function worker_calibrate(ensemble_size, n_iterations, observations, noise, prior, output_dir; failure_rate = 0.5, worker_pool = default_worker_pool(), ekp_kwargs...)
initialize(
ensemble_size,
Expand All @@ -75,15 +46,15 @@ function worker_calibrate(ensemble_size, n_iterations, observations, noise, prio
ekp_kwargs...,
)
for iter in 0:(n_iterations)
(; time) = @timed run_iteration(iter, ensemble_size, output_dir; worker_pool, failure_rate)
(; time) = @timed run_iteration(iter, config; worker_pool, failure_rate)
@info "Iteration $iter time: $time"
# Process results
G_ensemble = observation_map(iter)
save_G_ensemble(output_dir, iter, G_ensemble)
update_ensemble(output_dir, iter, prior)
iter_path = path_to_iteration(output_dir, iter)
end
return JLD2.load_object(path_to_iteration(output_dir, n_iterations))
return JLD2.load_object(joinpath(path_to_iteration(output_dir, n_iterations)), "eki_file.jld2")
end

function worker_calibrate(ekp::EKP.EnsembleKalmanProcess, ensemble_size,n_iterations, observations, noise, prior, output_dir; failure_rate = 0.5, worker_pool = default_worker_pool(), ekp_kwargs...)
Expand All @@ -104,20 +75,143 @@ function worker_calibrate(ekp::EKP.EnsembleKalmanProcess, ensemble_size,n_iterat
return JLD2.load_object(path_to_iteration(output_dir, n_iterations))
end

# function slurm_worker_pool(nprocs::Int; exeflags = "--project=$(Base.active_project())", slurm_kwargs...)
# return WorkerPool(addprocs(
# SlurmManager(nprocs);
# t = "01:00:00",
# cpus_per_task = 1,
# # TODO: Fix output
# output = "worker_%j_%4t.out",
# exeflags = "--project=$(Base.active_project())",
# slurm_kwargs...,
# ))
# end

function worker_calibrate(config; worker_pool = default_worker_pool())
(; ensemble_size, observations, noise, prior, output_dir) = config
return worker_calibrate(ensemble_size, observations, noise, prior, output_dir; worker_pool)
end

function slurm_worker_pool(nprocs::Int; slurm_kwargs...)
return WorkerPool(addprocs(
SlurmManager(nprocs);
t = "01:00:00", cpus_per_task = 1,
exeflags = "--project=$(Base.active_project())",
slurm_kwargs...,
))
end
worker_cookie() = begin Distributed.init_multi(); cluster_cookie() end
worker_arg() = `--worker=$(worker_cookie())`

launched = WorkerConfig[]

function add_slurm_workers(ntasks; launched = launched, c = Condition(), exeflags = "--project=$(Base.active_project())", kwargs...)
Distributed.init_multi()

default_params = Distributed.default_addprocs_params()
kwargs = merge(default_params, Dict{Symbol, Any}(kwargs))
dir = kwargs[:dir]
exename = kwargs[:exename]

slurm_kwargs = filter(x->!(x[1] in keys(default_params)), kwargs)
srunargs = []
for k in keys(slurm_kwargs)
if length(string(k)) == 1
push!(srunargs, "-$k")
val = p[k]
if length(val) > 0
push!(srunargs, "$(p[k])")
end
else
k2 = replace(string(k), "_"=>"-")
val = p[k]
if length(val) > 0
push!(srunargs, "--$(k2)=$(p[k])")
else
push!(srunargs, "--$(k2)")
end
end
end

# Check for given output file name
jobname = "julia-$(getpid())"

# Set output name
default_template = "$jobname-$(trunc(Int, Base.time() * 10))"
default_output(x) = "$default_template-$x.txt"
has_output_name = ("-o" in srunargs) | ("--output" in srunargs)
job_output_file = if has_output_name
# if has_output_name, ensure there is only one output arg
loc = findfirst(x-> x == "-o" || x == "--output", srunargs)
job_output = srunargs[loc+1]
# Remove output argument to reappend
filter!(x -> x != "-o" && x != "--output", srunargs)
filter!(x -> !occursin(r"^-[oe]", x), srunargs)
job_output
else
".$(default_output("%4t"))"
end
push!(srunargs, "-o", job_output_file)

srun_cmd = `srun -J $jobname -n $ntasks $(srunargs) $exename $(worker_arg())`

@info "Starting SLURM job $jobname: $srun_cmd"
srun_proc = open(srun_cmd)

slurm_spec_regex = r"([\w]+):([\d]+)#(\d{1,3}.\d{1,3}.\d{1,3}.\d{1,3})"
could_not_connect_regex = r"could not connect"
exiting_regex = r"exiting."

# Wait for workers to start
t_start = time()
t_waited = round(Int, time() - t_start)
delays = ExponentialBackOff(10, 1.0, 512.0, 2.0, 0.1)
for i in 0:ntasks - 1
slurm_spec_match::Union{RegexMatch,Nothing} = nothing
worker_errors = String[]
if !has_output_name
job_output_file = ".$(default_output(lpad(i, 4, "0")))"
end
for retry_delay in push!(collect(delays), 0)
t_waited = round(Int, time() - t_start)

# gpus_per_task=1
worker_pool = default_worker_pool()
# Wait for output log to be created and populated, then parse

if isfile(job_output_file)
if filesize(job_output_file) > 0
open(job_output_file) do f
# Due to error and warning messages, the specification
# may not appear on the file's first line
for line in eachline(f)
re_match = match(slurm_spec_regex, line)
if !isnothing(re_match)
slurm_spec_match = re_match
end
for expr in [could_not_connect_regex, exiting_regex]
if !isnothing(match(expr, line))
slurm_spec_match = nothing
push!(worker_errors, line)
end
end
end
end
end
if !isempty(worker_errors) || !isnothing(slurm_spec_match)
break # break if error or specification found
else
@info "Worker $i (after $t_waited s): Output file found, but no connection details yet"
end
else
@info "Worker $i (after $t_waited s): No output file \"$job_output_file\" yet"
end

# Sleep for some time to limit resource usage while waiting for the job to start
sleep(retry_delay)
end

if !isempty(worker_errors)
throw(SlurmException("Worker $i failed after $t_waited s: $(join(worker_errors, " "))"))
elseif isnothing(slurm_spec_match)
throw(SlurmException("Timeout after $t_waited s while waiting for worker $i to get ready."))
end

config = WorkerConfig()
config.port = parse(Int, slurm_spec_match[2])
config.host = strip(slurm_spec_match[3])
@info "Worker $i ready after $t_waited s on host $(config.host), port $(config.port)"
# Keep a reference to the proc, so it's properly closed once
# the last worker exits.
config.userdata = srun_proc
push!(launched, config)
notify(c)
end
return WorkerPool(workers())
end

0 comments on commit 7844f9e

Please sign in to comment.