-
Notifications
You must be signed in to change notification settings - Fork 3
/
sweep.py
37 lines (26 loc) · 853 Bytes
/
sweep.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
"""
Create a weights & biases sweep file based on the .yaml config file specified as an argument to the script.
Then sets the sweep ID as a environment variable so that it can be accessed easily.
"""
# STD
import os
import yaml
import sys
# EXT
import wandb
# PROJECT
from src.constants import PROJECT_NAME
try:
from secret import WANDB_API_KEY, WANDB_USER_NAME
os.environ["WANDB_API_KEY"] = WANDB_API_KEY
except ModuleNotFoundError:
WANDB_USER_NAME = os.environ["WANDB_API_KEY"]
if __name__ == "__main__":
wandb.init(PROJECT_NAME)
# Get path to sweep .yaml
config_yaml = sys.argv[1]
num_runs = int(sys.argv[2])
with open(config_yaml) as file:
config_dict = yaml.load(file, Loader=yaml.FullLoader)
sweep_id = wandb.sweep(config_dict, project=PROJECT_NAME)
wandb.agent(sweep_id, count=num_runs)