From 2d77c9dd42e5e2fe419b2a34a3ac094d8ef0c92b Mon Sep 17 00:00:00 2001 From: BBC-Esq Date: Sun, 3 Nov 2024 14:36:07 -0500 Subject: [PATCH] Update feature_extractor.py --- faster_whisper/feature_extractor.py | 73 +++++++++-------------------- 1 file changed, 21 insertions(+), 52 deletions(-) diff --git a/faster_whisper/feature_extractor.py b/faster_whisper/feature_extractor.py index 6371d5ef..d945757c 100644 --- a/faster_whisper/feature_extractor.py +++ b/faster_whisper/feature_extractor.py @@ -1,4 +1,6 @@ import torch +import numpy as np +import os # Adapted from https://github.com/huggingface/transformers/blob/main/src/transformers/models/whisper/feature_extraction_whisper.py # noqa: E501 @@ -11,6 +13,7 @@ def __init__( hop_length=160, chunk_length=30, n_fft=400, + mel_filter_path: str = "path/to/mel_filters.npz", ): if device == "auto": self.device = "cuda" if torch.cuda.is_available() else "cpu" @@ -23,57 +26,23 @@ def __init__( self.nb_max_frames = self.n_samples // hop_length self.time_per_frame = hop_length / sampling_rate self.sampling_rate = sampling_rate - self.mel_filters = self.get_mel_filters( - sampling_rate, n_fft, n_mels=feature_size - ) - - @staticmethod - def get_mel_filters(sr, n_fft, n_mels=128): - """ - Implementation of librosa.filters.mel in Pytorch - """ - # Initialize the weights - n_mels = int(n_mels) - - # Center freqs of each FFT bin - fftfreqs = torch.fft.rfftfreq(n=n_fft, d=1.0 / sr) - - # 'Center freqs' of mel bands - uniformly spaced between limits - min_mel = 0.0 - max_mel = 45.245640471924965 - - mels = torch.linspace(min_mel, max_mel, n_mels + 2) - - # Fill in the linear scale - f_min = 0.0 - f_sp = 200.0 / 3 - freqs = f_min + f_sp * mels - - # And now the nonlinear scale - min_log_hz = 1000.0 # beginning of log region (Hz) - min_log_mel = (min_log_hz - f_min) / f_sp # same (Mels) - logstep = torch.log(torch.tensor(6.4)) / 27.0 # step size for log region - - # If we have vector data, vectorize - log_t = mels >= min_log_mel - freqs[log_t] = min_log_hz * torch.exp(logstep * (mels[log_t] - min_log_mel)) - - mel_f = freqs - - fdiff = torch.diff(mel_f) - ramps = mel_f.view(-1, 1) - fftfreqs.view(1, -1) - - lower = -ramps[:-2] / fdiff[:-1].unsqueeze(1) - upper = ramps[2:] / fdiff[1:].unsqueeze(1) - - # Intersect them with each other and zero, vectorized across all i - weights = torch.maximum(torch.zeros_like(lower), torch.minimum(lower, upper)) - - # Slaney-style mel is scaled to be approx constant energy per channel - enorm = 2.0 / (mel_f[2 : n_mels + 2] - mel_f[:n_mels]) - weights *= enorm.unsqueeze(1) - - return weights + self.mel_filters = self.load_mel_filters(mel_filter_path, feature_size) + + def load_mel_filters(self, filepath: str, n_mels: int): + if not os.path.exists(filepath): + raise FileNotFoundError(f"Mel filter file not found at: {filepath}") + + mel_data = np.load(filepath) + key = f"mel_{n_mels}" + if key not in mel_data: + available_keys = ', '.join(mel_data.keys()) + raise KeyError( + f"Key '{key}' not found in mel_filters.npz. Available keys: {available_keys}" + ) + + mel_filters_np = mel_data[key] + mel_filters = torch.from_numpy(mel_filters_np).float().to(self.device) + return mel_filters def __call__(self, waveform, padding=True, chunk_length=None, to_cpu=False): """ @@ -103,7 +72,7 @@ def __call__(self, waveform, padding=True, chunk_length=None, to_cpu=False): ) magnitudes = stft[..., :-1].abs() ** 2 - mel_spec = self.mel_filters.to(waveform.device) @ magnitudes + mel_spec = self.mel_filters @ magnitudes log_spec = torch.clamp(mel_spec, min=1e-10).log10() log_spec = torch.maximum(log_spec, log_spec.max() - 8.0)