-
Notifications
You must be signed in to change notification settings - Fork 0
/
run.py
143 lines (113 loc) · 4.42 KB
/
run.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
import os
import json
import random
import time
import argparse
import numpy as np
import torch
from models import (
Autoencoder,
ConvolutionalAutoencoder,
ConvolutionalVAE,
DenoisingAutoencoder,
SparseAutoencoder,
VariationalAutoencoder,
DenoisingConvolutionalAutoencoder,
SparseConvolutionalAutoencoder
)
from settings import settings
from utils.dataloader import get_dataloader
from utils.trainer import train_autoencoder, visualize_reconstructions, load_checkpoint, evaluate_autoencoder
from utils import utils
def get_model_by_type(ae_type=None, input_dim=None, encoding_dim=None, device=None):
models = {
'ae': lambda: Autoencoder(input_dim, encoding_dim),
'dae': lambda: DenoisingAutoencoder(input_dim, encoding_dim),
'sparse': lambda: SparseAutoencoder(input_dim, encoding_dim),
'vae': lambda: VariationalAutoencoder(input_dim, encoding_dim),
'conv': ConvolutionalAutoencoder,
'conv_dae': DenoisingConvolutionalAutoencoder,
'conv_vae': ConvolutionalVAE,
'conv_sparse': SparseConvolutionalAutoencoder,
}
if ae_type is None:
return list(models.keys())
if ae_type not in models:
raise ValueError(f"Unknown AE type: {ae_type}")
model = models[ae_type]()
return model.to(device)
def set_seed(seed):
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
np.random.seed(seed)
random.seed(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
def load_params(path):
with open(path, "r", encoding='utf-8') as file:
params = json.load(file)
return params
def main(load_trained_model, ae_type=None, num_epochs=5, test_mode=True):
set_seed(1)
params = load_params(settings.PATH_PARAMS_JSON)
batch_size = params["batch_size"]
resolution = params["resolution"]
encoding_dim = params["encoding_dim"]
learning_rate = params.get("learning_rate", 0.001)
save_checkpoint = params["save_checkpoint"]
if not ae_type:
ae_type = params["ae_type"]
num_epochs = params["num_epochs"]
test_mode = False
# Calculate input_dim based on resolution
input_dim = 3 * resolution * resolution
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
dataloader = get_dataloader(settings.DATA_PATH, batch_size, resolution)
model = get_model_by_type(ae_type, input_dim, encoding_dim, device)
optimizer = torch.optim.Adam(model.parameters())
try:
if not load_trained_model:
start_epoch = 1
if os.path.exists(settings.PATH_SAVED_MODEL):
model, optimizer, start_epoch = load_checkpoint(
model, optimizer, settings.PATH_SAVED_MODEL, device
)
print(f"Loaded checkpoint and continuing training from epoch {start_epoch}.")
start_time = time.time()
train_autoencoder(
model,
dataloader,
num_epochs=num_epochs,
learning_rate=learning_rate,
device=device,
start_epoch=start_epoch,
optimizer=optimizer,
save_checkpoint=save_checkpoint,
ae_type=ae_type
)
elapsed_time = utils.format_time(time.time() - start_time)
print(f"\nTraining took {elapsed_time}")
print(f"Training complete up to epoch {num_epochs}!")
except KeyboardInterrupt:
print("\nTraining interrupted by user.")
if not test_mode:
valid_dataloader = get_dataloader(settings.VALID_DATA_PATH, batch_size, resolution)
avg_valid_loss = evaluate_autoencoder(model, valid_dataloader, device)
print(f"\nAverage validation loss: {avg_valid_loss:.4f}\n")
visualize_reconstructions(
model, valid_dataloader, num_samples=10,
device=device, resolution=resolution
)
if __name__ == "__main__":
parser = argparse.ArgumentParser(description='Training and testing autoencoders.')
parser.add_argument(
'--test', action='store_true', help='Run the test routine for all autoencoders.'
)
args = parser.parse_args()
if args.test:
ae_types = get_model_by_type()
for ae_type in ae_types:
print(f"\n===== Training {ae_type} =====\n")
main(load_trained_model=False, ae_type=ae_type)
else:
main(load_trained_model=False)