-
Notifications
You must be signed in to change notification settings - Fork 1
/
recall_precision.py
112 lines (86 loc) · 3.14 KB
/
recall_precision.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
#! /usr/bin/env python
"""Run a YOLO_v2 style detection model on test images."""
import argparse
import colorsys
import imghdr
import os
import os.path as osp
import random
import h5py
import numpy as np
from keras.models import load_model
from PIL import Image, ImageDraw, ImageFont
import tensorflow as tf
from keras import backend as K
from keras.layers import Input, Lambda, Conv2D
from keras.models import load_model, Model
from keras.callbacks import TensorBoard, ModelCheckpoint, EarlyStopping
from keras.optimizers import Adam
from mobiledet.models.keras_yolo import preprocess_true_boxes
from mobiledet.models.keras_yolo import yolo_eval, yolo_loss, decode_yolo_output, create_model
from mobiledet.models.keras_yolo import yolo_body_darknet, yolo_body_mobilenet
from mobiledet.models.keras_yolo import recall_precision
from mobiledet.utils.draw_boxes import draw_boxes
from mobiledet.utils import read_voc_datasets_train_batch, brightness_augment, augment_image
from mobiledet.models.keras_yolo import yolo_get_detector_mask
from cfg import *
from mobiledet.utils import *
from keras.utils import plot_model
from tensorflow.python.framework import graph_util
from tensorflow.python.framework import graph_io
import os
import tensorflow as tf
from tensorflow.python.tools.freeze_graph import freeze_graph
import time
parser = argparse.ArgumentParser(
description='Calculate YOLOv2 recall and precision on test datasets..')
parser.add_argument(
'-m',
'--weight_path',
help='path to trained model weight file')
parser.add_argument(
'-a',
'--anchors_path',
help='path to anchors file, defaults to pascal_anchors.txt',
default='model_data/uav123_anchors.txt')
parser.add_argument(
'-c',
'--classes_path',
help='path to classes file, defaults to drone_classes.txt',
default='model_data/drone_classes.txt')
parser.add_argument(
'-d',
'--data_path',
help='path to the HDF5 file which has a "test" group',
default='~/data/uav123.hdf5')
parser.add_argument(
'-s',
'--score_threshold',
type=float,
help='threshold for bounding box scores, default .6',
default=.6)
parser.add_argument(
'-iou',
'--iou_threshold',
type=float,
help='threshold for non max suppression IOU, default .65',
default=.6)
def _main(args):
model_path = os.path.expanduser(args.weight_path)
anchors_path = os.path.expanduser(args.anchors_path)
classes_path = os.path.expanduser(args.classes_path)
data_path = os.path.expanduser(args.data_path)
dataset = h5py.File(data_path, 'r')
class_names = get_classes(classes_path)
anchors = get_anchors(anchors_path)
if SHALLOW_DETECTOR:
anchors = anchors * 2
print(class_names)
print(anchors)
yolo_model, _ = create_model(anchors, class_names, load_pretrained=True,
feature_extractor=FEATURE_EXTRACTOR, pretrained_path=model_path)
hdf5_images = np.array(dataset['test/images'])
recall_precision(np.array(dataset['test/images']), np.array(dataset['test/boxes']),
yolo_model, anchors, class_names, num_samples=2048)
if __name__ == '__main__':
_main(parser.parse_args())