Skip to content

Commit

Permalink
Fixes #212 : workaround that disable SPD attention in latest version …
Browse files Browse the repository at this point in the history
…of openai-whisper (20240930) which prevents from accessing attention weights
  • Loading branch information
Jeronymous committed Oct 30, 2024
1 parent b6f035f commit ee35e7c
Showing 1 changed file with 19 additions and 3 deletions.
22 changes: 19 additions & 3 deletions whisper_timestamped/transcribe.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
__author__ = "Jérôme Louradour"
__credits__ = ["Jérôme Louradour"]
__license__ = "GPLv3"
__version__ = "1.15.4"
__version__ = "1.15.5"

# Set some environment variables
import os
Expand Down Expand Up @@ -46,6 +46,20 @@
AUDIO_TIME_PER_TOKEN = AUDIO_SAMPLES_PER_TOKEN / SAMPLE_RATE # 0.02 (sec)
SEGMENT_DURATION = N_FRAMES * HOP_LENGTH / SAMPLE_RATE # 30.0 (sec)

# Access attention in latest versions...
if whisper.__version__ >= "20240930":
from whisper.model import disable_sdpa
else:
from contextlib import contextmanager

# Dummy context manager that does nothing
@contextmanager
def disable_sdpa():
try:
yield
finally:
pass

# Logs
import logging
logger = logging.getLogger("whisper_timestamped")
Expand Down Expand Up @@ -885,7 +899,8 @@ def hook_output_logits(layer, ins, outs):
if compute_word_confidence or no_speech_threshold is not None:
all_hooks.append(model.decoder.ln.register_forward_hook(hook_output_logits))

transcription = model.transcribe(audio, **whisper_options)
with disable_sdpa():
transcription = model.transcribe(audio, **whisper_options)

finally:

Expand Down Expand Up @@ -1047,7 +1062,8 @@ def hook_output_logits(layer, ins, outs):

try:
model.alignment_heads = alignment_heads # Avoid exception "AttributeError: 'WhisperUntied' object has no attribute 'alignment_heads'. Did you mean: 'set_alignment_heads'?""
transcription = model.transcribe(audio, **whisper_options)
with disable_sdpa():
transcription = model.transcribe(audio, **whisper_options)
finally:
for hook in all_hooks:
hook.remove()
Expand Down

0 comments on commit ee35e7c

Please sign in to comment.