diff --git a/src/ClimaCalibrate.jl b/src/ClimaCalibrate.jl index 1c7130c6..2a78affa 100644 --- a/src/ClimaCalibrate.jl +++ b/src/ClimaCalibrate.jl @@ -2,6 +2,7 @@ module ClimaCalibrate include("ekp_interface.jl") include("model_interface.jl") +include("slurm.jl") include("backends.jl") include("emulate_sample.jl") diff --git a/src/backends.jl b/src/backends.jl index b68f6b3f..eb417d31 100644 --- a/src/backends.jl +++ b/src/backends.jl @@ -97,7 +97,8 @@ include(joinpath(experiment_dir, "generate_data.jl")) include(joinpath(experiment_dir, "observation_map.jl")) include(model_interface) -eki = calibrate(CaltechHPC, experiment_dir; time_limit = 3, model_interface); +slurm_kwargs = kwargs(time = 3) +eki = calibrate(CaltechHPC, experiment_dir; model_interface, slurm_kwargs); ``` """ function calibrate( @@ -115,12 +116,13 @@ function calibrate( model_interface = abspath( joinpath(experiment_dir, "..", "..", "model_interface.jl"), ), - time_limit = 60, + verbose = false, + slurm_kwargs = + kwargs(time_limit = 60, ntasks = 1, cpus_per_task = 1, gpus_per_task = 0, - partition = gpus_per_task > 0 ? "gpu" : "expansion", - verbose = false, + partition = gpus_per_task > 0 ? "gpu" : "expansion"), ) # ExperimentConfig is created from a YAML file within the experiment_dir (; n_iterations, output_dir, ensemble_size) = config @@ -132,17 +134,13 @@ function calibrate( @info "Iteration $iter" jobids = map(1:ensemble_size) do member @info "Running ensemble member $member" - sbatch_model_run(; - output_dir, + sbatch_model_run( iter, member, - time_limit, - ntasks, - partition, - cpus_per_task, - gpus_per_task, + output_dir, experiment_dir, - model_interface, + model_interface; + slurm_kwargs, ) end @@ -150,14 +148,10 @@ function calibrate( jobids, output_dir, iter, - time_limit, - ntasks, - partition, - cpus_per_task, - gpus_per_task, experiment_dir, - model_interface, + model_interface; verbose, + slurm_kwargs, ) report_iteration_status(statuses, output_dir, iter) @info "Completed iteration $iter, updating ensemble" @@ -167,263 +161,3 @@ function calibrate( end return eki end - -""" - log_member_error(output_dir, iteration, member, verbose = false) - -Log a warning message when an error occurs in a specific ensemble member during a model run in a Slurm environment. -If verbose, includes the ensemble member's output. -""" -function log_member_error(output_dir, iteration, member, verbose = false) - member_log = joinpath( - path_to_ensemble_member(output_dir, iteration, member), - "model_log.txt", - ) - warn_str = "Ensemble member $member raised an error. See model log at $abspath(member_log) for stacktrace" - if verbose - stacktrace = replace(readchomp(member_log), "\\n" => "\n") - warn_str = warn_str * ": \n$stacktrace" - end - @warn warn_str -end - -function generate_sbatch_script( - output_dir, - iter, - member, - time_limit, - ntasks, - partition, - cpus_per_task, - gpus_per_task, - experiment_dir, - model_interface, - module_load = """ - export MODULEPATH=/groups/esm/modules:\$MODULEPATH - module purge - module load climacommon/2024_04_30 - """, -) - member_log = joinpath( - path_to_ensemble_member(output_dir, iter, member), - "model_log.txt", - ) - sbatch_contents = """ - #!/bin/bash - #SBATCH --job-name=run_$(iter)_$(member) - #SBATCH --time=$(format_slurm_time(time_limit)) - #SBATCH --ntasks=$ntasks - #SBATCH --partition=$partition - #SBATCH --cpus-per-task=$cpus_per_task - #SBATCH --gpus-per-task=$gpus_per_task - #SBATCH --output=$member_log - - $module_load - - srun --output=$member_log --open-mode=append julia --project=$experiment_dir -e ' - import ClimaCalibrate as CAL - iteration = $iter; member = $member - model_interface = "$model_interface"; include(model_interface) - - experiment_dir = "$experiment_dir" - experiment_config = CAL.ExperimentConfig(experiment_dir) - experiment_id = experiment_config.id - physical_model = CAL.get_forward_model(Val(Symbol(experiment_id))) - CAL.run_forward_model(physical_model, CAL.get_config(physical_model, member, iteration, experiment_dir)) - @info "Forward Model Run Completed" experiment_id physical_model iteration member' - """ - return sbatch_contents -end - -""" - sbatch_model_run(; - output_dir, - iter, - member, - time_limit, - ntasks, - partition, - cpus_per_task, - gpus_per_task, - experiment_dir, - model_interface, - verbose, - ) - -Construct and execute a command to run a model simulation on a Slurm cluster for a single ensemble member. -""" -function sbatch_model_run(; - output_dir, - iter, - member, - time_limit, - ntasks, - partition, - cpus_per_task, - gpus_per_task, - experiment_dir, - model_interface, -) - sbatch_contents = generate_sbatch_script( - output_dir, - iter, - member, - time_limit, - ntasks, - partition, - cpus_per_task, - gpus_per_task, - experiment_dir, - model_interface, - ) - - sbatch_filepath, io = mktemp(output_dir) - write(io, sbatch_contents) - close(io) - - return submit_sbatch_job(sbatch_filepath) -end - -function wait_for_jobs( - jobids, - output_dir, - iter, - time_limit, - ntasks, - partition, - cpus_per_task, - gpus_per_task, - experiment_dir, - model_interface, - verbose, -) - statuses = map(job_status, jobids) - rerun_jobs = Set{Int}() - completed_jobs = Set{Int}() - - try - while !all(job_completed, statuses) - for (m, status) in enumerate(statuses) - m in completed_jobs && continue - - if job_failed(status) - log_member_error(output_dir, iter, m, verbose) - if !(m in rerun_jobs) - - @info "Rerunning ensemble member $m" - jobids[m] = sbatch_model_run(; - output_dir, - iter, - member = m, - time_limit, - ntasks, - partition, - cpus_per_task, - gpus_per_task, - experiment_dir, - model_interface, - ) - push!(rerun_jobs, m) - else - push!(completed_jobs, m) - end - elseif job_success(status) - @info "Ensemble member $m complete" - push!(completed_jobs, m) - end - end - sleep(5) - statuses = map(job_status, jobids) - end - return statuses - catch e - kill_all_jobs(jobids) - if !(e isa InterruptException) - @error "Pipeline crashed outside of a model run. Stacktrace for failed simulation" exception = - (e, catch_backtrace()) - end - return map(job_status, jobids) - end -end - -function report_iteration_status(statuses, output_dir, iter) - all(job_completed.(statuses)) || error("Some jobs are not complete") - if all(job_failed, statuses) - error( - "Full ensemble for iteration $iter has failed. See model logs in $(abspath(path_to_iteration(output_dir, iter))) for details.", - ) - elseif any(job_failed, statuses) - @warn "Failed ensemble members: $(findall(job_failed, statuses))" - end -end - -function submit_sbatch_job(sbatch_filepath; debug = false, env = ENV) - jobid = readchomp(setenv(`sbatch --parsable $sbatch_filepath`, env)) - debug || rm(sbatch_filepath) - return parse(Int, jobid) -end - -job_running(status) = status == "RUNNING" -job_success(status) = status == "COMPLETED" -job_failed(status) = status == "FAILED" -job_completed(status) = job_failed(status) || job_success(status) - -""" - job_status(jobid) - -Parse the slurm jobid's state and return one of three status strings: "COMPLETED", "FAILED", or "RUNNING" -""" -function job_status(jobid) - failure_statuses = ("FAILED", "CANCELLED+", "CANCELLED") - output = readchomp(`sacct -j $jobid --format=State --noheader`) - # Jobs usually have multiple statuses - statuses = strip.(split(output, "\n")) - if all(s -> s == "COMPLETED", statuses) - return "COMPLETED" - elseif any(s -> s in failure_statuses, statuses) - return "FAILED" - else - return "RUNNING" - end -end - -""" - kill_all_jobs(jobids) - -Takes a list of slurm job IDs and runs `scancel` on them. -""" -function kill_all_jobs(jobids) - for jobid in jobids - try - kill_slurm_job(jobid) - println("Cancelling slurm job $jobid") - catch e - println("Failed to cancel slurm job $jobid: ", e) - end - end -end - -kill_slurm_job(jobid) = run(`scancel $jobid`) - -function format_slurm_time(minutes::Int) - days, remaining_minutes = divrem(minutes, (60 * 24)) - hours, remaining_minutes = divrem(remaining_minutes, 60) - # Format the string according to Slurm's time format - if days > 0 - return string( - days, - "-", - lpad(hours, 2, '0'), - ":", - lpad(remaining_minutes, 2, '0'), - ":00", - ) - else - return string( - lpad(hours, 2, '0'), - ":", - lpad(remaining_minutes, 2, '0'), - ":00", - ) - end -end diff --git a/src/slurm.jl b/src/slurm.jl new file mode 100644 index 00000000..0d2a83b0 --- /dev/null +++ b/src/slurm.jl @@ -0,0 +1,253 @@ + +kwargs(; kwargs...) = kwargs + +""" +generate_sbatch_script + + +""" +function generate_sbatch_script( + iter, + member, + output_dir, + experiment_dir, + model_interface; + module_load = """ + export MODULEPATH=/groups/esm/modules:\$MODULEPATH + module purge + module load climacommon/2024_04_30 + """, + slurm_kwargs = Dict( + :time => 45, + :ntasks => 1, + :cpus_per_task => 1, + ) +) + + member_log = joinpath( + path_to_ensemble_member(output_dir, iter, member), + "model_log.txt", + ) + + # Format time in minutes to string for slurm + slurm_kwargs[:time] = format_slurm_time(slurm_kwargs[:time]) + + slurm_directives = map(kv -> "#SBATCH --$(kv[1])=$(kv[2])", collect(slurm_kwargs)) + slurm_directives_str = join(slurm_directives, "\n") + + sbatch_contents = """ + #!/bin/bash + #SBATCH --job-name=run_$(iter)_$(member) + #SBATCH --output=$member_log + $slurm_directives_str + + $module_load + + srun --output=$member_log --open-mode=append julia --project=$experiment_dir -e ' + import CalibrateAtmos as CAL + iteration = $iter; member = $member + model_interface = "$model_interface"; include(model_interface) + + experiment_dir = "$experiment_dir" + experiment_config = CAL.ExperimentConfig(experiment_dir) + experiment_id = experiment_config.id + physical_model = CAL.get_forward_model(Val(Symbol(experiment_id))) + CAL.run_forward_model(physical_model, CAL.get_config(physical_model, member, iteration, experiment_dir)) + @info "Forward Model Run Completed" experiment_id physical_model iteration member' + """ + return sbatch_contents +end + +""" + sbatch_model_run( + iter, + member, + output_dir, + experiment_dir; + model_interface, + verbose; + slurm_kwargs, + ) + +Construct and execute a command to run a model simulation on a Slurm cluster for a single ensemble member. +""" +function sbatch_model_run( + iter, + member, + output_dir, + experiment_dir, + model_interface; + slurm_kwargs = Dict(), + kwargs... +) + sbatch_contents = generate_sbatch_script( + iter, + member, + output_dir, + experiment_dir, + model_interface; + slurm_kwargs, + kwargs... + ) + + sbatch_filepath, io = mktemp(output_dir) + write(io, sbatch_contents) + close(io) + + return submit_sbatch_job(sbatch_filepath) +end + +function wait_for_jobs( + jobids, + output_dir, + iter, + experiment_dir, + model_interface; + verbose, + slurm_kwargs, +) + statuses = map(job_status, jobids) + rerun_jobs = Set{Int}() + completed_jobs = Set{Int}() + + try + while !all(job_completed, statuses) + for (m, status) in enumerate(statuses) + m in completed_jobs && continue + + if job_failed(status) + log_member_error(output_dir, iter, m, verbose) + if !(m in rerun_jobs) + + @info "Rerunning ensemble member $m" + jobids[m] = sbatch_model_run( + iter, + m, + output_dir, + experiment_dir, + model_interface; + slurm_kwargs, + ) + push!(rerun_jobs, m) + else + push!(completed_jobs, m) + end + elseif job_success(status) + @info "Ensemble member $m complete" + push!(completed_jobs, m) + end + end + sleep(5) + statuses = map(job_status, jobids) + end + return statuses + catch e + kill_all_jobs(jobids) + if !(e isa InterruptException) + @error "Pipeline crashed outside of a model run. Stacktrace for failed simulation" exception = + (e, catch_backtrace()) + end + return map(job_status, jobids) + end +end + +""" + log_member_error(output_dir, iteration, member, verbose = false) + +Log a warning message when an error occurs in a specific ensemble member during a model run in a Slurm environment. +If verbose, includes the ensemble member's output. +""" +function log_member_error(output_dir, iteration, member, verbose = false) + member_log = joinpath( + path_to_ensemble_member(output_dir, iteration, member), + "model_log.txt", + ) + warn_str = "Ensemble member $member raised an error. See model log at $abspath(member_log) for stacktrace" + if verbose + stacktrace = replace(readchomp(member_log), "\\n" => "\n") + warn_str = warn_str * ": \n$stacktrace" + end + @warn warn_str +end + +function report_iteration_status(statuses, output_dir, iter) + all(job_completed.(statuses)) || error("Some jobs are not complete") + if all(job_failed, statuses) + error( + "Full ensemble for iteration $iter has failed. See model logs in $(abspath(path_to_iteration(output_dir, iter))) for details.", + ) + elseif any(job_failed, statuses) + @warn "Failed ensemble members: $(findall(job_failed, statuses))" + end +end + +function submit_sbatch_job(sbatch_filepath; debug = false, env = ENV) + jobid = readchomp(setenv(`sbatch --parsable $sbatch_filepath`, env)) + debug || rm(sbatch_filepath) + return parse(Int, jobid) +end + +job_running(status) = status == "RUNNING" +job_success(status) = status == "COMPLETED" +job_failed(status) = status == "FAILED" +job_completed(status) = job_failed(status) || job_success(status) + +""" + job_status(jobid) + +Parse the slurm jobid's state and return one of three status strings: "COMPLETED", "FAILED", or "RUNNING" +""" +function job_status(jobid) + failure_statuses = ("FAILED", "CANCELLED+", "CANCELLED") + output = readchomp(`sacct -j $jobid --format=State --noheader`) + # Jobs usually have multiple statuses + statuses = strip.(split(output, "\n")) + if all(s -> s == "COMPLETED", statuses) + return "COMPLETED" + elseif any(s -> s in failure_statuses, statuses) + return "FAILED" + else + return "RUNNING" + end +end + +""" + kill_all_jobs(jobids) + +Takes a list of slurm job IDs and runs `scancel` on them. +""" +function kill_all_jobs(jobids) + for jobid in jobids + try + kill_slurm_job(jobid) + println("Cancelling slurm job $jobid") + catch e + println("Failed to cancel slurm job $jobid: ", e) + end + end +end + +kill_slurm_job(jobid) = run(`scancel $jobid`) + +function format_slurm_time(minutes::Int) + days, remaining_minutes = divrem(minutes, (60 * 24)) + hours, remaining_minutes = divrem(remaining_minutes, 60) + # Format the string according to Slurm's time format + if days > 0 + return string( + days, + "-", + lpad(hours, 2, '0'), + ":", + lpad(remaining_minutes, 2, '0'), + ":00", + ) + else + return string( + lpad(hours, 2, '0'), + ":", + lpad(remaining_minutes, 2, '0'), + ":00", + ) + end +end diff --git a/test/Project.toml b/test/Project.toml index 7e0a695b..77ce91b0 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -1,6 +1,6 @@ [deps] -ClimaCalibrate = "4347a170-ebd6-470c-89d3-5c705c0cacc2" CalibrateEmulateSample = "95e48a1f-0bec-4818-9538-3db4340308e3" +ClimaCalibrate = "4347a170-ebd6-470c-89d3-5c705c0cacc2" ClimaParams = "5c42b081-d73a-476f-9059-fd94b934656c" Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" EnsembleKalmanProcesses = "aa8a2aa5-91d8-4396-bcef-d4f2ec43552d"