From 6cee2a2a9e0d7ff2c36ebf2e1f7842a55f36d111 Mon Sep 17 00:00:00 2001 From: jhj0517 <97279763+jhj0517@users.noreply.github.com> Date: Fri, 12 Jul 2024 16:42:59 +0900 Subject: [PATCH 1/2] fix type hint --- modules/diarize/audio_loader.py | 30 ++++++++++++++++++++++-------- modules/diarize/diarizer.py | 5 +++-- 2 files changed, 25 insertions(+), 10 deletions(-) diff --git a/modules/diarize/audio_loader.py b/modules/diarize/audio_loader.py index 6db3cc6..9575196 100644 --- a/modules/diarize/audio_loader.py +++ b/modules/diarize/audio_loader.py @@ -24,32 +24,43 @@ def exact_div(x, y): TOKENS_PER_SECOND = exact_div(SAMPLE_RATE, N_SAMPLES_PER_TOKEN) # 20ms per audio token -def load_audio(file: str, sr: int = SAMPLE_RATE): +def load_audio(file: Union[str, np.ndarray], sr: int = SAMPLE_RATE) -> np.ndarray: """ - Open an audio file and read as mono waveform, resampling as necessary + Open an audio file or process a numpy array containing audio data as mono waveform, resampling as necessary. Parameters ---------- - file: str - The audio file to open + file: Union[str, np.ndarray] + The audio file to open or a numpy array containing the audio data. sr: int - The sample rate to resample the audio if necessary + The sample rate to resample the audio if necessary. Returns ------- A NumPy array containing the audio waveform, in float32 dtype. """ + if isinstance(file, np.ndarray): + if file.dtype != np.float32: + file = file.astype(np.float32) + if file.ndim > 1: + file = np.mean(file, axis=1) + + temp_file = tempfile.NamedTemporaryFile(delete=False, suffix=".wav") + write(temp_file.name, SAMPLE_RATE, (file * 32768).astype(np.int16)) + temp_file_path = temp_file.name + temp_file.close() + else: + temp_file_path = file + try: - # Launches a subprocess to decode audio while down-mixing and resampling as necessary. - # Requires the ffmpeg CLI to be installed. cmd = [ "ffmpeg", "-nostdin", "-threads", "0", "-i", - file, + temp_file_path, "-f", "s16le", "-ac", @@ -63,6 +74,9 @@ def load_audio(file: str, sr: int = SAMPLE_RATE): out = subprocess.run(cmd, capture_output=True, check=True).stdout except subprocess.CalledProcessError as e: raise RuntimeError(f"Failed to load audio: {e.stderr.decode()}") from e + finally: + if isinstance(file, np.ndarray): + os.remove(temp_file_path) return np.frombuffer(out, np.int16).flatten().astype(np.float32) / 32768.0 diff --git a/modules/diarize/diarizer.py b/modules/diarize/diarizer.py index c12c235..592e1bf 100644 --- a/modules/diarize/diarizer.py +++ b/modules/diarize/diarizer.py @@ -1,6 +1,7 @@ import os import torch -from typing import List +from typing import List, Union, BinaryIO +import numpy as np import time import logging @@ -20,7 +21,7 @@ def __init__(self, self.pipe = None def run(self, - audio: str, + audio: Union[str, BinaryIO, np.ndarray], transcribed_result: List[dict], use_auth_token: str, device: str From 072ec01f7c73b0affcc75c0be937da19304d97a1 Mon Sep 17 00:00:00 2001 From: jhj0517 <97279763+jhj0517@users.noreply.github.com> Date: Fri, 12 Jul 2024 16:43:30 +0900 Subject: [PATCH 2/2] add imports --- modules/diarize/audio_loader.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/modules/diarize/audio_loader.py b/modules/diarize/audio_loader.py index 9575196..2efd421 100644 --- a/modules/diarize/audio_loader.py +++ b/modules/diarize/audio_loader.py @@ -2,6 +2,8 @@ import subprocess from functools import lru_cache from typing import Optional, Union +from scipy.io.wavfile import write +import tempfile import numpy as np import torch