-
Notifications
You must be signed in to change notification settings - Fork 2
/
train.py
156 lines (116 loc) · 6.92 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
import time
import datetime
import pytz
import argparse
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.utils as vutils
from model.net import BasicModel, DomainAdversarialNet
from model.dataloader import Dataloaders
from model.layers import grad_reverse
from evaluate import evaluate
from utils import *
class Trainer():
def __init__(self, data_dir):
self.dataloaders = Dataloaders(data_dir)
self.train_dict = self.dataloaders.train_dict
self.test_dict = self.dataloaders.test_dict
def train_and_evaluate(self, config, checkpoint=None):
batch_size = config['batch_size']
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
train_dataloader = self.dataloaders.get_train_dataloader(batch_size = batch_size, shuffle=True)
num_batches = len(train_dataloader)
image_model = BasicModel().to(device)
sketch_model = BasicModel().to(device)
domain_net = DomainAdversarialNet().to(device)
params = [param for param in image_model.parameters() if param.requires_grad == True]
params.extend([param for param in sketch_model.parameters() if param.requires_grad == True])
params.extend([param for param in domain_net.parameters() if param.requires_grad == True])
optimizer = torch.optim.Adam(params, lr=config['lr'])
criterion = nn.TripletMarginLoss(margin = 1.0, p = 2)
domain_criterion = nn.BCELoss()
if checkpoint:
load_checkpoint(checkpoint, image_model, sketch_model, domain_net, optimizer)
print('Training...')
for epoch in range(config['epochs']):
accumulated_triplet_loss = RunningAverage()
accumulated_iteration_time = RunningAverage()
accumulated_image_domain_loss = RunningAverage()
accumulated_sketch_domain_loss = RunningAverage()
epoch_start_time = time.time()
image_model.train()
sketch_model.train()
domain_net.train()
for iteration, batch in enumerate(train_dataloader):
time_start = time.time()
'''GETTING THE DATA'''
anchors, positives, negatives, label_embeddings, positive_label_idxs, negative_label_idxs = batch
anchors = torch.autograd.Variable(anchors.to(device)); positives = torch.autograd.Variable(positives.to(device))
negatives = torch.autograd.Variable(negatives.to(device)); label_embeddings = torch.autograd.Variable(label_embeddings.to(device))
'''MAIN NET INFERENCE AND LOSS'''
pred_sketch_features = sketch_model(anchors)
pred_positives_features = image_model(positives)
pred_negatives_features = image_model(negatives)
triplet_loss = config['triplet_loss_ratio'] * criterion(pred_sketch_features, pred_positives_features, pred_negatives_features)
accumulated_triplet_loss.update(triplet_loss, anchors.shape[0])
'''DOMAIN ADVERSARIAL TRAINING''' # vannila generator for now. Later - add randomness in outputs of generator, or lower the label
'''DEFINE TARGETS'''
image_domain_targets = torch.full((anchors.shape[0],1), 1, dtype=torch.float, device=device)
sketch_domain_targets = torch.full((anchors.shape[0],1), 0, dtype=torch.float, device=device)
'''GET DOMAIN NET PREDICTIONS FOR INPUTS WITH G.R.L.'''
if epoch < 5:
grl_weight = 0
elif epoch < config['grl_threshold_epoch']:
grl_weight *= epoch/config['grl_threshold_epoch']
else:
grl_weight = 1
domain_pred_p_images = domain_net(grad_reverse(pred_positives_features, grl_weight))
domain_pred_n_images = domain_net(grad_reverse(pred_negatives_features, grl_weight))
domain_pred_sketches = domain_net(grad_reverse(pred_sketch_features, grl_weight))
'''DOMAIN LOSS'''
domain_loss_images = config['domain_loss_ratio'] * (domain_criterion(domain_pred_p_images, image_domain_targets) + domain_criterion(domain_pred_n_images, image_domain_targets))
accumulated_image_domain_loss.update(domain_loss_images, anchors.shape[0])
domain_loss_sketches = config['domain_loss_ratio'] * (domain_criterion(domain_pred_sketches, sketch_domain_targets))
accumulated_sketch_domain_loss.update(domain_loss_sketches, anchors.shape[0])
total_domain_loss = domain_loss_images + domain_loss_sketches
'''OPTIMIZATION W.R.T. BOTH LOSSES'''
optimizer.zero_grad()
total_loss = triplet_loss + total_domain_loss
total_loss.backward()
optimizer.step()
'''LOGGER'''
time_end = time.time()
accumulated_iteration_time.update(time_end - time_start)
if iteration % config['print_every'] == 0:
eta_cur_epoch = str(datetime.timedelta(seconds = int(accumulated_iteration_time() * (num_batches - iteration))))
print(datetime.datetime.now(pytz.timezone('Asia/Kolkata')).replace(microsecond = 0), end = ' ')
print('Epoch: %d [%d / %d] ; eta: %s' % (epoch, iteration, num_batches, eta_cur_epoch))
print('Average Triplet loss: %f(%f);' % (triplet_loss, accumulated_triplet_loss()))
print('Sketch domain loss: %f; Image Domain loss: %f' % (accumulated_sketch_domain_loss(), accumulated_image_domain_loss()))
'''END OF EPOCH'''
epoch_end_time = time.time()
print('Epoch %d complete, time taken: %s' % (epoch, str(datetime.timedelta(seconds = int(epoch_end_time - epoch_start_time)))))
torch.cuda.empty_cache()
save_checkpoint({'iteration': iteration + epoch * num_batches,
'image_model': image_model.state_dict(),
'sketch_model': sketch_model.state_dict(),
'domain_model': domain_model.state_dict(),
'optim_dict': optimizer.state_dict()},
checkpoint_dir = config['checkpoint_dir'])
print('Saved epoch!')
print('\n\n\n')
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='Training of SBIR')
parser.add_argument('--data_dir', help='Data directory path. Directory should contain two folders - sketches and photos, along with 2 .txt files for the labels', required = True)
parser.add_argument('--batch_size', type=int, help='Batch size to process the train sketches/photos', required = True)
parser.add_argument('--checkpoint_dir', help='Directory to save checkpoints', required=True)
parser.add_argument('--epochs', help='Number of epochs', required=True)
parser.add_argument('--domain_loss_ratio', help='Domain loss weight', default = 0.5)
parser.add_argument('--triplet_loss_ratio', help='Triplet loss weight', default = 1.0)
parser.add_argument('--grl_threshold_epoch', help='Threshold epoch for GRL lambda', default = 25)
parser.add_argument('--print_every', help='Logging interval in iterations', default = 10)
args = parser.parse_args()
trainer = Trainer(args.data_dir)
trainer.train_and_evaluate(vars(args))