-
Notifications
You must be signed in to change notification settings - Fork 0
/
translate.py
23 lines (22 loc) · 730 Bytes
/
translate.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
import requests
from easynmt import EasyNMT
import os
trans_model = EasyNMT("m2m_100_1.2B")
def translate_func(src_question, src_lang, tgt_lang):
assert type(src_question) == type("")
assert type(src_lang) == type("")
assert type(tgt_lang) == type("")
if "[SEP]" in src_question:
src_question = list(filter(lambda xx: xx ,map(lambda x: x.strip() ,
src_question.split("[SEP]"))))
else:
src_question = [src_question]
tgt_question = trans_model.translate(
src_question,
source_lang=src_lang, target_lang = tgt_lang
)
assert type(tgt_question) == type([])
tgt_question = "[SEP]".join(tgt_question)
return {
"Target Question": tgt_question
}