-
Notifications
You must be signed in to change notification settings - Fork 3
/
utils.py
84 lines (76 loc) · 3.03 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
import collections
import torch
# Characters in any transcript to be recognized
classes = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789-"
# Reference: https://github.com/meijieru/crnn.pytorch/blob/master/utils.py
class TranscriptEncoder:
"""
Encode transcript characters to integers and vice versa for CTC loss calc.
"""
def __init__(self, alphabet, ignore_case=False):
"""
Alphabets are possible characters in any transcript Which are [a-zA-Z0-9\-].
"""
self.ignore_case = ignore_case
if self.ignore_case:
alphabet = alphabet.lower()
self.alphabet = alphabet
self.dict = {}
for idx, char in enumerate(iter(self.alphabet)):
# idx 0 is reserved for CTC blank
self.dict[char] = idx + 1
self.dict['-'] = len(self.dict)
def encode(self, text):
"""
Support batch or single str.
Args:
text (str or list of str): texts to convert.
Returns:
torch.IntTensor [length_0 + length_1 + ... length_{n - 1}]: encoded texts.
torch.IntTensor [n]: length of each text.
"""
if isinstance(text, str):
text = [
self.dict.get(char.lower() if self.ignore_case else char, self.dict['-'])
for char in text
]
length = [len(text)]
elif isinstance(text, collections.Iterable):
length = [len(s) for s in text]
text = ''.join(text)
text, _ = self.encode(text)
return (torch.tensor(text), torch.tensor(length))
def decode(self, t, length, raw=False):
"""
Decode encoded texts back into strs.
Args:
torch.IntTensor [length_0 + length_1 + ... length_{n - 1}]: encoded texts.
torch.IntTensor [n]: length of each text.
Raises:
AssertionError: when the texts and its length does not match.
Returns:
text (str or list of str): texts to convert.
"""
if length.numel() == 1:
length = length.item()
assert t.numel() == length, "text with length: {} does not match declared length: {}".format(t.numel(), length)
if raw:
return ''.join([self.alphabet[i - 1] for i in t])
else:
char_list = []
for i in range(length):
if t[i] != 0 and (not (i > 0 and t[i - 1] == t[i])):
char_list.append(self.alphabet[t[i] - 1])
return ''.join(char_list)
else:
# batch mode
assert t.numel() == length.sum(), "texts with length: {} does not match declared length: {}".format(t.numel(), length.sum())
texts = []
index = 0
for i in range(length.numel()):
l = length[i]
texts.append(
self.decode(
t[index:index + l], torch.IntTensor([l]), raw=raw))
index += l
return texts