Skip to content

Commit

Permalink
replace NamedTuple with dataclass (#1105)
Browse files Browse the repository at this point in the history
* replace `NamedTuple` with `dataclass`

* add deprecation warnings
  • Loading branch information
MahmoudAshraf97 authored Nov 5, 2024
1 parent 814472f commit 203dddb
Show file tree
Hide file tree
Showing 2 changed files with 49 additions and 32 deletions.
75 changes: 45 additions & 30 deletions faster_whisper/transcribe.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,11 @@
import zlib

from collections import Counter, defaultdict
from dataclasses import asdict, dataclass
from inspect import signature
from math import ceil
from typing import BinaryIO, Iterable, List, NamedTuple, Optional, Tuple, Union
from typing import BinaryIO, Iterable, List, Optional, Tuple, Union
from warnings import warn

import ctranslate2
import numpy as np
Expand All @@ -30,14 +32,24 @@
)


class Word(NamedTuple):
@dataclass
class Word:
start: float
end: float
word: str
probability: float

def _asdict(self):
warn(
"Word._asdict() method is deprecated, use dataclasses.asdict(Word) instead",
DeprecationWarning,
2,
)
return asdict(self)


class Segment(NamedTuple):
@dataclass
class Segment:
id: int
seek: int
start: float
Expand All @@ -50,9 +62,18 @@ class Segment(NamedTuple):
words: Optional[List[Word]]
temperature: Optional[float] = 1.0

def _asdict(self):
warn(
"Segment._asdict() method is deprecated, use dataclasses.asdict(Segment) instead",
DeprecationWarning,
2,
)
return asdict(self)


# Added additional parameters for multilingual videos and fixes below
class TranscriptionOptions(NamedTuple):
@dataclass
class TranscriptionOptions:
beam_size: int
best_of: int
patience: float
Expand Down Expand Up @@ -83,7 +104,8 @@ class TranscriptionOptions(NamedTuple):
hotwords: Optional[str]


class TranscriptionInfo(NamedTuple):
@dataclass
class TranscriptionInfo:
language: str
language_probability: float
duration: float
Expand All @@ -108,7 +130,7 @@ class BatchedInferencePipeline:
def __init__(
self,
model,
options: Optional[NamedTuple] = None,
options: Optional[TranscriptionOptions] = None,
tokenizer=None,
language: Optional[str] = None,
):
Expand Down Expand Up @@ -473,7 +495,7 @@ def _batched_segments_generator(
results = self.forward(
features[i : i + batch_size],
chunks_metadata[i : i + batch_size],
**options._asdict(),
**asdict(options),
)

for result in results:
Expand Down Expand Up @@ -1043,16 +1065,15 @@ def generate_segments(
content_duration = float(content_frames * self.feature_extractor.time_per_frame)

if isinstance(options.clip_timestamps, str):
options = options._replace(
clip_timestamps=[
float(ts)
for ts in (
options.clip_timestamps.split(",")
if options.clip_timestamps
else []
)
]
)
options.clip_timestamps = [
float(ts)
for ts in (
options.clip_timestamps.split(",")
if options.clip_timestamps
else []
)
]

seek_points: List[int] = [
round(ts * self.frames_per_second) for ts in options.clip_timestamps
]
Expand Down Expand Up @@ -1999,23 +2020,17 @@ def restore_speech_timestamps(
# Ensure the word start and end times are resolved to the same chunk.
middle = (word.start + word.end) / 2
chunk_index = ts_map.get_chunk_index(middle)
word = word._replace(
start=ts_map.get_original_time(word.start, chunk_index),
end=ts_map.get_original_time(word.end, chunk_index),
)
word.start = ts_map.get_original_time(word.start, chunk_index)
word.end = ts_map.get_original_time(word.end, chunk_index)
words.append(word)

segment = segment._replace(
start=words[0].start,
end=words[-1].end,
words=words,
)
segment.start = words[0].start
segment.end = words[-1].end
segment.words = words

else:
segment = segment._replace(
start=ts_map.get_original_time(segment.start),
end=ts_map.get_original_time(segment.end),
)
segment.start = ts_map.get_original_time(segment.start)
segment.end = ts_map.get_original_time(segment.end)

yield segment

Expand Down
6 changes: 4 additions & 2 deletions faster_whisper/vad.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@
import functools
import os

from typing import Dict, List, NamedTuple, Optional, Tuple
from dataclasses import dataclass
from typing import Dict, List, Optional, Tuple

import numpy as np
import torch
Expand All @@ -11,7 +12,8 @@


# The code below is adapted from https://github.com/snakers4/silero-vad.
class VadOptions(NamedTuple):
@dataclass
class VadOptions:
"""VAD options.
Attributes:
Expand Down

0 comments on commit 203dddb

Please sign in to comment.