-
Notifications
You must be signed in to change notification settings - Fork 1
/
chunked_data.py
142 lines (113 loc) · 4.33 KB
/
chunked_data.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
# Write and read data with a buffer size smaller than the whole data size
# Assume data_count % chunk_size == 0
import h5py
import numpy as np
class ChunkedDataWriter:
def __init__(self,
filename,
proto,
chunk_size,
compression=True,
chunk_id_sep=':'):
self.proto = proto
self.chunk_size = chunk_size
self.compression = compression
self.chunk_id_sep = chunk_id_sep
self.count = 0
self.buffers = {
key: np.empty(shape=(self.chunk_size, ) + (shape or ()),
dtype=dtype)
for key, dtype, shape in self.proto
}
self.buffer_idx = 0
self.h5file = h5py.File(filename, 'w')
def __enter__(self):
return self
def write(self, *args):
for (key, _, _), data in zip(self.proto, args):
self.buffers[key][self.buffer_idx] = data
self.count += 1
self.buffer_idx += 1
if self.buffer_idx >= self.chunk_size:
self.flush()
def write_batch(self, *args):
batch_size = args[0].shape[0]
assert all(x.shape[0] == batch_size for x in args[1:])
data_idx = 0
while data_idx < batch_size:
data_rest = batch_size - data_idx
buffer_rest = self.chunk_size - self.buffer_idx
if data_rest < buffer_rest:
for (key, _, _), data in zip(self.proto, args):
buffer = self.buffers[key]
buffer[self.buffer_idx:self.buffer_idx +
data_rest] = data[data_idx:]
self.count += data_rest
data_idx += data_rest
self.buffer_idx += data_rest
else:
for (key, _, _), data in zip(self.proto, args):
buffer = self.buffers[key]
buffer[self.buffer_idx:] = data[data_idx:data_idx +
buffer_rest]
self.count += buffer_rest
data_idx += buffer_rest
self.buffer_idx += buffer_rest
self.flush()
# Write data without chunk
def create_dataset(self, key, data):
self.h5file.create_dataset(
key,
data=data,
compression='gzip' if self.compression else None,
shuffle=self.compression)
def flush(self):
if self.buffer_idx <= 0:
return
for key, _, _ in self.proto:
self.h5file.create_dataset(
'{}{}{}'.format(key, self.chunk_id_sep,
self.count // self.chunk_size),
data=self.buffers[key][:self.buffer_idx],
compression='gzip' if self.compression else None,
shuffle=self.compression)
self.buffer_idx = 0
self.h5file.flush()
def close(self):
self.flush()
self.h5file.close()
def __exit__(self, _type, value, traceback):
self.close()
class ChunkedDataReader:
def __init__(self, filename, chunk_id_sep=':'):
self.h5file = h5py.File(filename, 'r')
self.chunk_id_sep = chunk_id_sep
self.keys = self.h5file.keys()
self.key_names = {key.split(self.chunk_id_sep)[0] for key in self.keys}
def __enter__(self):
return self
def get_id(self, key):
return int(key.split(self.chunk_id_sep)[1])
def get(self, key, min_chunk=None, max_chunk=None):
keys = filter(lambda x: x.startswith(key + self.chunk_id_sep),
self.keys)
if min_chunk:
keys = filter(lambda x: self.get_id(x) >= min_chunk, keys)
if max_chunk:
keys = filter(lambda x: self.get_id(x) < max_chunk, keys)
keys = sorted(keys, key=lambda x: self.get_id(x))
if not keys:
raise KeyError(
f'{key}, min_chunk={min_chunk}, max_chunk={max_chunk}')
data = [self.h5file[key] for key in keys]
data = np.concatenate(data)
return data
def __getitem__(self, key):
if key in self.keys:
# Data is without chunk
return np.asarray(self.h5file[key])
return self.get(key)
def close(self):
self.h5file.close()
def __exit__(self, _type, value, traceback):
self.close()