forked from kevindegila/flask-joey
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathutils.py
47 lines (39 loc) · 1.8 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
from torchtext import data
from torchtext.datasets import TranslationDataset
from joeynmt.constants import UNK_TOKEN, EOS_TOKEN, BOS_TOKEN, PAD_TOKEN
class MonoLineDataset(TranslationDataset):
def __init__(self, line, field, **kwargs):
examples = []
line = line.strip()
fields = [('src', field)]
examples.append(data.Example.fromlist([line], fields))
super(TranslationDataset, self).__init__(examples, fields, **kwargs)
def load_line_as_data(line, level, lowercase, src_vocab, trg_vocab):
"""
Create a data set from one line.
Workaround for the usual torchtext data handling.
:param line: The input line to process.
:param level: "char", "bpe" or "word". Determines segmentation of the input.
:param lowercase: If True, lowercases inputs and outputs.
:param src_vocab: Path to source vocabulary.
:param trg_vocab: Path to target vocabulary.
:return:
"""
if level == "char":
tok_fun = lambda s: list(s)
else: # bpe or word, pre-tokenized
tok_fun = lambda s: s.split()
src_field = data.Field(init_token=None, eos_token=EOS_TOKEN, # FIXME
pad_token=PAD_TOKEN, tokenize=tok_fun,
batch_first=True, lower=lowercase,
unk_token=UNK_TOKEN,
include_lengths=True)
trg_field = data.Field(init_token=BOS_TOKEN, eos_token=EOS_TOKEN,
pad_token=PAD_TOKEN, tokenize=tok_fun,
unk_token=UNK_TOKEN,
batch_first=True, lower=lowercase,
include_lengths=True)
test_data = MonoLineDataset(line=line, field=(src_field))
src_field.vocab = src_vocab
trg_field.vocab = trg_vocab
return test_data, src_vocab, trg_vocab