-
Notifications
You must be signed in to change notification settings - Fork 0
/
trainning_salformer.py
157 lines (131 loc) · 6.81 KB
/
trainning_salformer.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
import random
random.seed(42)
import numpy as np
from env import *
import argparse
import torch
from torchvision.utils import save_image
from transformers import SwinModel
from dataset_new import ImagesWithSaliency
from torch.utils.data import DataLoader
from utils import inference
from torch.utils.tensorboard import SummaryWriter
def trainning_salformer(Model, device, batch_size, KL, CC, NSS, LR):
writer = SummaryWriter(comment=f"Model_{Model}_KL_{KL}_CC_{CC}_NSS_{NSS}")
number_epoch = 250
eps=1e-8
if Model == 'llama':
from model_llama import SalFormer
from transformers import LlamaModel
from tokenizer_llama import padding_fn
# llm = LlamaModel.from_pretrained("Enoch/llama-7b-hf", low_cpu_mem_usage=True)
llm = LlamaModel.from_pretrained("daryl149/Llama-2-7b-chat-hf", low_cpu_mem_usage=True)
neuron_n = 4096
print("llama loaded")
elif Model == 'bloom':
from model_llama import SalFormer
from transformers import BloomModel
from tokenizer_bloom import padding_fn
llm = BloomModel.from_pretrained("bigscience/bloom-3b")
neuron_n = 2560
print('BloomModel loaded')
elif Model == 'bert':
from model_swin import SalFormer
from transformers import BertModel
from tokenizer_bert import padding_fn
llm = BertModel.from_pretrained("bert-base-uncased")
print('BertModel loaded')
else:
print('model not available, possiblilities: llama, bloom, bert')
return
if not Model == 'bert':
for param in llm.parameters():
param.requires_grad = False
train_set = ImagesWithSaliency("data/train.npy", dtype=torch.float32)
val_set = ImagesWithSaliency("data/val.npy", dtype=torch.float32)
train_dataloader = DataLoader(train_set, batch_size=batch_size, shuffle=True, collate_fn=padding_fn)
vali_dataloader = DataLoader(val_set, batch_size=batch_size, shuffle=True, collate_fn=padding_fn)
vit = SwinModel.from_pretrained("microsoft/swin-tiny-patch4-window7-224")
print('SwinModel loaded')
if Model == 'bert':
model = SalFormer(vit, llm).to(device)
else:
model = SalFormer(vit, llm, neuron_n = neuron_n).to(device)
optimizer =torch.optim.AdamW(model.parameters(), lr=LR, weight_decay=0.0001)
n_iter = 0
cc_init = 0
for epoch in range(number_epoch):
for batch, (img, input_ids, fix, hm, name) in enumerate(train_dataloader):
optimizer.zero_grad()
y, kl, cc, nss = inference(model, device, eps, img, input_ids, fix, hm)
if torch.isnan(kl):
print(np.mean([ p.norm().cpu().detach().numpy() for p in model.parameters()]))
print(kl)
kl = torch.Tensor([0.0]).to(device)
print("kl is nan!")
if torch.isnan(cc):
print(np.mean([ p.norm().cpu().detach().numpy() for p in model.parameters()]))
print(cc)
cc = torch.Tensor([0.0]).to(device)
print("cc is nan!")
if torch.isnan(nss):
print(np.mean([ p.norm().cpu().detach().numpy() for p in model.parameters()]))
print(nss)
nss = torch.Tensor([0.0]).to(device)
print("nss is nan!")
loss = KL*kl - CC*cc - NSS*nss
loss.backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
optimizer.step()
# if batch == len(train_dataloader) - 2:
# for i in random.sample(range(0, y.shape[0]), 1):
# save_image(y[i].type(torch.float32), f'./results_llm/train/epoch{epoch}_batch{batch}_{i}.png')
# save_image(hm[i].type(torch.float32), f'./results_llm/train/epoch{epoch}_batch{batch}_{i}_truth.png')
writer.add_scalar('Loss/train', loss.item(), n_iter)
writer.add_scalar('Loss/train/kl', kl.item(), n_iter)
writer.add_scalar('Loss/train/cc', cc.item(), n_iter)
writer.add_scalar('Loss/train/nss', nss.item(), n_iter)
if batch == len(train_dataloader)-1:
print(f"Epoch {epoch}/{number_epoch} batch {batch}: ")
print("Training: loss ", loss.item(), "kl ", kl.item(), "cc ", cc.item(), "nss ", nss.item())
model.eval()
test_loss = 0
test_kl, test_cc, test_nss = 0,0,0
for batch, (img, input_ids, fix, hm, name) in enumerate(vali_dataloader):
with torch.no_grad():
y, kl, cc, nss = inference(model, device, eps, img, input_ids, fix, hm)
loss = KL*kl - CC*cc - NSS*nss
test_loss += loss.item()/len(vali_dataloader)
# if y.shape[0] == batch_size:
# for i in random.sample(range(0, y.shape[0]), 3):
# save_image(y[i].type(torch.float32), f'./results_llm/val/epoch{epoch}_batch{batch}_{i}.png')
# save_image(hm[i].type(torch.float32), f'./results_llm/val/epoch{epoch}_batch{batch}_{i}_truth.png')
test_kl += kl.item()/len(vali_dataloader)
test_cc += cc.item()/len(vali_dataloader)
test_nss += nss.item()/len(vali_dataloader)
if epoch > 50 and test_cc > cc_init:
cc_init = test_cc
torch.save({
'epoch': epoch,
'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'loss': loss
}, f'./ckpt/model_{Model}_{KL}kl_{CC}cc_{NSS}nss.tar')
model.train(True)
print("Testing: loss ", test_loss, "kl ", test_kl, "cc ", test_cc, "nss ", test_nss)
writer.add_scalar('Loss/test', test_loss, epoch)
writer.add_scalar('Loss/test/kl', test_kl, epoch)
writer.add_scalar('Loss/test/cc', test_cc, epoch)
writer.add_scalar('Loss/test/nss', test_nss, epoch)
n_iter += 1
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument("--model", type=str, default='llama')
parser.add_argument("--device", type=str, default='cuda')
parser.add_argument("--batch_size", type=int, default=16)
parser.add_argument("--kl", type=int, default=10)
parser.add_argument("--cc", type=int, default=5)
parser.add_argument("--nss", type=int, default=2)
parser.add_argument("--lr", type=float, default=0.00006)
args = vars(parser.parse_args())
trainning_salformer(Model = args['model'], device = args['device'], batch_size = args['batch_size'], KL = args['kl'], CC = args['cc'], NSS = args['nss'], LR= args['lr'])