-
Notifications
You must be signed in to change notification settings - Fork 13
/
data.py
103 lines (92 loc) · 2.6 KB
/
data.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
import random
asap_ranges = {
1: (2.0, 12.0),
2: (1.0, 6.0),
3: (0.0, 3.0),
4: (0.0, 3.0),
5: (0.0, 4.0),
6: (0.0, 4.0),
7: (0.0, 30.0),
8: (0.0, 60.0),
9: (0.5, 9.0),
10: (1.0, 24.0),
}
asap_essay_lengths = {
1: 649,
2: 704,
3: 219,
4: 203,
5: 258,
6: 289,
7: 371,
8: 1077,
9: 415,
10: 1024,
11: 252
}
def fix_score(score, prompt):
"""
fix the predicted score
"""
if prompt == 9: # telis
int_part = float(int(score))
float_part = score - int_part
result = int_part
if float_part < 0.25:
result = int_part
elif float_part < 0.75:
result = int_part + 0.5
else:
result = int_part + 1
min_score, max_score = asap_ranges[prompt]
if result < min_score:
return min_score
elif result > max_score:
return max_score
else:
return result
elif prompt <= 10:
min_score, max_score = asap_ranges[prompt]
if score < min_score:
return min_score
elif score > max_score:
return max_score
else:
return round(score)
else:
return score
def is_zh(s):
# '包含汉字的返回TRUE'
for c in s:
if c >= '\u4e00' and c <= '\u9fa5':
return True
return False
def load_asap_data(data_file, max_len=1024, data_sample_rate=1.0):
ids = []
texts = []
labels = []
sample_index = 0
with open(data_file) as fin:
for line in fin:
rand_value = random.random()
if rand_value > data_sample_rate:
continue
line = line.strip()
line_vec = line.split("\t")
if len(line_vec) == 3:
ids.append(line_vec[0])
if len(line_vec[1].split(" ")) >= max_len:
line_vec[1] = " ".join(line_vec[1].split(" ")[0:max_len])
texts.append(line_vec[1])
labels.append(float(line_vec[2]))
else:
ids.append(str(sample_index))
sample_index += 1
if is_zh(line_vec[0]) and len(line_vec[0].replace(" ", "")) >= max_len:
line_vec[0] = line_vec[0].replace(" ", "")[0:max_len]
elif len(line_vec[0].split(" ")) >= max_len:
line_vec[0] = " ".join(line_vec[0].split(" ")[0:max_len])
texts.append(line_vec[0])
labels.append(float(line_vec[1]))
for id, text, label in zip(ids, texts, labels):
yield (id, text, label)