-
Notifications
You must be signed in to change notification settings - Fork 0
/
submission.py
72 lines (58 loc) · 2.73 KB
/
submission.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
from __future__ import print_function, division
import argparse
import os
import torch.backends.cudnn as cudnn
import time
from datasets import __datasets__
from models import __models__
from utils import *
from torch.utils.data import DataLoader
import skimage
from skimage import io
cudnn.benchmark = True
parser = argparse.ArgumentParser(description='CGFNet')
parser.add_argument('--model', default='cgfnet', help='select a model structure', choices=__models__.keys())
parser.add_argument('--maxdisp', type=int, default=192, help='maximum disparity')
parser.add_argument('--dataset', default='kitti', help='dataset name', choices=__datasets__.keys())
parser.add_argument('--datapath', default='/home/wangqingyu/KITTI/2012', help='data path')
parser.add_argument('--testlist', default='/media/wangqingyu/机械硬盘1/##model/CGFNet/filenames/kitti12_test.txt', help='testing list')
parser.add_argument('--loadckpt', default='/media/wangqingyu/机械硬盘1/##model/CGFNet/0.5:1/sf-1215-12.ckpt', help='load the weights from a specific checkpoint')
# parse arguments
args = parser.parse_args()
# dataset, dataloader
StereoDataset = __datasets__[args.dataset]
test_dataset = StereoDataset(args.datapath, args.testlist, False)
TestImgLoader = DataLoader(test_dataset, 1, shuffle=False, num_workers=4, drop_last=False)
# model, optimizer
model = __models__[args.model](args.maxdisp, False, False)
model = nn.DataParallel(model)
model.cuda()
# load parameters
print("loading model {}".format(args.loadckpt))
state_dict = torch.load(args.loadckpt)
model.load_state_dict(state_dict['model'])
def test():
os.makedirs('./submission1', exist_ok=True)
for batch_idx, sample in enumerate(TestImgLoader):
start_time = time.time()
disp_est_np = tensor2numpy(test_sample(sample))
top_pad_np = tensor2numpy(sample["top_pad"])
right_pad_np = tensor2numpy(sample["right_pad"])
left_filenames = sample["left_filename"]
print('Iter {}/{}, time = {:3f}'.format(batch_idx, len(TestImgLoader),
time.time() - start_time))
for disp_est, top_pad, right_pad, fn in zip(disp_est_np, top_pad_np, right_pad_np, left_filenames):
assert len(disp_est.shape) == 2
disp_est = np.array(disp_est[top_pad:, :-right_pad], dtype=np.float32)
fn = os.path.join("submission1", fn.split('/')[-1])
print("saving to", fn, disp_est.shape)
disp_est_uint = np.round(disp_est * 256).astype(np.uint16)
skimage.io.imsave(fn, disp_est_uint)
# test one sample
@make_nograd_func
def test_sample(sample):
model.eval()
disp_ests = model(sample['left'].cuda(), sample['right'].cuda())
return disp_ests[-1]
if __name__ == '__main__':
test()