-
Notifications
You must be signed in to change notification settings - Fork 0
/
main.py
141 lines (121 loc) · 5.51 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
from config import *
from utils.DatasetHandler import DatasetHandler
from models.HybridNet import HybridNet
from qc.QiskitCircuit import QiskitCircuit
from utils.utils import *
from sklearn.metrics import confusion_matrix, classification_report
import torch.optim as optim
import torch.nn as nn
import torch
import os
# Suppressing warning
import warnings
warnings.filterwarnings('ignore')
#=======================================================================================================================
print('\n[%] Checking for the availability of GPUs')
gpu = False
if torch.cuda.is_available():
device = torch.device("cuda:0")
gpu = True
else:
device = torch.device("cpu")
print('\t [*] Running on device: {}'.format(device))
#=======================================================================================================================
print('\n[%] Initialize Quantum Hybrid Neural Network')
if gpu:
network = HybridNet()
else:
network = HybridNet()
optimizer = optim.Adam(network.parameters(), lr=LEARNING_RATE)#, momentum = MOMENTUM)
#=======================================================================================================================
print('\n[%] Printing Quantum Circuit')
circuit = QiskitCircuit(NUM_QUBITS, SIMULATOR, NUM_SHOTS)
print(circuit.circuit.draw(output='text', scale=1/NUM_LAYERS))
#=======================================================================================================================
print('\n[%] Printing Quantum Circuit Parameters')
print('\t [*] Number of Qubits: {}'.format(NUM_QUBITS))
print('\t [*] Number of R Layers: {}'.format(NUM_LAYERS))
print('\t [*] Number of Outputs: {}'.format(NUM_QC_OUTPUTS))
print('\t [*] Number of Shots: {}'.format(NUM_SHOTS))
#=======================================================================================================================
print('\n[%] Loading Dataset')
handler_train = DatasetHandler(training_root)
handler_val = DatasetHandler(validation_root)
classes = []
for i, c in enumerate(handler_train.classes):
cl = c.split(os.path.sep)[-1]
classes.append(cl)
classes.sort()
print('\t [*] Training classes: ',classes)
train_imgs, train_labels = handler_train.load_paths_labels(training_root, classes=classes)
val_imgs, val_labels = handler_val.load_paths_labels(validation_root, classes=classes)
train_loader = iter(handler_train.qcnn_data_loader(train_imgs, train_labels, batch_size=1, img_shape=(64,64,3)))
test_loader = iter(handler_val.qcnn_data_loader(val_imgs, val_labels, batch_size=1, img_shape=(64,64,3)))
print('\t [*] Training size: ', len(train_imgs))
print('\t [*] Validation size:', len(val_imgs))
#=======================================================================================================================
print('\n[%] Starting Training')
if LOAD_CHECKPOINT:
print('\t[%] Loading Checkpoint')
try:
checkpoint = torch.load(MODEL_SAVE_PATH)
network.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
epoch = checkpoint['epoch']
loss = checkpoint['loss']
print('\t\t [*] Checkpoint loaded, starting from:')
print('\t\t\t - Epoch {}'.format(epoch))
print('\t\t\t - Loss {}'.format(loss))
except:
print('\t\t [!] Checkpoint not found, training from scratch')
else:
print('\t [!] Checkpoint not activated, training from scratch')
#=======================================================================================================================
if TRAINING:
print('\t [*] Training ...')
train_loss_list = []
val_loss_list = []
loss_func = nn.CrossEntropyLoss()
for epoch in range(EPOCHS):
total_loss = []
for batch_idx in range(len(train_labels)):
data, target = next(train_loader)
optimizer.zero_grad()
# Forward pass
output = network(data)
# Calculating loss
loss = loss_func(output, target)
# Backward pass
loss.backward()
# Optimize the weights
optimizer.step()
total_loss.append(loss.item())
print('\r\t\t [*] [Epoch %d/%d] [Batch %d/%d] [Train Loss %f] ' % (epoch, EPOCHS, batch_idx, len(train_imgs) - 1, loss.item()),
end='\t\t')
with torch.no_grad():
val_loss = []
targets = []
predictions = []
for batch_idx in range(len(val_imgs)):
data, target = next(test_loader)
output = network(data)
loss = loss_func(output, target)
val_loss.append(loss.item())
targets.append(target.item())
predictions.append(network.predict(data).item())
train_loss_list.append(sum(total_loss) / len(total_loss))
val_loss_list.append(sum(val_loss) / len(val_loss))
print('[Val Loss %f] ' % (val_loss_list[-1]))
if epoch % 3 == 1:
cf = confusion_matrix(targets, predictions, normalize='true')
cr = classification_report(targets, predictions, target_names=classes, digits=4)
print('\t\t [*] Confusion Matrix:')
print_CF(cf, classes)
print('\t\t [*] Classification Report:')
print(cr)
torch.save({
'epoch': epoch,
'model_state_dict': network.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'loss': train_loss_list[-1],
}, MODEL_SAVE_PATH)