-
Notifications
You must be signed in to change notification settings - Fork 0
/
train_selfsupervised3D.py
executable file
·157 lines (125 loc) · 6.06 KB
/
train_selfsupervised3D.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
from vq_gan_3d.model.vqgan import Encoder, SamePadConv3d
import torch
import torch.nn as nn
import matplotlib.pyplot as plt
import numpy as np
from dataset.mrnet_contrastive_pairs import MRNetDatasetContrastivePairs
import torch.nn.functional as F
from scipy.spatial.distance import cdist
#from torchmetrics.functional import pairwise_cosine_similarity
from torchmetrics.utilities.compute import _safe_matmul
import argparse
import os
import json
from monai.data import DataLoader, Dataset
from sklearn.decomposition import PCA
from sklearn.manifold import TSNE
from tqdm import tqdm
from pytorch_metric_learning.losses import NTXentLoss
from torch.utils.tensorboard import SummaryWriter
#Arguments
parser = argparse.ArgumentParser()
parser.add_argument('--exp' , default='exp_3D_1/') # Experiment name
parser.add_argument('--epochs' , type = int,default=200) # Trainign epochs
parser.add_argument('--batch_size', type = int, default=10) # batch size
parser.add_argument('--save_model_interval', type = int, default=10) # checkpoint saving interval
parser.add_argument('--lr', type = float,default=0.0001)
parser.add_argument('--ckpt_folder', default="ckpt/")
parser.add_argument('--exp_details', default="First trail") # Further experiment detials to be saved in a text file
parser.add_argument('--data_dir', default="data/") # directory where data is located
parser.add_argument('--conv_only', action='store_true') # If set to true it doesn't use the dense layer at the end and only extracts conv features
parser.add_argument('--dataset', type = str,default='mrnet', help ='which dataset' ) # which dataset, custom can be defined
args = parser.parse_args()
epochs = args.epochs
batch_size = args.batch_size
exp = args.exp
data = args.dataset
lr = args.lr
ckpt_folder = args.ckpt_folder
exp_details = args.exp_details
data_dir = args.data_dir
conv_only = args.conv_only
save_model_interval = args.save_model_interval
isExist = os.path.exists(ckpt_folder+exp)
if not isExist:
# Create a new directory because it does not exist
os.makedirs(ckpt_folder+exp)
#write details of the experiment
with open(ckpt_folder+exp+'/arguments.txt', 'w') as f:
json.dump(args.__dict__, f, indent=2)
writer = SummaryWriter(log_dir= ckpt_folder+exp)
#load dataset
if data =='mrnet':
dataset = MRNetDatasetContrastivePairs(data_dir , task = 'acl', plane = 'sagittal',split='train')
dataset_val = MRNetDatasetContrastivePairs(data_dir , task = 'acl', plane = 'sagittal',split='valid')
#Data loaders
train_loader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=0, persistent_workers=False)
val_loader = DataLoader(dataset_val, batch_size=batch_size, shuffle=False, num_workers=0, persistent_workers=False)
class Model(nn.Module):
def __init__(self, conv_only = False,data='mrnet'):
super().__init__()
self.conv_only = conv_only #If dense layer
self.data= data
if self.data=='pccta':
self.downsample= [16, 16, 16] # how many times should we downssample
elif self.data=='mrnet':
self.downsample= [8, 64, 64] #downsampling is done to make the x,y,x dim as 4,4,4, e.g. MRNet dims are 32x256x256, Dividing by 8 64 64 makes the dims 4x4x4
self.n_hiddens= 16 # number of hidden units
self.image_channel = 1
self.embedding_dim=8
self.encoder = Encoder(downsample = self.downsample, n_hiddens=self.n_hiddens, image_channel=self.image_channel)
self.enc_out_ch = self.encoder.out_channels
self.conv3D = SamePadConv3d(self.enc_out_ch, self.embedding_dim, 1)
self.dense = nn.Sequential(nn.Flatten(),
nn.Linear(self.embedding_dim * 4**3, 128),
nn.ReLU(inplace=True),
nn.Dropout(0.5),
nn.Linear(128,32))
self.sup_layer = nn.Sequential(nn.ReLU(inplace=True),
nn.Linear(32,1))
def forward(self, x):
x = self.encoder(x)
h = self.conv3D(x)
if not(self.conv_only):
h = self.dense(h)
fin = None
return h, fin
model = Model(conv_only = conv_only, data= data )
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model.to(device)
#Training
optimizer = torch.optim.Adam(model.parameters(), lr=lr)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, epochs, eta_min = lr/2.0)
LossTr = NTXentLoss()
LossStats = np.zeros(( epochs, int(np.ceil(dataset.__len__()/batch_size))))
LossVal = np.zeros(( epochs, int(np.ceil(dataset_val.__len__()/batch_size))))
for Epoch in range(epochs):
progress_bar = tqdm(enumerate(train_loader), total=len(train_loader), ncols=110)
progress_bar.set_description(f"Epoch {Epoch}")
model.train()
for BatchNo, Batch in progress_bar:
optimizer.zero_grad()
Pos11 = Batch['data1']; Pos12 = Batch['data2']
PosEmb11, _ = model(Pos11.to(device)); PosEmb12, Pred = model(Pos12.to(device))
Labels = torch.arange(PosEmb11.shape[0])
LossPos1 = LossTr(torch.cat((PosEmb11, PosEmb12), dim = 0), torch.cat((Labels, Labels), dim = 0))
LTotal = LossPos1
LTotal.backward()
LossStats [ Epoch, BatchNo] = LossPos1.item()
optimizer.step()
scheduler.step()
#Plot latent embeddings
model.eval()
for val_step, Batch in enumerate(val_loader):
Pos11 = Batch['data1']; Pos12 = Batch['data2']
with torch.no_grad():
#predictions
PosEmb11, _ = model(Pos11.to(device)); PosEmb12, _ = model(Pos12.to(device))
#loss
Labels = torch.arange(PosEmb11.shape[0])
LossVal [ Epoch, val_step] = LossTr(torch.cat((PosEmb11, PosEmb12), dim = 0),torch.cat((Labels, Labels), dim = 0)).item()
torch.cuda.empty_cache()
if (Epoch + 1) % save_model_interval == 0 :
torch.save(model.state_dict(), ckpt_folder+exp+'/model_'+ str(Epoch))
writer.add_scalar('Train/Loss', LossStats[ Epoch, :].mean(), Epoch)
writer.add_scalar('Val/Loss', LossVal[ Epoch, :].mean(), Epoch)