This repository has been archived by the owner on Mar 19, 2024. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 0
/
main.py
133 lines (114 loc) · 4.91 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
import argparse
import torch
from tqdm import tqdm
import os
import glob
import numpy as np
from torch.optim.lr_scheduler import StepLR
from torch_geometric.datasets import MoleculeNet
from torch_geometric.loader import DataLoader
from gnn_model import GCN
from utils import save_plot, filtered_result
parser = argparse.ArgumentParser(description = "Graph Neural Networks for estimating water solubility of a molecule structure.")
parser.add_argument('-lr', '--learning_rate', default = 4e-3)
parser.add_argument('-ep', '--epoch', default = 2000)
parser.add_argument('-m', '--mode', default="train")
parser.add_argument('-g', '--num_graphs_per_batch', default=6)
args = parser.parse_args()
lr = args.learning_rate
total_epoch = int(args.epoch)
MODE = args.mode.lower()
num_graphs_per_batch = int(args.num_graphs_per_batch)
data = MoleculeNet(root="./dataset/",name="ESOL")
num_features = data.num_features # features of a node
gcn_model = GCN(num_features)
# Root mean squared error
loss_fn = torch.nn.MSELoss()
optimizer = torch.optim.Adam(gcn_model.parameters(), lr)
lr_scheduler = StepLR(optimizer, step_size=250, gamma=0.05)
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model = gcn_model.to(device)
data_size = len(data)
train_loader = DataLoader(data[:int(data_size * 0.8)], batch_size=num_graphs_per_batch, shuffle=True)
valid_loader = DataLoader(data[int(data_size * 0.8):int(data_size * 0.9)], batch_size=num_graphs_per_batch, shuffle=True)
test_loader = DataLoader(data[int(data_size * 0.9):], batch_size=num_graphs_per_batch, shuffle=True)
def run_training():
model.train()
# Enumerate over the data
for batch in train_loader:
# Use GPU
batch.to(device)
# Reset gradients
optimizer.zero_grad()
# Passing the node features and the connection info
pred, embedding = model(batch.x.float(), batch.edge_index, batch.batch)
# Calculating the loss and gradients
loss = loss_fn(pred, batch.y)
loss.backward()
# Update using the gradients
optimizer.step()
return loss, embedding
def run_validation():
model.eval()
# Enumerate over the data
for batch in valid_loader:
# Use GPU
batch.to(device)
# Reset gradients
with torch.no_grad():
# Passing the node features and the connection info
pred, embedding = model(batch.x.float(), batch.edge_index, batch.batch)
# Calculating the loss and gradients
loss = loss_fn(pred, batch.y)
return loss, embedding
def run_testing():
model.eval()
y_real_list, y_pred_list = [], []
for batch in test_loader:
# Use GPU
batch.to(device)
# Reset gradients
with torch.no_grad():
y_real_list.extend(batch.y.tolist())
# Passing the node features and the connection info
pred, embedding = model(batch.x.float(), batch.edge_index, batch.batch)
y_pred_list.extend(pred.detach().tolist())
return y_real_list, y_pred_list
if MODE == "train":
print("Starting training...")
train_losses, valid_losses = [], []
saved_validation_loss = 1000000
for epoch in tqdm(range(1, total_epoch+1), desc= "Training Epoch"):
train_loss, h = run_training()
lr_scheduler.step()
valid_loss, valid_h = run_validation()
train_loss = train_loss.detach().numpy()
valid_loss = valid_loss.detach().numpy()
train_losses.append(np.float32(train_loss))
valid_losses.append(np.float32(valid_loss))
if epoch % 100 == 0:
# if valid_loss < saved_validation_loss:
# saved_validation_loss = valid_loss
# os.system("rm ./weights/*.pt")
# torch.save(model.state_dict(),"./weights/"+str(epoch)+".pt")
# print("Weight saved at epoch: ", epoch)
# print(f"Epoch {epoch} | Train Loss {train_loss} | Valid Loss {valid_loss}")
os.system("rm ./weights/*.pt")
torch.save(model.state_dict(),"./weights/"+str(epoch)+".pt")
print("Weight saved at epoch: ", epoch)
print(f"Epoch {epoch} | Train Loss {train_loss} | Valid Loss {valid_loss}")
train_losses, collect_at_each = filtered_result(train_losses)
valid_losses, collect_at_each = filtered_result(valid_losses)
save_plot(train_loss_list=train_losses, valid_loss_list=valid_losses, filter_bucket=collect_at_each)
if MODE in ["test", "train"]:
weight_filename_list = glob.glob("./weights/*.pt")
if len(weight_filename_list) !=1:
print("Error: Weight file not present inside ./weights/")
exit()
else:
weight_filename = weight_filename_list[0]
print("Loading weight file for testing: ", weight_filename)
model.load_state_dict(torch.load(weight_filename))
y_real_list, y_pred_list = run_testing()
save_plot(test_loss_list=[y_real_list, y_pred_list])
print("Code Executed Successfully.")