-
Notifications
You must be signed in to change notification settings - Fork 27
/
run.py
196 lines (168 loc) · 7.33 KB
/
run.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
import argparse
import math
import torch
import cv2
import numpy as np
#from tqdm import tqdm
from skimage import color, io
from model import *
from evaluators.draw_graph import draw_graph
DETECTOR_TRAINED_MODEL = "saved/detector/checkpoint-iteration150000.pth"
TRAINED_MODEL = "saved/pairing/checkpoint-iteration125000.pth"
SCALE_IMAGE_DEFAULT = 0.52 # percent of original size
INCLUDE_THRESHOLD_DEFAULT = 0.9 # threshold for using the bounding box (0 to 1)
PAIR_THRESHOLD_DEFAULT = 0.7 # threshold for using the bounding box (0 to 1)
def getCorners(xyrhw):
xc=xyrhw[0]
yc=xyrhw[1]
rot=xyrhw[2]
h=xyrhw[3]
w=xyrhw[4]
h = min(30000,h)
w = min(30000,w)
tr = ( int(w*math.cos(rot)-h*math.sin(rot) + xc), int(w*math.sin(rot)+h*math.cos(rot) + yc) )
tl = ( int(-w*math.cos(rot)-h*math.sin(rot) + xc), int(-w*math.sin(rot)+h*math.cos(rot) + yc) )
br = ( int(w*math.cos(rot)+h*math.sin(rot) + xc), int(w*math.sin(rot)-h*math.cos(rot) + yc) )
bl = ( int(-w*math.cos(rot)+h*math.sin(rot) + xc), int(-w*math.sin(rot)-h*math.cos(rot) + yc) )
return tl,tr,br,bl
def plotRect(img,color,xyrhw,lineW=1):
tl,tr,br,bl = getCorners(xyrhw)
cv2.line(img,tl,tr,color,lineW)
cv2.line(img,tr,br,color,lineW)
cv2.line(img,br,bl,color,lineW)
cv2.line(img,bl,tl,color,lineW)
def detect_boxes(run_img,np_img, include_threshold=INCLUDE_THRESHOLD_DEFAULT, output_image=None,model_checkpoint=DETECTOR_TRAINED_MODEL,use_gpu=None):
if gpu is not None:
device="cuda"
else:
device="cpu"
# device= "cuda" if use_gpu else "cpu"
print(f"Using {device} device")
# fetch the model
checkpoint = torch.load(model_checkpoint, map_location=lambda storage, location: storage)
print(f"Using {checkpoint['arch']}")
model = eval(checkpoint['arch'])(checkpoint['config']['model'])
model.load_state_dict(checkpoint['state_dict'])
# run the image through the model
print(f"Run image through model: {imagePath}")
result = model(run_img)
# produce the output
boundingboxes = result[0].tolist()
output = []
print(f"Process bounding boxes: {imagePath}")
#for i in tqdm(boundingboxes[0]):
for i in boundingboxes[0]:
if i[0] < include_threshold:
continue
print(i)
tl,tr,br,bl = getCorners(i[1:])
scale=1
bb = {
'poly_points': [ [float(tl[0]/scale),float(tl[1]/scale)],
[float(tr[0]/scale),float(tr[1]/scale)],
[float(br[0]/scale),float(br[1]/scale)],
[float(bl[0]/scale),float(bl[1]/scale)] ],
'type':'detectorPrediction',
'textPred': float(i[7]),
'fieldPred': float(i[8])
}
colour = (255,0,0) # red
if bb['textPred'] > bb['fieldPred']:
colour = (0,0,255) # blue
output.append(bb)
if output_image:
plotRect(np_img, colour, i[1:6])
if output_image:
print(f"Saving output: {output_image}")
io.imsave(output_image, np_img)
return output
def detect_boxes_and_pairs(run_img,np_img, output_image=None,model_checkpoint=TRAINED_MODEL,pair_threshold=PAIR_THRESHOLD_DEFAULT,use_gpu=None):
if gpu is not None:
device="cuda"
else:
device="cpu"
# device= "cuda" if use_gpu else "cpu"
print(f"Using {device} device")
# fetch the model
checkpoint = torch.load(model_checkpoint, map_location=lambda storage, location: storage)
print(f"Using {checkpoint['arch']}")
model = eval(checkpoint['arch'])(checkpoint['config']['model'])
model.load_state_dict(checkpoint['state_dict'])
model.to(device)
# run the image through the model
print(f"Run image through model: {imagePath}")
run_img=run_img.to(device)
result = model(run_img)
outputBoxes, outputOffsets, relPred, relIndexes, bbPred = result
relPred = torch.sigmoid(relPred)
np_img = draw_graph(outputBoxes,relPred,relIndexes,np_img,pair_threshold)
if output_image:
print(f"Saving output: {output_image}")
io.imsave(output_image, np_img)
return result
def main(imagePath,scale_image,detection,checkpoint,detect_threshold,output_image,pair_threshold,gpu=None):
print(f"Loading image: {imagePath}")
np_img = cv2.imread(imagePath, cv2.IMREAD_COLOR)
print(f"Transforming image: {imagePath}")
width = int(np_img.shape[1] * scale_image)
height = int(np_img.shape[0] * scale_image)
new_size = (width, height)
np_img = cv2.resize(np_img,new_size)
img = cv2.cvtColor(np_img, cv2.COLOR_BGR2GRAY)
img = img[None,None,:,:]
img = img.astype(np.float32)
img = torch.from_numpy(img)
img = 1.0 - img / 128.0
if detection:
if checkpoint is None:
checkpoint = DETECTOR_TRAINED_MODEL
result = detect_boxes(
img,
np_img,
include_threshold=args.detect_threshold,
output_image=output_image,
model_checkpoint = checkpoint,
use_gpu=gpu
)
else:
if checkpoint is None:
checkpoint = TRAINED_MODEL
np_img=np_img.astype(np.float32)/255
result = detect_boxes_and_pairs(
img,
np_img,
output_image=output_image,
pair_threshold=args.pair_threshold,
model_checkpoint = checkpoint,
use_gpu=gpu
)
if __name__ == "__main__":
parser = argparse.ArgumentParser(description='Run on a single image')
parser.add_argument('image', type=str, help='Path to the image to convert')
parser.add_argument('output_image', type=str, help="A path to save a version of the original image with form boxes overlaid")
parser.add_argument('--scale-image', type=float, default=SCALE_IMAGE_DEFAULT,
help='Scale the image by this proportion (between 0 and 1). 0.52 for pretrained model on NAF images')
parser.add_argument('--detect-threshold', type=float, default=INCLUDE_THRESHOLD_DEFAULT,
help='Include boxes where the confidence is above this threshold (between 0 and 1)')
parser.add_argument('--pair-threshold', type=float, default=INCLUDE_THRESHOLD_DEFAULT,
help='Include relationships where the confidence is above this threshold (between 0 and 1) default: 0.7')
parser.add_argument('-c', '--checkpoint', default=None, type=str,
help='path to checkpoint (default: pretrained model)')
parser.add_argument('-d', '--detection', default=False, action='store_const', const=True,
help='Run detection model. Default is full (pairing) model')
parser.add_argument('-g', '--gpu', default=None, type=int,
help='gpu number (default: cpu only)')
args = parser.parse_args()
imagePath = args.image
output_image = args.output_image
scale_image = args.scale_image
checkpoint = args.checkpoint
detection=args.detection
detect_threshold=args.detect_threshold
pair_threshold=args.pair_threshold
gpu=args.gpu
if gpu is not None:
with torch.cuda.device(gpu):
main(imagePath,scale_image,detection,checkpoint,detect_threshold,output_image,pair_threshold,gpu)
else:
main(imagePath,scale_image,detection,checkpoint,detect_threshold,output_image,pair_threshold,gpu)