-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathVocabulary.py
32 lines (26 loc) · 937 Bytes
/
Vocabulary.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
# -*- coding: utf-8 -*-
class Vocabulary(object):
"""Simple vocabulary wrapper."""
def __init__(self):
self.word2idx = {}
self.idx2word = {}
self.idx = 0
def add_word(self, word):
if not word in self.word2idx:
self.word2idx[word] = self.idx
self.idx2word[self.idx] = word
self.idx += 1
def __call__(self, word):
if not word in self.word2idx:
return self.word2idx['<unk>']
return self.word2idx[word]
def __len__(self):
return len(self.word2idx)
def indices2words(self, indices, start_token='<start>', end_token='<end>'):
start_token_idx = self.word2idx[start_token]
words = [self.idx2word[word_id.item()] for word_id in indices if word_id != start_token_idx]
try:
words = words[:words.index(end_token)]
except ValueError:
pass
return words