Skip to content

Commit

Permalink
torch2.0, remove compile for now, round to times to 3 decimal
Browse files Browse the repository at this point in the history
  • Loading branch information
m-bain committed May 4, 2023
1 parent d2116b9 commit 4e2ac4e
Show file tree
Hide file tree
Showing 6 changed files with 40 additions and 34 deletions.
12 changes: 6 additions & 6 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -61,23 +61,23 @@ This repository refines the timestamps of openAI's Whisper model via forced alig


<h2 align="left" id="setup">Setup ⚙️</h2>
Tested for PyTorch 0.11, Python 3.8 (use other versions at your own risk!)
Tested for PyTorch 2.0, Python 3.10 (use other versions at your own risk!)

GPU execution requires the NVIDIA libraries cuBLAS 11.x and cuDNN 8.x to be installed on the system. Please refer to the [CTranslate2 documentation](https://opennmt.net/CTranslate2/installation.html).


### 1. Create Python3.8 environment
### 1. Create Python3.10 environment

`conda create --name whisperx python=3.8`
`conda create --name whisperx python=3.10`

`conda activate whisperx`


### 2. Install PyTorch 0.11.0, e.g. for Linux and Windows:
### 2. Install PyTorch2.0, e.g. for Linux and Windows CUDA11.7:

`pip install torch==1.11.0+cu113 torchvision==0.12.0+cu113 torchaudio==0.11.0 --extra-index-url https://download.pytorch.org/whl/cu113`
`pip3 install torch torchvision torchaudio`

See other methods [here.](https://pytorch.org/get-started/previous-versions/#wheel-4)
See other methods [here.](https://pytorch.org/get-started/locally/)

### 3. Install this repo

Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
setup(
name="whisperx",
py_modules=["whisperx"],
version="3.0.0",
version="3.0.2",
description="Time-Accurate Automatic Speech Recognition using Whisper.",
readme="README.md",
python_requires=">=3.8",
Expand Down
4 changes: 2 additions & 2 deletions whisperx/alignment.py
Original file line number Diff line number Diff line change
Expand Up @@ -268,8 +268,8 @@ def align(
start, end, score = None, None, None
if cdx in clean_cdx:
char_seg = char_segments[clean_cdx.index(cdx)]
start = char_seg.start * ratio + t1
end = char_seg.end * ratio + t1
start = round(char_seg.start * ratio + t1, 3)
end = round(char_seg.end * ratio + t1, 3)
score = char_seg.score

char_segments_arr["char"].append(char)
Expand Down
5 changes: 1 addition & 4 deletions whisperx/asr.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,9 +181,6 @@ def _sanitize_parameters(self, **kwargs):

def preprocess(self, audio):
audio = audio['inputs']
if isinstance(audio, np.ndarray):
audio = torch.from_numpy(audio)

features = log_mel_spectrogram(audio, padding=N_SAMPLES - audio.shape[0])
return {'inputs': features}

Expand Down Expand Up @@ -256,7 +253,7 @@ def data(audio, segments):
def detect_language(self, audio: np.ndarray):
if audio.shape[0] < N_SAMPLES:
print("Warning: audio is shorter than 30s, language detection may be inaccurate.")
segment = log_mel_spectrogram(torch.from_numpy(audio[:N_SAMPLES]),
segment = log_mel_spectrogram(audio[: N_SAMPLES],
padding=0 if audio.shape[0] >= N_SAMPLES else N_SAMPLES - audio.shape[0])
encoder_output = self.model.encode(segment)
results = self.model.model.detect_language(encoder_output)
Expand Down
44 changes: 29 additions & 15 deletions whisperx/audio.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,12 +22,6 @@
FRAMES_PER_SECOND = exact_div(SAMPLE_RATE, HOP_LENGTH) # 10ms per audio frame
TOKENS_PER_SECOND = exact_div(SAMPLE_RATE, N_SAMPLES_PER_TOKEN) # 20ms per audio token

with np.load(
os.path.join(os.path.dirname(__file__), "assets", "mel_filters.npz")
) as f:
MEL_FILTERS = torch.from_numpy(f[f"mel_{80}"])



def load_audio(file: str, sr: int = SAMPLE_RATE):
"""
Expand Down Expand Up @@ -85,9 +79,27 @@ def pad_or_trim(array, length: int = N_SAMPLES, *, axis: int = -1):
return array


@torch.compile(fullgraph=True)
@lru_cache(maxsize=None)
def mel_filters(device, n_mels: int = N_MELS) -> torch.Tensor:
"""
load the mel filterbank matrix for projecting STFT into a Mel spectrogram.
Allows decoupling librosa dependency; saved using:
np.savez_compressed(
"mel_filters.npz",
mel_80=librosa.filters.mel(sr=16000, n_fft=400, n_mels=80),
)
"""
assert n_mels == 80, f"Unsupported n_mels: {n_mels}"
with np.load(
os.path.join(os.path.dirname(__file__), "assets", "mel_filters.npz")
) as f:
return torch.from_numpy(f[f"mel_{n_mels}"]).to(device)


def log_mel_spectrogram(
audio: torch.Tensor,
audio: Union[str, np.ndarray, torch.Tensor],
n_mels: int = N_MELS,
padding: int = 0,
device: Optional[Union[str, torch.device]] = None,
):
Expand All @@ -96,7 +108,7 @@ def log_mel_spectrogram(
Parameters
----------
audio: torch.Tensor, shape = (*)
audio: Union[str, np.ndarray, torch.Tensor], shape = (*)
The path to audio or either a NumPy array or Tensor containing the audio waveform in 16 kHz
n_mels: int
Expand All @@ -113,19 +125,21 @@ def log_mel_spectrogram(
torch.Tensor, shape = (80, n_frames)
A Tensor that contains the Mel spectrogram
"""
global MEL_FILTERS
if not torch.is_tensor(audio):
if isinstance(audio, str):
audio = load_audio(audio)
audio = torch.from_numpy(audio)

if device is not None:
audio = audio.to(device)
if padding > 0:
audio = F.pad(audio, (0, padding))
window = torch.hann_window(N_FFT).to(audio.device)
stft = torch.stft(audio, N_FFT, HOP_LENGTH, window=window, return_complex=False)
# Square the real and imaginary components and sum them together, similar to torch.abs() on complex tensors
magnitudes = (stft[:, :-1, :] ** 2).sum(dim=-1)
stft = torch.stft(audio, N_FFT, HOP_LENGTH, window=window, return_complex=True)
magnitudes = stft[..., :-1].abs() ** 2

MEL_FILTERS = MEL_FILTERS.to(audio.device)
mel_spec = MEL_FILTERS @ magnitudes
filters = mel_filters(audio.device, n_mels)
mel_spec = filters @ magnitudes

log_spec = torch.clamp(mel_spec, min=1e-10).log10()
log_spec = torch.maximum(log_spec, log_spec.max() - 8.0)
Expand Down
7 changes: 1 addition & 6 deletions whisperx/transcribe.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,6 @@ def cli():

parser.add_argument("--hf_token", type=str, default=None, help="Hugging Face Access Token to access PyAnnote gated models")
# parser.add_argument("--model_flush", action="store_true", help="Flush memory from each model after use, reduces GPU requirement but slower processing >1 audio file.")
parser.add_argument("--tmp_dir", default=None, help="Temporary directory to write audio file if input if not .wav format (only for VAD).")
# fmt: on

args = parser.parse_args().__dict__
Expand All @@ -86,10 +85,6 @@ def cli():
# model_flush: bool = args.pop("model_flush")
os.makedirs(output_dir, exist_ok=True)

tmp_dir: str = args.pop("tmp_dir")
if tmp_dir is not None:
os.makedirs(tmp_dir, exist_ok=True)

align_model: str = args.pop("align_model")
interpolate_method: str = args.pop("interpolate_method")
no_align: bool = args.pop("no_align")
Expand Down Expand Up @@ -195,7 +190,7 @@ def cli():
tmp_results = results
print(">>Performing diarization...")
results = []
diarize_model = DiarizationPipeline(use_auth_token=hf_token, device=device)
diarize_model = DiarizationPipeline(use_auth_token=hf_token)
for result, input_audio_path in tmp_results:
diarize_segments = diarize_model(input_audio_path, min_speakers=min_speakers, max_speakers=max_speakers)
results_segments, word_segments = assign_word_speakers(diarize_segments, result["segments"])
Expand Down

0 comments on commit 4e2ac4e

Please sign in to comment.