Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added encoder output #1136

Closed
wants to merge 3 commits into from
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 18 additions & 3 deletions faster_whisper/transcribe.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from dataclasses import asdict, dataclass
from inspect import signature
from math import ceil
from typing import BinaryIO, Iterable, List, Optional, Tuple, Union
from typing import BinaryIO, Iterable, List, Optional, Tuple, Union, Any
from warnings import warn

import ctranslate2
Expand Down Expand Up @@ -115,6 +115,21 @@ class TranscriptionInfo:
vad_options: VadOptions


@dataclass
class TranscriptionResult:
segments: Iterable["Segment"]
info: "TranscriptionInfo"
encoder_output: Any

def __iter__(self) -> Iterable:
# Only include the original values for backwards compatibility
yield self.segments
yield self.info

def all(self) -> Tuple[Iterable["Segment"], "TranscriptionInfo", Any]:
# Access all three components
return self.segments, self.info, self.encoder_output

# The code below is originally from HF pipeline and is used in whisper-x
# (https://github.com/m-bain/whisperX) and adapted for faster_whisper

Expand Down Expand Up @@ -694,7 +709,7 @@ def transcribe(
hotwords: Optional[str] = None,
language_detection_threshold: Optional[float] = 0.5,
language_detection_segments: int = 1,
) -> Tuple[Iterable[Segment], TranscriptionInfo]:
) -> TranscriptionResult:
"""Transcribes an input file.

Arguments:
Expand Down Expand Up @@ -970,7 +985,7 @@ def transcribe(
vad_options=vad_parameters,
all_language_probs=all_language_probs,
)
return segments, info
return TranscriptionResult(segments=segments, info=info, encoder_output=encoder_output)

def _split_segments_by_timestamps(
self,
Expand Down
Loading