forked from xuarehere/yolo_series_deepsort_pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
ped_det_server.py
executable file
·155 lines (128 loc) · 5.98 KB
/
ped_det_server.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
"""
This module gets video in input and outputs the
json file with coordination of bboxes in the video.
"""
from os.path import basename, splitext, join, isfile, isdir, dirname
from os import makedirs
from tqdm import tqdm
import cv2
import argparse
import torch
from detector import build_detector
from deep_sort import build_tracker
from utils.tools import tik_tok, is_video
from utils.draw import compute_color_for_labels
from utils.parser import get_config
from utils.json_logger import BboxToJsonLogger
import warnings
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument("--VIDEO_PATH", type=str, default="./demo/ped.avi")
parser.add_argument("--config_detection", type=str, default="./configs/yolov3.yaml")
parser.add_argument("--config_deepsort", type=str, default="./configs/deep_sort.yaml")
parser.add_argument("--write-fps", type=int, default=20)
parser.add_argument("--frame_interval", type=int, default=1)
parser.add_argument("--save_path", type=str, default="./output")
parser.add_argument("--cpu", dest="use_cuda", action="store_false", default=True)
args = parser.parse_args()
assert isfile(args.VIDEO_PATH), "Error: Video not found"
assert is_video(args.VIDEO_PATH), "Error: Not Supported format"
if args.frame_interval < 1: args.frame_interval = 1
return args
class VideoTracker(object):
def __init__(self, cfg, args):
self.cfg = cfg
self.args = args
use_cuda = args.use_cuda and torch.cuda.is_available()
if not use_cuda:
warnings.warn("Running in cpu mode!")
self.vdo = cv2.VideoCapture()
self.detector = build_detector(cfg, use_cuda=use_cuda)
self.deepsort = build_tracker(cfg, use_cuda=use_cuda)
self.class_names = self.detector.class_names
# Configure output video and json
self.logger = BboxToJsonLogger()
filename, extension = splitext(basename(self.args.VIDEO_PATH))
self.output_file = join(self.args.save_path, f'{filename}.avi')
self.json_output = join(self.args.save_path, f'{filename}.json')
if not isdir(dirname(self.json_output)):
makedirs(dirname(self.json_output))
def __enter__(self):
self.vdo.open(self.args.VIDEO_PATH)
self.total_frames = int(cv2.VideoCapture.get(self.vdo, cv2.CAP_PROP_FRAME_COUNT))
self.im_width = int(self.vdo.get(cv2.CAP_PROP_FRAME_WIDTH))
self.im_height = int(self.vdo.get(cv2.CAP_PROP_FRAME_HEIGHT))
video_details = {'frame_width': self.im_width,
'frame_height': self.im_height,
'frame_rate': self.args.write_fps,
'video_name': self.args.VIDEO_PATH}
codec = cv2.VideoWriter_fourcc(*'XVID')
self.writer = cv2.VideoWriter(self.output_file, codec, self.args.write_fps,
(self.im_width, self.im_height))
self.logger.add_video_details(**video_details)
assert self.vdo.isOpened()
return self
def __exit__(self, exc_type, exc_value, exc_traceback):
if exc_type:
print(exc_type, exc_value, exc_traceback)
def run(self):
idx_frame = 0
pbar = tqdm(total=self.total_frames + 1)
while self.vdo.grab():
if idx_frame % args.frame_interval == 0:
_, ori_im = self.vdo.retrieve()
timestamp = self.vdo.get(cv2.CAP_PROP_POS_MSEC)
frame_id = int(self.vdo.get(cv2.CAP_PROP_POS_FRAMES))
self.logger.add_frame(frame_id=frame_id, timestamp=timestamp)
self.detection(frame=ori_im, frame_id=frame_id)
self.save_frame(ori_im)
idx_frame += 1
pbar.update()
self.logger.json_output(self.json_output)
@tik_tok
def detection(self, frame, frame_id):
im = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
# do detection
bbox_xywh, cls_conf, cls_ids = self.detector(im)
if bbox_xywh is not None:
# select person class
mask = cls_ids == 0
bbox_xywh = bbox_xywh[mask]
bbox_xywh[:, 3:] *= 1.2 # bbox dilation just in case bbox too small
cls_conf = cls_conf[mask]
# do tracking
outputs = self.deepsort.update(bbox_xywh, cls_conf, im)
# draw boxes for visualization
if len(outputs) > 0:
frame = self.draw_boxes(img=frame, frame_id=frame_id, output=outputs)
def draw_boxes(self, img, frame_id, output, offset=(0, 0)):
for i, box in enumerate(output):
x1, y1, x2, y2, identity = [int(ii) for ii in box]
self.logger.add_bbox_to_frame(frame_id=frame_id,
bbox_id=identity,
top=y1,
left=x1,
width=x2 - x1,
height=y2 - y1)
x1 += offset[0]
x2 += offset[0]
y1 += offset[1]
y2 += offset[1]
# box text and bar
self.logger.add_label_to_bbox(frame_id=frame_id, bbox_id=identity, category='pedestrian', confidence=0.9)
color = compute_color_for_labels(identity)
label = '{}{:d}'.format("", identity)
t_size = cv2.getTextSize(label, cv2.FONT_HERSHEY_PLAIN, 2, 2)[0]
cv2.rectangle(img, (x1, y1), (x2, y2), color, 3)
cv2.rectangle(img, (x1, y1), (x1 + t_size[0] + 3, y1 + t_size[1] + 4), color, -1)
cv2.putText(img, label, (x1, y1 + t_size[1] + 4), cv2.FONT_HERSHEY_PLAIN, 2, [255, 255, 255], 2)
return img
def save_frame(self, frame) -> None:
if frame is not None: self.writer.write(frame)
if __name__ == "__main__":
args = parse_args()
cfg = get_config()
cfg.merge_from_file(args.config_detection)
cfg.merge_from_file(args.config_deepsort)
with VideoTracker(cfg, args) as vdo_trk:
vdo_trk.run()