diff --git a/faster_whisper/transcribe.py b/faster_whisper/transcribe.py index 6ddf6e5f..7325a6c3 100644 --- a/faster_whisper/transcribe.py +++ b/faster_whisper/transcribe.py @@ -119,6 +119,58 @@ def __init__( self.model: WhisperModel = model self.last_speech_timestamp = 0.0 + def forward(self, features, tokenizer, chunks_metadata, options): + encoder_output, outputs = self.generate_segment_batched( + features, tokenizer, options + ) + + segmented_outputs = [] + segment_sizes = [] + for chunk_metadata, output in zip(chunks_metadata, outputs): + duration = chunk_metadata["end_time"] - chunk_metadata["start_time"] + segment_size = int(ceil(duration) * self.model.frames_per_second) + segment_sizes.append(segment_size) + ( + subsegments, + seek, + single_timestamp_ending, + ) = self.model._split_segments_by_timestamps( + tokenizer=tokenizer, + tokens=output["tokens"], + time_offset=chunk_metadata["start_time"], + segment_size=segment_size, + segment_duration=duration, + seek=0, + ) + segmented_outputs.append( + [ + dict( + text=tokenizer.decode(subsegment["tokens"]), + avg_logprob=output["avg_logprob"], + no_speech_prob=output["no_speech_prob"], + tokens=subsegment["tokens"], + start=subsegment["start"], + end=subsegment["end"], + compression_ratio=get_compression_ratio( + tokenizer.decode(subsegment["tokens"]) + ), + ) + for subsegment in subsegments + ] + ) + if options.word_timestamps: + self.last_speech_timestamp = self.model.add_word_timestamps( + segmented_outputs, + tokenizer, + encoder_output, + segment_sizes, + options.prepend_punctuations, + options.append_punctuations, + self.last_speech_timestamp, + ) + + return segmented_outputs + def generate_segment_batched( self, features: np.ndarray, @@ -223,9 +275,6 @@ def forward(self, features, tokenizer, chunks_metadata, options): compression_ratio=get_compression_ratio( tokenizer.decode(subsegment["tokens"]) ), - seek=int( - chunk_metadata["start_time"] * self.model.frames_per_second - ), ) for subsegment in subsegments ] @@ -542,7 +591,7 @@ def transcribe( vad_options=vad_parameters, all_language_probs=all_language_probs, ) - + audio_chunks, chunks_metadata = collect_chunks(audio, clip_timestamps) features = ( np.stack(