-
Notifications
You must be signed in to change notification settings - Fork 1
/
run_scorer.py
84 lines (60 loc) · 2.65 KB
/
run_scorer.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
import sys
import pandas as pd
from coval.coval.eval import evaluator
from coval.coval.conll import reader
from utils import *
def main():
allmetrics = [('mentions', evaluator.mentions), ('muc', evaluator.muc),
('bcub', evaluator.b_cubed), ('ceafe', evaluator.ceafe),
('lea', evaluator.lea)]
NP_only = 'NP_only' in sys.argv
remove_nested = 'remove_nested' in sys.argv
keep_singletons = ('remove_singletons' not in sys.argv
and 'removIe_singleton' not in sys.argv)
min_span = False
path = sys.argv[1]
mention_type = sys.argv[2]
print(path, mention_type)
level = 'topic' # 'topic
sys_file = f'data/ecb/gold_singletons/dev_{mention_type}_{level}_level.conll'
all_scores = {}
max_conll_f1 = (None, 0)
for key_file in os.listdir(path):
if key_file.endswith('conll') and level in key_file: # and key_file.startswith('dev'):
print('Processing file: {}'.format(key_file))
full_path = os.path.join(path, key_file)
scores = evaluate(full_path, sys_file, allmetrics, NP_only, remove_nested,
keep_singletons, min_span)
all_scores[key_file] = scores
if scores['conll'] > max_conll_f1[1]:
max_conll_f1 = (key_file, scores['conll'])
df = pd.DataFrame.from_dict(all_scores)
df.to_csv(os.path.join(path, 'all_scores.csv'))
print(max_conll_f1)
def evaluate(key_file, sys_file, metrics, NP_only, remove_nested,
keep_singletons, min_span):
doc_coref_infos = reader.get_coref_infos(key_file, sys_file, NP_only,
remove_nested, keep_singletons, min_span)
conll = 0
conll_subparts_num = 0
scores = {}
for name, metric in metrics:
recall, precision, f1 = evaluator.evaluate_documents(doc_coref_infos,
metric,
beta=1)
scores['{}_{}'.format(name, 'recall')] = recall
scores['{}_{}'.format(name, 'precision')] = precision
scores['{}_{}'.format(name, 'f1')] = f1
if name in ["muc", "bcub", "ceafe"]:
conll += f1
conll_subparts_num += 1
# print(name.ljust(10), 'Recall: %.2f' % (recall * 100),
# ' Precision: %.2f' % (precision * 100),
# ' F1: %.2f' % (f1 * 100))
scores['conll'] = (conll / 3) * 100
return scores
# if conll_subparts_num == 3:
# conll = (conll / 3) * 100
# print('CoNLL score: %.2f' % conll)
if __name__ == '__main__':
main()