This repository has been archived by the owner on Jan 23, 2024. It is now read-only.
forked from facebookresearch/MUSE
-
Notifications
You must be signed in to change notification settings - Fork 17
/
supervised.py
155 lines (134 loc) · 7.48 KB
/
supervised.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
# Original work Copyright (c) 2017-present, Facebook, Inc.
# Modified work Copyright (c) 2018, Xilun Chen
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
#
import os
import json
import argparse
from collections import OrderedDict
import numpy as np
import time
import torch
from src.utils import bool_flag, initialize_exp
from src.models import build_model
from src.trainer import Trainer
from src.evaluation import Evaluator
VALIDATION_METRIC_SUP = 'precision_at_1-csls_knn_10'
VALIDATION_METRIC_UNSUP = 'mean_cosine-csls_knn_10-S2T-10000'
# default path to embeddings embeddings if not otherwise specified
EMB_DIR = 'data/fasttext-vectors/'
# main
parser = argparse.ArgumentParser(description='Supervised training')
parser.add_argument("--seed", type=int, default=-1, help="Initialization seed")
parser.add_argument("--verbose", type=int, default=2, help="Verbose level (2:debug, 1:info, 0:warning)")
parser.add_argument("--exp_path", type=str, default="", help="Where to store experiment logs and models")
parser.add_argument("--exp_name", type=str, default="debug", help="Experiment name")
parser.add_argument("--exp_id", type=str, default="", help="Experiment ID")
# parser.add_argument("--cuda", type=bool_flag, default=True, help="Run on GPU")
parser.add_argument("--device", type=str, default="cuda", help="Run on GPU or CPU")
parser.add_argument("--export", type=str, default="txt", help="Export embeddings after training (txt / pth)")
# data
parser.add_argument("--src_langs", type=str, nargs='+', default=['de', 'es', 'fr', 'it', 'pt'], help="Source languages")
parser.add_argument("--tgt_lang", type=str, default='es', help="Target language")
parser.add_argument("--emb_dim", type=int, default=300, help="Embedding dimension")
parser.add_argument("--max_vocab", type=int, default=200000, help="Maximum vocabulary size (-1 to disable)")
# training refinement
parser.add_argument("--n_refinement", type=int, default=5, help="Number of refinement iterations (0 to disable the refinement procedure)")
parser.add_argument("--batch_size", type=int, default=32, help="Batch size")
parser.add_argument("--map_beta", type=float, default=0.001, help="Beta for orthogonalization")
# MPSR parameters
parser.add_argument("--mpsr_optimizer", type=str, default="adam", help="Multilingual Pseudo-Supervised Refinement optimizer")
parser.add_argument("--mpsr_orthogonalize", type=bool_flag, default=True, help="During MPSR, whether to perform orthogonalization")
parser.add_argument("--mpsr_n_steps", type=int, default=30000, help="Number of optimization steps for MPSR")
# dictionary creation parameters (for refinement)
parser.add_argument("--dico_train", type=str, default="default", help="Path to training dictionary (default or identical_char)")
parser.add_argument("--dico_eval", type=str, default="default", help="Path to evaluation dictionary")
parser.add_argument("--dico_method", type=str, default='csls_knn_10', help="Method used for dictionary generation (nn/invsm_beta_30/csls_knn_10)")
parser.add_argument("--dico_build", type=str, default='S2T&T2S', help="S2T,T2S,S2T|T2S,S2T&T2S")
parser.add_argument("--dico_threshold", type=float, default=0, help="Threshold confidence for dictionary generation")
parser.add_argument("--dico_max_rank", type=int, default=10000, help="Maximum dictionary words rank (0 to disable)")
parser.add_argument("--dico_min_size", type=int, default=0, help="Minimum generated dictionary size (0 to disable)")
parser.add_argument("--dico_max_size", type=int, default=0, help="Maximum generated dictionary size (0 to disable)")
parser.add_argument("--semeval_ignore_oov", type=bool_flag, default=True, help="Whether to ignore OOV in SEMEVAL evaluation (the original authors used True)")
# reload pre-trained embeddings
parser.add_argument("--src_embs", type=str, nargs='+', default=[], help="Reload source embeddings (should be in the same order as in src_langs)")
parser.add_argument("--tgt_emb", type=str, default='', help="Reload target embeddings")
parser.add_argument("--normalize_embeddings", type=str, default="", help="Normalize embeddings before training")
# parse parameters
params = parser.parse_args()
# post-processing options
params.src_N = len(params.src_langs)
params.all_langs = params.src_langs + [params.tgt_lang]
# load default embeddings if no embeddings specified
if len(params.src_embs) == 0:
params.src_embs = []
for lang in params.src_langs:
params.src_embs.append(os.path.join(EMB_DIR, f'wiki.{lang}.vec'))
if len(params.tgt_emb) == 0:
params.tgt_emb = os.path.join(EMB_DIR, f'wiki.{params.tgt_lang}.vec')
# check parameters
assert not params.device.lower().startswith('cuda') or torch.cuda.is_available()
assert params.dico_train in ["identical_char", "default"] or os.path.isfile(params.dico_train)
assert params.dico_build in ["S2T", "T2S", "S2T|T2S", "S2T&T2S"]
assert params.dico_max_size == 0 or params.dico_max_size < params.dico_max_rank
assert params.dico_max_size == 0 or params.dico_max_size > params.dico_min_size
assert all([os.path.isfile(emb) for emb in params.src_embs])
assert os.path.isfile(params.tgt_emb)
assert params.dico_eval == 'default' or os.path.isfile(params.dico_eval)
assert params.export in ["", "txt", "pth"]
# build logger / model / trainer / evaluator
logger = initialize_exp(params)
# N+1 embeddings, N mappings , N+1 discriminators
embs, mappings, discriminators = build_model(params, False)
trainer = Trainer(embs, mappings, discriminators, params)
evaluator = Evaluator(trainer)
# load a training dictionary. if a dictionary path is not provided, use a default
# one ("default") or create one based on identical character strings ("identical_char")
trainer.load_training_dico(params.dico_train)
# define the validation metric
VALIDATION_METRIC = VALIDATION_METRIC_UNSUP if params.dico_train == 'identical_char' else VALIDATION_METRIC_SUP
logger.info("Validation metric: %s" % VALIDATION_METRIC)
"""
Learning loop for Procrustes Iterative Learning
"""
for n_epoch in range(params.n_refinement + 1):
logger.info('Starting iteration %i...' % n_epoch)
# build a dictionary from aligned embeddings (unless
# it is the first iteration and we use the init one)
if n_epoch > 0 or not hasattr(trainer, 'dicos'):
trainer.build_dictionary()
# optimize MPSR
tic = time.time()
n_words_mpsr = 0
stats = {'MPSR_COSTS': []}
for n_iter in range(params.mpsr_n_steps):
# mpsr training step
n_words_mpsr += trainer.mpsr_step(stats)
# log stats
if n_iter % 500 == 0:
stats_str = [('MPSR_COSTS', 'MPSR loss')]
stats_log = ['%s: %.4f' % (v, np.mean(stats[k]))
for k, v in stats_str if len(stats[k]) > 0]
stats_log.append('%i samples/s' % int(n_words_mpsr / (time.time() - tic)))
logger.info(('%06i - ' % n_iter) + ' - '.join(stats_log))
# reset
tic = time.time()
n_words_mpsr = 0
for k, _ in stats_str:
del stats[k][:]
# embeddings evaluation
to_log = OrderedDict({'n_epoch': n_epoch})
evaluator.all_eval(to_log)
# JSON log / save best model / end of epoch
logger.info("__log__:%s" % json.dumps(to_log))
trainer.save_best(to_log, VALIDATION_METRIC)
logger.info('End of iteration %i.\n\n' % n_epoch)
# update the learning rate (effective only if using SGD for MPSR)
trainer.update_mpsr_lr(to_log, VALIDATION_METRIC)
# export embeddings
if params.export:
trainer.reload_best()
trainer.export()