forked from dasguptar/treelstm.pytorch
-
Notifications
You must be signed in to change notification settings - Fork 2
/
dataset.py
86 lines (71 loc) · 2.79 KB
/
dataset.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
import os
from tqdm import tqdm
from copy import deepcopy
import torch
import torch.utils.data as data
import Constants
from tree import Tree
# Dataset class for SICK dataset
class SICKDataset(data.Dataset):
def __init__(self, path, vocab, num_classes):
super(SICKDataset, self).__init__()
self.vocab = vocab
self.num_classes = num_classes
self.lsentences = self.read_sentences(os.path.join(path, 'a.toks'))
self.rsentences = self.read_sentences(os.path.join(path, 'b.toks'))
self.ltrees = self.read_trees(os.path.join(path, 'a.parents'))
self.rtrees = self.read_trees(os.path.join(path, 'b.parents'))
self.labels = self.read_labels(os.path.join(path, 'sim.txt'))
self.size = self.labels.size(0)
def __len__(self):
return self.size
def __getitem__(self, index):
ltree = deepcopy(self.ltrees[index])
rtree = deepcopy(self.rtrees[index])
lsent = deepcopy(self.lsentences[index])
rsent = deepcopy(self.rsentences[index])
label = deepcopy(self.labels[index])
return (ltree, lsent, rtree, rsent, label)
def read_sentences(self, filename):
with open(filename, 'r') as f:
sentences = [self.read_sentence(line) for line in tqdm(f.readlines())]
return sentences
def read_sentence(self, line):
indices = self.vocab.convertToIdx(line.split(), Constants.UNK_WORD)
return torch.LongTensor(indices)
def read_trees(self, filename):
with open(filename, 'r') as f:
trees = [self.read_tree(line) for line in tqdm(f.readlines())]
return trees
def read_tree(self, line):
parents = list(map(int, line.split()))
trees = dict()
root = None
for i in range(1, len(parents) + 1):
if i - 1 not in trees.keys() and parents[i - 1] != -1:
idx = i
prev = None
while True:
parent = parents[idx - 1]
if parent == -1:
break
tree = Tree()
if prev is not None:
tree.add_child(prev)
trees[idx - 1] = tree
tree.idx = idx - 1
if parent - 1 in trees.keys():
trees[parent - 1].add_child(tree)
break
elif parent == 0:
root = tree
break
else:
prev = tree
idx = parent
return root
def read_labels(self, filename):
with open(filename, 'r') as f:
labels = list(map(lambda x: float(x), f.readlines()))
labels = torch.Tensor(labels)
return labels