From 4e2ac4e4e9da9b2b392c13f68f0718e86d293fb0 Mon Sep 17 00:00:00 2001 From: Max Bain Date: Thu, 4 May 2023 20:38:13 +0100 Subject: [PATCH] torch2.0, remove compile for now, round to times to 3 decimal --- README.md | 12 ++++++------ setup.py | 2 +- whisperx/alignment.py | 4 ++-- whisperx/asr.py | 5 +---- whisperx/audio.py | 44 ++++++++++++++++++++++++++++-------------- whisperx/transcribe.py | 7 +------ 6 files changed, 40 insertions(+), 34 deletions(-) diff --git a/README.md b/README.md index c9951ce3..1f41bb96 100644 --- a/README.md +++ b/README.md @@ -61,23 +61,23 @@ This repository refines the timestamps of openAI's Whisper model via forced alig

Setup ⚙️

-Tested for PyTorch 0.11, Python 3.8 (use other versions at your own risk!) +Tested for PyTorch 2.0, Python 3.10 (use other versions at your own risk!) GPU execution requires the NVIDIA libraries cuBLAS 11.x and cuDNN 8.x to be installed on the system. Please refer to the [CTranslate2 documentation](https://opennmt.net/CTranslate2/installation.html). -### 1. Create Python3.8 environment +### 1. Create Python3.10 environment -`conda create --name whisperx python=3.8` +`conda create --name whisperx python=3.10` `conda activate whisperx` -### 2. Install PyTorch 0.11.0, e.g. for Linux and Windows: +### 2. Install PyTorch2.0, e.g. for Linux and Windows CUDA11.7: -`pip install torch==1.11.0+cu113 torchvision==0.12.0+cu113 torchaudio==0.11.0 --extra-index-url https://download.pytorch.org/whl/cu113` +`pip3 install torch torchvision torchaudio` -See other methods [here.](https://pytorch.org/get-started/previous-versions/#wheel-4) +See other methods [here.](https://pytorch.org/get-started/locally/) ### 3. Install this repo diff --git a/setup.py b/setup.py index 859d1717..66f22cd8 100644 --- a/setup.py +++ b/setup.py @@ -6,7 +6,7 @@ setup( name="whisperx", py_modules=["whisperx"], - version="3.0.0", + version="3.0.2", description="Time-Accurate Automatic Speech Recognition using Whisper.", readme="README.md", python_requires=">=3.8", diff --git a/whisperx/alignment.py b/whisperx/alignment.py index 2ae77f36..e63e6e5c 100644 --- a/whisperx/alignment.py +++ b/whisperx/alignment.py @@ -268,8 +268,8 @@ def align( start, end, score = None, None, None if cdx in clean_cdx: char_seg = char_segments[clean_cdx.index(cdx)] - start = char_seg.start * ratio + t1 - end = char_seg.end * ratio + t1 + start = round(char_seg.start * ratio + t1, 3) + end = round(char_seg.end * ratio + t1, 3) score = char_seg.score char_segments_arr["char"].append(char) diff --git a/whisperx/asr.py b/whisperx/asr.py index 1ca12ce9..ba6220bd 100644 --- a/whisperx/asr.py +++ b/whisperx/asr.py @@ -181,9 +181,6 @@ def _sanitize_parameters(self, **kwargs): def preprocess(self, audio): audio = audio['inputs'] - if isinstance(audio, np.ndarray): - audio = torch.from_numpy(audio) - features = log_mel_spectrogram(audio, padding=N_SAMPLES - audio.shape[0]) return {'inputs': features} @@ -256,7 +253,7 @@ def data(audio, segments): def detect_language(self, audio: np.ndarray): if audio.shape[0] < N_SAMPLES: print("Warning: audio is shorter than 30s, language detection may be inaccurate.") - segment = log_mel_spectrogram(torch.from_numpy(audio[:N_SAMPLES]), + segment = log_mel_spectrogram(audio[: N_SAMPLES], padding=0 if audio.shape[0] >= N_SAMPLES else N_SAMPLES - audio.shape[0]) encoder_output = self.model.encode(segment) results = self.model.model.detect_language(encoder_output) diff --git a/whisperx/audio.py b/whisperx/audio.py index 8ac06748..513ab7c9 100644 --- a/whisperx/audio.py +++ b/whisperx/audio.py @@ -22,12 +22,6 @@ FRAMES_PER_SECOND = exact_div(SAMPLE_RATE, HOP_LENGTH) # 10ms per audio frame TOKENS_PER_SECOND = exact_div(SAMPLE_RATE, N_SAMPLES_PER_TOKEN) # 20ms per audio token -with np.load( - os.path.join(os.path.dirname(__file__), "assets", "mel_filters.npz") -) as f: - MEL_FILTERS = torch.from_numpy(f[f"mel_{80}"]) - - def load_audio(file: str, sr: int = SAMPLE_RATE): """ @@ -85,9 +79,27 @@ def pad_or_trim(array, length: int = N_SAMPLES, *, axis: int = -1): return array -@torch.compile(fullgraph=True) +@lru_cache(maxsize=None) +def mel_filters(device, n_mels: int = N_MELS) -> torch.Tensor: + """ + load the mel filterbank matrix for projecting STFT into a Mel spectrogram. + Allows decoupling librosa dependency; saved using: + + np.savez_compressed( + "mel_filters.npz", + mel_80=librosa.filters.mel(sr=16000, n_fft=400, n_mels=80), + ) + """ + assert n_mels == 80, f"Unsupported n_mels: {n_mels}" + with np.load( + os.path.join(os.path.dirname(__file__), "assets", "mel_filters.npz") + ) as f: + return torch.from_numpy(f[f"mel_{n_mels}"]).to(device) + + def log_mel_spectrogram( - audio: torch.Tensor, + audio: Union[str, np.ndarray, torch.Tensor], + n_mels: int = N_MELS, padding: int = 0, device: Optional[Union[str, torch.device]] = None, ): @@ -96,7 +108,7 @@ def log_mel_spectrogram( Parameters ---------- - audio: torch.Tensor, shape = (*) + audio: Union[str, np.ndarray, torch.Tensor], shape = (*) The path to audio or either a NumPy array or Tensor containing the audio waveform in 16 kHz n_mels: int @@ -113,19 +125,21 @@ def log_mel_spectrogram( torch.Tensor, shape = (80, n_frames) A Tensor that contains the Mel spectrogram """ - global MEL_FILTERS + if not torch.is_tensor(audio): + if isinstance(audio, str): + audio = load_audio(audio) + audio = torch.from_numpy(audio) if device is not None: audio = audio.to(device) if padding > 0: audio = F.pad(audio, (0, padding)) window = torch.hann_window(N_FFT).to(audio.device) - stft = torch.stft(audio, N_FFT, HOP_LENGTH, window=window, return_complex=False) - # Square the real and imaginary components and sum them together, similar to torch.abs() on complex tensors - magnitudes = (stft[:, :-1, :] ** 2).sum(dim=-1) + stft = torch.stft(audio, N_FFT, HOP_LENGTH, window=window, return_complex=True) + magnitudes = stft[..., :-1].abs() ** 2 - MEL_FILTERS = MEL_FILTERS.to(audio.device) - mel_spec = MEL_FILTERS @ magnitudes + filters = mel_filters(audio.device, n_mels) + mel_spec = filters @ magnitudes log_spec = torch.clamp(mel_spec, min=1e-10).log10() log_spec = torch.maximum(log_spec, log_spec.max() - 8.0) diff --git a/whisperx/transcribe.py b/whisperx/transcribe.py index e284e83b..4b5a6641 100644 --- a/whisperx/transcribe.py +++ b/whisperx/transcribe.py @@ -72,7 +72,6 @@ def cli(): parser.add_argument("--hf_token", type=str, default=None, help="Hugging Face Access Token to access PyAnnote gated models") # parser.add_argument("--model_flush", action="store_true", help="Flush memory from each model after use, reduces GPU requirement but slower processing >1 audio file.") - parser.add_argument("--tmp_dir", default=None, help="Temporary directory to write audio file if input if not .wav format (only for VAD).") # fmt: on args = parser.parse_args().__dict__ @@ -86,10 +85,6 @@ def cli(): # model_flush: bool = args.pop("model_flush") os.makedirs(output_dir, exist_ok=True) - tmp_dir: str = args.pop("tmp_dir") - if tmp_dir is not None: - os.makedirs(tmp_dir, exist_ok=True) - align_model: str = args.pop("align_model") interpolate_method: str = args.pop("interpolate_method") no_align: bool = args.pop("no_align") @@ -195,7 +190,7 @@ def cli(): tmp_results = results print(">>Performing diarization...") results = [] - diarize_model = DiarizationPipeline(use_auth_token=hf_token, device=device) + diarize_model = DiarizationPipeline(use_auth_token=hf_token) for result, input_audio_path in tmp_results: diarize_segments = diarize_model(input_audio_path, min_speakers=min_speakers, max_speakers=max_speakers) results_segments, word_segments = assign_word_speakers(diarize_segments, result["segments"])