Skip to content

Commit

Permalink
Merge pull request #198 from jhj0517/feature/upgrade-faster-whisper
Browse files Browse the repository at this point in the history
Migrate faster whisper to 1.0.3
  • Loading branch information
jhj0517 authored Jul 7, 2024
2 parents d8c2ba0 + c1f12f6 commit 6a24751
Show file tree
Hide file tree
Showing 5 changed files with 22 additions and 36 deletions.
6 changes: 0 additions & 6 deletions app.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,6 @@ def launch(self):
nb_min_speech_duration_ms = gr.Number(label="Minimum Speech Duration (ms)", precision=0, value=250)
nb_max_speech_duration_s = gr.Number(label="Maximum Speech Duration (s)", value=9999)
nb_min_silence_duration_ms = gr.Number(label="Minimum Silence Duration (ms)", precision=0, value=2000)
nb_window_size_sample = gr.Number(label="Window Size (samples)", precision=0, value=1024)
nb_speech_pad_ms = gr.Number(label="Speech Padding (ms)", precision=0, value=400)
with gr.Accordion("Diarization", open=False):
cb_diarize = gr.Checkbox(label="Enable Diarization")
Expand Down Expand Up @@ -152,7 +151,6 @@ def launch(self):
min_speech_duration_ms=nb_min_speech_duration_ms,
max_speech_duration_s=nb_max_speech_duration_s,
min_silence_duration_ms=nb_min_silence_duration_ms,
window_size_sample=nb_window_size_sample,
speech_pad_ms=nb_speech_pad_ms,
chunk_length_s=nb_chunk_length_s,
batch_size=nb_batch_size,
Expand Down Expand Up @@ -203,7 +201,6 @@ def launch(self):
nb_min_speech_duration_ms = gr.Number(label="Minimum Speech Duration (ms)", precision=0, value=250)
nb_max_speech_duration_s = gr.Number(label="Maximum Speech Duration (s)", value=9999)
nb_min_silence_duration_ms = gr.Number(label="Minimum Silence Duration (ms)", precision=0, value=2000)
nb_window_size_sample = gr.Number(label="Window Size (samples)", precision=0, value=1024)
nb_speech_pad_ms = gr.Number(label="Speech Padding (ms)", precision=0, value=400)
with gr.Accordion("Diarization", open=False):
cb_diarize = gr.Checkbox(label="Enable Diarization")
Expand Down Expand Up @@ -241,7 +238,6 @@ def launch(self):
min_speech_duration_ms=nb_min_speech_duration_ms,
max_speech_duration_s=nb_max_speech_duration_s,
min_silence_duration_ms=nb_min_silence_duration_ms,
window_size_sample=nb_window_size_sample,
speech_pad_ms=nb_speech_pad_ms,
chunk_length_s=nb_chunk_length_s,
batch_size=nb_batch_size,
Expand Down Expand Up @@ -284,7 +280,6 @@ def launch(self):
nb_min_speech_duration_ms = gr.Number(label="Minimum Speech Duration (ms)", precision=0, value=250)
nb_max_speech_duration_s = gr.Number(label="Maximum Speech Duration (s)", value=9999)
nb_min_silence_duration_ms = gr.Number(label="Minimum Silence Duration (ms)", precision=0, value=2000)
nb_window_size_sample = gr.Number(label="Window Size (samples)", precision=0, value=1024)
nb_speech_pad_ms = gr.Number(label="Speech Padding (ms)", precision=0, value=400)
with gr.Accordion("Diarization", open=False):
cb_diarize = gr.Checkbox(label="Enable Diarization")
Expand Down Expand Up @@ -324,7 +319,6 @@ def launch(self):
min_speech_duration_ms=nb_min_speech_duration_ms,
max_speech_duration_s=nb_max_speech_duration_s,
min_silence_duration_ms=nb_min_silence_duration_ms,
window_size_sample=nb_window_size_sample,
speech_pad_ms=nb_speech_pad_ms,
chunk_length_s=nb_chunk_length_s,
batch_size=nb_batch_size,
Expand Down
27 changes: 14 additions & 13 deletions modules/vad/silero_vad.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from faster_whisper.vad import VadOptions
from faster_whisper.vad import VadOptions, get_vad_model
import numpy as np
from typing import BinaryIO, Union, List, Optional
import warnings
Expand All @@ -9,6 +9,8 @@
class SileroVAD:
def __init__(self):
self.sampling_rate = 16000
self.window_size_samples = 512
self.model = None

def run(self,
audio: Union[str, BinaryIO, np.ndarray],
Expand Down Expand Up @@ -54,8 +56,8 @@ def run(self,

return audio

@staticmethod
def get_speech_timestamps(
self,
audio: np.ndarray,
vad_options: Optional[VadOptions] = None,
progress: gr.Progress = gr.Progress(),
Expand All @@ -72,22 +74,19 @@ def get_speech_timestamps(
Returns:
List of dicts containing begin and end samples of each speech chunk.
"""

if self.model is None:
self.update_model()

if vad_options is None:
vad_options = VadOptions(**kwargs)

threshold = vad_options.threshold
min_speech_duration_ms = vad_options.min_speech_duration_ms
max_speech_duration_s = vad_options.max_speech_duration_s
min_silence_duration_ms = vad_options.min_silence_duration_ms
window_size_samples = vad_options.window_size_samples
window_size_samples = self.window_size_samples
speech_pad_ms = vad_options.speech_pad_ms

if window_size_samples not in [512, 1024, 1536]:
warnings.warn(
"Unusual window_size_samples! Supported window_size_samples:\n"
" - [512, 1024, 1536] for 16000 sampling_rate"
)

sampling_rate = 16000
min_speech_samples = sampling_rate * min_speech_duration_ms / 1000
speech_pad_samples = sampling_rate * speech_pad_ms / 1000
Expand All @@ -101,8 +100,7 @@ def get_speech_timestamps(

audio_length_samples = len(audio)

model = faster_whisper.vad.get_vad_model()
state = model.get_initial_state(batch_size=1)
state, context = self.model.get_initial_states(batch_size=1)

speech_probs = []
for current_start_sample in range(0, audio_length_samples, window_size_samples):
Expand All @@ -111,7 +109,7 @@ def get_speech_timestamps(
chunk = audio[current_start_sample: current_start_sample + window_size_samples]
if len(chunk) < window_size_samples:
chunk = np.pad(chunk, (0, int(window_size_samples - len(chunk))))
speech_prob, state = model(chunk, state, sampling_rate)
speech_prob, state, context = self.model(chunk, state, context, sampling_rate)
speech_probs.append(speech_prob)

triggered = False
Expand Down Expand Up @@ -207,6 +205,9 @@ def get_speech_timestamps(

return speeches

def update_model(self):
self.model = get_vad_model()

@staticmethod
def collect_chunks(audio: np.ndarray, chunks: List[dict]) -> np.ndarray:
"""Collects and concatenates audio chunks."""
Expand Down
1 change: 0 additions & 1 deletion modules/whisper/whisper_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,6 @@ def run(self,
min_speech_duration_ms=params.min_speech_duration_ms,
max_speech_duration_s=params.max_speech_duration_s,
min_silence_duration_ms=params.min_silence_duration_ms,
window_size_samples=params.window_size_samples,
speech_pad_ms=params.speech_pad_ms
)
self.vad.run(
Expand Down
22 changes: 7 additions & 15 deletions modules/whisper/whisper_parameter.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@ class WhisperParameters:
min_speech_duration_ms: gr.Number
max_speech_duration_s: gr.Number
min_silence_duration_ms: gr.Number
window_size_sample: gr.Number
speech_pad_ms: gr.Number
chunk_length_s: gr.Number
batch_size: gr.Number
Expand Down Expand Up @@ -111,11 +110,6 @@ class WhisperParameters:
This parameter is related with Silero VAD. In the end of each speech chunk wait for min_silence_duration_ms
before separating it
window_size_samples: gr.Number
This parameter is related with Silero VAD. Audio chunks of window_size_samples size are fed to the silero VAD model.
WARNING! Silero VAD models were trained using 512, 1024, 1536 samples for 16000 sample rate.
Values other than these may affect model performance!!
speech_pad_ms: gr.Number
This parameter is related with Silero VAD. Final speech chunks are padded by speech_pad_ms each side
Expand Down Expand Up @@ -178,13 +172,12 @@ def as_value(*args) -> 'WhisperValues':
min_speech_duration_ms=args[15],
max_speech_duration_s=args[16],
min_silence_duration_ms=args[17],
window_size_samples=args[18],
speech_pad_ms=args[19],
chunk_length_s=args[20],
batch_size=args[21],
is_diarize=args[22],
hf_token=args[23],
diarization_device=args[24]
speech_pad_ms=args[18],
chunk_length_s=args[19],
batch_size=args[20],
is_diarize=args[21],
hf_token=args[22],
diarization_device=args[23]
)


Expand All @@ -208,7 +201,6 @@ class WhisperValues:
min_speech_duration_ms: int
max_speech_duration_s: float
min_silence_duration_ms: int
window_size_samples: int
speech_pad_ms: int
chunk_length_s: int
batch_size: int
Expand All @@ -217,4 +209,4 @@ class WhisperValues:
diarization_device: str
"""
A data class to use Whisper parameters.
"""
"""
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
--extra-index-url https://download.pytorch.org/whl/cu121
torch
git+https://github.com/jhj0517/jhj0517-whisper.git
faster-whisper==1.0.2
faster-whisper==1.0.3
transformers
gradio==4.29.0
pytube
Expand Down

0 comments on commit 6a24751

Please sign in to comment.