-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathmain.py
70 lines (61 loc) · 1.7 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
"""
_summary_
"""
import torch
from matplotlib import pyplot as plt
import tqdm.auto as tqdm
from utils import load_data
from train import train_epoch
from models.r3d import R3D
if __name__ == "__main__":
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
TRAIN_DATA_PATH = "data/shapenetcore/train_imgs"
TEST_DATA_PATH = "data/shapenetcore/test_imgs"
VOXEL_SIZE = 64
pixel_shape = (192,256)
train_2d, train_3d = load_data(
TRAIN_DATA_PATH,
voxel_size = VOXEL_SIZE,
pixel_shape = pixel_shape,
device = device
)
test_2d, test_3d = load_data(
TEST_DATA_PATH,
voxel_size = VOXEL_SIZE,
pixel_shape = pixel_shape,
device = device
)
model = R3D(
).to(device)
# Setting hyperparameters for optimizers
LR = 1e-3
WD = 0.2
betas=(0.9, 0.98)
# Initializing optimizers
vae_optim = torch.optim.AdamW(
model.vae_parameters(),
lr=LR,
weight_decay=WD,
betas=betas
)
gan_optim = torch.optim.AdamW(
model.gan_parameters(),
lr=LR,
weight_decay=WD,
betas=betas
)
NUM_EPOCHS = 25
train_loss = []
for _ in tqdm.tqdm(NUM_EPOCHS):
train_loss.append(train_epoch(
train_2d, train_3d, model, vae_optim, gan_optim
))
# Plotting training losses
plt.figure(figsize=(16, 6))
plt.plot(range(len(train_loss)), [tloss[0] for tloss in train_loss], label='Training VAE loss')
plt.plot(range(len(train_loss)), [tloss[1] for tloss in train_loss], label='Training GAN loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.title('Model Performance')
plt.legend()
plt.show()