-
Notifications
You must be signed in to change notification settings - Fork 109
/
run_wer.py
107 lines (91 loc) · 3.52 KB
/
run_wer.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
import sys, os
from tqdm import tqdm
import multiprocessing
from jiwer import compute_measures
from zhon.hanzi import punctuation
import string
import numpy as np
from transformers import WhisperProcessor, WhisperForConditionalGeneration
import soundfile as sf
import scipy
import zhconv
from funasr import AutoModel
punctuation_all = punctuation + string.punctuation
wav_res_text_path = sys.argv[1]
res_path = sys.argv[2]
lang = sys.argv[3] # zh or en
device = "cuda:0"
def load_en_model():
model_id = "openai/whisper-large-v3"
processor = WhisperProcessor.from_pretrained(model_id)
model = WhisperForConditionalGeneration.from_pretrained(model_id).to(device)
return processor, model
def load_zh_model():
model = AutoModel(model="paraformer-zh")
return model
def process_one(hypo, truth):
raw_truth = truth
raw_hypo = hypo
for x in punctuation_all:
if x == '\'':
continue
truth = truth.replace(x, '')
hypo = hypo.replace(x, '')
truth = truth.replace(' ', ' ')
hypo = hypo.replace(' ', ' ')
if lang == "zh":
truth = " ".join([x for x in truth])
hypo = " ".join([x for x in hypo])
elif lang == "en":
truth = truth.lower()
hypo = hypo.lower()
else:
raise NotImplementedError
measures = compute_measures(truth, hypo)
ref_list = truth.split(" ")
wer = measures["wer"]
subs = measures["substitutions"] / len(ref_list)
dele = measures["deletions"] / len(ref_list)
inse = measures["insertions"] / len(ref_list)
return (raw_truth, raw_hypo, wer, subs, dele, inse)
def run_asr(wav_res_text_path, res_path):
if lang == "en":
processor, model = load_en_model()
elif lang == "zh":
model = load_zh_model()
params = []
for line in open(wav_res_text_path).readlines():
line = line.strip()
if len(line.split('|')) == 2:
wav_res_path, text_ref = line.split('|')
elif len(line.split('|')) == 3:
wav_res_path, wav_ref_path, text_ref = line.split('|')
elif len(line.split('|')) == 4: # for edit
wav_res_path, _, text_ref, wav_ref_path = line.split('|')
else:
raise NotImplementedError
if not os.path.exists(wav_res_path):
continue
params.append((wav_res_path, text_ref))
fout = open(res_path, "w")
n_higher_than_50 = 0
wers_below_50 = []
for wav_res_path, text_ref in tqdm(params):
if lang == "en":
wav, sr = sf.read(wav_res_path)
if sr != 16000:
wav = scipy.signal.resample(wav, int(len(wav) * 16000 / sr))
input_features = processor(wav, sampling_rate=16000, return_tensors="pt").input_features
input_features = input_features.to(device)
forced_decoder_ids = processor.get_decoder_prompt_ids(language="english", task="transcribe")
predicted_ids = model.generate(input_features, forced_decoder_ids=forced_decoder_ids)
transcription = processor.batch_decode(predicted_ids, skip_special_tokens=True)[0]
elif lang == "zh":
res = model.generate(input=wav_res_path,
batch_size_s=300)
transcription = res[0]["text"]
transcription = zhconv.convert(transcription, 'zh-cn')
raw_truth, raw_hypo, wer, subs, dele, inse = process_one(transcription, text_ref)
fout.write(f"{wav_res_path}\t{wer}\t{raw_truth}\t{raw_hypo}\t{inse}\t{dele}\t{subs}\n")
fout.flush()
run_asr(wav_res_text_path, res_path)