Skip to content

Commit

Permalink
fix: resolve pre-commit linters warnings
Browse files Browse the repository at this point in the history
  • Loading branch information
andreabak committed Mar 16, 2024
1 parent 470d76c commit 431ca2f
Showing 1 changed file with 21 additions and 18 deletions.
39 changes: 21 additions & 18 deletions whispersubs/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from itertools import chain
from pathlib import Path
from time import monotonic_ns
from typing import TYPE_CHECKING, ContextManager, Iterable, Iterator, Sequence
from typing import TYPE_CHECKING, Any, ClassVar, ContextManager, Iterable, Iterator, Sequence

import av
import blessed
Expand All @@ -29,10 +29,10 @@
import nvidia.cublas.lib
import nvidia.cudnn.lib
except ImportError:
warnings.warn("NVIDIA CUDA libraries not found, inference will run on CPU only.")
warnings.warn("NVIDIA CUDA libraries not found, inference will run on CPU only.", stacklevel=1)
else:
_cublas_libs = os.path.dirname(nvidia.cublas.lib.__file__)
_cudnn_libs = os.path.dirname(nvidia.cudnn.lib.__file__)
_cublas_libs = os.path.dirname(nvidia.cublas.lib.__file__) # noqa: PTH120
_cudnn_libs = os.path.dirname(nvidia.cudnn.lib.__file__) # noqa: PTH120
_ld_library_path = os.environ.get("LD_LIBRARY_PATH", "")
if _cudnn_libs not in _ld_library_path or _cublas_libs not in _ld_library_path:
os.environ["LD_LIBRARY_PATH"] = ":".join([
Expand All @@ -41,7 +41,7 @@
*_ld_library_path.split(":"),
])
# restart the script with the updated environment
os.execve(sys.executable, [sys.executable] + sys.argv, os.environ)
os.execve(sys.executable, [sys.executable, *sys.argv], os.environ) # noqa: S606


_logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -82,7 +82,7 @@ def extract_audio(file: Path) -> Iterator[tuple[NDArray[np.float32], float]]:
yield frame_np[0], duration


def transcribe_segments(
def transcribe_segments( # noqa: PLR0914
audio_chunks: Iterable[NDArray[np.float32]],
*,
model_size: str = DEFAULT_WHISPER_MODEL_SIZE,
Expand All @@ -109,7 +109,7 @@ def transcribe_segments(
if progress:

@contextlib.contextmanager
def pbar_cm():
def pbar_cm() -> Iterator[enlighten.Counter]:
manager = enlighten.get_manager(set_scroll=False)
counter = manager.counter(
total=total_duration,
Expand All @@ -136,11 +136,12 @@ def pbar_cm():
_logger.log(logging.DEBUG if progress else logging.INFO, "Transcribing audio stream")
pbar: enlighten.Counter | None
with pbar_context as pbar:
audio_chunks_iterator: Iterator[NDArray[np.float32]] = iter(audio_chunks)
last_chunk: bool = False
start_t: int = monotonic_ns()
while True:
try:
chunk = next(audio_chunks)
chunk = next(audio_chunks_iterator)
except StopIteration:
last_chunk = True
else:
Expand Down Expand Up @@ -211,9 +212,9 @@ def create_subtitles(
"""
Create a subtitle file from the given segments.
"""
segments_count = 0
index = 0
for segments_count, segment in enumerate(segments):
segments_count: int = 0
index: int = 0
for segments_count, segment in enumerate(segments): # noqa: B007
segment_start = segment.start
segment_end = segment.end
segment_duration = segment_end - segment_start
Expand Down Expand Up @@ -321,26 +322,28 @@ class LogFormatter(logging.Formatter):
Custom formatter that adds colors using blessed and uses relative timestamps.
"""

COLOR_MAP = {
COLOR_MAP: ClassVar[dict[int, str]] = {
logging.DEBUG: "blue",
logging.INFO: "green",
logging.WARNING: "yellow",
logging.ERROR: "red",
logging.CRITICAL: "bold_red",
}
LEVEL_TAG_MAP = {
LEVEL_TAG_MAP: ClassVar[dict[int, str]] = {
logging.DEBUG: "[D]",
logging.INFO: "[I]",
logging.WARNING: "[W]",
logging.ERROR: "[E]",
logging.CRITICAL: "[C]",
}
FORMAT = "%(reltime)s %(color)s%(leveltag)s%(color_reset)s%(condname)s %(message)s"
FORMAT: ClassVar[str] = (
"%(reltime)s %(color)s%(leveltag)s%(color_reset)s%(condname)s %(message)s"
)

def __init__(self, *args, **kwargs):
super().__init__(fmt=self.FORMAT, *args, **kwargs)
def __init__(self, *args: Any, **kwargs: Any) -> None:
super().__init__(self.FORMAT, *args, **kwargs)

def format(self, record):
def format(self, record: logging.LogRecord) -> str:
record.condname = f" ({record.name})" if record.name != "__main__" else ""
record.leveltag = self.LEVEL_TAG_MAP.get(record.levelno, "[?]")
record.color = getattr(term, self.COLOR_MAP.get(record.levelno, "white"))
Expand All @@ -349,7 +352,7 @@ def format(self, record):
return super().format(record)


def main():
def main() -> None:
logging.basicConfig(level=logging.INFO)

parser = argparse.ArgumentParser(description="Transcribe audio/video files into subtitles")
Expand Down

0 comments on commit 431ca2f

Please sign in to comment.