Skip to content

Commit

Permalink
Fix NMS for float16.
Browse files Browse the repository at this point in the history
  • Loading branch information
vcarpani authored and hgaiser committed Dec 7, 2018
1 parent 913776a commit e729013
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 3 deletions.
10 changes: 8 additions & 2 deletions keras_retinanet/backend/tensorflow_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,10 +68,16 @@ def resize_images(images, size, method='bilinear', align_corners=False):
return tensorflow.image.resize_images(images, size, methods[method], align_corners)


def non_max_suppression(*args, **kwargs):
def non_max_suppression(boxes, scores, max_output_size, iou_threshold, **kwargs):
""" See https://www.tensorflow.org/versions/master/api_docs/python/tf/image/non_max_suppression .
"""
return tensorflow.image.non_max_suppression(*args, **kwargs)
return tensorflow.image.non_max_suppression(
boxes=tensorflow.cast(boxes, tensorflow.float32),
scores=tensorflow.cast(scores, tensorflow.float32),
max_output_size=max_output_size,
iou_threshold=iou_threshold,
**kwargs
)


def range(*args, **kwargs):
Expand Down
2 changes: 1 addition & 1 deletion keras_retinanet/layers/filter_detections.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ def _filter_detections(scores, labels):
filtered_scores = keras.backend.gather(scores, indices)[:, 0]

# perform NMS
nms_indices = backend.non_max_suppression(filtered_boxes, filtered_scores, max_output_size=max_detections, iou_threshold=nms_threshold)
nms_indices = backend.non_max_suppression(boxes=filtered_boxes, scores=filtered_scores, max_output_size=max_detections, iou_threshold=nms_threshold)

# filter indices based on NMS
indices = keras.backend.gather(indices, nms_indices)
Expand Down

0 comments on commit e729013

Please sign in to comment.