forked from pudumagico/nsvqasp
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathdataset.py
69 lines (58 loc) · 2.53 KB
/
dataset.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
import numpy as np
import h5py
import torch
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
import tools as utils
class ClevrQuestionDataset(Dataset):
def __init__(self, question_h5_path, max_samples, vocab_json):
self.max_samples = max_samples
question_h5 = h5py.File(question_h5_path, 'r')
self.questions = torch.LongTensor(np.asarray(question_h5['questions'], dtype=np.int64))
self.image_idxs = np.asarray(question_h5['image_idxs'], dtype=np.int64)
self.programs, self.answers = None, None
if 'programs' in question_h5:
self.programs = torch.LongTensor(np.asarray(question_h5['programs'], dtype=np.int64))
if 'answers' in question_h5:
self.answers = np.asarray(question_h5['answers'], dtype=np.int64)
self.vocab = utils.load_vocab(vocab_json)
self.img_filenames = question_h5['img_filenames']
def __len__(self):
if self.max_samples:
return min(self.max_samples, len(self.questions))
else:
return len(self.questions)
def __getitem__(self, idx):
if idx >= len(self):
raise ValueError('index %d out of range (%d)' % (idx, len(self)))
question = self.questions[idx]
image_idx = self.image_idxs[idx]
img_filename = self.img_filenames[idx]
program = -1
answer = -1
if self.programs is not None:
program = self.programs[idx]
if self.answers is not None:
answer = self.answers[idx]
return question, program, answer, image_idx, img_filename.decode()
def get_dataset(opt, split):
"""Get function for dataset class"""
assert split in ['train', 'val']
if opt.dataset == 'clevr':
if split == 'train':
question_h5_path = opt.clevr_train_question_path
max_sample = opt.max_train_samples
else:
question_h5_path = opt.clevr_val_question_path
max_sample = opt.max_val_samples
dataset = ClevrQuestionDataset(question_h5_path, max_sample, opt.clevr_vocab_path)
else:
raise ValueError('Invalid dataset')
return dataset
def get_dataloader(opt, split):
"""Get function for dataloader class"""
dataset = get_dataset(opt, split)
# shuffle = opt.shuffle if split == 'train' else 0
loader = DataLoader(dataset=dataset, batch_size=64, shuffle=True, num_workers=opt.num_workers)
print('| %s %s loader has %d samples' % (opt.dataset, split, len(loader.dataset)))
return loader