-
Notifications
You must be signed in to change notification settings - Fork 0
/
reader.py
120 lines (105 loc) · 4.03 KB
/
reader.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
# coding:utf-8
from utils import UNK, ModelType, TaskType, load_dic, sent2ids, logger, ModelType
class Dataset(object):
def __init__(self, train_path, test_path, source_dic_path, target_dic_path,
model_type):
self.train_path = train_path
self.test_path = test_path
self.source_dic_path = source_dic_path
self.target_dic_path = target_dic_path
self.model_type = ModelType(model_type)
self.source_dic = load_dic(self.source_dic_path)
self.target_dic = load_dic(self.target_dic_path)
_record_reader = {
ModelType.CLASSIFICATION_MODE: self._read_classification_record,
ModelType.REGRESSION_MODE: self._read_regression_record,
ModelType.RANK_MODE: self._read_rank_record,
}
assert isinstance(model_type, ModelType)
self.record_reader = _record_reader[model_type.mode]
self.is_infer = False
def train(self):
'''
Load trainset.
'''
logger.info("[reader] load trainset from %s" % self.train_path)
with open(self.train_path) as f:
# TODO:转换数据格式
for line_id, line in enumerate(f):
yield self.record_reader(line)
def test(self):
'''
Load testset.
'''
with open(self.test_path) as f:
for line_id, line in enumerate(f):
yield self.record_reader(line)
def infer(self):
self.is_infer = True
with open(self.train_path) as f:
for line in f:
yield self.record_reader(line)
def _read_classification_record(self, line):
'''
data format:
<source words> [TAB] <target words> [TAB] <label>
@line: str
a string line which represent a record.
'''
fs = line.strip().split('\t')
assert len(fs) == 3, "wrong format for classification\n" + \
"the format shoud be " +\
"<source words> [TAB] <target words> [TAB] <label>'"
source = sent2ids(fs[0], self.source_dic)
target = sent2ids(fs[1], self.target_dic)
if not self.is_infer:
label = int(fs[2])
return (
source,
target,
label, )
return source, target
def _read_regression_record(self, line):
'''
data format:
<source words> [TAB] <target words> [TAB] <label>
@line: str
a string line which represent a record.
'''
fs = line.strip().split('\t')
assert len(fs) == 3, "wrong format for regression\n" + \
"the format shoud be " +\
"<source words> [TAB] <target words> [TAB] <label>'"
source = sent2ids(fs[0], self.source_dic)
target = sent2ids(fs[1], self.target_dic)
if not self.is_infer:
label = float(fs[2])
return (
source,
target,
[label], )
return source, target
def _read_rank_record(self, line):
'''
data format:
<source words> [TAB] <left_target words> [TAB] <right_target words> [TAB] <label>
'''
fs = line.strip().split('\t')
assert len(fs) == 4, "wrong format for rank\n" + \
"the format should be " +\
"<source words> [TAB] <left_target words> [TAB] <right_target words> [TAB] <label>"
source = sent2ids(fs[0], self.source_dic)
left_target = sent2ids(fs[1], self.target_dic)
right_target = sent2ids(fs[2], self.target_dic)
if not self.is_infer:
label = int(fs[3])
return (source, left_target, right_target, label)
return source, left_target, right_target
if __name__ == '__main__':
path = './data/classification/train.txt'
test_path = './data/classification/test.txt'
source_dic = './data/vocab.txt'
dataset = Dataset(path, test_path, source_dic, source_dic,
ModelType.CLASSIFICATION)
for rcd in dataset.train():
print rcd