-
Notifications
You must be signed in to change notification settings - Fork 1
/
process.py
136 lines (118 loc) · 5.53 KB
/
process.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
from pathlib import Path
import argparse
import datetime
import re
data_FMT = '%H:%M:%S.%f'
output_FMT = '%S.%f'
origin_time_str = "1900-01-01 00:00:00.000"
origin_time = datetime.datetime.fromisoformat(origin_time_str)
arrow = re.compile("-->")
def process_caption(caption):
caption = re.sub('\W+', '', caption)
return caption
def get_reco_id_factory(data_dir, extension):
def get_reco_id(file_path):
relative_path = str(Path(file_path).relative_to(data_dir))
return relative_path.replace(extension, "").replace("/", "-")
return get_reco_id
def fast_round(number, precision):
return int(number * (10.**precision)) / (10.**precision)
ts2str = lambda timestamp: f"{(int(timestamp*100)):06}"
def merge_segments(utterances):
text, start, end = utterances[0]
for idx, utt in enumerate(utterances[1:]):
cur_text, cur_start, cur_end = utt
if cur_start <= end:
end = cur_end
text += " " + cur_text
if cur_start > end or idx == len(utterances) - 2:
yield text, start, end
end = cur_end
text = cur_text
start = cur_start
def get_utt_id(reco_id, begin_seconds, end_seconds):
return f"{reco_id}-{ts2str(begin_seconds)}-{ts2str(end_seconds)}"
def parse_vtt_to_utterances(vtt_file, reco_id, begin_offset, end_offset, merge=False):
utterances = []
with open(vtt_file) as fp:
lines = fp.readlines()
for idx in range(len(lines)):
line = lines[idx]
line = line.strip()
match = arrow.search(line)
if match:
begin_time_str, end_time_str = line.strip().split(' --> ')
begin_time_str = '1900-01-01 ' + begin_time_str
end_time_str = '1900-01-01 ' + end_time_str
try:
begin_time = datetime.datetime.fromisoformat(begin_time_str)
end_time = datetime.datetime.fromisoformat(end_time_str)
except:
continue
if begin_time >= end_time:
continue
# write to caption file
caption = lines[idx+1]
caption = process_caption(caption)
#process audio
begin_offset_timedelta = datetime.timedelta(seconds=begin_offset)
end_offset_timedelta = datetime.timedelta(seconds=end_offset)
begin_time = begin_time + begin_offset_timedelta - origin_time
end_time = end_time + end_offset_timedelta - origin_time
begin_seconds = begin_time.total_seconds()
end_seconds = end_time.total_seconds()
begin_seconds = round(begin_seconds, 2)
end_seconds = round(end_seconds, 2)
utterances.append((caption, begin_seconds, end_seconds))
if merge:
utterances = merge_segments(utterances)
detailed_utterances = []
for utt in utterances:
caption, begin_seconds, end_seconds = utt
utt_id = get_utt_id(reco_id, begin_seconds, end_seconds)
utt = {"utt_id": utt_id, "reco_id": reco_id, "caption": caption, "begin_seconds": begin_seconds, "end_seconds": end_seconds}
detailed_utterances.append(utt)
return detailed_utterances
def write_segments(utterances, fp):
for utt in utterances:
fp.write(f"{utt['utt_id']} {utt['reco_id']} {utt['begin_seconds']} {utt['end_seconds']}\n")
def write_utt2spk(utterances, fp):
for utt in utterances:
utt_id = utt['utt_id']
fp.write(f"{utt_id} {utt_id}\n")
def write_text(utterances, fp):
for utt in utterances:
fp.write(f"{utt['utt_id']} {utt['caption']}\n")
def get_all_audio_text_file_pairs(data_dir, audio_file_extension=".wav", caption_file_extension=".vtt"):
for audio_file in Path(data_dir).rglob('*'+audio_file_extension):
caption_file = audio_file.with_suffix(caption_file_extension)
if caption_file.exists():
yield audio_file, caption_file
def write_wavscp(reco_id, audio_file, fp):
fp.write(f"{reco_id} {audio_file.resolve()}\n")
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('data_dir', help='data directory')
parser.add_argument('output_dir')
parser.add_argument('--audio-file-extension', default=".wav", help='audio file extension')
parser.add_argument('--caption-file-extension', default=".zh-TW.vtt", help='caption file extension')
parser.add_argument('--begin-offset', type=float, default=0.0, help='global begin offset')
parser.add_argument('--end-offset', type=float, default=0.0, help='global end offset')
parser.add_argument('--merge-consecutive-segments', action='store_true', help='merge segments')
args = parser.parse_args()
get_reco_id = get_reco_id_factory(args.data_dir, args.audio_file_extension)
segment_file = open(Path(args.output_dir) / "segments", "w")
utt2spk_file = open(Path(args.output_dir) / "utt2spk", "w")
text_file = open(Path(args.output_dir) / "text", "w")
wavscp_file = open(Path(args.output_dir) / "wav.scp", "w")
for audio_file, caption_file in get_all_audio_text_file_pairs(args.data_dir, args.audio_file_extension, args.caption_file_extension):
reco_id = get_reco_id(audio_file)
utterances = parse_vtt_to_utterances(caption_file, reco_id, args.begin_offset, args.end_offset, args.merge_consecutive_segments)
write_segments(utterances, segment_file)
write_utt2spk(utterances, utt2spk_file)
write_text(utterances, text_file)
write_wavscp(reco_id, audio_file, wavscp_file)
segment_file.close()
utt2spk_file.close()
text_file.close()
wavscp_file.close()