-
Notifications
You must be signed in to change notification settings - Fork 59
/
Copy pathhmdb_dataset.py
113 lines (87 loc) · 3.48 KB
/
hmdb_dataset.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
import torch
import torch.utils.data as data_utl
import numpy as np
import random
import os
import lintel
def video_to_tensor(pic):
"""Convert a ``numpy.ndarray`` to tensor.
Converts a numpy.ndarray (T x H x W x C)
to a torch.FloatTensor of shape (C x T x H x W)
Args:
pic (numpy.ndarray): Video to be converted to tensor.
Returns:
Tensor: Converted video.
"""
return torch.from_numpy(pic.transpose([3,0,1,2]))
class HMDB(data_utl.Dataset):
def __init__(self, split_file, root, mode='rgb', length=16, model='2d', random=True, c2i={}):
self.class_to_id = c2i
self.id_to_class = []
for i in range(len(c2i.keys())):
for k in c2i.keys():
if c2i[k] == i:
self.id_to_class.append(k)
cid = 0
self.data = []
self.model = model
self.size = 112
with open(split_file, 'r') as f:
for l in f.readlines():
if len(l) <= 5:
continue
v,c = l.strip().split(' ')
v = mode+'_'+v.split('.')[0]+'.mp4'
if c not in self.class_to_id:
self.class_to_id[c] = cid
self.id_to_class.append(c)
cid += 1
self.data.append([os.path.join(root, v), self.class_to_id[c]])
self.split_file = split_file
self.root = root
self.mode = mode
self.length = length
self.random = random
def __getitem__(self, index):
vid, cls = self.data[index]
with open(vid, 'rb') as f:
enc_vid = f.read()
df, w, h, _ = lintel.loadvid(enc_vid, should_random_seek=self.random, num_frames=self.length*2)
df = np.frombuffer(df, dtype=np.uint8)
w=w//2
h=h//2
# center crop
if not self.random:
i = int(round((h-self.size)/2.))
j = int(round((w-self.size)/2.))
df = np.reshape(df, newshape=(self.length*2, h*2, w*2, 3))[::2,::2,::2,:][:, i:-i, j:-j, :]
else:
th = self.size
tw = self.size
i = random.randint(0, h - th) if h!=th else 0
j = random.randint(0, w - tw) if w!=tw else 0
df = np.reshape(df, newshape=(self.length*2, h*2, w*2, 3))[::2,::2,::2,:][:, i:i+th, j:j+tw, :]
if self.mode == 'flow':
#print(df[:,:,:,1:].mean())
#exit()
# only take the 2 channels corresponding to flow (x,y)
df = df[:,:,:,1:]
if self.model == '2d':
# this should be redone...
# stack 10 along channel axis
df = np.asarray([df[:10],df[2:12],df[4:14]]) # gives 3x10xHxWx2
df = df.transpose(0,1,4,2,3).reshape(3,20,self.size,self.size).transpose(0,2,3,1)
df = 1-2*(df.astype(np.float32)/255)
if self.model == '2d':
# 2d -> return TxCxHxW
return df.transpose([0,3,1,2]), cls
# 3d -> return CxTxHxW
return df.transpose([3,0,1,2]), cls
def __len__(self):
return len(self.data)
if __name__ == '__main__':
DS = HMDB
dataseta = DS('data/hmdb/split1_train.txt', '/ssd/hmdb/', model='2d', mode='flow', length=16)
dataset = DS('data/hmdb/split1_test.txt', '/ssd/hmdb/', model='2d', mode='rgb', length=16, c2i=dataseta.class_to_id)
for i in range(len(dataseta)):
print(dataseta[i][0].shape)