forked from lemonhu/NER-BERT-pytorch
-
Notifications
You must be signed in to change notification settings - Fork 1
/
data_loader.py
140 lines (111 loc) · 5.14 KB
/
data_loader.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
"""Data loader"""
import random
import numpy as np
import os
import sys
import torch
from pytorch_pretrained_bert import BertTokenizer
import utils
class DataLoader(object):
def __init__(self, data_dir, bert_model_dir, params, token_pad_idx=0):
self.data_dir = data_dir
self.batch_size = params.batch_size
self.max_len = params.max_len
self.device = params.device
self.seed = params.seed
self.token_pad_idx = 0
tags = self.load_tags()
self.tag2idx = {tag: idx for idx, tag in enumerate(tags)}
self.idx2tag = {idx: tag for idx, tag in enumerate(tags)}
params.tag2idx = self.tag2idx
params.idx2tag = self.idx2tag
self.tag_pad_idx = self.tag2idx['O']
self.tokenizer = BertTokenizer.from_pretrained(bert_model_dir, do_lower_case=True)
def load_tags(self):
tags = []
file_path = os.path.join(self.data_dir, 'tags.txt')
with open(file_path, 'r') as file:
for tag in file:
tags.append(tag.strip())
return tags
def load_sentences_tags(self, sentences_file, tags_file, d):
"""Loads sentences and tags from their corresponding files.
Maps tokens and tags to their indices and stores them in the provided dict d.
"""
sentences = []
tags = []
with open(sentences_file, 'r') as file:
for line in file:
# replace each token by its index
tokens = self.tokenizer.tokenize(line.strip())
sentences.append(self.tokenizer.convert_tokens_to_ids(tokens))
with open(tags_file, 'r') as file:
for line in file:
# replace each tag by its index
tag_seq = [self.tag2idx.get(tag) for tag in line.strip().split(' ')]
tags.append(tag_seq)
# checks to ensure there is a tag for each token
assert len(sentences) == len(tags)
for i in range(len(sentences)):
assert len(tags[i]) == len(sentences[i])
# storing sentences and tags in dict d
d['data'] = sentences
d['tags'] = tags
d['size'] = len(sentences)
def load_data(self, data_type):
"""Loads the data for each type in types from data_dir.
Args:
data_type: (str) has one of 'train', 'val', 'test' depending on which data is required.
Returns:
data: (dict) contains the data with tags for each type in types.
"""
data = {}
if data_type in ['train', 'val', 'test']:
sentences_file = os.path.join(self.data_dir, data_type, 'sentences.txt')
tags_path = os.path.join(self.data_dir, data_type, 'tags.txt')
self.load_sentences_tags(sentences_file, tags_path, data)
else:
raise ValueError("data type not in ['train', 'val', 'test']")
return data
def data_iterator(self, data, shuffle=False):
"""Returns a generator that yields batches data with tags.
Args:
data: (dict) contains data which has keys 'data', 'tags' and 'size'
shuffle: (bool) whether the data should be shuffled
Yields:
batch_data: (tensor) shape: (batch_size, max_len)
batch_tags: (tensor) shape: (batch_size, max_len)
"""
# make a list that decides the order in which we go over the data- this avoids explicit shuffling of data
order = list(range(data['size']))
if shuffle:
random.seed(self.seed)
random.shuffle(order)
# one pass over data
for i in range(data['size']//self.batch_size):
# fetch sentences and tags
sentences = [data['data'][idx] for idx in order[i*self.batch_size:(i+1)*self.batch_size]]
tags = [data['tags'][idx] for idx in order[i*self.batch_size:(i+1)*self.batch_size]]
# batch length
batch_len = len(sentences)
# compute length of longest sentence in batch
batch_max_len = max([len(s) for s in sentences])
max_len = min(batch_max_len, self.max_len)
# prepare a numpy array with the data, initialising the data with pad_idx
batch_data = self.token_pad_idx * np.ones((batch_len, max_len))
batch_tags = self.tag_pad_idx * np.ones((batch_len, max_len))
# copy the data to the numpy array
for j in range(batch_len):
cur_len = len(sentences[j])
if cur_len <= max_len:
batch_data[j][:cur_len] = sentences[j]
batch_tags[j][:cur_len] = tags[j]
else:
batch_data[j] = sentences[j][:max_len]
batch_tags[j] = tags[j][:max_len]
# since all data are indices, we convert them to torch LongTensors
batch_data = torch.tensor(batch_data, dtype=torch.long)
batch_tags = torch.tensor(batch_tags, dtype=torch.long)
# shift tensors to GPU if available
batch_data, batch_tags = batch_data.to(self.device), batch_tags.to(self.device)
yield batch_data, batch_tags