forked from karpathy/nanoGPT
-
Notifications
You must be signed in to change notification settings - Fork 18
/
run_experiments.py
165 lines (137 loc) · 7.26 KB
/
run_experiments.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
import json
import subprocess
import os
import sys
import pandas as pd
import argparse
from datetime import datetime
from itertools import product
from rich import print
from rich.console import Console
from rich.table import Table
def parse_args():
parser = argparse.ArgumentParser(description="Run experiments based on a json configuration file.")
parser.add_argument('-c', "--config", type=str, required=True, help="Path to the configuration JSON file.")
parser.add_argument('-o', "--output_dir", type=str, default="out", help="Directory to place the set of output checkpoints.")
parser.add_argument("--csv_ckpt_dir", type=str, default="", help="Directory to place the set of csv checkpoints in csv_logs.")
parser.add_argument("--prefix", type=str, default='', help="Optional prefix for tensorboard_run_name and out_dir.")
parser.add_argument("--add_names", action="store_true", help="Include names of values of the configuration parameters in addition to values (may cause too long a file name).")
parser.add_argument("--use-best-val-loss-from", nargs=2, metavar=('csv_dir', 'output_dir'), type=str, default=['', ''],
help="Grab the best val loss of the run given by the csv_dir. Then, use the corresponding ckpt from the matching output_dir")
parser.add_argument('--override_max_iters', default=None, type=int)
parser.add_argument('--override_dataset', default=None, type=str)
parser.add_argument('--override_block_size', default=None, type=int)
return parser.parse_args()
def find_best_val_loss(csv_dir, output_dir):
csvList = os.listdir(csv_dir)
best_ckpt_loss = sys.float_info.max
best_ckpt_name = ""
for fname in csvList:
ext = os.path.splitext(fname)
params = ext[0].split('-')
df = pd.read_csv(f"{csv_dir}/{fname}", header=None)
best_valid_loss = df.iloc[:,-1].dropna().min()
if best_valid_loss < best_ckpt_loss:
best_ckpt_loss = best_valid_loss
best_ckpt_name = ext[0]
dirList = os.listdir(output_dir)
for fname in dirList:
params = best_ckpt_name.split('-', 2)
if params[2] in fname:
best_ckpt_name = fname
break
print("best_valid_loss: ", best_ckpt_loss)
print("best_valid_chpt: ", best_ckpt_name)
return f"{output_dir}/{best_ckpt_name}"
def check_conditions(conditions, combo_dict):
return all(combo_dict.get(cond[0]) == cond[1] for cond in conditions)
def expand_range(value):
if isinstance(value, dict) and 'range' in value:
range_def = value['range']
start, end = range_def['start'], range_def['end']
step = range_def.get('step', 1 if isinstance(start, int) else 0.1)
return list(range(start, end + 1, step)) if isinstance(start, int) else [start + i * step for i in range(int((end - start) / step) + 1)]
return value
def generate_combinations(config):
print("Generating parameter combinations...")
# Extract parameter groups if present, otherwise use an empty list to simplify logic
parameter_groups = config.pop('parameter_groups', [{}])
# Prepare base and conditional parameters
base_params = {k: expand_range(v) if isinstance(v, dict) and 'range' in v else v for k, v in config.items() if not isinstance(v, dict) or ('range' in v)}
base_params = {k: v if isinstance(v, list) else [v] for k, v in base_params.items()}
conditional_params = {k: v for k, v in config.items() if isinstance(v, dict) and 'conditions' in v}
for group in parameter_groups:
current_base_params = {**base_params, **group}
base_combinations = list(product(*[[(k, v) for v in values] for k, values in current_base_params.items()]))
for base_combination in base_combinations:
combo_dict = dict(base_combination)
valid_combos = [combo_dict]
for cond_param, values in conditional_params.items():
new_combos = []
for combo in valid_combos:
if check_conditions(values['conditions'], combo):
option_values = values['options'] if isinstance(values['options'], list) else [values['options']]
for option_value in option_values:
new_combo = {**combo, cond_param: option_value}
new_combos.append(new_combo)
else:
new_combos.append(combo)
valid_combos = new_combos
for combo in valid_combos:
yield combo
def format_config_name(config, config_basename, prefix, add_names):
if add_names:
config_items = [f"{k}_{v}" for k, v in config.items()]
else:
config_items = [f"{v}" for _, v in config.items()]
return f"{prefix}{config_basename}-{'-'.join(config_items)}"
def run_command(config, config_basename, output_dir, csv_ckpt_dir, prefix, add_names,
best_val_loss_from, override_max_iters, override_dataset, override_block_size):
formatted_name = format_config_name(config, config_basename, prefix, add_names)
base_command = ["python3", "train.py"]
config['tensorboard_run_name'] = formatted_name
timestamp_prefix = datetime.now().strftime('%Y%m%d_%H%M%S')
config['out_dir'] = os.path.join(output_dir, f"{timestamp_prefix}_{formatted_name}")
base_command.extend(["--timestamp", timestamp_prefix])
if override_max_iters:
config['max_iters'] = str(override_max_iters)
if override_dataset:
config['dataset'] = override_dataset
if override_block_size:
config['block_size'] = str(override_block_size)
# Print the entered arguments before each run
console = Console()
table = Table(title="Entered Arguments", show_header=True, header_style="bold magenta")
table.add_column("Argument", style="cyan")
table.add_column("Value", style="green")
for key, value in config.items():
table.add_row(key, str(value))
console.print(table)
for key, value in config.items():
if isinstance(value, bool):
base_command.extend([f"--{'' if value else 'no-'}{key}"])
elif isinstance(value, list):
for val in value:
base_command.extend([f"--{key}", str(val)])
else:
base_command.extend([f"--{key}", str(value)])
if best_val_loss_from[0] and best_val_loss_from[1]:
base_command.extend(["--init_from", "prev_run"])
ckpt_path = find_best_val_loss(best_val_loss_from[0], best_val_loss_from[1])
base_command.extend(["--prev_run_ckpt", ckpt_path])
if csv_ckpt_dir:
base_command.extend(["--csv_ckpt_dir", csv_ckpt_dir])
print(f"Running command: {' '.join(base_command)}")
subprocess.run(base_command)
def main():
args = parse_args()
config_basename = os.path.splitext(os.path.basename(args.config))[0]
with open(args.config, 'r') as file:
original_configurations = json.load(file)
for config in original_configurations:
for combination in generate_combinations(config):
run_command(combination, config_basename, args.output_dir, args.csv_ckpt_dir,
args.prefix, args.add_names, args.use_best_val_loss_from,
args.override_max_iters, args.override_dataset, args.override_block_size)
if __name__ == "__main__":
main()