diff --git a/src/slurm.jl b/src/slurm.jl index fb394052..b5800986 100644 --- a/src/slurm.jl +++ b/src/slurm.jl @@ -2,9 +2,18 @@ export kwargs, sbatch_model_run, wait_for_jobs kwargs(; kwargs...) = Dict{Symbol, Any}(kwargs...) +function generate_sbatch_directives(slurm_kwargs) + @assert haskey(slurm_kwargs, :time) "Slurm kwargs must include key :time" + + slurm_kwargs[:time] = format_slurm_time(slurm_kwargs[:time]) + slurm_directives = map(collect(slurm_kwargs)) do (k, v) + "#SBATCH --$(replace(string(k), "_" => "-"))=$(replace(string(v), "_" => "-"))" + end + return join(slurm_directives, "\n") +end + """ - generate_sbatch_script( - iter, member, + generate_sbatch_script(iter, member, output_dir, experiment_dir, model_interface; module_load_str, slurm_kwargs, ) @@ -14,28 +23,22 @@ Generate a string containing an sbatch script to run the forward model. Helper function for `sbatch_model_run`. """ function generate_sbatch_script( - iter, - member, - output_dir, - experiment_dir, - model_interface, - module_load_str; + iter::Int, + member::Int, + output_dir::AbstractString, + experiment_dir::AbstractString, + model_interface::AbstractString, + module_load_str::AbstractString; slurm_kwargs, ) member_log = path_to_model_log(output_dir, iter, member) - - # Format time in minutes to string for slurm - slurm_kwargs[:time] = format_slurm_time(slurm_kwargs[:time]) - - slurm_directives = map(collect(slurm_kwargs)) do (k, v) - "#SBATCH --$(replace(string(k), "_" => "-"))=$(replace(string(v), "_" => "-"))" - end + slurm_directives = generate_sbatch_directives(slurm_kwargs) sbatch_contents = """ #!/bin/bash #SBATCH --job-name=run_$(iter)_$(member) #SBATCH --output=$member_log - $(join(slurm_directives, "\n")) + $slurm_directives $module_load_str srun --output=$member_log --open-mode=append julia --project=$experiment_dir -e ' @@ -83,8 +86,16 @@ function sbatch_model_run( :ntasks => 1, :cpus_per_task => 1, ), - kwargs..., ) + # Type and existence checks + @assert isdir(output_dir) "Output directory does not exist: $output_dir" + @assert isdir(experiment_dir) "Experiment directory does not exist: $experiment_dir" + @assert isfile(model_interface) "Model interface file does not exist: $model_interface" + + # Range checks + @assert iter >= 0 "Iteration number must be non-negative" + @assert member > 0 "Member number must be positive" + sbatch_contents = generate_sbatch_script( iter, member, @@ -93,18 +104,18 @@ function sbatch_model_run( model_interface, module_load_str; slurm_kwargs, - kwargs..., ) - sbatch_filepath, io = mktemp(output_dir) - write(io, sbatch_contents) - close(io) - - return submit_sbatch_job(sbatch_filepath) + jobid = mktemp(output_dir) do sbatch_filepath, io + write(io, sbatch_contents) + close(io) + submit_sbatch_job(sbatch_filepath) + end + return jobid end function wait_for_jobs( - jobids, + jobids::Vector{Int}, output_dir, iter, experiment_dir, @@ -112,9 +123,10 @@ function wait_for_jobs( module_load_str; verbose, slurm_kwargs, + reruns = 1, ) statuses = map(job_status, jobids) - rerun_jobs = Set{Int}() + rerun_job_count = zeros(length(jobids)) completed_jobs = Set{Int}() try @@ -124,7 +136,7 @@ function wait_for_jobs( if job_failed(status) log_member_error(output_dir, iter, m, verbose) - if !(m in rerun_jobs) + if rerun_job_count[m] < reruns @info "Rerunning ensemble member $m" jobids[m] = sbatch_model_run( @@ -136,7 +148,7 @@ function wait_for_jobs( module_load_str; slurm_kwargs, ) - push!(rerun_jobs, m) + rerun_job_count[m] += 1 else push!(completed_jobs, m) end @@ -162,7 +174,7 @@ end """ log_member_error(output_dir, iteration, member, verbose = false) -Log a warning message when an error occurs in a specific ensemble member during a model run in a Slurm environment. +Log a warning message when an error occurs. If verbose, includes the ensemble member's output. """ function log_member_error(output_dir, iteration, member, verbose = false) @@ -189,14 +201,14 @@ function report_iteration_status(statuses, output_dir, iter) end end -function submit_sbatch_job(sbatch_filepath; debug = false, env = deepcopy(ENV)) +function submit_sbatch_job(sbatch_filepath; env = deepcopy(ENV)) + # Ensure that we don't inherit unwanted environment variables unset_env_vars = ("SLURM_MEM_PER_CPU", "SLURM_MEM_PER_GPU", "SLURM_MEM_PER_NODE") for k in unset_env_vars haskey(env, k) && delete!(env, k) end jobid = readchomp(setenv(`sbatch --parsable $sbatch_filepath`, env)) - debug || rm(sbatch_filepath) return parse(Int, jobid) end diff --git a/test/slurm_unit_tests.jl b/test/slurm_unit_tests.jl index 7e45b649..7661f3b7 100644 --- a/test/slurm_unit_tests.jl +++ b/test/slurm_unit_tests.jl @@ -96,6 +96,7 @@ sleep(1) # Test batch cancellation jobids = ntuple(x -> submit_cmd_helper(test_cmd), 5) + CAL.kill_all_jobs(jobids) for jobid in jobids @test CAL.job_completed(CAL.job_status(jobid))