forked from z65451/SVR
-
Notifications
You must be signed in to change notification settings - Fork 0
/
test.py
145 lines (115 loc) · 4.9 KB
/
test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
import os
import torch
# from pycocotools import coco
import queue
import threading
from model_video import build_model, weights_init, MyEnsemble
from tools import custom_print
from train import train_finetune_with_flow,train_finetune
from val import validation
import time
import datetime
import collections
from torch.utils.data import DataLoader
import argparse
torch.autograd.set_detect_anomaly(True)
torch.backends.cudnn.benchmark = True
from model_video2_seam import build_model2
from test_trained import test_finetune
import cv2
import sys
sys.path.append('NonUniformBlurKernelEstimation/')
if __name__ == '__main__':
# train_val_config
parser = argparse.ArgumentParser()
parser.add_argument('--model', default='./models/image_best.pth',help="restore checkpoint")
parser.add_argument('--use_flow',default=False, help="dataset for evaluation")
parser.add_argument('--img_size',default=224, help="size of input image")
parser.add_argument('--lr',default=1e-3, help="learning rate")
parser.add_argument('--lr_de',default=20000, help="learning rate decay")
parser.add_argument('--batch_size',default=1, help="batch size")
parser.add_argument('--group_size',default=4, help="group size")
parser.add_argument('--epochs',default=10000000, help="epoch")
# parser.add_argument('--train_datapath',default='DAVIS_FBMS_flow/DAVIS_FBMS_flow/', help="training dataset")
parser.add_argument('--train_datapath_small',default='datasets/15/small/testing/', help="training dataset")
parser.add_argument('--train_datapath_large',default='datasets/15/large/testning/', help="training dataset")
parser.add_argument('--val_datapath',default='datasets/15/small/training/', help="training dataset")
args = parser.parse_args()
train_datapath_small = args.train_datapath_small
train_datapath_large = args.train_datapath_large
val_datapath = [args.val_datapath]
# project config
project_name = 'UFO'
device = torch.device('cuda:0')
img_size = args.img_size
lr = args.lr
lr_de = args.lr_de
epochs = args.epochs
batch_size = args.batch_size
group_size = args.group_size
log_interval = 1
val_interval = 1000
use_flow=args.use_flow
if use_flow:
from model_video_flow import build_model, weights_init
# create log dir
log_root = './logs'
if not os.path.exists(log_root):
os.makedirs(log_root)
# create log txt
log_txt_file = os.path.join(log_root, project_name + '_log.txt')
custom_print(project_name, log_txt_file, 'w')
# create model save dir
models_root = './models'
if not os.path.exists(models_root):
os.makedirs(models_root)
models_train_last = os.path.join(models_root, project_name + '_last_ft.pth')
models_train_best = os.path.join(models_root, project_name + '_best_ft.pth')
# continute load checkpoint
model_path = args.model
gpu_id='cuda:0'
device = torch.device(gpu_id)
net1 = build_model(device) #.to(device)
net2 = build_model2(device) #.to(device)
# netEnsemble = MyEnsemble(net, net2) #.to(device)
cc=0
# for param in netEnsemble.parameters():
# param.requires_grad = True
# cc+=1
# print(sum(p.numel() for p in netEnsemble.parameters() if p.requires_grad))
# for p in net.sp1[0].parameters():
# p.requires_grad=False
# for p in net.sp2[0].parameters():
# p.requires_grad=False
# for p in net.cls[0].parameters():
# p.requires_grad=False
# for p in net.cls_m[0].parameters():
# p.requires_grad=False
net1=net1.to(device)
net1=torch.nn.DataParallel(net1)
state_dict=torch.load(model_path, map_location=gpu_id)
net1.load_state_dict(state_dict)
# net.train()
net2=net2.to(device)
net2=torch.nn.DataParallel(net2)
print(sum(p.numel() for p in net2.parameters() if p.requires_grad))
for param in net2.parameters():
param.requires_grad = True
for param in net1.parameters():
param.requires_grad = False
print(sum(p.numel() for p in net1.parameters() if p.requires_grad))
net = MyEnsemble(net1, net2)
main_model_path = "./models/UFO_last_ftTrainedStereo.pth"
# net=torch.nn.DataParallel(net)
state_dict=torch.load(main_model_path, map_location=gpu_id)
net.load_state_dict(state_dict)
for param in net.parameters():
param.requires_grad = False
print(sum(p.numel() for p in net.parameters() if p.requires_grad))
net.to(device)
# net.eval()
net.train()
if use_flow==False:
train_finetune(net, train_datapath_small, train_datapath_large , device, batch_size, log_txt_file, val_datapath, models_train_best, models_train_last, lr, lr_de, epochs, log_interval, val_interval)
else:
train_finetune_with_flow(net, train_datapath , device, batch_size, log_txt_file, val_datapath, models_train_best, models_train_last, lr, lr_de, epochs, log_interval, val_interval)