Skip to content

Commit

Permalink
initial cleanup
Browse files Browse the repository at this point in the history
  • Loading branch information
MahmoudAshraf97 committed Nov 13, 2024
1 parent c2bf036 commit 705d2b3
Showing 1 changed file with 63 additions and 76 deletions.
139 changes: 63 additions & 76 deletions faster_whisper/transcribe.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,6 @@ def _asdict(self):
return asdict(self)


# Added additional parameters for multilingual videos and fixes below
@dataclass
class TranscriptionOptions:
beam_size: int
Expand Down Expand Up @@ -115,18 +114,7 @@ class TranscriptionInfo:
vad_options: VadOptions


# The code below is originally from HF pipeline and is used in whisper-x
# (https://github.com/m-bain/whisperX) and adapted for faster_whisper


class BatchedInferencePipeline:
"""
Huggingface Pipeline wrapper for WhisperModel.
Copyright (c) 2022, Max Bain
All rights reserved.
Modified by Mobius Labs GmbH
"""

def __init__(
self,
model,
Expand All @@ -140,9 +128,9 @@ def __init__(
self.preset_language = language
self.last_speech_timestamp = 0.0

def forward(self, features, chunks_metadata, **forward_params):
def forward(self, features, chunks_metadata, options):
encoder_output, outputs = self.model.generate_segment_batched(
features, self.tokenizer, forward_params
features, self.tokenizer, options
)

segmented_outputs = []
Expand Down Expand Up @@ -179,14 +167,14 @@ def forward(self, features, chunks_metadata, **forward_params):
for subsegment in subsegments
]
)
if forward_params["word_timestamps"]:
if options.word_timestamps:
self.last_speech_timestamp = self.model.add_word_timestamps(
segmented_outputs,
self.tokenizer,
encoder_output,
segment_sizes,
forward_params["prepend_punctuations"],
forward_params["append_punctuations"],
options.prepend_punctuations,
options.append_punctuations,
self.last_speech_timestamp,
)

Expand Down Expand Up @@ -263,7 +251,7 @@ def transcribe(
max_new_tokens: Optional[int] = None,
chunk_length: Optional[int] = None,
clip_timestamps: Optional[List[dict]] = None,
batch_size: int = 16,
batch_size: int = 8,
hotwords: Optional[str] = None,
) -> Tuple[Iterable[Segment], TranscriptionInfo]:
"""transcribe audio in chunks in batched fashion and return with language info.
Expand All @@ -282,22 +270,11 @@ def transcribe(
repetition_penalty: Penalty applied to the score of previously generated tokens
(set > 1 to penalize).
no_repeat_ngram_size: Prevent repetitions of ngrams with this size (set 0 to disable).
temperature: Temperature for sampling. It can be a tuple of temperatures,
which will be successively used upon failures according to either
`compression_ratio_threshold` or `log_prob_threshold`.
compression_ratio_threshold: If the gzip compression ratio is above this value,
treat as failed.
log_prob_threshold: If the average log probability over sampled tokens is
below this value, treat as failed.
log_prob_low_threshold: This parameter alone is sufficient to skip an output text,
whereas log_prob_threshold also looks for appropriate no_speech_threshold value.
This value should be less than log_prob_threshold.
no_speech_threshold: If the no_speech probability is higher than this value AND
the average log probability over sampled tokens is below `log_prob_threshold`,
consider the segment as silent.
temperature: Temperature for sampling. If a list or tuple is passed,
only the first value is used.
initial_prompt: Optional text string or iterable of token ids to provide as a
prompt for the first window.
prefix: Optional text to provide as a prefix for the first window.
prompt for the each window.
prefix: Optional text to provide as a prefix at the beginning of each window.
suppress_blank: Suppress blank outputs at the beginning of the sampling.
suppress_tokens: List of token IDs to suppress. -1 will suppress a default set
of symbols as defined in `tokenizer.non_speech_tokens()`.
Expand Down Expand Up @@ -325,28 +302,34 @@ def transcribe(
hotwords:
Hotwords/hint phrases to the model. Has no effect if prefix is not None.
Static params: (Fixed for batched version)
Unused Arguments
max_initial_timestamp: The initial timestamp cannot be later than this, set at 0.0.
multilingual: If True, perform transcription on multilingual videos. Set as False.
output_language: Valid only if multilingual is set to True.
Specifies the string representing the output language. One of
'en' (English) or 'hybrid' (code-switched transcription). set as None.
language_detection_threshold: If the maximum probability of the language tokens is
higher than this value, the language is detected.
language_detection_segments: Number of segments to consider for the language detection.
compression_ratio_threshold: If the gzip compression ratio is above this value,
treat as failed.
log_prob_threshold: If the average log probability over sampled tokens is
below this value, treat as failed.
log_prob_low_threshold: This parameter alone is sufficient to skip an output text,
whereas log_prob_threshold also looks for appropriate no_speech_threshold value.
This value should be less than log_prob_threshold.
no_speech_threshold: If the no_speech probability is higher than this value AND
the average log probability over sampled tokens is below `log_prob_threshold`,
consider the segment as silent.
hallucination_silence_threshold: Optional[float]
When word_timestamps is True, skip silent periods longer than this threshold
(in seconds) when a possible hallucination is detected. set as None.
condition_on_previous_text: If True, the previous output of the model is provided
as a prompt for the next window; disabling may make the text inconsistent across
windows, but the model becomes less prone to getting stuck in a failure loop,
such as repetition looping or timestamps going out of sync. Set as False
prompt_reset_on_temperature: Resets prompt if temperature is above this value.
Arg has effect only if condition_on_previous_text is True. Set at 0.5
#TODO: support "hallucination_silence_threshold" when "word_timestamps=True"
hallucination_silence_threshold: Optional[float]
When word_timestamps is True, skip silent periods longer than this threshold
(in seconds) when a possible hallucination is detected. set as None.
unused:
language_detection_threshold: If the maximum probability of the language tokens is
higher than this value, the language is detected.
language_detection_segments: Number of segments to consider for the language detection.
Returns:
A tuple with:
Expand Down Expand Up @@ -412,8 +395,7 @@ def transcribe(
/ sampling_rate
)

# batched options: see the difference with default options in WhisperModel
batched_options = TranscriptionOptions(
options = TranscriptionOptions(
beam_size=beam_size,
best_of=best_of,
patience=patience,
Expand All @@ -425,7 +407,9 @@ def transcribe(
no_speech_threshold=no_speech_threshold,
compression_ratio_threshold=compression_ratio_threshold,
temperatures=(
temperature if isinstance(temperature, (list, tuple)) else [temperature]
temperature[:1]
if isinstance(temperature, (list, tuple))
else [temperature]
),
initial_prompt=initial_prompt,
prefix=prefix,
Expand All @@ -438,7 +422,7 @@ def transcribe(
word_timestamps=word_timestamps,
hallucination_silence_threshold=None,
condition_on_previous_text=False,
clip_timestamps="0",
clip_timestamps=clip_timestamps,
prompt_reset_on_temperature=0.5,
multilingual=False,
output_language=None,
Expand All @@ -451,7 +435,7 @@ def transcribe(
language_probability=language_probability,
duration=duration,
duration_after_vad=duration_after_vad,
transcription_options=batched_options,
transcription_options=options,
vad_options=None,
all_language_probs=all_language_probs,
)
Expand Down Expand Up @@ -480,7 +464,7 @@ def transcribe(
features,
chunks_metadata,
batch_size,
batched_options,
options,
log_progress,
)

Expand All @@ -495,7 +479,7 @@ def _batched_segments_generator(
results = self.forward(
features[i : i + batch_size],
chunks_metadata[i : i + batch_size],
**asdict(options),
options,
)

for result in results:
Expand Down Expand Up @@ -1735,50 +1719,53 @@ def generate_segment_batched(
self,
features: torch.Tensor,
tokenizer: Tokenizer,
options: dict,
options: TranscriptionOptions,
):
batch_size = features.shape[0]
all_tokens = []
prompt_reset_since = 0

if options["initial_prompt"] is not None:
initial_prompt = " " + options["initial_prompt"].strip()
initial_prompt_tokens = tokenizer.encode(initial_prompt)
all_tokens.extend(initial_prompt_tokens)
previous_tokens = all_tokens[prompt_reset_since:]
prompt = self.get_prompt(
tokenizer,
previous_tokens,
without_timestamps=options["without_timestamps"],
prefix=options["prefix"],
previous_tokens=(
tokenizer.encode(options.initial_prompt)
if options.initial_prompt is not None
else []
),
without_timestamps=options.without_timestamps,
prefix=options.prefix,
hotwords=options.hotwords,
)

encoder_output = self.encode(features)

result = self.model.generate(
results = self.model.generate(
encoder_output,
[prompt] * batch_size,
beam_size=options["beam_size"],
patience=options["patience"],
length_penalty=options["length_penalty"],
beam_size=options.beam_size,
patience=options.patience,
length_penalty=options.length_penalty,
max_length=self.max_length,
suppress_blank=options["suppress_blank"],
suppress_tokens=options["suppress_tokens"],
suppress_blank=options.suppress_blank,
suppress_tokens=options.suppress_tokens,
return_scores=True,
return_no_speech_prob=True,
sampling_temperature=options.temperatures[0],
repetition_penalty=options.repetition_penalty,
no_repeat_ngram_size=options.no_repeat_ngram_size,
)

output = []
for res in result:
output.append({})
for result in results:
# return scores
seq_len = len(res.sequences_ids[0])
cum_logprob = res.scores[0] * (seq_len ** options["length_penalty"])
output[-1]["avg_logprob"] = cum_logprob / (seq_len + 1)
seq_len = len(result.sequences_ids[0])
cum_logprob = result.scores[0] * (seq_len**options.length_penalty)

# return no speech prob
output[-1]["no_speech_prob"] = res.no_speech_prob
output[-1]["tokens"] = res.sequences_ids[0]
output.append(
dict(
avg_logprob=cum_logprob / (seq_len + 1),
no_speech_prob=result.no_speech_prob,
tokens=result.sequences_ids[0],
)
)

return encoder_output, output

Expand Down

0 comments on commit 705d2b3

Please sign in to comment.