-
Notifications
You must be signed in to change notification settings - Fork 24
/
nms.py
73 lines (71 loc) · 3.05 KB
/
nms.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
def overlapping_area(detection_1, detection_2):
'''
Function to calculate overlapping area'si
`detection_1` and `detection_2` are 2 detections whose area
of overlap needs to be found out.
Each detection is list in the format ->
[x-top-left, y-top-left, confidence-of-detections, width-of-detection, height-of-detection]
The function returns a value between 0 and 1,
which represents the area of overlap.
0 is no overlap and 1 is complete overlap.
Area calculated from ->
http://math.stackexchange.com/questions/99565/simplest-way-to-calculate-the-intersect-area-of-two-rectangles
'''
# Calculate the x-y co-ordinates of the
# rectangles
x1_tl = detection_1[0]
x2_tl = detection_2[0]
x1_br = detection_1[0] + detection_1[3]
x2_br = detection_2[0] + detection_2[3]
y1_tl = detection_1[1]
y2_tl = detection_2[1]
y1_br = detection_1[1] + detection_1[4]
y2_br = detection_2[1] + detection_2[4]
# Calculate the overlapping Area
x_overlap = max(0, min(x1_br, x2_br)-max(x1_tl, x2_tl))
y_overlap = max(0, min(y1_br, y2_br)-max(y1_tl, y2_tl))
overlap_area = x_overlap * y_overlap
area_1 = detection_1[3] * detection_2[4]
area_2 = detection_2[3] * detection_2[4]
total_area = area_1 + area_2 - overlap_area
return overlap_area / float(total_area)
def nms(detections, threshold=.5):
'''
This function performs Non-Maxima Suppression.
`detections` consists of a list of detections.
Each detection is in the format ->
[x-top-left, y-top-left, confidence-of-detections, width-of-detection, height-of-detection]
If the area of overlap is greater than the `threshold`,
the area with the lower confidence score is removed.
The output is a list of detections.
'''
if len(detections) == 0:
return []
# Sort the detections based on confidence score
detections = sorted(detections, key=lambda detections: detections[2],
reverse=True)
# Unique detections will be appended to this list
new_detections=[]
# Append the first detection
new_detections.append(detections[0])
# Remove the detection from the original list
del detections[0]
# For each detection, calculate the overlapping area
# and if area of overlap is less than the threshold set
# for the detections in `new_detections`, append the
# detection to `new_detections`.
# In either case, remove the detection from `detections` list.
for index, detection in enumerate(detections):
for new_detection in new_detections:
if overlapping_area(detection, new_detection) > threshold:
del detections[index]
break
else:
new_detections.append(detection)
del detections[index]
return new_detections
if __name__ == "__main__":
# Example of how to use the NMS Module
detections = [[31, 31, .9, 10, 10], [31, 31, .12, 10, 10], [100, 34, .8,10, 10]]
print "Detections before NMS = {}".format(detections)
print "Detections after NMS = {}".format(nms(detections))