-
Notifications
You must be signed in to change notification settings - Fork 53
/
image_iter.py
91 lines (81 loc) · 3 KB
/
image_iter.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
#!/usr/bin/env python
# encoding: utf-8
'''
@author: yaoyaozhong
@contact: zhongyaoyao@bupt.edu.cn
@file: image_iter_yy.py
@time: 2020/06/03
@desc: training dataset loader for .rec
'''
import torchvision.transforms as transforms
import torch.utils.data as data
import numpy as np
import cv2
import os
import torch
import mxnet as mx
from mxnet import ndarray as nd
from mxnet import io
from mxnet import recordio
import logging
import numbers
import random
logger = logging.getLogger()
from IPython import embed
class FaceDataset(data.Dataset):
def __init__(self, path_imgrec, rand_mirror):
self.rand_mirror = rand_mirror
assert path_imgrec
if path_imgrec:
logging.info('loading recordio %s...',
path_imgrec)
path_imgidx = path_imgrec[0:-4] + ".idx"
print(path_imgrec, path_imgidx)
self.imgrec = recordio.MXIndexedRecordIO(path_imgidx, path_imgrec, 'r')
s = self.imgrec.read_idx(0)
header, _ = recordio.unpack(s)
if header.flag > 0:
print('header0 label', header.label)
self.header0 = (int(header.label[0]), int(header.label[1]))
# assert(header.flag==1)
# self.imgidx = range(1, int(header.label[0]))
self.imgidx = []
self.id2range = {}
self.seq_identity = range(int(header.label[0]), int(header.label[1]))
for identity in self.seq_identity:
s = self.imgrec.read_idx(identity)
header, _ = recordio.unpack(s)
a, b = int(header.label[0]), int(header.label[1])
count = b - a
self.id2range[identity] = (a, b)
self.imgidx += range(a, b)
print('id2range', len(self.id2range))
else:
self.imgidx = list(self.imgrec.keys)
self.seq = self.imgidx
def __getitem__(self, index):
idx = self.seq[index]
s = self.imgrec.read_idx(idx)
header, s = recordio.unpack(s)
label = header.label
if not isinstance(label, numbers.Number):
label = label[0]
_data = mx.image.imdecode(s)
if self.rand_mirror:
_rd = random.randint(0, 1)
if _rd == 1:
_data = mx.ndarray.flip(data=_data, axis=1)
_data = nd.transpose(_data, axes=(2, 0, 1))
_data = _data.asnumpy()
img = torch.from_numpy(_data)
return img, label
def __len__(self):
return len(self.seq)
if __name__ == '__main__':
root = '/raid/Data/faces_webface_112x112/train.rec'
embed()
dataset = FaceDataset(path_imgrec =root, rand_mirror = False)
trainloader = data.DataLoader(dataset, batch_size=32, shuffle=True, num_workers=2, drop_last=False)
print(len(dataset))
for data, label in trainloader:
print(data.shape, label)