Skip to content

Commit

Permalink
Merge pull request #217 from jhj0517/feature/use-factory-pattern
Browse files Browse the repository at this point in the history
Refactor to factory pattern
  • Loading branch information
jhj0517 authored Jul 16, 2024
2 parents 48382b6 + 2db409c commit 825f362
Show file tree
Hide file tree
Showing 2 changed files with 67 additions and 35 deletions.
42 changes: 7 additions & 35 deletions app.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import argparse
import gradio as gr

from modules.whisper.whisper_Inference import WhisperInference
from modules.whisper.whisper_factory import WhisperFactory
from modules.whisper.faster_whisper_inference import FasterWhisperInference
from modules.whisper.insanely_fast_whisper_inference import InsanelyFastWhisperInference
from modules.translation.nllb_inference import NLLBInference
Expand All @@ -16,7 +16,12 @@ class App:
def __init__(self, args):
self.args = args
self.app = gr.Blocks(css=CSS, theme=self.args.theme)
self.whisper_inf = self.init_whisper()
self.whisper_inf = WhisperFactory.create_whisper_inference(
whisper_type=self.args.whisper_type,
model_dir=self.args.faster_whisper_model_dir,
output_dir=self.args.output_dir,
args=self.args
)
print(f"Use \"{self.args.whisper_type}\" implementation")
print(f"Device \"{self.whisper_inf.device}\" is detected")
self.nllb_inf = NLLBInference(
Expand All @@ -27,39 +32,6 @@ def __init__(self, args):
output_dir=os.path.join(self.args.output_dir, "translations")
)

def init_whisper(self):
# Temporal fix of the issue : https://github.com/jhj0517/Whisper-WebUI/issues/144
os.environ['KMP_DUPLICATE_LIB_OK'] = 'True'

whisper_type = self.args.whisper_type.lower().strip()

if whisper_type in ["faster_whisper", "faster-whisper", "fasterwhisper"]:
whisper_inf = FasterWhisperInference(
model_dir=self.args.faster_whisper_model_dir,
output_dir=self.args.output_dir,
args=self.args
)
elif whisper_type in ["whisper"]:
whisper_inf = WhisperInference(
model_dir=self.args.whisper_model_dir,
output_dir=self.args.output_dir,
args=self.args
)
elif whisper_type in ["insanely_fast_whisper", "insanely-fast-whisper", "insanelyfastwhisper",
"insanely_faster_whisper", "insanely-faster-whisper", "insanelyfasterwhisper"]:
whisper_inf = InsanelyFastWhisperInference(
model_dir=self.args.insanely_fast_whisper_model_dir,
output_dir=self.args.output_dir,
args=self.args
)
else:
whisper_inf = FasterWhisperInference(
model_dir=self.args.faster_whisper_model_dir,
output_dir=self.args.output_dir,
args=self.args
)
return whisper_inf

def create_whisper_parameters(self):
with gr.Row():
dd_model = gr.Dropdown(choices=self.whisper_inf.available_models, value="large-v2",
Expand Down
60 changes: 60 additions & 0 deletions modules/whisper/whisper_factory.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
from argparse import Namespace
import os

from modules.whisper.faster_whisper_inference import FasterWhisperInference
from modules.whisper.whisper_Inference import WhisperInference
from modules.whisper.insanely_fast_whisper_inference import InsanelyFastWhisperInference
from modules.whisper.whisper_base import WhisperBase


class WhisperFactory:
@staticmethod
def create_whisper_inference(
whisper_type: str,
model_dir: str,
output_dir: str,
args: Namespace
) -> "WhisperBase":
"""
Create a whisper inference class based on the provided whisper_type.
Parameters
----------
whisper_type: str
The repository name of whisper inference to use. Supported values are:
- "faster-whisper" from
- "whisper"
- insanely-fast-whisper", "insanely_fast_whisper", "insanelyfastwhisper",
"insanely-faster-whisper", "insanely_faster_whisper", "insanelyfasterwhisper"
model_dir: str
The directory path where the whisper model is located.
output_dir: str
The directory path where the output files will be saved.
args: Any
Additional arguments to be passed to the whisper inference object.
Returns
-------
WhisperBase
An instance of the appropriate whisper inference class based on the whisper_type.
"""
# Temporal fix of the bug : https://github.com/jhj0517/Whisper-WebUI/issues/144
os.environ['KMP_DUPLICATE_LIB_OK'] = 'True'

whisper_type = whisper_type.lower().strip()

faster_whisper_typos = ["faster_whisper", "faster-whisper", "fasterwhisper"]
whisper_typos = ["whisper"]
insanely_fast_whisper_typos = [
"insanely_fast_whisper", "insanely-fast-whisper", "insanelyfastwhisper",
"insanely_faster_whisper", "insanely-faster-whisper", "insanelyfasterwhisper"
]

if whisper_type in faster_whisper_typos:
return FasterWhisperInference(model_dir, output_dir, args)
elif whisper_type in whisper_typos:
return WhisperInference(model_dir, output_dir, args)
elif whisper_type in insanely_fast_whisper_typos:
return InsanelyFastWhisperInference(model_dir, output_dir, args)
else:
return FasterWhisperInference(model_dir, output_dir, args)

0 comments on commit 825f362

Please sign in to comment.