Skip to content

Commit

Permalink
Clean up slurm pipeline (#73)
Browse files Browse the repository at this point in the history
  • Loading branch information
nefrathenrici authored Mar 25, 2024
1 parent ee60582 commit 50ad5e4
Show file tree
Hide file tree
Showing 12 changed files with 67 additions and 43 deletions.
2 changes: 1 addition & 1 deletion docs/src/experiment_setup_guide.md
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,7 @@ observation_map(::Val(Symbol(experiment_id)), iteration)
This function must load in model diagnostics for each ensemble member in the iteration and construct an array `arr = Array{Float64}(undef, dims..., ensemble_size)` such that
`arr[:, i]` will return the i-th ensemble member's observation map output. Note this floating point precision is required for the EKI update step.

In the update step of EKI, the array will be saved in a JLD2 file named `observation_map.jld2` in the iteration folder of the output directory.
In the update step of EKI, the array will be saved in a JLD2 file named `G_ensemble.jld2` in the iteration folder of the output directory.

As an example, in `observation_map(iteration)` in the `sphere_held_suarez_rhoe_equilmoist` experiment, we have the following sequence:

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
#SBATCH --ntasks=64
#SBATCH --cpus-per-task=8
#SBATCH --partition=expansion
#SBATCH --output="experiments/sphere_held_suarez_rhoe_equilmoist/truth_simulation/model_log.out"
#SBATCH --output="experiments/sphere_held_suarez_rhoe_equilmoist/truth_simulation/model_log.txt"
#SBATCH --partition=expansion

# Configure the environment
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,17 +3,15 @@ using Statistics
import YAML
import EnsembleKalmanProcesses: TOMLInterface
import JLD2
import CalibrateAtmos: observation_map
import CalibrateAtmos: observation_map, get_ekp_config
using ClimaAnalysis
export observation_map

function observation_map(::Val{:sphere_held_suarez_rhoe_equilmoist}, iteration)
experiment_id = "sphere_held_suarez_rhoe_equilmoist"
config =
YAML.load_file(joinpath("experiments", experiment_id, "ekp_config.yml"))
config = get_ekp_config(experiment_id)
output_dir = config["output_dir"]
ensemble_size = config["ensemble_size"]
model_output = "ta_60d_average.nc"

dims = 1
G_ensemble = Array{Float64}(undef, dims..., ensemble_size)
Expand Down
4 changes: 2 additions & 2 deletions experiments/surface_fluxes_perfect_model/Manifest.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# This file is machine-generated - editing it directly is not advised

julia_version = "1.10.1"
julia_version = "1.10.2"
manifest_format = "2.0"
project_hash = "8b52a1f87337958a3f0ec6e731cccb862b78de20"

Expand Down Expand Up @@ -300,7 +300,7 @@ uuid = "83423d85-b0ee-5818-9007-b63ccbeb887a"
version = "1.16.1+1"

[[deps.CalibrateAtmos]]
deps = ["CalibrateEmulateSample", "ClimaComms", "ClimaCore", "ClimaParams", "Distributions", "EnsembleKalmanProcesses", "JLD2", "PrecompileTools", "Random", "SciMLBase", "TOML", "YAML"]
deps = ["CalibrateEmulateSample", "ClimaComms", "ClimaParams", "Distributions", "EnsembleKalmanProcesses", "JLD2", "PrecompileTools", "Random", "SciMLBase", "TOML", "YAML"]
path = "../.."
uuid = "4347a170-ebd6-470c-89d3-5c705c0cacc2"
version = "0.1.0"
Expand Down
2 changes: 0 additions & 2 deletions experiments/surface_fluxes_perfect_model/generate_truth.jl
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
# generate_truth: generate true y, noise and x_inputs
using Pkg
experiment_id = "surface_fluxes_perfect_model"
Pkg.activate("experiments/$experiment_id")

import SurfaceFluxes as SF
import SurfaceFluxes.Parameters as SFPP
Expand Down
5 changes: 2 additions & 3 deletions experiments/surface_fluxes_perfect_model/observation_map.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ using Statistics
import YAML
import EnsembleKalmanProcesses: TOMLInterface
import JLD2
import CalibrateAtmos: observation_map
import CalibrateAtmos: observation_map, get_ekp_config

"""
observation_map(::Val{:surface_fluxes_perfect_model}, iteration)
Expand All @@ -12,8 +12,7 @@ as specified by process_member_data, for the given iteration.
"""
function observation_map(::Val{:surface_fluxes_perfect_model}, iteration)
experiment_id = "surface_fluxes_perfect_model"
config =
YAML.load_file(joinpath("experiments", experiment_id, "ekp_config.yml"))
config = get_ekp_config(experiment_id)
output_dir = config["output_dir"]
ensemble_size = config["ensemble_size"]
model_output = "model_ustar_array.jld2"
Expand Down
2 changes: 1 addition & 1 deletion pipeline.sh
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ fi
init_id=$(sbatch --parsable \
--output=$logfile \
--partition=$partition \
--export=generate_data=$generate_data \
slurm/initialize.sbatch $experiment_id)
echo -e "Initialization job_id: $init_id\n"

Expand All @@ -21,7 +22,6 @@ dependency="afterok:$init_id"
for i in $(seq 0 $((n_iterations - 1)))
do
echo "Scheduling iteration $i"

ensemble_array_id=$(
sbatch --dependency=$dependency --kill-on-invalid-dep=yes --parsable \
--job=model-$i \
Expand Down
10 changes: 9 additions & 1 deletion slurm/initialize.sbatch
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,15 @@ experiment_id=$1
JULIA_NUM_PRECOMPILE_TASKS=8

echo "Initializing calibration for experiment: $experiment_id"
julia --color=no --project=experiments/$experiment_id -e 'using Pkg; Pkg.instantiate(;verbose=true)'

julia --color=no --project=experiments/$experiment_id -e '
using Pkg; Pkg.instantiate(;verbose=true)
'

if [ "$generate_data" = true ] ; then
echo "Generating observations"
julia --project=experiments/$experiment_id experiments/$experiment_id/generate_truth.jl
fi

julia --color=no --project=experiments/$experiment_id -e '
import CalibrateAtmos
Expand Down
25 changes: 15 additions & 10 deletions slurm/model_run.sbatch
Original file line number Diff line number Diff line change
Expand Up @@ -2,19 +2,24 @@

# Extract command-line arguments
experiment_id=$1
iteration=$2
i=$2

# Find output directory
format_i=$(printf "iteration_%03d" "$iteration")
format_i=$(printf "iteration_%03d" "$i")
member=$(printf "member_%03d" "$SLURM_ARRAY_TASK_ID")
output=output/$experiment_id/$format_i/$member/model_log.out
output=output/$experiment_id/$format_i/$member/model_log.txt

# Run the forward model
srun --output=$output julia --color=no --project=experiments/$experiment_id -e "
import CalibrateAtmos
include(\"experiments/$experiment_id/model_interface.jl\")
srun --output=$output julia --color=no --project=experiments/$experiment_id -e '
import CalibrateAtmos as CAL
experiment_id = "'$experiment_id'"
i = '$i'
member = '$SLURM_ARRAY_TASK_ID'
physical_model = CalibrateAtmos.get_forward_model(Val(:$experiment_id))
config = CalibrateAtmos.get_config(physical_model, $SLURM_ARRAY_TASK_ID, $iteration, \"$experiment_id\")
CalibrateAtmos.run_forward_model(physical_model, config)
"
include("experiments/$experiment_id/model_interface.jl")
physical_model = CAL.get_forward_model(Val(Symbol(experiment_id)))
config = CAL.get_config(physical_model, member, i, experiment_id)
CAL.run_forward_model(physical_model, config)
'
11 changes: 8 additions & 3 deletions slurm/parse_commandline.sh
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ slurm_time="2:00:00"
slurm_ntasks="1"
slurm_cpus_per_task="1"
slurm_gpus_per_task="0"

generate_data=false
help_message="Usage:
./pipeline.sh [options] experiment_id
Expand All @@ -12,13 +12,14 @@ Options:
-n, --ntasks: Set number of tasks to launch (default: 1).
-c, --cpus_per_task: Set CPU cores per task (mutually exclusive with -g, default: 8).
-g, --gpus_per_task: Set GPUs per task (mutually exclusive with -c, default: 0).
--generate_data: If set, generates observational data for use in the calibration.
-h, --help: Display this help message.
Arguments:
experiment_id: A unique identifier for your experiment (required)."

# Parse arguments using getopt
VALID_ARGS=$(getopt -o h,t:,n:,c:,g: --long help,time:,ntasks:,cpus_per_task:,gpus_per_task: -- "$@")
VALID_ARGS=$(getopt -o h,t:,n:,c:,g: --long help,time:,ntasks:,cpus_per_task:,gpus_per_task:,generate_data -- "$@")
if [[ $? -ne 0 ]]; then
exit 1;
fi
Expand All @@ -44,6 +45,10 @@ while [ : ]; do
slurm_gpus_per_task="$2"
shift 2
;;
--generate_data)
generate_data=true
shift 1
;;
-h | --help)
printf "%s\n" "$help_message"
exit 0
Expand All @@ -62,7 +67,7 @@ fi
ensemble_size=$(grep "ensemble_size:" experiments/$experiment_id/ekp_config.yml | awk '{print $2}')
n_iterations=$(grep "n_iterations:" experiments/$experiment_id/ekp_config.yml | awk '{print $2}')
output=$(grep "output_dir:" experiments/$experiment_id/ekp_config.yml | awk '{print $2}')
logfile=$output/experiment_log.out
logfile=$output/experiment_log.txt

# Set partition
if [[ $slurm_gpus_per_task -gt 0 ]]; then
Expand Down
12 changes: 4 additions & 8 deletions slurm/update.sbatch
Original file line number Diff line number Diff line change
Expand Up @@ -6,17 +6,13 @@ i=$2
echo "Running update step after iteration $i"

julia --color=no --project=experiments/$experiment_id -e '
import YAML, JLD2
import CalibrateAtmos
import CalibrateAtmos as CAL
experiment_id = "'$experiment_id'"
i = '$i'
include("experiments/'$experiment_id'/model_interface.jl")
G_ensemble = CalibrateAtmos.observation_map(Val(Symbol(experiment_id)), i)
config = YAML.load_file(joinpath("experiments", experiment_id, "ekp_config.yml"))
output_dir = config["output_dir"]
iter_path = CalibrateAtmos.path_to_iteration(output_dir, i)
JLD2.save_object(joinpath(iter_path, "observation_map.jld2"), G_ensemble)
CalibrateAtmos.update_ensemble(experiment_id, i)
G_ensemble = CAL.observation_map(Val(Symbol(experiment_id)), i)
CAL.save_G_ensemble(experiment_id, i, G_ensemble)
CAL.update_ensemble(experiment_id, i)
'
echo "Update step for iteration $i complete"
29 changes: 22 additions & 7 deletions src/ekp_interface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,25 @@ function get_prior(param_dict::AbstractDict; names = nothing)
return prior
end

"""
get_ekp_config(experiment_id)
Load the EKP configuration for a given `experiment_id`
"""
get_ekp_config(experiment_id) =
YAML.load_file(joinpath("experiments", experiment_id, "ekp_config.yml"))

"""
save_G_ensemble(experiment_id, iteration, G_ensemble)
Save an ensemble's observation map output to the correct folder.
"""
function save_G_ensemble(experiment_id, iteration, G_ensemble)
config = get_ekp_config(experiment_id)
iter_path = path_to_iteration(config["output_dir"], iteration)
JLD2.save_object(joinpath(iter_path, "G_ensemble.jld2"), G_ensemble)
end

"""
initialize(
experiment_id;
Expand Down Expand Up @@ -109,7 +128,7 @@ function update_ensemble(
eki = JLD2.load_object(eki_path)

# Load data from the ensemble
G_ens = JLD2.load_object(joinpath(iter_path, "observation_map.jld2"))
G_ens = JLD2.load_object(joinpath(iter_path, "G_ensemble.jld2"))

# Update
EKP.update_ensemble!(eki, G_ens)
Expand Down Expand Up @@ -154,8 +173,7 @@ eki = CalibrateAtmos.calibrate(experiment_id)
```
"""
function calibrate(experiment_id; device = ClimaComms.device())
ekp_config =
YAML.load_file(joinpath("experiments", experiment_id, "ekp_config.yml"))
ekp_config = get_ekp_config(experiment_id)
# initialize the CalibrateAtmos
initialize(experiment_id)

Expand All @@ -179,10 +197,7 @@ function calibrate(experiment_id; device = ClimaComms.device())

# update EKP with the ensemble output and update calibrated parameters
G_ensemble = observation_map(Val(Symbol(experiment_id)), i)
JLD2.save_object(
joinpath(path_to_iteration(output_dir, i), "observation_map.jld2"),
G_ensemble,
)
save_G_ensemble(experiment_id, i, G_ensemble)
eki = update_ensemble(experiment_id, i)
end
return eki
Expand Down

0 comments on commit 50ad5e4

Please sign in to comment.