Skip to content

Commit

Permalink
. [skip ci]
Browse files Browse the repository at this point in the history
  • Loading branch information
MahmoudAshraf97 committed Nov 16, 2024
1 parent cce3c67 commit aededdf
Showing 1 changed file with 53 additions and 4 deletions.
57 changes: 53 additions & 4 deletions faster_whisper/transcribe.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,58 @@ def __init__(
self.model: WhisperModel = model
self.last_speech_timestamp = 0.0

def forward(self, features, tokenizer, chunks_metadata, options):
encoder_output, outputs = self.generate_segment_batched(
features, tokenizer, options
)

segmented_outputs = []
segment_sizes = []
for chunk_metadata, output in zip(chunks_metadata, outputs):
duration = chunk_metadata["end_time"] - chunk_metadata["start_time"]
segment_size = int(ceil(duration) * self.model.frames_per_second)
segment_sizes.append(segment_size)
(
subsegments,
seek,
single_timestamp_ending,
) = self.model._split_segments_by_timestamps(
tokenizer=tokenizer,
tokens=output["tokens"],
time_offset=chunk_metadata["start_time"],
segment_size=segment_size,
segment_duration=duration,
seek=0,
)
segmented_outputs.append(
[
dict(
text=tokenizer.decode(subsegment["tokens"]),
avg_logprob=output["avg_logprob"],
no_speech_prob=output["no_speech_prob"],
tokens=subsegment["tokens"],
start=subsegment["start"],
end=subsegment["end"],
compression_ratio=get_compression_ratio(
tokenizer.decode(subsegment["tokens"])
),
)
for subsegment in subsegments
]
)
if options.word_timestamps:
self.last_speech_timestamp = self.model.add_word_timestamps(
segmented_outputs,
tokenizer,
encoder_output,
segment_sizes,
options.prepend_punctuations,
options.append_punctuations,
self.last_speech_timestamp,
)

return segmented_outputs

def generate_segment_batched(
self,
features: np.ndarray,
Expand Down Expand Up @@ -223,9 +275,6 @@ def forward(self, features, tokenizer, chunks_metadata, options):
compression_ratio=get_compression_ratio(
tokenizer.decode(subsegment["tokens"])
),
seek=int(
chunk_metadata["start_time"] * self.model.frames_per_second
),
)
for subsegment in subsegments
]
Expand Down Expand Up @@ -542,7 +591,7 @@ def transcribe(
vad_options=vad_parameters,
all_language_probs=all_language_probs,
)

audio_chunks, chunks_metadata = collect_chunks(audio, clip_timestamps)
features = (
np.stack(
Expand Down

0 comments on commit aededdf

Please sign in to comment.