-
Notifications
You must be signed in to change notification settings - Fork 0
/
hyper_params_search.py
40 lines (31 loc) · 1.24 KB
/
hyper_params_search.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
"""
the main script that do hyper parameter search
Author: Abdelkarim eljandoubi
date: Nov 2023
"""
import json
import os
from set_trainer import lora_trainer
def hp_space(trial):
"""define the hyperparameter search space"""
return {"num_train_epochs": trial.suggest_int("num_train_epochs", 1, 10),
"learning_rate": trial.suggest_float("learning_rate", 1e-6, 1e-3,
log=True),
"per_device_train_batch_size": trial.suggest_categorical(
"per_device_train_batch_size", [2**i for i in range(5, 8)]),
"gradient_accumulation_steps": trial.suggest_categorical(
"gradient_accumulation_steps", [2**i for i in range(5)])
}
def search(model_checkpoint, n_trials):
"""execute the search"""
# check if this step has been executed
if os.path.isfile("optimal.json"):
return
# load the trainer
trainer = lora_trainer(model_checkpoint)
# search for the best hyperparameters
best_run = trainer.hyperparameter_search(
n_trials=n_trials, hp_space=hp_space)
# write the best hyperparameters
with open("optimal.json", "w", encoding='utf-8') as fp:
json.dump(best_run.hyperparameters, fp)