-
Notifications
You must be signed in to change notification settings - Fork 2
/
findBestEpoch.py
86 lines (76 loc) · 2.33 KB
/
findBestEpoch.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
#!/usr/bin/env python2
# -*- coding: utf-8 -*-
"""
Created on Sat Nov 25 13:33:24 2017
@author: victoria
"""
from yllab import *
training = False
training = True
findBest = False
findBest = True
#sortKey = None
if training:
pred('training.......')
from train import *
import trainInterface
trainInterface.train()
from val import *
import val
#import inferenceInterface
#c.reload = inferenceInterface
import predictInterface
c.predictInterface = predictInterface
#setMod('train')
evaluFun,sortKey = accEvalu,'acc'
if findBest:
pred('auto Find Best Epoch .......')
c.args.step = None
epochs=sorted([findints(filename(p))[-1] for p in glob(c.weightsPrefix+'*')])
df = autoFindBestEpoch(c,evaluFun,sortkey=sortKey, savefig='epoch.png',epochs=epochs)
# row = df.loc[df[sortKey].argmax()]
row = df.loc[df[sortKey].argmax()]
restore =int(row.restore)
df.to_csv('%s:%s_epoch:%s.csv'%(sortKey,row[sortKey],restore))
else:
restore = -1
#%%
if __name__ == '__main__':
pred('refine .......')
c.args.restore = restore
c.args.step = .2
# reload(c.reload)
# inference = inferenceInterface.inference
reload(predictInterface)
inference = predictInterface.predict
# c.inference = inference
e = Evalu(evaluFun,
# evaluName='restore-%s'%restore,
valNames=c.names,
# loadcsv=1,
# logFormat='acc:{acc:.3f}, loss:{loss:.3f}',
sortkey=sortKey,
# loged=False,
saveResoult=False,
)
c.names.sort(key=lambda x:readgt(x).shape[0])
for name in c.names[:]:
img,gt = readimg(name),readgt(name)
prob = inference((name))
re = prob.argmax(2)
e.evalu(re,gt,name)
gtc = labelToColor(gt,colors)
rec = labelToColor(re,colors)
smallGap = 10
# show(img[::smallGap,::smallGap],gtc[::smallGap,::smallGap],(gt!=re)[::smallGap,::smallGap],rec[::smallGap,::smallGap])
# diff = binaryDiff(re,gt)
# show(img,diff,re)
# show(img,diff)
# show(diff)
# yellowImg=gt[...,None]*img+(npa-[255,255,0]).astype(np.uint8)*~gt[...,None]
# show(yellowImg,diff)
# imsave(pathjoin(args.out,name+'.png'),uint8(re))
imsave('val/%s.tif'%name,uint8(rec))
print args.restore,e[sortKey].mean()
'''
'''