-
Notifications
You must be signed in to change notification settings - Fork 0
/
data_loader.py
346 lines (275 loc) · 13.6 KB
/
data_loader.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
import glob
import imp
import os
import random
import h5py
import numpy as np
import torch.utils.data as data
import itertools
from spirl.utils.general_utils import AttrDict, map_dict, maybe_retrieve, shuffle_with_seed
from spirl.utils.pytorch_utils import RepeatedDataLoader
from spirl.utils.video_utils import resize_video
class Dataset(data.Dataset):
def __init__(self, data_dir, data_conf, phase, shuffle=True, dataset_size=-1):
self.phase = phase
self.data_dir = data_dir
self.spec = data_conf.dataset_spec
self.dataset_size = dataset_size
self.device = data_conf.device
print('loading files from', self.data_dir)
self.filenames = self._get_filenames()
self.filenames = self._filter_filenames(self.filenames)
self.samples_per_file = self._get_samples_per_file(self.filenames[0])
self.shuffle = shuffle and phase == 'train'
self.n_worker = 8 if shuffle else 1 # was 4 before
def get_data_loader(self, batch_size, n_repeat):
print('len {} dataset {}'.format(self.phase, len(self)))
assert self.device in ['cuda', 'cpu'] # Otherwise the logic below is wrong
return RepeatedDataLoader(self, batch_size=batch_size, shuffle=self.shuffle, num_workers=self.n_worker,
drop_last=True, n_repeat=n_repeat, pin_memory=self.device == 'cuda',
worker_init_fn=lambda x: np.random.seed(np.random.randint(65536) + x))
def __getitem__(self, index):
"""Load a single sequence from disk according to index."""
raise NotImplementedError("Needs to be implemented in sub-class!")
def _get_samples_per_file(self, path):
"""Returns number of data samples per data file."""
raise NotImplementedError("Needs to be implemented in sub-class!")
def _get_filenames(self):
"""Loads filenames from self.data_dir, expects subfolders train/val/test, each with hdf5 files"""
filenames = sorted(glob.glob(os.path.join(self.data_dir, self.phase + '/*.h5')))
if not filenames:
raise RuntimeError('No filenames found in {}'.format(self.data_dir))
filenames = shuffle_with_seed(filenames)
return filenames
def _filter_filenames(self, filenames):
"""Optionally filters filenames / limits to max number of filenames etc."""
if "n_seqs" in self.spec:
# limit the max number of sequences in dataset
if self.phase == "train" and len(filenames) < self.spec.n_seqs:
raise ValueError("Not enough seqs in dataset!")
filenames = filenames[:self.spec.n_seqs]
if "seq_repeat" in self.spec:
# repeat sequences in dataset
repeat = max(self.spec.seq_repeat, self.dataset_size / len(filenames))
filenames *= int(repeat)
filenames = shuffle_with_seed(filenames)
return filenames
def __len__(self):
if self.dataset_size != -1:
return self.dataset_size
return len(self.filenames) * self.samples_per_file
class GlobalSplitDataset(Dataset):
"""Splits in train/val/test using global percentages."""
def _get_filenames(self):
filenames = self._load_h5_files(self.data_dir)
if not filenames:
raise RuntimeError('No filenames found in {}'.format(self.data_dir))
filenames = shuffle_with_seed(filenames)
filenames = self._split_with_percentage(self.spec.split, filenames)
return filenames
def _load_h5_files(self, dir):
filenames = []
for root, dirs, files in os.walk(dir):
for file in files:
if file.endswith(".h5"): filenames.append(os.path.join(root, file))
return filenames
def _split_with_percentage(self, frac, filenames):
assert sum(frac.values()) <= 1.0 # fractions cannot sum up to more than 1
assert self.phase in frac
if self.phase == 'train':
start, end = 0, frac['train']
elif self.phase == 'val':
start, end = frac['train'], frac['train'] + frac['val']
else:
start, end = frac['train'] + frac['val'], frac['train'] + frac['val'] + frac['test']
start, end = int(len(filenames) * start), int(len(filenames) * end)
return filenames[start:end]
class VideoDataset(Dataset):
"""Generic video dataset. Assumes that HDF5 file has images/states/actions/pad_mask."""
def __init__(self, *args, resolution, **kwargs):
super().__init__(*args, **kwargs)
self.randomize_length = self.spec.randomize_length if 'randomize_length' in self.spec else False
self.crop_subseq = 'crop_rand_subseq' in self.spec and self.spec.crop_rand_subseq
self.img_sz = resolution
self.subsampler = self._get_subsampler()
def __getitem__(self, index):
data = self._get_raw_data(index)
# maybe subsample seqs
if self.subsampler is not None:
data = self._subsample_data(data)
# sample random subsequence of fixed length
if self.crop_subseq:
end_ind = np.argmax(data.pad_mask * np.arange(data.pad_mask.shape[0], dtype=np.float32), 0)
data = self._crop_rand_subseq(data, end_ind, length=self.spec.subseq_len)
# Make length consistent
start_ind = 0
end_ind = np.argmax(data.pad_mask * np.arange(data.pad_mask.shape[0], dtype=np.float32), 0) \
if self.randomize_length or self.crop_subseq else self.spec.max_seq_len - 1
end_ind, data = self._sample_max_len_video(data, end_ind, target_len=self.spec.subseq_len if self.crop_subseq
else self.spec.max_seq_len)
if self.randomize_length:
end_ind = self._randomize_length(start_ind, end_ind, data)
data.start_ind, data.end_ind = start_ind, end_ind
# perform final processing on data
data.images = self._preprocess_images(data.images)
return data
def _get_raw_data(self, index):
data = AttrDict()
file_index = index // self.samples_per_file
path = self.filenames[file_index]
try:
with h5py.File(path, 'r') as F:
ex_index = index % self.samples_per_file # get the index
key = 'traj{}'.format(ex_index)
# Fetch data into a dict
for name in F[key].keys():
if name in ['states', 'actions', 'pad_mask']:
data[name] = F[key + '/' + name][()].astype(np.float32)
if key + '/images' in F:
data.images = F[key + '/images'][()]
else:
data.images = np.zeros((data.states.shape[0], 2, 2, 3), dtype=np.uint8)
except:
raise ValueError("Could not load from file {}".format(path))
return data
def _get_samples_per_file(self, path):
with h5py.File(path, 'r') as F:
return F['traj_per_file'][()]
def _get_subsampler(self):
subsampler_class = maybe_retrieve(self.spec, 'subsampler')
if subsampler_class is not None:
subsample_args = maybe_retrieve(self.spec, 'subsample_args')
assert subsample_args is not None # need to specify subsampler args dict
subsampler = subsampler_class(**subsample_args)
else:
subsampler = None
return subsampler
def _subsample_data(self, data_dict):
idxs = None
for key in data_dict:
data_dict[key], idxs = self.subsampler(data_dict[key], idxs=idxs)
return data_dict
def _crop_rand_subseq(self, data, end_ind, length):
"""Crops a random subseq of specified length from the full sequence."""
assert length <= end_ind + 1 # sequence needs to be longer than desired subsequence length
start = np.random.randint(0, end_ind - length + 2)
for key in data:
data[key] = data[key][start : int(start+length)]
return data
def _sample_max_len_video(self, data_dict, end_ind, target_len):
""" This function processes data tensors so as to have length equal to target_len
by sampling / padding if necessary """
extra_length = (end_ind + 1) - target_len
if self.phase == 'train':
offset = max(0, int(np.random.rand() * (extra_length + 1)))
else:
offset = 0
data_dict = map_dict(lambda tensor: self._maybe_pad(tensor, offset, target_len), data_dict)
if 'actions' in data_dict:
data_dict.actions = data_dict.actions[:-1]
end_ind = min(end_ind - offset, target_len - 1)
return end_ind, data_dict
@staticmethod
def _maybe_pad(val, offset, target_length):
"""Pads / crops sequence to desired length."""
val = val[offset:]
len = val.shape[0]
if len > target_length:
return val[:target_length]
elif len < target_length:
return np.concatenate((val, np.zeros([int(target_length - len)] + list(val.shape[1:]), dtype=val.dtype)))
else:
return val
def _randomize_length(self, start_ind, end_ind, data_dict):
""" This function samples part of the input tensors so that the length of the result
is uniform between 1 and max """
length = 3 + int(np.random.rand() * (end_ind - 2)) # The length of the seq is from 2 to total length
chop_length = int(np.random.rand() * (end_ind + 1 - length)) # from 0 to the reminder
end_ind = length - 1
pad_mask = np.logical_and((np.arange(self.spec['max_seq_len']) <= end_ind),
(np.arange(self.spec['max_seq_len']) >= start_ind)).astype(np.float32)
# Chop off the beginning of the arrays
def pad(array):
array = np.concatenate([array[chop_length:], np.repeat(array[-1:], chop_length, 0)], 0)
array[end_ind + 1:] = 0
return array
for key in filter(lambda key: key != 'pad_mask', data_dict):
data_dict[key] = pad(data_dict[key])
data_dict.pad_mask = pad_mask
return end_ind
def _preprocess_images(self, images):
assert images.dtype == np.uint8, 'image need to be uint8!'
images = resize_video(images, (self.img_sz, self.img_sz))
images = np.transpose(images, [0, 3, 1, 2]) # convert to channel-first
images = images.astype(np.float32) / 255 * 2 - 1
assert images.dtype == np.float32, 'image need to be float32!'
return images
class PreloadVideoDataset(VideoDataset):
"""Loads all sequences into memory for accelerated training (only possible for small datasets)."""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self._data = self._load_data()
def _load_data(self):
"""Load all sequences into memory."""
print("Preloading all sequences from {}".format(self.data_dir))
return [super(PreloadVideoDataset, self)._get_raw_data(i) for i in range(len(self.filenames))]
def _get_raw_data(self, index):
return self._data[index]
class GlobalSplitVideoDataset(VideoDataset, GlobalSplitDataset):
pass
class PreloadGlobalSplitVideoDataset(PreloadVideoDataset, GlobalSplitDataset):
pass
class GlobalSplitStateSequenceDataset(GlobalSplitVideoDataset):
"""Outputs observation in data dict, not images."""
def __getitem__(self, item):
data = super().__getitem__(item)
data.observations = data.pop('states')
return data
class GlobalSplitActionSequenceDataset(GlobalSplitVideoDataset):
"""Outputs observation in data dict, not images."""
def __getitem__(self, item):
data = super().__getitem__(item)
data.observations = data.pop('actions')
return data
class MixedVideoDataset(GlobalSplitVideoDataset):
"""Loads filenames from multiple directories and merges them with percentage."""
def _load_h5_files(self, unused_dir):
assert 'data_dirs' in self.spec and 'percentages' in self.spec
assert np.sum(self.spec.percentages) == 1
files = [super(MixedVideoDataset, self)._load_h5_files(dir) for dir in self.spec.data_dirs]
files = [shuffle_with_seed(f) for f in files]
total_size = min([1 / p * len(f) for p, f in zip(self.spec.percentages, files)])
filenames = list(itertools.chain.from_iterable(
[f[:int(total_size*p)] for p, f in zip(self.spec.percentages, files)]))
return filenames
class GeneratedVideoDataset(VideoDataset):
def __len__(self):
if self.dataset_size != -1:
return self.dataset_size
if self.phase == 'train':
return 10000
else:
return 200
def _get_filenames(self):
return [None]
def _get_samples_per_file(self, path):
pass
def get_sample(self):
raise NotImplementedError("Needs to be implemented by child class.")
def __getitem__(self, index):
if not self.shuffle:
# Set seed such that validation is always deterministic
np.random.seed(index)
data = self.get_sample()
return data
@staticmethod
def visualize(*args, **kwargs):
"""Enables dataset-specific visualization."""
pass
class RandomVideoDataset(GeneratedVideoDataset):
def get_sample(self):
data_dict = AttrDict()
data_dict.images = np.random.rand(self.spec['max_seq_len'], 3, self.img_sz, self.img_sz).astype(np.float32)
data_dict.states = np.random.rand(self.spec['max_seq_len'], self.spec['state_dim']).astype(np.float32)
data_dict.actions = np.random.rand(self.spec['max_seq_len'] - 1, self.spec['n_actions']).astype(np.float32)
return data_dict