-
Notifications
You must be signed in to change notification settings - Fork 0
/
dataset.py
101 lines (83 loc) · 2.8 KB
/
dataset.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
from torchvision import datasets, transforms
from nltk import word_tokenize
from torch.nn.utils.rnn import pad_sequence
import torch
import utils
def word_dict(coco):
words = {}
for i in range(len(coco)):
print(i, flush=True, end='\r')
_, captions = coco[i]
for cap in captions:
tokens = utils.token(cap)
for tok in tokens:
if tok not in words:
words[tok] = 1
else:
words[tok] += 1
vocab = []
for w in words:
if words[w] >= 2:
vocab.append(w)
vocab = ['*begin', '*end', '*unk'] + vocab
return vocab
def cap_to_idx(sen: str, words: dict):
word_list = word_tokenize(sen)
word_list = [w.lower() for w in word_list]
word_list = ['*begin'] + word_list[:-1] + ['*end']
#print(word_list)
enc = []
for w in word_list:
if w in words:
enc.append(words[w])
return torch.Tensor(enc).long()
class MyCoco(datasets.CocoCaptions):
def __init__(self, word_dict, *args, **kwargs):
super(MyCoco, self).__init__(*args, **kwargs)
self.word_dict = word_dict
def __getitem__(self, idx):
im, caps = super().__getitem__(idx)
cap = caps[0]
cap_enc = cap_to_idx(cap, self.word_dict)
return im, cap_enc
def load_vocab(fn='data/vocab.txt', to_del={'.', '...', '-', ',', ';', ':', '!', '?', 'a', 'an', 'the'}):
vocab = []
with open(fn, 'r') as f:
while True:
line = f.readline()
if line:
tok = line.rstrip('\n')
if tok not in to_del:
vocab.append(tok)
else:
break
vocab = {w: i for i, w in enumerate(vocab)}
return vocab
def collate_fn(data):
# sort data by caption length
data.sort(key=lambda x: len(x[1]), reverse=True)
images, captions = zip(*data)
# Merge image tensors (stack)
images = torch.stack(images, 0)
# Merge captions
caption_lengths = [len(caption) for caption in captions]
# zero-matrix num_captions x caption_max_length
padded_captions = pad_sequence(captions, padding_value=0, batch_first=True)
return images, padded_captions, caption_lengths#, caption_lengths
if __name__ == '__main__':
root = './data/train2014'
annot = './data/annotations/captions_train2014.json'
coco = datasets.CocoCaptions(root, annot, transform=transforms.ToTensor())
words = word_dict(coco)
with open('./data/vocab.txt', 'w') as f:
for w in words:
f.write(w + '\n')
vocab = load_vocab('data/vocab.txt')
print(len(vocab))
print(len(coco))
im, enc = coco[5]
im = transforms.ToPILImage()(im)
im.show()
for i in range(len(coco)):
_ = coco[i]
print(i , flush=True, end='\r')