-
Notifications
You must be signed in to change notification settings - Fork 88
/
data_queue.py
112 lines (97 loc) · 5.14 KB
/
data_queue.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
#-------------------------------------------------------------------------------
# Author: Lukasz Janyst <lukasz@jany.st>
# Date: 17.09.2017
#-------------------------------------------------------------------------------
# This file is part of SSD-TensorFlow.
#
# SSD-TensorFlow is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# SSD-TensorFlow is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with SSD-Tensorflow. If not, see <http://www.gnu.org/licenses/>.
#-------------------------------------------------------------------------------
import queue as q
import numpy as np
import multiprocessing as mp
#-------------------------------------------------------------------------------
class DataQueue:
#---------------------------------------------------------------------------
def __init__(self, img_template, label_template, maxsize):
#-----------------------------------------------------------------------
# Figure out the data tupes, sizes and shapes of both arrays
#-----------------------------------------------------------------------
self.img_dtype = img_template.dtype
self.img_shape = img_template.shape
self.img_bc = len(img_template.tobytes())
self.label_dtype = label_template.dtype
self.label_shape = label_template.shape
self.label_bc = len(label_template.tobytes())
#-----------------------------------------------------------------------
# Make an array pool and queue
#-----------------------------------------------------------------------
self.array_pool = []
self.array_queue = mp.Queue(maxsize)
for i in range(maxsize):
img_buff = mp.Array('c', self.img_bc, lock=False)
img_arr = np.frombuffer(img_buff, dtype=self.img_dtype)
img_arr = img_arr.reshape(self.img_shape)
label_buff = mp.Array('c', self.label_bc, lock=False)
label_arr = np.frombuffer(label_buff, dtype=self.label_dtype)
label_arr = label_arr.reshape(self.label_shape)
self.array_pool.append((img_arr, label_arr))
self.array_queue.put(i)
self.queue = mp.Queue(maxsize)
#---------------------------------------------------------------------------
def put(self, img, label, boxes, *args, **kwargs):
#-----------------------------------------------------------------------
# Check whether the params are consistent with the data we can store
#-----------------------------------------------------------------------
def check_consistency(name, arr, dtype, shape, byte_count):
if type(arr) is not np.ndarray:
raise ValueError(name + ' needs to be a numpy array')
if arr.dtype != dtype:
raise ValueError('{}\'s elements need to be of type {} but is {}' \
.format(name, str(dtype), str(arr.dtype)))
if arr.shape != shape:
raise ValueError('{}\'s shape needs to be {} but is {}' \
.format(name, shape, arr.shape))
if len(arr.tobytes()) != byte_count:
raise ValueError('{}\'s byte count needs to be {} but is {}' \
.format(name, byte_count, len(arr.data)))
check_consistency('img', img, self.img_dtype, self.img_shape,
self.img_bc)
check_consistency('label', label, self.label_dtype, self.label_shape,
self.label_bc)
#-----------------------------------------------------------------------
# If we can not get the slot within timeout we are actually full, not
# empty
#-----------------------------------------------------------------------
try:
arr_id = self.array_queue.get(*args, **kwargs)
except q.Empty:
raise q.Full()
#-----------------------------------------------------------------------
# Copy the arrays into the shared pool
#-----------------------------------------------------------------------
self.array_pool[arr_id][0][:] = img
self.array_pool[arr_id][1][:] = label
self.queue.put((arr_id, boxes), *args, **kwargs)
#---------------------------------------------------------------------------
def get(self, *args, **kwargs):
item = self.queue.get(*args, **kwargs)
arr_id = item[0]
boxes = item[1]
img = np.copy(self.array_pool[arr_id][0])
label = np.copy(self.array_pool[arr_id][1])
self.array_queue.put(arr_id)
return img, label, boxes
#---------------------------------------------------------------------------
def empty(self):
return self.queue.empty()