Skip to content

Commit

Permalink
Support dumping nbest into manifest in ASR (NVIDIA#8662)
Browse files Browse the repository at this point in the history
* support nbest decoding

Signed-off-by: zhehuaichen <dian.chenzhehuai@gmail.com>

* remove extract_nbest and just dump them if exists

Signed-off-by: zhehuaichen <dian.chenzhehuai@gmail.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* introduce extract_nbest so as to ease the user configuration of nbest

Signed-off-by: zhehuaichen <dian.chenzhehuai@gmail.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* remove breakpoints

Signed-off-by: zhehuaichen <dian.chenzhehuai@gmail.com>

* use flag to check instead of checking nbest item type

Signed-off-by: zhehuaichen <dian.chenzhehuai@gmail.com>

---------

Signed-off-by: zhehuaichen <dian.chenzhehuai@gmail.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
zhehuaichen and pre-commit-ci[bot] authored Mar 18, 2024
1 parent 13f0f23 commit 92d4aa0
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 3 deletions.
24 changes: 22 additions & 2 deletions examples/asr/transcribe_speech.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,6 +195,7 @@ class TranscriptionConfig:
# Only use transcribe_partial_audio() when the audio is too long to fit in memory
# Your manifest input should have `offset` field to use transcribe_partial_audio()
allow_partial_transcribe: bool = False
extract_nbest: bool = False # Extract n-best hypotheses from the model


@hydra_runner(config_name="TranscriptionConfig", schema=TranscriptionConfig)
Expand Down Expand Up @@ -279,13 +280,19 @@ def main(cfg: TranscriptionConfig) -> Union[TranscriptionConfig, List[Hypothesis
if isinstance(asr_model.decoding, MultiTaskDecoding):
cfg.multitask_decoding.compute_langs = cfg.compute_langs
cfg.multitask_decoding.preserve_alignments = cfg.preserve_alignment
if cfg.extract_nbest:
cfg.multitask_decoding.beam.return_best_hypothesis = False
cfg.return_hypotheses = True
asr_model.change_decoding_strategy(cfg.multitask_decoding)
elif cfg.decoder_type is not None:
# TODO: Support compute_langs in CTC eventually
if cfg.compute_langs and cfg.decoder_type == 'ctc':
raise ValueError("CTC models do not support `compute_langs` at the moment")

decoding_cfg = cfg.rnnt_decoding if cfg.decoder_type == 'rnnt' else cfg.ctc_decoding
if cfg.extract_nbest:
decoding_cfg.beam.return_best_hypothesis = False
cfg.return_hypotheses = True
decoding_cfg.compute_timestamps = cfg.compute_timestamps # both ctc and rnnt support it
if 'preserve_alignments' in decoding_cfg:
decoding_cfg.preserve_alignments = preserve_alignment
Expand All @@ -298,6 +305,9 @@ def main(cfg: TranscriptionConfig) -> Union[TranscriptionConfig, List[Hypothesis

# Check if ctc or rnnt model
elif hasattr(asr_model, 'joint'): # RNNT model
if cfg.extract_nbest:
cfg.rnnt_decoding.beam.return_best_hypothesis = False
cfg.return_hypotheses = True
cfg.rnnt_decoding.fused_batch_size = -1
cfg.rnnt_decoding.compute_timestamps = cfg.compute_timestamps
cfg.rnnt_decoding.compute_langs = cfg.compute_langs
Expand All @@ -309,6 +319,9 @@ def main(cfg: TranscriptionConfig) -> Union[TranscriptionConfig, List[Hypothesis
if cfg.compute_langs:
raise ValueError("CTC models do not support `compute_langs` at the moment.")
cfg.ctc_decoding.compute_timestamps = cfg.compute_timestamps
if cfg.extract_nbest:
cfg.ctc_decoding.beam.return_best_hypothesis = False
cfg.return_hypotheses = True

asr_model.change_decoding_strategy(cfg.ctc_decoding)

Expand All @@ -318,6 +331,8 @@ def main(cfg: TranscriptionConfig) -> Union[TranscriptionConfig, List[Hypothesis
isinstance(asr_model, EncDecHybridRNNTCTCModel) and cfg.decoder_type == "ctc"
):
cfg.decoding = cfg.ctc_decoding
elif isinstance(asr_model.decoding, MultiTaskDecoding):
cfg.decoding = cfg.multitask_decoding
else:
cfg.decoding = cfg.rnnt_decoding

Expand Down Expand Up @@ -402,9 +417,14 @@ def autocast(dtype=None):
logging.info(f"Finished transcribing {len(filepaths)} files !")
logging.info(f"Writing transcriptions into file: {cfg.output_filename}")

# if transcriptions form a tuple of (best_hypotheses, all_hypotheses), extract just best hypothesis
# if transcriptions form a tuple of (best_hypotheses, all_hypotheses)
if type(transcriptions) == tuple and len(transcriptions) == 2:
transcriptions = transcriptions[0]
if cfg.extract_nbest:
# extract all hypotheses if exists
transcriptions = transcriptions[1]
else:
# extract just best hypothesis
transcriptions = transcriptions[0]

if cfg.return_transcriptions:
return transcriptions
Expand Down
3 changes: 2 additions & 1 deletion nemo/collections/asr/parts/utils/transcribe_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -390,7 +390,8 @@ def write_transcription(
if not cfg.decoding.beam.return_best_hypothesis:
beam = []
for hyp in hyps:
beam.append((hyp.text, hyp.score))
score = hyp.score.numpy().item() if isinstance(hyp.score, torch.Tensor) else hyp.score
beam.append((hyp.text, score))
beams.append(beam)
else:
raise TypeError
Expand Down

0 comments on commit 92d4aa0

Please sign in to comment.