Skip to content

Commit

Permalink
code refactor (#22)
Browse files Browse the repository at this point in the history
  • Loading branch information
hosseinfani committed Jul 22, 2022
1 parent 3366df7 commit c733b89
Show file tree
Hide file tree
Showing 4 changed files with 48 additions and 48 deletions.
43 changes: 43 additions & 0 deletions ds/trec09mq/refiners.qtypes.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
import pandas as pd
import pickle
import seaborn as sns
import matplotlib.pyplot as plt

def get_refiners_by_qtypes(all, dataset, qtypes, output):

refiners = pd.read_csv(all, index_col=0, nrows=0).columns[4::3]
q_ = pd.read_csv(dataset, sep=',')
qtypes = pd.read_csv(qtypes, sep='\t', usecols=['Topic', 'Class'])
qtypes['Topic'] = qtypes['Topic'].astype(int)
qtypes.rename(columns=({'Topic': 'qid'}), inplace=True)

q_ = pd.merge(q_, qtypes, on='qid', how='inner')
qtypes = qtypes['Class'].unique()

result = dict()
for qt in qtypes:
result[qt] = dict()
for r in refiners: result[qt][r] = 0

for index, row in q_.iterrows():
for i in range(1, row['star_model_count'] * 3, 3): result[row["Class"]][str(row[i + 3])] += 1

result = pd.DataFrame.from_dict(result)
with open(output, 'wb') as f: pickle.dump(result, f)
return result

def heatmap(matrix, output):
with open(matrix, 'rb') as f: df = pickle.load(f)
df=df.transpose()
sns.heatmap(df, annot=True)
plt.title("Distribution of query refiners in query types", fontsize=12)
plt.savefig(output, bbox_inches='tight', dpi=100)
plt.show()

if __name__ == "__main__":
output = 'refiners.qtypes.pkl'
result = get_refiners_by_qtypes('../../qe/output/trec09mq/topics.trec09mq.bm25.map.all.csv',
'../../qe/output/trec09mq/topics.trec09mq.bm25.map.dataset.csv',
'./queryclasses',
output)
heatmap(output, "./refiners.qtypes.png")
47 changes: 0 additions & 47 deletions ds/trec09mq/stat.py

This file was deleted.

2 changes: 2 additions & 0 deletions environment.yml
Original file line number Diff line number Diff line change
Expand Up @@ -34,3 +34,5 @@ dependencies:
- torch==1.7.1
- pyserini==0.10.0.1
- prettytable==2.1.0
- seaborn
- matplotlib
4 changes: 3 additions & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -20,4 +20,6 @@ tqdm==4.45.0
transformers==4.0.0
sentencepiece==0.1.9
pyserini==0.10.0.1
prettytable==2.1.0
prettytable==2.1.0
seaborn
matplotlib

0 comments on commit c733b89

Please sign in to comment.