-
Notifications
You must be signed in to change notification settings - Fork 15
/
run_experiment.py
102 lines (79 loc) · 2.54 KB
/
run_experiment.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
# general
import os
import wandb
import ml_collections
import sys
import copy
# torch
import numpy as np
import torch
# project
from path_handler import model_path
from model import get_model
import dataset
import trainer
import tester
# args
from absl import app
from absl import flags
from ml_collections.config_flags import config_flags
FLAGS = flags.FLAGS
config_flags.DEFINE_config_file("config", default="config.py")
def main(_):
if "absl.logging" in sys.modules:
import absl.logging
absl.logging.set_verbosity("info")
absl.logging.set_stderrthreshold("info")
config = FLAGS.config
print(config)
# Set the seed
torch.manual_seed(config.seed)
np.random.seed(config.seed)
# initialize weight and bias
if not config.train:
os.environ["WANDB_MODE"] = "dryrun"
tags = [
config.model,
config.dataset,
config.kernelnet_activation_function,
"seq_length={}".format(config.seq_length),
]
if config.dataset == "MNIST":
tags.append(str(config.permuted))
wandb.init(
project="ckconv",
config=copy.deepcopy(dict(config)),
group=config.dataset,
entity="vu_uva_team",
tags=tags,
# save_code=True,
# job_type=config.function,
)
# Define the device to be used and move model to that device
config["device"] = (
"cuda:0" if (config.device == "cuda" and torch.cuda.is_available()) else "cpu"
)
# Define transforms and create dataloaders
dataloaders, test_loader = dataset.get_dataset(config, num_workers=4)
# Define model
model = get_model(config)
# WandB – wandb.watch() automatically fetches all layer dimensions, gradients, model parameters and logs them automatically to your dashboard.
# Using log="all" log histograms of parameter values in addition to gradients
# wandb.watch(model, log="all", log_freq=200) # -> There was a wandb bug that made runs in Sweeps crash
# Create model directory and instantiate config.path
model_path(config)
if config.pretrained:
# Load model state dict
model.module.load_state_dict(torch.load(config.path), strict=False)
# Train the model
if config.train:
# Print arguments (Sanity check)
print(config)
# Train the model
import datetime
print(datetime.datetime.now())
trainer.train(model, dataloaders, config, test_loader)
# Select test function
tester.test(model, test_loader, config)
if __name__ == "__main__":
app.run(main)