-
Notifications
You must be signed in to change notification settings - Fork 5
/
data_loader.py
199 lines (167 loc) · 8.28 KB
/
data_loader.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
"""Data loader"""
import os
import torch
import utils
import random
import numpy as np
from transformers import BertTokenizer
class DataLoader(object):
def __init__(self, data_dir, bert_class, params, token_pad_idx=0, tag_pad_idx=-1):
self.data_dir = data_dir
self.batch_size = params.batch_size
self.max_len = params.max_len
self.device = params.device
self.seed = params.seed
self.token_pad_idx = token_pad_idx
self.tag_pad_idx = tag_pad_idx
tags = self.load_tags()
self.tag2idx = {tag: idx for idx, tag in enumerate(tags)}
self.idx2tag = {idx: tag for idx, tag in enumerate(tags)}
params.tag2idx = self.tag2idx
params.idx2tag = self.idx2tag
self.tokenizer = BertTokenizer.from_pretrained(bert_class, do_lower_case=False)
def load_tags(self):
tags = []
file_path = os.path.join(self.data_dir, 'tags.txt')
with open(file_path, 'r') as file:
for tag in file:
tags.append(tag.strip())
return tags
def load_sentences_tags(self, sentences_file, tags_file, d):
"""Loads sentences and tags from their corresponding files.
Maps tokens and tags to their indices and stores them in the provided dict d.
"""
sentences = []
tags = []
with open(sentences_file, 'r') as file:
for line in file:
# replace each token by its index
tokens = line.strip().split(' ')
subwords = list(map(self.tokenizer.tokenize, tokens))
subword_lengths = list(map(len, subwords))
subwords = ['CLS'] + [item for indices in subwords for item in indices]
token_start_idxs = 1 + np.cumsum([0] + subword_lengths[:-1])
sentences.append((self.tokenizer.convert_tokens_to_ids(subwords),token_start_idxs))
if tags_file != None:
with open(tags_file, 'r') as file:
for line in file:
# replace each tag by its index
tag_seq = [self.tag2idx.get(tag) for tag in line.strip().split(' ')]
tags.append(tag_seq)
# checks to ensure there is a tag for each token
assert len(sentences) == len(tags)
for i in range(len(sentences)):
assert len(tags[i]) == len(sentences[i][-1])
d['tags'] = tags
# storing sentences and tags in dict d
d['data'] = sentences
d['size'] = len(sentences)
def load_data(self, data_type):
"""Loads the data for each type in types from data_dir.
Args:
data_type: (str) has one of 'train', 'val', 'test' depending on which data is required.
Returns:
data: (dict) contains the data with tags for each type in types.
"""
data = {}
if data_type in ['train', 'val', 'test']:
print('Loading ' + data_type)
sentences_file = os.path.join(self.data_dir, data_type, 'sentences.txt')
tags_path = os.path.join(self.data_dir, data_type, 'tags.txt')
self.load_sentences_tags(sentences_file, tags_path, data)
elif data_type == 'interactive':
sentences_file = os.path.join(self.data_dir, data_type, 'sentences.txt')
self.load_sentences_tags(sentences_file, tags_file=None, d=data)
else:
raise ValueError("data type not in ['train', 'val', 'test']")
return data
def load_data_active(self, data_type, test_file):
"""Loads the data for each type in types from data_dir.
Args:
data_type: (str) has one of 'train', 'val', 'test' depending on which data is required.
Returns:
data: (dict) contains the data with tags for each type in types.
"""
data = {}
if data_type in ['train', 'val', 'test']:
print('Loading ' + data_type)
sentences_file = os.path.join(self.data_dir, data_type, 'sentences.txt')
tags_path = os.path.join(self.data_dir, data_type, 'tags.txt')
self.load_sentences_tags(sentences_file, tags_path, data)
elif data_type == 'interactive':
sentences_file = test_file
self.load_sentences_tags(sentences_file, tags_file=None, d=data)
else:
raise ValueError("data type not in ['train', 'val', 'test']")
return data
def data_iterator(self, data, shuffle=False):
"""Returns a generator that yields batches data with tags.
Args:
data: (dict) contains data which has keys 'data', 'tags' and 'size'
shuffle: (bool) whether the data should be shuffled
Yields:
batch_data: (tensor) shape: (batch_size, max_len)
batch_tags: (tensor) shape: (batch_size, max_len)
"""
# make a list that decides the order in which we go over the data- this avoids explicit shuffling of data
order = list(range(data['size']))
if shuffle:
random.seed(self.seed)
random.shuffle(order)
interMode = False if 'tags' in data else True
if data['size'] % self.batch_size == 0:
BATCH_NUM = data['size']//self.batch_size
else:
BATCH_NUM = data['size']//self.batch_size + 1
# one pass over data
for i in range(BATCH_NUM):
# fetch sentences and tags
if i * self.batch_size < data['size'] < (i+1) * self.batch_size:
sentences = [data['data'][idx] for idx in order[i*self.batch_size:]]
if not interMode:
tags = [data['tags'][idx] for idx in order[i*self.batch_size:]]
else:
sentences = [data['data'][idx] for idx in order[i*self.batch_size:(i+1)*self.batch_size]]
if not interMode:
tags = [data['tags'][idx] for idx in order[i*self.batch_size:(i+1)*self.batch_size]]
# batch length
batch_len = len(sentences)
# compute length of longest sentence in batch
batch_max_subwords_len = max([len(s[0]) for s in sentences])
max_subwords_len = min(batch_max_subwords_len, self.max_len)
max_token_len = 0
# prepare a numpy array with the data, initialising the data with pad_idx
batch_data = self.token_pad_idx * np.ones((batch_len, max_subwords_len))
batch_token_starts = []
# copy the data to the numpy array
for j in range(batch_len):
cur_subwords_len = len(sentences[j][0])
if cur_subwords_len <= max_subwords_len:
batch_data[j][:cur_subwords_len] = sentences[j][0]
else:
batch_data[j] = sentences[j][0][:max_subwords_len]
token_start_idx = sentences[j][-1]
token_starts = np.zeros(max_subwords_len)
token_starts[[idx for idx in token_start_idx if idx < max_subwords_len]] = 1
batch_token_starts.append(token_starts)
max_token_len = max(int(sum(token_starts)), max_token_len)
if not interMode:
batch_tags = self.tag_pad_idx * np.ones((batch_len, max_token_len))
for j in range(batch_len):
cur_tags_len = len(tags[j])
if cur_tags_len <= max_token_len:
batch_tags[j][:cur_tags_len] = tags[j]
else:
batch_tags[j] = tags[j][:max_token_len]
# since all data are indices, we convert them to torch LongTensors
batch_data = torch.tensor(batch_data, dtype=torch.long)
batch_token_starts = torch.tensor(batch_token_starts, dtype=torch.long)
if not interMode:
batch_tags = torch.tensor(batch_tags, dtype=torch.long)
# shift tensors to GPU if available
batch_data, batch_token_starts = batch_data.to(self.device), batch_token_starts.to(self.device)
if not interMode:
batch_tags = batch_tags.to(self.device)
yield batch_data, batch_token_starts, batch_tags
else:
yield batch_data, batch_token_starts