-
Notifications
You must be signed in to change notification settings - Fork 3
/
ir_utils.py
149 lines (129 loc) · 4.75 KB
/
ir_utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
import os, json, string, re
from gensim.summarization.bm25 import BM25
import parsing_utils
from constant import *
from pinput import *
#################### Building IR models
def build_query(params):
result = ""
for param in params:
param_tokens = tokenize(param)
result += (" " + " ".join(param_tokens))
return result.strip()
def build_bm25(docsdata, level, ret_traindocs=False):
"""
build array of traindocs for modeling [td1, td2, td3],
and a map from tescase name to its testcls level doc pos in traindocs
"""
if level == "cls":
traindocs, indices = testcls_traindocs(docsdata)
else:
traindocs, indices = testcase_traindocs(docsdata)
model = BM25(traindocs)
ret = {"indices": indices, "model": model}
if ret_traindocs:
ret["traindocs"] = traindocs
return ret
def get_sim_di_q(tcp, img, irdata, testcases):
# get testcase score
params = list(parsing_utils.get_file_params(img).keys())
query = build_query(params).split()
model = irdata["model"]# model: [d1, d2, d3...]
indices = irdata["indices"] # indice: cls1: idx to d1, ...
scores = model.get_scores(query)
sim_di_q = {}
for test in list(testcases):
testkey = test.split("#")[0] if "cls" in tcp else test
sim_di_q[test] = scores[indices[testkey]]
return sim_di_q
######################## build document, traversing extended classes
def testcase_traindocs(docsdata):
aggr_extclass(docsdata)
tests = sorted(list(get_tests(docsdata)))
indices = {}
traindocs = []
for idx, test in enumerate(tests):
indices[test] = idx
testcls, testcase = test.split("#")
# each test's doc is its test body plus test file body shared by all tests
doc = docsdata[testcls][testcase] + docsdata[testcls]["global"]
traindocs.append(doc)
return traindocs, indices
def testcls_traindocs(docsdata):
traindocs = []
indices = {}
for idx, (testcls, clsdata) in enumerate(sorted(docsdata.items())):
indices[testcls] = idx
# each test's doc is the entire test file
doc = []
for key in clsdata:
if key != "extendedClasses":
doc += clsdata[key]
traindocs.append(doc)
return traindocs, indices
def aggr_extclass(docsdata):
for testcls, clsdata in docsdata.items():
if len(clsdata["extendedClasses"]) == 1:
# collect chain of extended classes
currcls = testcls
chain = []
while currcls:
chain.append(currcls)
currcls = find_extcls(docsdata.keys(), docsdata[currcls]["extendedClasses"])
# traverse reversed chain, aggregate test infos
for i in range(len(chain)-1, 0, -1):
for key in docsdata[chain[i]]:
if key == "extendedClasses":
continue
elif key == "global":
docsdata[chain[i-1]][key] += docsdata[chain[i]][key]
else:
if key not in docsdata[chain[i-1]]:
docsdata[chain[i-1]][key] = []
docsdata[chain[i-1]][key] += docsdata[chain[i]][key]
elif len(clsdata["extendedClasses"]) > 1:
print("[strange] more than one extended class", testcls)
else:
pass
def find_extcls(allcls, exts):
# find the fully qualified name of the extend cls
for extcls in exts:
for c in allcls:
if c.endswith("."+extcls):
return c
# possible to return the wrong cls with the same suffix but different qualname
return None
def get_tests(docsdata):
tests = set()
for testcls in docsdata:
for test in docsdata[testcls].keys():
if test != "global" and test != "extendedClasses":
tests.add(testcls + "#" + test)
return tests
##################### string manipulation and serialization
def tokenize(s):
result = ""
buff = ""
for word in s.split():
for c in word:
if c in string.ascii_lowercase:
buff += c
elif c in string.ascii_uppercase:
# old buffer
if buff != "":
# add to result only if len(buffer) > 1
if len(buff) > 1:
result += buff + " "
buff = ""
# new buffer
buff += c.lower()
else:
if buff != "":
if len(buff) > 1:
result += buff + " "
buff = ""
if buff != "":
if len(buff) > 1:
result += buff + " "
buff = ""
return result.strip().split()