diff --git a/experiments/surface_fluxes_perfect_model/Manifest.toml b/experiments/surface_fluxes_perfect_model/Manifest.toml index c8bde98b..8e87f984 100644 --- a/experiments/surface_fluxes_perfect_model/Manifest.toml +++ b/experiments/surface_fluxes_perfect_model/Manifest.toml @@ -42,10 +42,10 @@ weakdeps = ["StaticArrays"] AdaptStaticArraysExt = "StaticArrays" [[deps.AliasTables]] -deps = ["Random"] -git-tree-sha1 = "07591db28451b3e45f4c0088a2d5e986ae5aa92d" +deps = ["PtrArrays", "Random"] +git-tree-sha1 = "9876e1e164b144ca45e9e3198d0b689cadfed9ff" uuid = "66dad0bd-aa9a-41b7-9441-69ab47430ed8" -version = "1.1.1" +version = "1.1.3" [[deps.Animations]] deps = ["Colors"] @@ -167,9 +167,15 @@ version = "0.11.10" [[deps.Cairo_jll]] deps = ["Artifacts", "Bzip2_jll", "CompilerSupportLibraries_jll", "Fontconfig_jll", "FreeType2_jll", "Glib_jll", "JLLWrappers", "LZO_jll", "Libdl", "Pixman_jll", "Xorg_libXext_jll", "Xorg_libXrender_jll", "Zlib_jll", "libpng_jll"] -git-tree-sha1 = "a4c43f59baa34011e303e76f5c8c91bf58415aaf" +git-tree-sha1 = "a2f1c8c668c8e3cb4cca4e57a8efdb09067bb3fd" uuid = "83423d85-b0ee-5818-9007-b63ccbeb887a" -version = "1.18.0+1" +version = "1.18.0+2" + +[[deps.Calculus]] +deps = ["LinearAlgebra"] +git-tree-sha1 = "f641eb0a4f00c343bbc32346e1217b86f3ce9dad" +uuid = "49dc2e85-a5d0-5ad3-a950-438e2897f1b9" +version = "0.5.1" [[deps.ChainRulesCore]] deps = ["Compat", "LinearAlgebra"] @@ -219,9 +225,9 @@ version = "0.4.0" [[deps.ColorSchemes]] deps = ["ColorTypes", "ColorVectorSpace", "Colors", "FixedPointNumbers", "PrecompileTools", "Random"] -git-tree-sha1 = "67c1f244b991cad9b0aa4b7540fb758c2488b129" +git-tree-sha1 = "4b270d6465eb21ae89b732182c20dc165f8bf9f2" uuid = "35d6a980-a343-548e-a6ea-1d62b119f2f4" -version = "3.24.0" +version = "3.25.0" [[deps.ColorTypes]] deps = ["FixedPointNumbers", "Random"] @@ -241,9 +247,9 @@ weakdeps = ["SpecialFunctions"] [[deps.Colors]] deps = ["ColorTypes", "FixedPointNumbers", "Reexport"] -git-tree-sha1 = "fc08e5930ee9a4e03f84bfb5211cb54e7769758a" +git-tree-sha1 = "362a287c3aa50601b0bc359053d5c2468f0e7ce0" uuid = "5ae59095-9a9b-59fe-a467-6f913c188581" -version = "0.12.10" +version = "0.12.11" [[deps.Combinatorics]] git-tree-sha1 = "08c8b6831dc00bfea825826be0bc8336fc369860" @@ -258,9 +264,9 @@ version = "0.3.0" [[deps.Compat]] deps = ["TOML", "UUIDs"] -git-tree-sha1 = "c955881e3c981181362ae4088b35995446298b80" +git-tree-sha1 = "b1c55339b7c6c350ee89f2c1604299660525b248" uuid = "34da2185-b29b-5c13-b0c7-acf172513d20" -version = "4.14.0" +version = "4.15.0" weakdeps = ["Dates", "LinearAlgebra"] [deps.Compat.extensions] @@ -362,6 +368,12 @@ deps = ["ArgTools", "FileWatching", "LibCURL", "NetworkOptions"] uuid = "f43a241f-c20a-4ad4-852c-f6b1247861c6" version = "1.6.0" +[[deps.DualNumbers]] +deps = ["Calculus", "NaNMath", "SpecialFunctions"] +git-tree-sha1 = "5837a837389fccf076445fce071c8ddaea35a566" +uuid = "fa6b7ba4-c1ee-5f82-b5fc-ecf0adba8f74" +version = "0.6.8" + [[deps.EarCut_jll]] deps = ["Artifacts", "JLLWrappers", "Libdl", "Pkg"] git-tree-sha1 = "e3290f2d49e661fbd94046d7e3726ffcb2d41053" @@ -416,9 +428,9 @@ version = "3.3.10+0" [[deps.FastGaussQuadrature]] deps = ["LinearAlgebra", "SpecialFunctions", "StaticArrays"] -git-tree-sha1 = "58d83dd5a78a36205bdfddb82b1bb67682e64487" +git-tree-sha1 = "0f478d8bad6f52573fb7658a263af61f3d96e43a" uuid = "442a2c76-b920-505d-bb47-c5924d526838" -version = "0.4.9" +version = "0.5.1" [[deps.FileIO]] deps = ["Pkg", "Requires", "UUIDs"] @@ -442,13 +454,14 @@ version = "0.9.21" uuid = "7b1f6079-737a-58dc-b8bc-7a2ca5c1b5ee" [[deps.FillArrays]] -deps = ["LinearAlgebra", "Random"] -git-tree-sha1 = "35f0c0f345bff2c6d636f95fdb136323b5a796ef" +deps = ["LinearAlgebra"] +git-tree-sha1 = "0653c0a2396a6da5bc4766c43041ef5fd3efbe57" uuid = "1a297f60-69ca-5386-bcde-b61e274b549b" -version = "1.7.0" -weakdeps = ["SparseArrays", "Statistics"] +version = "1.11.0" +weakdeps = ["PDMats", "SparseArrays", "Statistics"] [deps.FillArrays.extensions] + FillArraysPDMatsExt = "PDMats" FillArraysSparseArraysExt = "SparseArrays" FillArraysStatisticsExt = "Statistics" @@ -470,9 +483,9 @@ version = "2.23.1" [[deps.FixedPointNumbers]] deps = ["Statistics"] -git-tree-sha1 = "335bfdceacc84c5cdf16aadc768aa5ddfc5383cc" +git-tree-sha1 = "05882d6995ae5c12bb5f36dd2ed3f61c98cbb172" uuid = "53c48c17-4a7d-5ca2-90c5-79b7896eea93" -version = "0.8.4" +version = "0.8.5" [[deps.Fontconfig_jll]] deps = ["Artifacts", "Bzip2_jll", "Expat_jll", "FreeType2_jll", "JLLWrappers", "Libdl", "Libuuid_jll", "Zlib_jll"] @@ -524,10 +537,10 @@ deps = ["Random"] uuid = "9fa8497b-333b-5362-9e8d-4d0656e87820" [[deps.GaussianRandomFields]] -deps = ["Arpack", "FFTW", "FastGaussQuadrature", "LinearAlgebra", "RecipesBase", "SpecialFunctions", "Statistics"] -git-tree-sha1 = "d9c335f2c06424029b2addf9abf602e0feb2f53e" +deps = ["Arpack", "FFTW", "FastGaussQuadrature", "LinearAlgebra", "Random", "RecipesBase", "SpecialFunctions", "Statistics", "StatsBase"] +git-tree-sha1 = "055849d7a602c31eda477a0b0b86c9473a3e4fb9" uuid = "e4b2fa32-6e09-5554-b718-106ed5adafe9" -version = "2.1.6" +version = "2.2.4" [[deps.GeoInterface]] deps = ["Extents"] @@ -549,9 +562,9 @@ version = "0.21.0+0" [[deps.Glib_jll]] deps = ["Artifacts", "Gettext_jll", "JLLWrappers", "Libdl", "Libffi_jll", "Libiconv_jll", "Libmount_jll", "PCRE2_jll", "Zlib_jll"] -git-tree-sha1 = "359a1ba2e320790ddbe4ee8b4d54a305c0ea2aff" +git-tree-sha1 = "7c82e6a6cd34e9d935e9aa4051b66c6ff3af59ba" uuid = "7746bdde-850d-59dc-9ae8-88ece973131d" -version = "2.80.0+0" +version = "2.80.2+0" [[deps.Graphics]] deps = ["Colors", "LinearAlgebra", "NaNMath"] @@ -582,6 +595,12 @@ git-tree-sha1 = "129acf094d168394e80ee1dc4bc06ec835e510a3" uuid = "2e76f6c2-a576-52d4-95c1-20adfe4de566" version = "2.8.1+1" +[[deps.HypergeometricFunctions]] +deps = ["DualNumbers", "LinearAlgebra", "OpenLibm_jll", "SpecialFunctions"] +git-tree-sha1 = "f218fe3736ddf977e0e772bc9a586b2383da2685" +uuid = "34004b35-14d8-5ef3-9330-4cdb6864b03a" +version = "0.3.23" + [[deps.ImageAxes]] deps = ["AxisArrays", "ImageBase", "ImageCore", "Reexport", "SimpleTraits"] git-tree-sha1 = "2e4520d67b0cef90865b3ef727594d2a58e0e1f8" @@ -602,9 +621,9 @@ version = "0.10.2" [[deps.ImageIO]] deps = ["FileIO", "IndirectArrays", "JpegTurbo", "LazyModules", "Netpbm", "OpenEXR", "PNGFiles", "QOI", "Sixel", "TiffImages", "UUIDs"] -git-tree-sha1 = "bca20b2f5d00c4fbc192c3212da8fa79f4688009" +git-tree-sha1 = "437abb322a41d527c197fa800455f79d414f0a3c" uuid = "82e4d734-157c-48bb-816b-45c225c6df19" -version = "0.6.7" +version = "0.6.8" [[deps.ImageMetadata]] deps = ["AxisArrays", "ImageAxes", "ImageBase", "ImageCore"] @@ -678,20 +697,10 @@ weakdeps = ["Random", "RecipesBase", "Statistics"] IntervalSetsRecipesBaseExt = "RecipesBase" IntervalSetsStatisticsExt = "Statistics" -[[deps.InverseFunctions]] -deps = ["Test"] -git-tree-sha1 = "e7cbed5032c4c397a6ac23d1493f3289e01231c4" -uuid = "3587e190-3f89-42d0-90ee-14403ec27112" -version = "0.1.14" -weakdeps = ["Dates"] - - [deps.InverseFunctions.extensions] - DatesExt = "Dates" - [[deps.IrrationalConstants]] -git-tree-sha1 = "7fd44fd4ff43fc60815f8e764c0f352b83c49151" +git-tree-sha1 = "630b497eafcc20001bba38a4651b327dcfc491d2" uuid = "92d709cd-6900-40b7-9082-c6be49f344b6" -version = "0.1.1" +version = "0.2.2" [[deps.Isoband]] deps = ["isoband_jll"] @@ -735,15 +744,15 @@ version = "0.1.5" [[deps.JpegTurbo_jll]] deps = ["Artifacts", "JLLWrappers", "Libdl"] -git-tree-sha1 = "3336abae9a713d2210bb57ab484b1e065edd7d23" +git-tree-sha1 = "c84a835e1a09b289ffcd2271bf2a337bbdda6637" uuid = "aacddb02-875f-59d6-b918-886e6ef4fbf8" -version = "3.0.2+0" +version = "3.0.3+0" [[deps.KernelAbstractions]] deps = ["Adapt", "Atomix", "InteractiveUtils", "LinearAlgebra", "MacroTools", "PrecompileTools", "Requires", "SparseArrays", "StaticArrays", "UUIDs", "UnsafeAtomics", "UnsafeAtomicsLLVM"] -git-tree-sha1 = "ed7167240f40e62d97c1f5f7735dea6de3cc5c49" +git-tree-sha1 = "db02395e4c374030c53dc28f3c1d33dec35f7272" uuid = "63c18a36-062a-441e-b654-da1e3ab1ce7c" -version = "0.9.18" +version = "0.9.19" [deps.KernelAbstractions.extensions] EnzymeExt = "EnzymeCore" @@ -866,15 +875,15 @@ version = "1.17.0+0" [[deps.Libmount_jll]] deps = ["Artifacts", "JLLWrappers", "Libdl"] -git-tree-sha1 = "4b683b19157282f50bfd5dcaa2efe5295814ea22" +git-tree-sha1 = "0c4f9c4f1a50d8f35048fa0532dabbadf702f81e" uuid = "4b2f31a3-9ecc-558c-b454-b3730dcb73e9" -version = "2.40.0+0" +version = "2.40.1+0" [[deps.Libuuid_jll]] deps = ["Artifacts", "JLLWrappers", "Libdl"] -git-tree-sha1 = "27fd5cc10be85658cacfe11bb81bee216af13eda" +git-tree-sha1 = "5ee6203157c120d79034c748a2acba45b82b8807" uuid = "38a345b3-de98-5d2b-a5d3-14cd9215e700" -version = "2.40.0+0" +version = "2.40.1+0" [[deps.LightXML]] deps = ["Libdl", "XML2_jll"] @@ -1111,10 +1120,10 @@ uuid = "efcefdf7-47ab-520b-bdef-62a2eaa19f15" version = "10.42.0+1" [[deps.PDMats]] -deps = ["LinearAlgebra", "SparseArrays", "SuiteSparse", "Test"] -git-tree-sha1 = "95a4038d1011dfdbde7cecd2ad0ac411e53ab1bc" +deps = ["LinearAlgebra", "SparseArrays", "SuiteSparse"] +git-tree-sha1 = "949347156c25054de2db3b166c52ac4728cbad65" uuid = "90014a1f-27ba-587c-ab20-58faa44d9150" -version = "0.10.1" +version = "0.11.31" [[deps.PNGFiles]] deps = ["Base64", "CEnum", "ImageCore", "IndirectArrays", "OffsetArrays", "libpng_jll"] @@ -1160,9 +1169,9 @@ version = "0.4.21" [[deps.Pixman_jll]] deps = ["Artifacts", "CompilerSupportLibraries_jll", "JLLWrappers", "LLVMOpenMP_jll", "Libdl"] -git-tree-sha1 = "64779bc4c9784fee475689a1752ef4d5747c5e87" +git-tree-sha1 = "35621f10a7531bc8fa58f74610b1bfb70a3cfc6b" uuid = "30392449-352a-5448-841d-b1acce4e97dc" -version = "0.42.2+0" +version = "0.43.4+0" [[deps.Pkg]] deps = ["Artifacts", "Dates", "Downloads", "FileWatching", "LibGit2", "Libdl", "Logging", "Markdown", "Printf", "REPL", "Random", "SHA", "Serialization", "TOML", "Tar", "UUIDs", "p7zip_jll"] @@ -1237,6 +1246,11 @@ git-tree-sha1 = "763a8ceb07833dd51bb9e3bbca372de32c0605ad" uuid = "92933f4c-e287-5a05-a399-4b506db050ca" version = "1.10.0" +[[deps.PtrArrays]] +git-tree-sha1 = "077664975d750757f30e739c870fbbdc01db7913" +uuid = "43287f4e-b6f4-7ad1-bb20-aadabca52c3d" +version = "1.1.0" + [[deps.QOI]] deps = ["ColorTypes", "FileIO", "FixedPointNumbers"] git-tree-sha1 = "18e8f4d1426e965c7b532ddd260599e1510d26ce" @@ -1346,6 +1360,12 @@ version = "3.2.0+0" uuid = "ea8e919c-243c-51af-8825-aaa63cd721ce" version = "0.7.0" +[[deps.SIMD]] +deps = ["PrecompileTools"] +git-tree-sha1 = "2803cab51702db743f3fda07dd1745aadfbf43bd" +uuid = "fdea26ae-647d-5447-a871-4b548cad5224" +version = "3.5.0" + [[deps.Scratch]] deps = ["Dates"] git-tree-sha1 = "3bac05bc7e74a75fd9cba4295cde4045d9fe2386" @@ -1435,9 +1455,9 @@ version = "1.10.0" [[deps.SpecialFunctions]] deps = ["IrrationalConstants", "LogExpFunctions", "OpenLibm_jll", "OpenSpecFun_jll"] -git-tree-sha1 = "e2cfc4012a19088254b3950b85c3c1d8882d864d" +git-tree-sha1 = "2f5d4697f21388cbe1ff299430dd169ef97d7e14" uuid = "276daf66-3868-5448-9aa4-cd146d93841b" -version = "2.3.1" +version = "2.4.0" weakdeps = ["ChainRulesCore"] [deps.SpecialFunctions.extensions] @@ -1478,15 +1498,23 @@ version = "1.7.0" [[deps.StatsBase]] deps = ["DataAPI", "DataStructures", "LinearAlgebra", "LogExpFunctions", "Missings", "Printf", "Random", "SortingAlgorithms", "SparseArrays", "Statistics", "StatsAPI"] -git-tree-sha1 = "d1bf48bfcc554a3761a133fe3a9bb01488e06916" +git-tree-sha1 = "5cf7606d6cef84b543b483848d4ae08ad9832b21" uuid = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91" -version = "0.33.21" +version = "0.34.3" [[deps.StatsFuns]] -deps = ["ChainRulesCore", "InverseFunctions", "IrrationalConstants", "LogExpFunctions", "Reexport", "Rmath", "SpecialFunctions"] -git-tree-sha1 = "5950925ff997ed6fb3e985dcce8eb1ba42a0bbe7" +deps = ["HypergeometricFunctions", "IrrationalConstants", "LogExpFunctions", "Reexport", "Rmath", "SpecialFunctions"] +git-tree-sha1 = "cef0472124fab0695b58ca35a77c6fb942fdab8a" uuid = "4c63d2b9-4356-54db-8cca-17b64c39e42c" -version = "0.9.18" +version = "1.3.1" + + [deps.StatsFuns.extensions] + StatsFunsChainRulesCoreExt = "ChainRulesCore" + StatsFunsInverseFunctionsExt = "InverseFunctions" + + [deps.StatsFuns.weakdeps] + ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" + InverseFunctions = "3587e190-3f89-42d0-90ee-14403ec27112" [[deps.StringEncodings]] deps = ["Libiconv_jll"] @@ -1574,15 +1602,15 @@ weakdeps = ["ClimaParams"] CreateParametersExt = "ClimaParams" [[deps.TiffImages]] -deps = ["ColorTypes", "DataStructures", "DocStringExtensions", "FileIO", "FixedPointNumbers", "IndirectArrays", "Inflate", "Mmap", "OffsetArrays", "PkgVersion", "ProgressMeter", "UUIDs"] -git-tree-sha1 = "34cc045dd0aaa59b8bbe86c644679bc57f1d5bd0" +deps = ["ColorTypes", "DataStructures", "DocStringExtensions", "FileIO", "FixedPointNumbers", "IndirectArrays", "Inflate", "Mmap", "OffsetArrays", "PkgVersion", "ProgressMeter", "SIMD", "UUIDs"] +git-tree-sha1 = "bc7fd5c91041f44636b2c134041f7e5263ce58ae" uuid = "731e570b-9d59-4bfa-96dc-6df516fadf69" -version = "0.6.8" +version = "0.10.0" [[deps.TranscodingStreams]] -git-tree-sha1 = "71509f04d045ec714c4748c785a59045c3736349" +git-tree-sha1 = "5d54d076465da49d6746c647022f3b3674e64156" uuid = "3bb67fe8-82b1-5028-8e26-92a6c54297fa" -version = "0.10.7" +version = "0.10.8" weakdeps = ["Random", "Test"] [deps.TranscodingStreams.extensions] diff --git a/src/ClimaCalibrate.jl b/src/ClimaCalibrate.jl index 1c7130c6..2a78affa 100644 --- a/src/ClimaCalibrate.jl +++ b/src/ClimaCalibrate.jl @@ -2,6 +2,7 @@ module ClimaCalibrate include("ekp_interface.jl") include("model_interface.jl") +include("slurm.jl") include("backends.jl") include("emulate_sample.jl") diff --git a/src/backends.jl b/src/backends.jl index b68f6b3f..486084de 100644 --- a/src/backends.jl +++ b/src/backends.jl @@ -97,7 +97,8 @@ include(joinpath(experiment_dir, "generate_data.jl")) include(joinpath(experiment_dir, "observation_map.jl")) include(model_interface) -eki = calibrate(CaltechHPC, experiment_dir; time_limit = 3, model_interface); +slurm_kwargs = kwargs(time = 3) +eki = calibrate(CaltechHPC, experiment_dir; model_interface, slurm_kwargs); ``` """ function calibrate( @@ -115,12 +116,8 @@ function calibrate( model_interface = abspath( joinpath(experiment_dir, "..", "..", "model_interface.jl"), ), - time_limit = 60, - ntasks = 1, - cpus_per_task = 1, - gpus_per_task = 0, - partition = gpus_per_task > 0 ? "gpu" : "expansion", verbose = false, + slurm_kwargs = Dict(:time_limit => 45), ) # ExperimentConfig is created from a YAML file within the experiment_dir (; n_iterations, output_dir, ensemble_size) = config @@ -132,17 +129,13 @@ function calibrate( @info "Iteration $iter" jobids = map(1:ensemble_size) do member @info "Running ensemble member $member" - sbatch_model_run(; - output_dir, + sbatch_model_run( iter, member, - time_limit, - ntasks, - partition, - cpus_per_task, - gpus_per_task, + output_dir, experiment_dir, - model_interface, + model_interface; + slurm_kwargs, ) end @@ -150,14 +143,10 @@ function calibrate( jobids, output_dir, iter, - time_limit, - ntasks, - partition, - cpus_per_task, - gpus_per_task, experiment_dir, - model_interface, + model_interface; verbose, + slurm_kwargs, ) report_iteration_status(statuses, output_dir, iter) @info "Completed iteration $iter, updating ensemble" @@ -167,263 +156,3 @@ function calibrate( end return eki end - -""" - log_member_error(output_dir, iteration, member, verbose = false) - -Log a warning message when an error occurs in a specific ensemble member during a model run in a Slurm environment. -If verbose, includes the ensemble member's output. -""" -function log_member_error(output_dir, iteration, member, verbose = false) - member_log = joinpath( - path_to_ensemble_member(output_dir, iteration, member), - "model_log.txt", - ) - warn_str = "Ensemble member $member raised an error. See model log at $abspath(member_log) for stacktrace" - if verbose - stacktrace = replace(readchomp(member_log), "\\n" => "\n") - warn_str = warn_str * ": \n$stacktrace" - end - @warn warn_str -end - -function generate_sbatch_script( - output_dir, - iter, - member, - time_limit, - ntasks, - partition, - cpus_per_task, - gpus_per_task, - experiment_dir, - model_interface, - module_load = """ - export MODULEPATH=/groups/esm/modules:\$MODULEPATH - module purge - module load climacommon/2024_04_30 - """, -) - member_log = joinpath( - path_to_ensemble_member(output_dir, iter, member), - "model_log.txt", - ) - sbatch_contents = """ - #!/bin/bash - #SBATCH --job-name=run_$(iter)_$(member) - #SBATCH --time=$(format_slurm_time(time_limit)) - #SBATCH --ntasks=$ntasks - #SBATCH --partition=$partition - #SBATCH --cpus-per-task=$cpus_per_task - #SBATCH --gpus-per-task=$gpus_per_task - #SBATCH --output=$member_log - - $module_load - - srun --output=$member_log --open-mode=append julia --project=$experiment_dir -e ' - import ClimaCalibrate as CAL - iteration = $iter; member = $member - model_interface = "$model_interface"; include(model_interface) - - experiment_dir = "$experiment_dir" - experiment_config = CAL.ExperimentConfig(experiment_dir) - experiment_id = experiment_config.id - physical_model = CAL.get_forward_model(Val(Symbol(experiment_id))) - CAL.run_forward_model(physical_model, CAL.get_config(physical_model, member, iteration, experiment_dir)) - @info "Forward Model Run Completed" experiment_id physical_model iteration member' - """ - return sbatch_contents -end - -""" - sbatch_model_run(; - output_dir, - iter, - member, - time_limit, - ntasks, - partition, - cpus_per_task, - gpus_per_task, - experiment_dir, - model_interface, - verbose, - ) - -Construct and execute a command to run a model simulation on a Slurm cluster for a single ensemble member. -""" -function sbatch_model_run(; - output_dir, - iter, - member, - time_limit, - ntasks, - partition, - cpus_per_task, - gpus_per_task, - experiment_dir, - model_interface, -) - sbatch_contents = generate_sbatch_script( - output_dir, - iter, - member, - time_limit, - ntasks, - partition, - cpus_per_task, - gpus_per_task, - experiment_dir, - model_interface, - ) - - sbatch_filepath, io = mktemp(output_dir) - write(io, sbatch_contents) - close(io) - - return submit_sbatch_job(sbatch_filepath) -end - -function wait_for_jobs( - jobids, - output_dir, - iter, - time_limit, - ntasks, - partition, - cpus_per_task, - gpus_per_task, - experiment_dir, - model_interface, - verbose, -) - statuses = map(job_status, jobids) - rerun_jobs = Set{Int}() - completed_jobs = Set{Int}() - - try - while !all(job_completed, statuses) - for (m, status) in enumerate(statuses) - m in completed_jobs && continue - - if job_failed(status) - log_member_error(output_dir, iter, m, verbose) - if !(m in rerun_jobs) - - @info "Rerunning ensemble member $m" - jobids[m] = sbatch_model_run(; - output_dir, - iter, - member = m, - time_limit, - ntasks, - partition, - cpus_per_task, - gpus_per_task, - experiment_dir, - model_interface, - ) - push!(rerun_jobs, m) - else - push!(completed_jobs, m) - end - elseif job_success(status) - @info "Ensemble member $m complete" - push!(completed_jobs, m) - end - end - sleep(5) - statuses = map(job_status, jobids) - end - return statuses - catch e - kill_all_jobs(jobids) - if !(e isa InterruptException) - @error "Pipeline crashed outside of a model run. Stacktrace for failed simulation" exception = - (e, catch_backtrace()) - end - return map(job_status, jobids) - end -end - -function report_iteration_status(statuses, output_dir, iter) - all(job_completed.(statuses)) || error("Some jobs are not complete") - if all(job_failed, statuses) - error( - "Full ensemble for iteration $iter has failed. See model logs in $(abspath(path_to_iteration(output_dir, iter))) for details.", - ) - elseif any(job_failed, statuses) - @warn "Failed ensemble members: $(findall(job_failed, statuses))" - end -end - -function submit_sbatch_job(sbatch_filepath; debug = false, env = ENV) - jobid = readchomp(setenv(`sbatch --parsable $sbatch_filepath`, env)) - debug || rm(sbatch_filepath) - return parse(Int, jobid) -end - -job_running(status) = status == "RUNNING" -job_success(status) = status == "COMPLETED" -job_failed(status) = status == "FAILED" -job_completed(status) = job_failed(status) || job_success(status) - -""" - job_status(jobid) - -Parse the slurm jobid's state and return one of three status strings: "COMPLETED", "FAILED", or "RUNNING" -""" -function job_status(jobid) - failure_statuses = ("FAILED", "CANCELLED+", "CANCELLED") - output = readchomp(`sacct -j $jobid --format=State --noheader`) - # Jobs usually have multiple statuses - statuses = strip.(split(output, "\n")) - if all(s -> s == "COMPLETED", statuses) - return "COMPLETED" - elseif any(s -> s in failure_statuses, statuses) - return "FAILED" - else - return "RUNNING" - end -end - -""" - kill_all_jobs(jobids) - -Takes a list of slurm job IDs and runs `scancel` on them. -""" -function kill_all_jobs(jobids) - for jobid in jobids - try - kill_slurm_job(jobid) - println("Cancelling slurm job $jobid") - catch e - println("Failed to cancel slurm job $jobid: ", e) - end - end -end - -kill_slurm_job(jobid) = run(`scancel $jobid`) - -function format_slurm_time(minutes::Int) - days, remaining_minutes = divrem(minutes, (60 * 24)) - hours, remaining_minutes = divrem(remaining_minutes, 60) - # Format the string according to Slurm's time format - if days > 0 - return string( - days, - "-", - lpad(hours, 2, '0'), - ":", - lpad(remaining_minutes, 2, '0'), - ":00", - ) - else - return string( - lpad(hours, 2, '0'), - ":", - lpad(remaining_minutes, 2, '0'), - ":00", - ) - end -end diff --git a/src/ekp_interface.jl b/src/ekp_interface.jl index 3eb7ed1e..51c486e0 100644 --- a/src/ekp_interface.jl +++ b/src/ekp_interface.jl @@ -85,6 +85,16 @@ Constructs the path to an ensemble member's directory for a given iteration and path_to_ensemble_member(output_dir, iteration, member) = EKP.TOMLInterface.path_to_ensemble_member(output_dir, iteration, member) +""" + path_to_model_log(output_dir, iteration, member) + +Constructs the path to an ensemble member's forward model log for a given iteration and member number. +""" +path_to_model_log(output_dir, iteration, member) = joinpath( + path_to_ensemble_member(output_dir, iteration, member), + "model_log.txt", +) + """ path_to_iteration(output_dir, iteration) diff --git a/src/slurm.jl b/src/slurm.jl new file mode 100644 index 00000000..00b4e91e --- /dev/null +++ b/src/slurm.jl @@ -0,0 +1,250 @@ + +kwargs(; kwargs...) = Dict{Symbol, Any}(kwargs...) + +""" +generate_sbatch_script + + +""" +function generate_sbatch_script( + iter, + member, + output_dir, + experiment_dir, + model_interface; + module_load = """ + export MODULEPATH=/groups/esm/modules:\$MODULEPATH + module purge + module load climacommon/2024_04_30 + """, + slurm_kwargs = Dict{Symbol, Any}( + :time => 45, + :ntasks => 1, + :cpus_per_task => 1, + ), +) + member_log = path_to_model_log(output_dir, iter, member) + + # Format time in minutes to string for slurm + slurm_kwargs[:time] = format_slurm_time(slurm_kwargs[:time]) + + slurm_directives = map(collect(slurm_kwargs)) do (k, v) + "#SBATCH --$(replace(string(k), "_" => "-"))=$(replace(string(v), "_" => "-"))" + end + slurm_directives_str = join(slurm_directives, "\n") + + sbatch_contents = """ + #!/bin/bash + #SBATCH --job-name=run_$(iter)_$(member) + #SBATCH --output=$member_log + $slurm_directives_str + + $module_load + + srun --output=$member_log --open-mode=append julia --project=$experiment_dir -e ' + import ClimaCalibrate as CAL + iteration = $iter; member = $member + model_interface = "$model_interface"; include(model_interface) + + experiment_dir = "$experiment_dir" + experiment_config = CAL.ExperimentConfig(experiment_dir) + experiment_id = experiment_config.id + physical_model = CAL.get_forward_model(Val(Symbol(experiment_id))) + CAL.run_forward_model(physical_model, CAL.get_config(physical_model, member, iteration, experiment_dir)) + @info "Forward Model Run Completed" experiment_id physical_model iteration member' + """ + return sbatch_contents +end + +""" + sbatch_model_run( + iter, + member, + output_dir, + experiment_dir; + model_interface, + verbose; + slurm_kwargs, + ) + +Construct and execute a command to run a model simulation on a Slurm cluster for a single ensemble member. +""" +function sbatch_model_run( + iter, + member, + output_dir, + experiment_dir, + model_interface; + slurm_kwargs = Dict{Symbol, Any}(), + kwargs..., +) + sbatch_contents = generate_sbatch_script( + iter, + member, + output_dir, + experiment_dir, + model_interface; + slurm_kwargs, + kwargs..., + ) + + sbatch_filepath, io = mktemp(output_dir) + write(io, sbatch_contents) + close(io) + + return submit_sbatch_job(sbatch_filepath) +end + +function wait_for_jobs( + jobids, + output_dir, + iter, + experiment_dir, + model_interface; + verbose, + slurm_kwargs, +) + statuses = map(job_status, jobids) + rerun_jobs = Set{Int}() + completed_jobs = Set{Int}() + + try + while !all(job_completed, statuses) + for (m, status) in enumerate(statuses) + m in completed_jobs && continue + + if job_failed(status) + log_member_error(output_dir, iter, m, verbose) + if !(m in rerun_jobs) + + @info "Rerunning ensemble member $m" + jobids[m] = sbatch_model_run( + iter, + m, + output_dir, + experiment_dir, + model_interface; + slurm_kwargs, + ) + push!(rerun_jobs, m) + else + push!(completed_jobs, m) + end + elseif job_success(status) + @info "Ensemble member $m complete" + push!(completed_jobs, m) + end + end + sleep(5) + statuses = map(job_status, jobids) + end + return statuses + catch e + kill_all_jobs(jobids) + if !(e isa InterruptException) + @error "Pipeline crashed outside of a model run. Stacktrace for failed simulation" exception = + (e, catch_backtrace()) + end + return map(job_status, jobids) + end +end + +""" + log_member_error(output_dir, iteration, member, verbose = false) + +Log a warning message when an error occurs in a specific ensemble member during a model run in a Slurm environment. +If verbose, includes the ensemble member's output. +""" +function log_member_error(output_dir, iteration, member, verbose = false) + member_log = path_to_model_log(output_dir, iteration, member) + warn_str = "Ensemble member $member raised an error. See model log at $(abspath(member_log)) for stacktrace" + if verbose + stacktrace = replace(readchomp(member_log), "\\n" => "\n") + warn_str = warn_str * ": \n$stacktrace" + end + @warn warn_str +end + +function report_iteration_status(statuses, output_dir, iter) + all(job_completed.(statuses)) || error("Some jobs are not complete") + if all(job_failed, statuses) + error( + "Full ensemble for iteration $iter has failed. See model logs in $(abspath(path_to_iteration(output_dir, iter))) for details.", + ) + elseif any(job_failed, statuses) + @warn "Failed ensemble members: $(findall(job_failed, statuses))" + end +end + +function submit_sbatch_job(sbatch_filepath; debug = false, env = ENV) + jobid = readchomp(setenv(`sbatch --parsable $sbatch_filepath`, env)) + debug || rm(sbatch_filepath) + return parse(Int, jobid) +end + +job_running(status) = status == "RUNNING" +job_success(status) = status == "COMPLETED" +job_failed(status) = status == "FAILED" +job_completed(status) = job_failed(status) || job_success(status) + +""" + job_status(jobid) + +Parse the slurm jobid's state and return one of three status strings: "COMPLETED", "FAILED", or "RUNNING" +""" +function job_status(jobid) + failure_statuses = ("FAILED", "CANCELLED+", "CANCELLED") + output = readchomp(`sacct -j $jobid --format=State --noheader`) + # Jobs usually have multiple statuses + statuses = strip.(split(output, "\n")) + if all(s -> s == "COMPLETED", statuses) + return "COMPLETED" + elseif any(s -> s in failure_statuses, statuses) + return "FAILED" + else + return "RUNNING" + end +end + +""" + kill_all_jobs(jobids) + +Takes a list of slurm job IDs and runs `scancel` on them. +""" +function kill_all_jobs(jobids) + for jobid in jobids + try + kill_slurm_job(jobid) + println("Cancelling slurm job $jobid") + catch e + println("Failed to cancel slurm job $jobid: ", e) + end + end +end + +kill_slurm_job(jobid) = run(`scancel $jobid`) + +function format_slurm_time(minutes::Int) + days, remaining_minutes = divrem(minutes, (60 * 24)) + hours, remaining_minutes = divrem(remaining_minutes, 60) + # Format the string according to Slurm's time format + if days > 0 + return string( + days, + "-", + lpad(hours, 2, '0'), + ":", + lpad(remaining_minutes, 2, '0'), + ":00", + ) + else + return string( + lpad(hours, 2, '0'), + ":", + lpad(remaining_minutes, 2, '0'), + ":00", + ) + end +end + +format_slurm_time(str::AbstractString) = str diff --git a/test/Project.toml b/test/Project.toml index 7e0a695b..77ce91b0 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -1,6 +1,6 @@ [deps] -ClimaCalibrate = "4347a170-ebd6-470c-89d3-5c705c0cacc2" CalibrateEmulateSample = "95e48a1f-0bec-4818-9538-3db4340308e3" +ClimaCalibrate = "4347a170-ebd6-470c-89d3-5c705c0cacc2" ClimaParams = "5c42b081-d73a-476f-9059-fd94b934656c" Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" EnsembleKalmanProcesses = "aa8a2aa5-91d8-4396-bcef-d4f2ec43552d" diff --git a/test/caltech_hpc_e2e.jl b/test/caltech_hpc_e2e.jl index 19b3e247..29fdabc2 100644 --- a/test/caltech_hpc_e2e.jl +++ b/test/caltech_hpc_e2e.jl @@ -3,7 +3,7 @@ # And include this file import ClimaCalibrate: - get_backend, CaltechHPC, JuliaBackend, calibrate, get_prior + get_backend, CaltechHPC, JuliaBackend, calibrate, get_prior, kwargs using Test import EnsembleKalmanProcesses: get_ϕ_mean_final, get_g_mean_final @@ -40,8 +40,8 @@ backend = get_backend() eki = calibrate( backend, experiment_dir; - time_limit = 5, model_interface, + slurm_kwargs = kwargs(time = 5), verbose = true, ) test_sf_calibration_output(eki, prior) diff --git a/test/slurm_unit_tests.jl b/test/slurm_unit_tests.jl index acf026fd..715a15fc 100644 --- a/test/slurm_unit_tests.jl +++ b/test/slurm_unit_tests.jl @@ -13,6 +13,13 @@ const GPUS_PER_TASK = 1 const EXPERIMENT_DIR = "exp/dir" const MODEL_INTERFACE = "model_interface.jl" +const slurm_kwargs = CAL.kwargs( + time = TIME_LIMIT, + partition = PARTITION, + cpus_per_task = CPUS_PER_TASK, + gpus_per_task = GPUS_PER_TASK, +) + # Time formatting tests @test CAL.format_slurm_time(TIME_LIMIT) == "01:30:00" @test CAL.format_slurm_time(1) == "00:01:00" @@ -21,27 +28,22 @@ const MODEL_INTERFACE = "model_interface.jl" # Generate and validate sbatch file contents sbatch_file = CAL.generate_sbatch_script( - OUTPUT_DIR, ITER, MEMBER, - TIME_LIMIT, - NTASKS, - PARTITION, - CPUS_PER_TASK, - GPUS_PER_TASK, + OUTPUT_DIR, EXPERIMENT_DIR, - MODEL_INTERFACE, + MODEL_INTERFACE; + slurm_kwargs, ) expected_sbatch_contents = """ #!/bin/bash #SBATCH --job-name=run_1_1 -#SBATCH --time=01:30:00 -#SBATCH --ntasks=1 +#SBATCH --output=test/iteration_001/member_001/model_log.txt #SBATCH --partition=expansion -#SBATCH --cpus-per-task=16 #SBATCH --gpus-per-task=1 -#SBATCH --output=test/iteration_001/member_001/model_log.txt +#SBATCH --cpus-per-task=16 +#SBATCH --time=01:30:00 export MODULEPATH=/groups/esm/modules:\$MODULEPATH module purge