-
Notifications
You must be signed in to change notification settings - Fork 3
/
Copy pathrun_trainer.pbs
160 lines (126 loc) · 3.17 KB
/
run_trainer.pbs
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
#!/bin/bash
#PBS -P <your NCI project>
#PBS -q gpuvolta
#PBS -l walltime=15:00:00
#PBS -l mem=382GB
#PBS -l ncpus=48
#PBS -l ngpus=4
#PBS -l jobfs=200GB
#PBS -l storage=gdata/rt52+gdata/<your NCI project>
#PBS -N pbs_fourcastnext_trainer
set -eu
module load cuda/11.7.0
curr_dir="$( cd "$( dirname "${BASH_SOURCE[0]}" )" && pwd )"
curr_path=${PBS_O_WORKDIR:-$curr_dir}
export PATH=$curr_path/python_env/bin:$PATH
export WORLD_SIZE=${PBS_NGPUS:-1}
max_pretrain_steps=40000
max_finetune_steps=2000
## resume training from the last checkpoint
## e.g. resume_checkpoint_path="$curr_path/lightning_logs/version_1/checkpoints/model-step=19500.ckpt"
resume_checkpoint_path=""
ray_port=26378
ray_head_host=$(hostname)
num_ray_worker_nodes=3
ray_worker_file=$curr_path/ray_worker.pbs
ray_worker_name=ray_worker
ray_tmp_root=/tmp
mkdir -p $ray_tmp_root
ray_tmp_root=$(mktemp -p $ray_tmp_root -d XXX)
job_id_file=$PBS_JOBFS/job_ids
export OMP_NUM_THREADS=1
function cleanup {
set +e
rm -f $ray_worker_file
qdel $(cat $job_id_file) &
rm -rf $ray_tmp_root &
ray stop -f &
set -e
wait
}
trap "cleanup" EXIT
cat > $ray_worker_file <<-EOF
#!/bin/bash
#PBS -P fp0
#PBS -q normal
#PBS -l walltime=15:00:00
#PBS -l mem=96GB
#PBS -l ncpus=24
#PBS -l jobfs=200GB
#PBS -l storage=gdata/rt52+gdata/fr5
#PBS -N $ray_worker_name
#PBS -e /g/data/fr5/jxg900/fourcastnet/pbs_logs
#PBS -o /g/data/fr5/jxg900/fourcastnet/pbs_logs
set -xeu
export PATH=$PATH
export PYTHONPATH=$curr_path
export OMP_NUM_THREADS=1
export RAY_num_heartbeats_timeout=120
mkdir -p $ray_tmp_root
for i in {1..600}
do
if ray start --address='$ray_head_host:$ray_port' --num-cpus=\$(expr \$PBS_NCPUS - 1) --block --disable-usage-stats
then
break
fi
echo "restarting ray worker \$i"
sleep 30
done
EOF
mkdir -p $ray_tmp_root
export RAY_num_heartbeats_timeout=120
ray start --head --port=$ray_port \
--num-cpus=$(expr $PBS_NCPUS - 3) \
--num-gpus=0 \
--disable-usage-stats \
--include-dashboard=False \
--temp-dir=$ray_tmp_root
nvidia-smi
(
for i in $(seq 1 $num_ray_worker_nodes)
do
qsub $ray_worker_file >> $job_id_file
done
) &
## single-step pre-training
if [ ! -f $curr_path/pre-trained_best_model.txt ]
then
(
cd $curr_path
python -u trainer.py \
--max-train-steps=$max_pretrain_steps \
--base-lr=3e-3 \
--max-sampling-time-steps=1 \
--resume-checkpoint-path="$resume_checkpoint_path" \
--best-model-path=$curr_path/pre-trained_best_model.txt
)
fi
## multi-step fine-tuning
(
cd $curr_path
if [ -f $curr_path/fine-tuning.progress ]
then
start_step=$(cat $curr_path/fine-tuning.progress)
start_step=$(echo $start_step+1|bc)
else
start_step=2
fi
for step in $(seq $start_step 4)
do
if [ $step -eq 2 ]
then
best_model_path=$curr_path/pre-trained_best_model.txt
else
best_model_path=$curr_path/best_model.txt
fi
echo "fine-tuning multi-step: $step"
python -u trainer.py \
--max-train-steps=$max_finetune_steps \
--base-lr=1e-4 \
--max-sampling-time-steps=$step \
--resume-checkpoint-path="$(cat $best_model_path)" \
--best-model-path=$curr_path/best_model.txt
echo $step > $curr_path/fine-tuning.progress
done
rm -f $curr_path/fine-tuning.progress
)