-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathdata_samplers.py
73 lines (59 loc) · 2.58 KB
/
data_samplers.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
import os
import itertools
import numpy as np
from torch.utils.data.sampler import Sampler
NO_LABEL = -1
class TwoStreamBatchSampler(Sampler):
"""Iterate two sets of indices
An 'epoch' is one iteration through the primary indices.
During the epoch, the secondary indices are iterated through
as many times as needed.
"""
def __init__(self, primary_indices, secondary_indices, batch_size, secondary_batch_size):
self.primary_indices = primary_indices
self.secondary_indices = secondary_indices
self.secondary_batch_size = secondary_batch_size
self.primary_batch_size = batch_size - secondary_batch_size
assert len(self.primary_indices) >= self.primary_batch_size > 0
assert len(self.secondary_indices) >= self.secondary_batch_size > 0
def __iter__(self):
primary_iter = iterate_once(self.primary_indices)
secondary_iter = iterate_eternally(self.secondary_indices)
return (
primary_batch + secondary_batch
for (primary_batch, secondary_batch)
in zip(grouper(primary_iter, self.primary_batch_size),
grouper(secondary_iter, self.secondary_batch_size))
)
def __len__(self):
return len(self.primary_indices) // self.primary_batch_size
def relabel_dataset(dataset, labels):
unlabeled_idxs = []
for idx, (path, _) in enumerate(dataset.imgs):
filename = os.path.basename(path)
# set label for image
if filename in labels:
label_idx = dataset.class_to_idx[labels[filename]]
dataset.imgs[idx] = path, label_idx
del labels[filename]
else:
dataset.imgs[idx] = path, NO_LABEL
unlabeled_idxs.append(idx)
if len(labels) != 0:
message = "List of unlabeled contains {} unknown files: {}, ..."
some_missing = ', '.join(list(labels.keys())[:5])
raise LookupError(message.format(len(labels), some_missing))
labeled_idxs = sorted(set(range(len(dataset.imgs))) - set(unlabeled_idxs))
return labeled_idxs, unlabeled_idxs
def iterate_once(indices):
return np.random.permutation(indices)
def iterate_eternally(indices):
def infinite_shuffles():
while True:
yield np.random.permutation(indices)
return itertools.chain.from_iterable(infinite_shuffles())
def grouper(iterable, n):
"Collect data into fixed-length chunks or blocks"
# grouper('ABCDEFG', 3) --> ABC DEF"
args = [iter(iterable)] * n # create n items where first item is non-empty list
return zip(*args) # unpair args have unequal length to n chunks