Skip to content

Commit

Permalink
fix model download and vad fail
Browse files Browse the repository at this point in the history
  • Loading branch information
arda-argmax committed Dec 5, 2024
1 parent e56cd0a commit 4e3610e
Showing 1 changed file with 27 additions and 7 deletions.
34 changes: 27 additions & 7 deletions whisperkit/pipelines.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from typing import Optional

from argmaxtools.utils import _maybe_git_clone, get_logger
from huggingface_hub import snapshot_download

from whisperkit import _constants

Expand Down Expand Up @@ -132,8 +133,29 @@ def clone_models(self):
""" Download WhisperKit model files from Hugging Face Hub
(only the files needed for `self.whisper_version`)
"""
self.models_dir = os.path.join(self.repo_dir, "models") # dummy
self.results_dir = os.path.join(self.repo_dir, "results")
self.models_dir = os.path.join(
self.repo_dir, "Models", self.whisper_version.replace("/", "_"))

os.makedirs(self.models_dir, exist_ok=True)

snapshot_download(
repo_id=_constants.MODEL_REPO_ID,
allow_patterns=f"{self.whisper_version.replace('/', '_')}/*",
revision=self.model_commit_hash,
local_dir=os.path.dirname(self.models_dir),
local_dir_use_symlinks=True
)

if self.model_commit_hash is None:
self.model_commit_hash = subprocess.run(
f"git ls-remote git@hf.co:{_constants.MODEL_REPO_ID}",
shell=True, stdout=subprocess.PIPE
).stdout.decode("utf-8").rsplit("\n")[0].rsplit("\t")[0]
logger.info(
"--model-commit-hash not specified, "
f"imputing with HEAD={self.model_commit_hash}")

self.results_dir = os.path.join(self.models_dir, "results")
os.makedirs(self.results_dir, exist_ok=True)

def transcribe(self, audio_file_path: str, forced_language: Optional[str] = None) -> str:
Expand All @@ -143,11 +165,10 @@ def transcribe(self, audio_file_path: str, forced_language: Optional[str] = None
self.cli_path,
"transcribe",
"--audio-path", audio_file_path,
"--model-prefix", self.whisper_version.rsplit("/")[0],
"--model", self.whisper_version.rsplit("/")[1],
"--model-path", self.models_dir,
"--text-decoder-compute-units", self._text_decoder_compute_units,
"--audio-encoder-compute-units", self._audio_encoder_compute_units,
"--chunking-strategy", "vad",
# "--chunking-strategy", "vad",
"--report-path", self.results_dir, "--report",
"--word-timestamps" if self._word_timestamps else "",
"" if forced_language is None else f"--use-prefill-prompt --language {forced_language}",
Expand Down Expand Up @@ -182,8 +203,7 @@ def transcribe_folder(self, audio_folder_path: str, forced_language: Optional[st
self.cli_path,
"transcribe",
"--audio-folder", audio_folder_path,
"--model-prefix", self.whisper_version.rsplit("/")[0],
"--model", self.whisper_version.rsplit("/")[1],
"--model-path", self.models_dir,
"--text-decoder-compute-units", self._text_decoder_compute_units,
"--audio-encoder-compute-units", self._audio_encoder_compute_units,
"--report-path", self.results_dir, "--report",
Expand Down

0 comments on commit 4e3610e

Please sign in to comment.