-
Notifications
You must be signed in to change notification settings - Fork 2
/
combo.py
126 lines (112 loc) · 6.15 KB
/
combo.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
###############################################################################
#
# Iterative Procrustes + SGM Combination System.
#
# Written by Kelly Marchisio, 2020-2021.
#
###############################################################################
import argparse
from third_party.vecmap import embeddings
import proc_v_sgm
def main(args):
# Process data. Get train/dev split, seeds.
(word_pairs, src_embs, src_word2ind, src_ind2word, trg_embs, trg_word2ind,
trg_ind2word, oov_word_pairs) = proc_v_sgm.load_embs_and_wordpairs(args)
print('OOV Word Pairs:', oov_word_pairs)
_, (train_inds, dev_inds) = proc_v_sgm.create_train_dev_split(word_pairs,
args.n_seeds, src_word2ind, trg_word2ind, args.randomize_seeds)
gold_src_train_inds, gold_trg_train_inds = proc_v_sgm.unzip_pairs(train_inds)
src_dev_inds, trg_dev_inds = proc_v_sgm.unzip_pairs(dev_inds)
# Normalize embeddings in-place.
print('Normalizing embeddings...')
embeddings.normalize(src_embs, args.norm)
embeddings.normalize(trg_embs, args.norm)
print('Done normalizing embeddings.')
# Make similarity matrices.
xxT = src_embs @ src_embs.T
yyT = trg_embs @ trg_embs.T
sgm_hyps_src = []
sgm_hyps_trg = []
for i in range(10):
print('----------------------------------')
print('----------------------------------')
print('Starting Iteration', i)
print('----------------------------------')
if args.start == 'proc':
print('\nRunning Iterative Procrustes for {0} iterations'.format(
args.proc_iters), flush=True)
_, _, proc_hyps_int, _, _ = proc_v_sgm.iterative_procrustes_w_csls(src_embs, trg_embs,
sgm_hyps_src, sgm_hyps_trg, gold_src_train_inds,
gold_trg_train_inds, dev_inds, args.new_nseeds_per_round,
total_i=args.proc_iters,
diff_seeds_for_rev=args.diff_seeds_for_rev, k=args.k)
print('\nRunning SGM', flush=True)
proc_hyps_src, proc_hyps_trg = proc_v_sgm.unzip_pairs(proc_hyps_int)
hyps, _, sgm_hyps_int = proc_v_sgm.iterative_softsgm(xxT, yyT, proc_hyps_src,
proc_hyps_trg, gold_src_train_inds, gold_trg_train_inds,
args.softsgm_iters, args.k, args.min_prob, dev_inds,
args.new_nseeds_per_round, curr_i=1,
total_i=args.iterative_softsgm_iters,
diff_seeds_for_rev=args.diff_seeds_for_rev,
run_reverse=True)
sgm_hyps_src, sgm_hyps_trg = proc_v_sgm.unzip_pairs(sgm_hyps_int)
elif args.start == 'sgm':
print('\nRunning SGM', flush=True)
_, _, sgm_hyps_int = proc_v_sgm.iterative_softsgm(xxT, yyT,
sgm_hyps_src, sgm_hyps_trg,
gold_src_train_inds, gold_trg_train_inds,
args.softsgm_iters, args.k, args.min_prob, dev_inds,
args.new_nseeds_per_round, curr_i=1,
total_i=args.iterative_softsgm_iters,
diff_seeds_for_rev=args.diff_seeds_for_rev, run_reverse=True)
print('\nRunning Iterative Procrustes for {0} iterations'.format(
args.proc_iters), flush=True)
sgm_hyps_src, sgm_hyps_trg = proc_v_sgm.unzip_pairs(sgm_hyps_int)
hyps, _, proc_hyps_int, _, _ = proc_v_sgm.iterative_procrustes_w_csls(src_embs, trg_embs,
sgm_hyps_src, sgm_hyps_trg, gold_src_train_inds,
gold_trg_train_inds, dev_inds, args.new_nseeds_per_round,
total_i=args.proc_iters,
diff_seeds_for_rev=args.diff_seeds_for_rev, k=args.k)
sgm_hyps_src, sgm_hyps_trg = proc_v_sgm.unzip_pairs(proc_hyps_int)
# Eval.
dev_src_inds, dev_trg_inds = proc_v_sgm.unzip_pairs(dev_inds)
dev_hyps = set(hyp for hyp in hyps if hyp[0] in dev_src_inds)
matches, precision, recall = proc_v_sgm.eval(dev_hyps, dev_inds)
print('\tDev Pairs matched: {0} \n\t(Precision; {1}%) (Recall: {2}%)'
.format(len(matches), precision, recall), flush=True)
parser = argparse.ArgumentParser(description='LAP Experiments')
parser.add_argument('--src-embs', metavar='PATH', required=True,
help='Path to source embeddings.')
parser.add_argument('--trg-embs', metavar='PATH', required=True,
help='Path to target embeddings.')
parser.add_argument('--start', choices=['proc', 'sgm'], required=True,
help='Whether to start with Iterative Procrustes or SGM.')
parser.add_argument('--norm', metavar='N', choices=['noop', 'unit', 'center'],
nargs='+', required=True,
help='How to normalize embeddings (can take multiple args)')
parser.add_argument('--max-embs', type=int, default=200000,
help='Maximum num of word embeddings to use.')
parser.add_argument('--min-prob', type=float, default=0.0,
help='The minimum probability to consider for softsgm')
parser.add_argument('--pairs', metavar='PATH', required=True,
help='train seeds + dev pairs')
parser.add_argument('--n-seeds', type=int, required=True, help='Num train seeds to use')
parser.add_argument('--proc-iters', type=int, default=10,
help='Rounds of iterative Procrustes to run.')
parser.add_argument('--iterative-softsgm-iters', type=int, default=1,
help='Rounds of iterative SoftSGM to run.')
parser.add_argument('--softsgm-iters', type=int, default=1,
help='Rounds of SoftSGM to run to create probdist.')
parser.add_argument('--k', type=int, default=1,
help='How many hypotheses to return per source word.')
parser.add_argument('--randomize-seeds', action='store_true',
help='If set, randomizes the seeds to use (instead of getting them in '
'order from args.pairs file)')
parser.add_argument('--new-nseeds-per-round', metavar='N', type=int, nargs='+',
default=-1, help='Number of seeds to add per round in iterative runs.')
parser.add_argument('--diff-seeds-for-rev', action='store_true',
help='When running matching in reverse, regenerate seeds (if there are '
'additional input seeds from a previous round, these will then be '
'shuffled.')
args = parser.parse_args()
main(args)