diff --git a/src/slurm.jl b/src/slurm.jl index fb394052..dd057c60 100644 --- a/src/slurm.jl +++ b/src/slurm.jl @@ -2,9 +2,18 @@ export kwargs, sbatch_model_run, wait_for_jobs kwargs(; kwargs...) = Dict{Symbol, Any}(kwargs...) +function generate_sbatch_directives(slurm_kwargs) + @assert haskey(slurm_kwargs, :time) "Slurm kwargs must include key :time" + + slurm_kwargs[:time] = format_slurm_time(slurm_kwargs[:time]) + slurm_directives = map(collect(slurm_kwargs)) do (k, v) + "#SBATCH --$(replace(string(k), "_" => "-"))=$(replace(string(v), "_" => "-"))" + end + return join(slurm_directives, "\n") +end + """ - generate_sbatch_script( - iter, member, + generate_sbatch_script(iter, member, output_dir, experiment_dir, model_interface; module_load_str, slurm_kwargs, ) @@ -14,28 +23,23 @@ Generate a string containing an sbatch script to run the forward model. Helper function for `sbatch_model_run`. """ function generate_sbatch_script( - iter, - member, - output_dir, - experiment_dir, - model_interface, - module_load_str; + iter::Int, + member::Int, + output_dir::AbstractString, + experiment_dir::AbstractString, + model_interface::AbstractString, + module_load_str::AbstractString; slurm_kwargs, ) member_log = path_to_model_log(output_dir, iter, member) - - # Format time in minutes to string for slurm - slurm_kwargs[:time] = format_slurm_time(slurm_kwargs[:time]) - - slurm_directives = map(collect(slurm_kwargs)) do (k, v) - "#SBATCH --$(replace(string(k), "_" => "-"))=$(replace(string(v), "_" => "-"))" - end + slurm_directives = generate_sbatch_directives(slurm_kwargs) sbatch_contents = """ #!/bin/bash #SBATCH --job-name=run_$(iter)_$(member) #SBATCH --output=$member_log - $(join(slurm_directives, "\n")) + $slurm_directives + set -euo pipefail $module_load_str srun --output=$member_log --open-mode=append julia --project=$experiment_dir -e ' @@ -45,6 +49,7 @@ function generate_sbatch_script( experiment_dir = "$experiment_dir" CAL.run_forward_model(CAL.set_up_forward_model(member, iteration, experiment_dir))' + exit 0 """ return sbatch_contents end @@ -83,8 +88,16 @@ function sbatch_model_run( :ntasks => 1, :cpus_per_task => 1, ), - kwargs..., ) + # Type and existence checks + @assert isdir(output_dir) "Output directory does not exist: $output_dir" + @assert isdir(experiment_dir) "Experiment directory does not exist: $experiment_dir" + @assert isfile(model_interface) "Model interface file does not exist: $model_interface" + + # Range checks + @assert iter >= 0 "Iteration number must be non-negative" + @assert member > 0 "Member number must be positive" + sbatch_contents = generate_sbatch_script( iter, member, @@ -93,18 +106,18 @@ function sbatch_model_run( model_interface, module_load_str; slurm_kwargs, - kwargs..., ) - sbatch_filepath, io = mktemp(output_dir) - write(io, sbatch_contents) - close(io) - - return submit_sbatch_job(sbatch_filepath) + jobid = mktemp(output_dir) do sbatch_filepath, io + write(io, sbatch_contents) + close(io) + submit_sbatch_job(sbatch_filepath) + end + return jobid end function wait_for_jobs( - jobids, + jobids::Vector{Int}, output_dir, iter, experiment_dir, @@ -112,9 +125,10 @@ function wait_for_jobs( module_load_str; verbose, slurm_kwargs, + reruns = 1, ) statuses = map(job_status, jobids) - rerun_jobs = Set{Int}() + rerun_job_count = zeros(length(jobids)) completed_jobs = Set{Int}() try @@ -124,7 +138,7 @@ function wait_for_jobs( if job_failed(status) log_member_error(output_dir, iter, m, verbose) - if !(m in rerun_jobs) + if rerun_job_count[m] < reruns @info "Rerunning ensemble member $m" jobids[m] = sbatch_model_run( @@ -136,7 +150,7 @@ function wait_for_jobs( module_load_str; slurm_kwargs, ) - push!(rerun_jobs, m) + rerun_job_count[m] += 1 else push!(completed_jobs, m) end @@ -162,7 +176,7 @@ end """ log_member_error(output_dir, iteration, member, verbose = false) -Log a warning message when an error occurs in a specific ensemble member during a model run in a Slurm environment. +Log a warning message when an error occurs. If verbose, includes the ensemble member's output. """ function log_member_error(output_dir, iteration, member, verbose = false) @@ -189,14 +203,14 @@ function report_iteration_status(statuses, output_dir, iter) end end -function submit_sbatch_job(sbatch_filepath; debug = false, env = deepcopy(ENV)) +function submit_sbatch_job(sbatch_filepath; env = deepcopy(ENV)) + # Ensure that we don't inherit unwanted environment variables unset_env_vars = ("SLURM_MEM_PER_CPU", "SLURM_MEM_PER_GPU", "SLURM_MEM_PER_NODE") for k in unset_env_vars haskey(env, k) && delete!(env, k) end jobid = readchomp(setenv(`sbatch --parsable $sbatch_filepath`, env)) - debug || rm(sbatch_filepath) return parse(Int, jobid) end diff --git a/test/slurm_unit_tests.jl b/test/slurm_unit_tests.jl index 7e45b649..34888949 100644 --- a/test/slurm_unit_tests.jl +++ b/test/slurm_unit_tests.jl @@ -42,6 +42,7 @@ expected_sbatch_contents = """ #SBATCH --gpus-per-task=1 #SBATCH --cpus-per-task=16 #SBATCH --time=01:30:00 +set -euo pipefail export MODULEPATH=/groups/esm/modules:\$MODULEPATH module purge module load climacommon/2024_05_27 @@ -53,6 +54,7 @@ model_interface = "model_interface.jl"; include(model_interface) experiment_dir = "exp/dir" CAL.run_forward_model(CAL.set_up_forward_model(member, iteration, experiment_dir))' +exit 0 """ for (generated_str, test_str) in @@ -96,6 +98,7 @@ sleep(1) # Test batch cancellation jobids = ntuple(x -> submit_cmd_helper(test_cmd), 5) + CAL.kill_all_jobs(jobids) for jobid in jobids @test CAL.job_completed(CAL.job_status(jobid))