diff --git a/keras_retinanet/backend/tensorflow_backend.py b/keras_retinanet/backend/tensorflow_backend.py index 7de5b2ea9..a0969feea 100644 --- a/keras_retinanet/backend/tensorflow_backend.py +++ b/keras_retinanet/backend/tensorflow_backend.py @@ -68,16 +68,10 @@ def resize_images(images, size, method='bilinear', align_corners=False): return tensorflow.image.resize_images(images, size, methods[method], align_corners) -def non_max_suppression(boxes, scores, max_output_size, iou_threshold, **kwargs): +def non_max_suppression(*args, **kwargs): """ See https://www.tensorflow.org/versions/master/api_docs/python/tf/image/non_max_suppression . """ - return tensorflow.image.non_max_suppression( - boxes=tensorflow.cast(boxes, tensorflow.float32), - scores=tensorflow.cast(scores, tensorflow.float32), - max_output_size=max_output_size, - iou_threshold=iou_threshold, - **kwargs - ) + return tensorflow.image.non_max_suppression(*args, **kwargs) def range(*args, **kwargs): diff --git a/keras_retinanet/layers/filter_detections.py b/keras_retinanet/layers/filter_detections.py index 210cce43d..f73e918b2 100644 --- a/keras_retinanet/layers/filter_detections.py +++ b/keras_retinanet/layers/filter_detections.py @@ -57,7 +57,7 @@ def _filter_detections(scores, labels): filtered_scores = keras.backend.gather(scores, indices)[:, 0] # perform NMS - nms_indices = backend.non_max_suppression(boxes=filtered_boxes, scores=filtered_scores, max_output_size=max_detections, iou_threshold=nms_threshold) + nms_indices = backend.non_max_suppression(filtered_boxes, filtered_scores, max_output_size=max_detections, iou_threshold=nms_threshold) # filter indices based on NMS indices = keras.backend.gather(indices, nms_indices) diff --git a/keras_retinanet/utils/image.py b/keras_retinanet/utils/image.py index 6be9cd7ee..b372cf493 100644 --- a/keras_retinanet/utils/image.py +++ b/keras_retinanet/utils/image.py @@ -49,7 +49,7 @@ def preprocess_image(x, mode='caffe'): # mostly identical to "https://github.com/keras-team/keras-applications/blob/master/keras_applications/imagenet_utils.py" # except for converting RGB -> BGR since we assume BGR already - #covert always to float32 to keep compatibility with opencv + # covert always to float32 to keep compatibility with opencv x = x.astype(np.float32) if mode == 'tf': @@ -62,6 +62,7 @@ def preprocess_image(x, mode='caffe'): return x + def cast_image_to_floatx(x): """ Convert an image to the actual keras floatx. @@ -75,6 +76,7 @@ def cast_image_to_floatx(x): return x + def adjust_transform_for_image(transform, image, relative_translation): """ Adjust a transformation for a specific image.