Skip to content

Commit

Permalink
Improve slurm controller
Browse files Browse the repository at this point in the history
  • Loading branch information
nefrathenrici committed Jul 11, 2024
1 parent fbc4a04 commit c01b0f5
Show file tree
Hide file tree
Showing 2 changed files with 47 additions and 30 deletions.
74 changes: 44 additions & 30 deletions src/slurm.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,18 @@ export kwargs, sbatch_model_run, wait_for_jobs

kwargs(; kwargs...) = Dict{Symbol, Any}(kwargs...)

function generate_sbatch_directives(slurm_kwargs)
@assert haskey(slurm_kwargs, :time) "Slurm kwargs must include key :time"

Check warning on line 6 in src/slurm.jl

View check run for this annotation

Codecov / codecov/patch

src/slurm.jl#L5-L6

Added lines #L5 - L6 were not covered by tests

slurm_kwargs[:time] = format_slurm_time(slurm_kwargs[:time])
slurm_directives = map(collect(slurm_kwargs)) do (k, v)
"#SBATCH --$(replace(string(k), "_" => "-"))=$(replace(string(v), "_" => "-"))"

Check warning on line 10 in src/slurm.jl

View check run for this annotation

Codecov / codecov/patch

src/slurm.jl#L8-L10

Added lines #L8 - L10 were not covered by tests
end
return join(slurm_directives, "\n")

Check warning on line 12 in src/slurm.jl

View check run for this annotation

Codecov / codecov/patch

src/slurm.jl#L12

Added line #L12 was not covered by tests
end

"""
generate_sbatch_script(
iter, member,
generate_sbatch_script(iter, member,
output_dir, experiment_dir, model_interface;
module_load_str, slurm_kwargs,
)
Expand All @@ -14,28 +23,23 @@ Generate a string containing an sbatch script to run the forward model.
Helper function for `sbatch_model_run`.
"""
function generate_sbatch_script(
iter,
member,
output_dir,
experiment_dir,
model_interface,
module_load_str;
iter::Int,
member::Int,
output_dir::AbstractString,
experiment_dir::AbstractString,
model_interface::AbstractString,
module_load_str::AbstractString;
slurm_kwargs,
)
member_log = path_to_model_log(output_dir, iter, member)

# Format time in minutes to string for slurm
slurm_kwargs[:time] = format_slurm_time(slurm_kwargs[:time])

slurm_directives = map(collect(slurm_kwargs)) do (k, v)
"#SBATCH --$(replace(string(k), "_" => "-"))=$(replace(string(v), "_" => "-"))"
end
slurm_directives = generate_sbatch_directives(slurm_kwargs)

Check warning on line 35 in src/slurm.jl

View check run for this annotation

Codecov / codecov/patch

src/slurm.jl#L35

Added line #L35 was not covered by tests

sbatch_contents = """
#!/bin/bash
#SBATCH --job-name=run_$(iter)_$(member)
#SBATCH --output=$member_log
$(join(slurm_directives, "\n"))
$slurm_directives
set -euo pipefail
$module_load_str
srun --output=$member_log --open-mode=append julia --project=$experiment_dir -e '
Expand All @@ -45,6 +49,7 @@ function generate_sbatch_script(
experiment_dir = "$experiment_dir"
CAL.run_forward_model(CAL.set_up_forward_model(member, iteration, experiment_dir))'
exit 0
"""
return sbatch_contents
end
Expand Down Expand Up @@ -83,8 +88,16 @@ function sbatch_model_run(
:ntasks => 1,
:cpus_per_task => 1,
),
kwargs...,
)
# Type and existence checks
@assert isdir(output_dir) "Output directory does not exist: $output_dir"
@assert isdir(experiment_dir) "Experiment directory does not exist: $experiment_dir"
@assert isfile(model_interface) "Model interface file does not exist: $model_interface"

Check warning on line 95 in src/slurm.jl

View check run for this annotation

Codecov / codecov/patch

src/slurm.jl#L93-L95

Added lines #L93 - L95 were not covered by tests

# Range checks
@assert iter >= 0 "Iteration number must be non-negative"
@assert member > 0 "Member number must be positive"

Check warning on line 99 in src/slurm.jl

View check run for this annotation

Codecov / codecov/patch

src/slurm.jl#L98-L99

Added lines #L98 - L99 were not covered by tests

sbatch_contents = generate_sbatch_script(
iter,
member,
Expand All @@ -93,28 +106,29 @@ function sbatch_model_run(
model_interface,
module_load_str;
slurm_kwargs,
kwargs...,
)

sbatch_filepath, io = mktemp(output_dir)
write(io, sbatch_contents)
close(io)

return submit_sbatch_job(sbatch_filepath)
jobid = mktemp(output_dir) do sbatch_filepath, io
write(io, sbatch_contents)
close(io)
submit_sbatch_job(sbatch_filepath)

Check warning on line 114 in src/slurm.jl

View check run for this annotation

Codecov / codecov/patch

src/slurm.jl#L111-L114

Added lines #L111 - L114 were not covered by tests
end
return jobid

Check warning on line 116 in src/slurm.jl

View check run for this annotation

Codecov / codecov/patch

src/slurm.jl#L116

Added line #L116 was not covered by tests
end

function wait_for_jobs(
jobids,
jobids::Vector{Int},
output_dir,
iter,
experiment_dir,
model_interface,
module_load_str;
verbose,
slurm_kwargs,
reruns = 1,
)
statuses = map(job_status, jobids)
rerun_jobs = Set{Int}()
rerun_job_count = zeros(length(jobids))

Check warning on line 131 in src/slurm.jl

View check run for this annotation

Codecov / codecov/patch

src/slurm.jl#L131

Added line #L131 was not covered by tests
completed_jobs = Set{Int}()

try
Expand All @@ -124,7 +138,7 @@ function wait_for_jobs(

if job_failed(status)
log_member_error(output_dir, iter, m, verbose)
if !(m in rerun_jobs)
if rerun_job_count[m] < reruns

Check warning on line 141 in src/slurm.jl

View check run for this annotation

Codecov / codecov/patch

src/slurm.jl#L141

Added line #L141 was not covered by tests

@info "Rerunning ensemble member $m"
jobids[m] = sbatch_model_run(
Expand All @@ -136,7 +150,7 @@ function wait_for_jobs(
module_load_str;
slurm_kwargs,
)
push!(rerun_jobs, m)
rerun_job_count[m] += 1

Check warning on line 153 in src/slurm.jl

View check run for this annotation

Codecov / codecov/patch

src/slurm.jl#L153

Added line #L153 was not covered by tests
else
push!(completed_jobs, m)
end
Expand All @@ -162,7 +176,7 @@ end
"""
log_member_error(output_dir, iteration, member, verbose = false)
Log a warning message when an error occurs in a specific ensemble member during a model run in a Slurm environment.
Log a warning message when an error occurs.
If verbose, includes the ensemble member's output.
"""
function log_member_error(output_dir, iteration, member, verbose = false)
Expand All @@ -189,14 +203,14 @@ function report_iteration_status(statuses, output_dir, iter)
end
end

function submit_sbatch_job(sbatch_filepath; debug = false, env = deepcopy(ENV))
function submit_sbatch_job(sbatch_filepath; env = deepcopy(ENV))

Check warning on line 206 in src/slurm.jl

View check run for this annotation

Codecov / codecov/patch

src/slurm.jl#L206

Added line #L206 was not covered by tests
# Ensure that we don't inherit unwanted environment variables
unset_env_vars =
("SLURM_MEM_PER_CPU", "SLURM_MEM_PER_GPU", "SLURM_MEM_PER_NODE")
for k in unset_env_vars
haskey(env, k) && delete!(env, k)
end
jobid = readchomp(setenv(`sbatch --parsable $sbatch_filepath`, env))
debug || rm(sbatch_filepath)
return parse(Int, jobid)
end

Expand Down
3 changes: 3 additions & 0 deletions test/slurm_unit_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ expected_sbatch_contents = """
#SBATCH --gpus-per-task=1
#SBATCH --cpus-per-task=16
#SBATCH --time=01:30:00
set -euo pipefail
export MODULEPATH=/groups/esm/modules:\$MODULEPATH
module purge
module load climacommon/2024_05_27
Expand All @@ -53,6 +54,7 @@ model_interface = "model_interface.jl"; include(model_interface)
experiment_dir = "exp/dir"
CAL.run_forward_model(CAL.set_up_forward_model(member, iteration, experiment_dir))'
exit 0
"""

for (generated_str, test_str) in
Expand Down Expand Up @@ -96,6 +98,7 @@ sleep(1)

# Test batch cancellation
jobids = ntuple(x -> submit_cmd_helper(test_cmd), 5)

CAL.kill_all_jobs(jobids)
for jobid in jobids
@test CAL.job_completed(CAL.job_status(jobid))
Expand Down

0 comments on commit c01b0f5

Please sign in to comment.