-
Notifications
You must be signed in to change notification settings - Fork 5
/
generate_data_qa.py
53 lines (39 loc) · 1.56 KB
/
generate_data_qa.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
from tqdm import tqdm
import pickle as pkl
import numpy as np
import torch
train_tri_list_path = 'qa_data/train_tri_list.prune.pkl'
dev_tri_list_path = 'qa_data/dev_tri_list.prune.pkl'
test_tri_list_path = 'qa_data/test_tri_list.prune.pkl'
ent_emb_all_path = 'qa_data/tzw.ent.npy'
train_tri_list = pkl.load(open(train_tri_list_path, 'rb'))
dev_tri_list = pkl.load(open(dev_tri_list_path, 'rb'))
test_tri_list = pkl.load(open(test_tri_list_path, 'rb'))
out_ent = 'dataset/down_qa/entities.dict'
out_ent_emb = 'dataset/down_qa/ent_emb.pt'
out_rel = 'dataset/down_qa/relations.dict'
ent_dic = {}
rel_dic = {}
for triples_list in [train_tri_list, dev_tri_list, test_tri_list]:
for row in tqdm(triples_list):
for triple in row[0]:
h, r, t = triple
if h not in ent_dic:
ent_dic[h] = len(ent_dic)
if t not in ent_dic:
ent_dic[t] = len(ent_dic)
if r not in rel_dic:
rel_dic[r] = len(rel_dic)
ent_dic = sorted(ent_dic.items(), key = lambda d:d[1], reverse=False)
rel_dic = sorted(rel_dic.items(), key = lambda d:d[1], reverse=False)
ent_emb = np.zeros((len(ent_dic),1024))
ent_emb_all = np.load(ent_emb_all_path)
with open(out_ent,'w') as f_ent:
for name, idx in ent_dic:
f_ent.write(str(idx)+'\t'+str(name)+'\n')
ent_emb[idx] = ent_emb_all[name]
ent_emb = torch.from_numpy(ent_emb)
torch.save(ent_emb, out_ent_emb)
with open(out_rel,'w') as f_rel:
for name, idx in rel_dic:
f_rel.write(str(idx)+'\t'+str(name)+'\n')