Skip to content

Commit

Permalink
Merge pull request #208 from jhj0517/fix/diarization-type
Browse files Browse the repository at this point in the history
Fix type hint in diarization function
  • Loading branch information
jhj0517 authored Jul 12, 2024
2 parents 79933ea + 072ec01 commit e4c9d55
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 10 deletions.
32 changes: 24 additions & 8 deletions modules/diarize/audio_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
import subprocess
from functools import lru_cache
from typing import Optional, Union
from scipy.io.wavfile import write
import tempfile

import numpy as np
import torch
Expand All @@ -24,32 +26,43 @@ def exact_div(x, y):
TOKENS_PER_SECOND = exact_div(SAMPLE_RATE, N_SAMPLES_PER_TOKEN) # 20ms per audio token


def load_audio(file: str, sr: int = SAMPLE_RATE):
def load_audio(file: Union[str, np.ndarray], sr: int = SAMPLE_RATE) -> np.ndarray:
"""
Open an audio file and read as mono waveform, resampling as necessary
Open an audio file or process a numpy array containing audio data as mono waveform, resampling as necessary.
Parameters
----------
file: str
The audio file to open
file: Union[str, np.ndarray]
The audio file to open or a numpy array containing the audio data.
sr: int
The sample rate to resample the audio if necessary
The sample rate to resample the audio if necessary.
Returns
-------
A NumPy array containing the audio waveform, in float32 dtype.
"""
if isinstance(file, np.ndarray):
if file.dtype != np.float32:
file = file.astype(np.float32)
if file.ndim > 1:
file = np.mean(file, axis=1)

temp_file = tempfile.NamedTemporaryFile(delete=False, suffix=".wav")
write(temp_file.name, SAMPLE_RATE, (file * 32768).astype(np.int16))
temp_file_path = temp_file.name
temp_file.close()
else:
temp_file_path = file

try:
# Launches a subprocess to decode audio while down-mixing and resampling as necessary.
# Requires the ffmpeg CLI to be installed.
cmd = [
"ffmpeg",
"-nostdin",
"-threads",
"0",
"-i",
file,
temp_file_path,
"-f",
"s16le",
"-ac",
Expand All @@ -63,6 +76,9 @@ def load_audio(file: str, sr: int = SAMPLE_RATE):
out = subprocess.run(cmd, capture_output=True, check=True).stdout
except subprocess.CalledProcessError as e:
raise RuntimeError(f"Failed to load audio: {e.stderr.decode()}") from e
finally:
if isinstance(file, np.ndarray):
os.remove(temp_file_path)

return np.frombuffer(out, np.int16).flatten().astype(np.float32) / 32768.0

Expand Down
5 changes: 3 additions & 2 deletions modules/diarize/diarizer.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import os
import torch
from typing import List
from typing import List, Union, BinaryIO
import numpy as np
import time
import logging

Expand All @@ -20,7 +21,7 @@ def __init__(self,
self.pipe = None

def run(self,
audio: str,
audio: Union[str, BinaryIO, np.ndarray],
transcribed_result: List[dict],
use_auth_token: str,
device: str
Expand Down

0 comments on commit e4c9d55

Please sign in to comment.