From a64a04b11d1df9c4628b0320b04bccca66ce271c Mon Sep 17 00:00:00 2001 From: Huiying Date: Fri, 19 Apr 2024 09:34:23 -0700 Subject: [PATCH 1/6] change the condition for get qkv tensor from linear_qkv output (#8965) Signed-off-by: HuiyingLi Co-authored-by: Adi Renduchintala --- .../common/megatron/adapters/mcore_mixins.py | 26 +++++++++---------- 1 file changed, 12 insertions(+), 14 deletions(-) diff --git a/nemo/collections/nlp/modules/common/megatron/adapters/mcore_mixins.py b/nemo/collections/nlp/modules/common/megatron/adapters/mcore_mixins.py index 64e8fe44e1e8..2aeb014c1b40 100644 --- a/nemo/collections/nlp/modules/common/megatron/adapters/mcore_mixins.py +++ b/nemo/collections/nlp/modules/common/megatron/adapters/mcore_mixins.py @@ -86,29 +86,27 @@ def get_query_key_value_tensors(self, hidden_states, key_value_states=None): linear_qkv_output, _ = self.linear_qkv(hidden_states) layernorm_output = None - # In megatron/core/models/gpt/gpt_layer_specs.py TELayerNormColumnParallelLinear is used for linear_qkv. - # TELayerNormColumnParallelLinear fused LN and linear, both will be returned. - # In nemo/collections/nlp/models/language_modeling/megatron/falcon/falcon_spec.py TEColumnParallelLinear is used for linear_qkv, + # In megatron/core/models/gpt/gpt_layer_specs.py when fused module is used(e.g. TELayerNormColumnParallelLinear) + # both LN and qkv will be returned. + # In nemo/collections/nlp/models/language_modeling/megatron/falcon/falcon_spec.py TEColumnParallelLinear(non-fused) is used for linear_qkv, # which only returns linear. - if isinstance(self.linear_qkv, TELayerNormColumnParallelLinear): - mixed_qkv, layernorm_output = linear_qkv_output - elif isinstance(self.linear_qkv, TEColumnParallelLinear): # only mixed_qkv + if isinstance(linear_qkv_output, tuple): + if len(linear_qkv_output) == 2: # fused module, qkv&LN + mixed_qkv, layernorm_output = linear_qkv_output + else: + raise ValueError(f"Unexpected number of outputs from linear_qkv output: {len(linear_qkv_output)}") + else: # for qkv&LN not fused only mixed_qkv mixed_qkv = linear_qkv_output - else: - raise ValueError( - f"Unrecognized module type '{type(self.linear_qkv)}' when getting query, key, value tensors for mcore mixins. " - ) # LoRA logic if self.is_adapter_available(): lora_kqv_adapter = self.get_adapter_module(AdapterName.LORA_KQV_ADAPTER) if lora_kqv_adapter and self.adapter_cfg[AdapterName.LORA_KQV_ADAPTER]['enabled']: - if isinstance(self.linear_qkv, TELayerNormColumnParallelLinear): + if layernorm_output is not None: lora_mixed_qkv = lora_kqv_adapter(layernorm_output) - elif isinstance(self.linear_qkv, TEColumnParallelLinear): - lora_mixed_qkv = lora_kqv_adapter(hidden_states) else: - raise ValueError(f"Unrecognized module type '{type(self.linear_qkv)}' when applying lora.") + lora_mixed_qkv = lora_kqv_adapter(hidden_states) + mixed_qkv = mixed_qkv + lora_mixed_qkv # [sq, b, hp] --> [sq, b, ng, (np/ng + 2) * hn] From b271e12934a4d658f95c7b1a43b1bc06cf9dd216 Mon Sep 17 00:00:00 2001 From: Shashank Verma Date: Fri, 19 Apr 2024 09:39:53 -0700 Subject: [PATCH 2/6] Update Latest News (#8837) * Update Latest News Adds links to articles on * NeMo framework on GKE * Responsible Gen AI using NeMo and Picasso * NeMo powering Amazon Titan foundation models Signed-off-by: Shashank Verma * Minor updates to latest news in README * Remove bullets * Editing text for clarity Signed-off-by: Shashank Verma * Format latest news as a dropdown list * Uses embedded html to format news to dropdown, hiding lengthy details * Fixes formatting of the title Signed-off-by: Shashank Verma * Add break to improve readability of latest news image Signed-off-by: Shashank Verma * Add LLM and MM section in latest news Signed-off-by: Shashank Verma * Add margin in latest news expandable lists Signed-off-by: Shashank Verma * Remove styling of expandable list * Github appears to not render styled elements when embedded as raw html in rst Signed-off-by: Shashank Verma * Fold the first news item by default Signed-off-by: Shashank Verma --------- Signed-off-by: Shashank Verma Signed-off-by: Shashank Verma --- README.rst | 42 ++++++++++++++++++++++++++++++++++-------- 1 file changed, 34 insertions(+), 8 deletions(-) diff --git a/README.rst b/README.rst index f7374641d66d..f4c2a541960f 100644 --- a/README.rst +++ b/README.rst @@ -41,17 +41,43 @@ Latest News ----------- -- 2023/12/06 `New NVIDIA NeMo Framework Features and NVIDIA H200 `_ +.. raw:: html -.. image:: https://github.com/sbhavani/TransformerEngine/blob/main/docs/examples/H200-NeMo-performance.png - :target: https://developer.nvidia.com/blog/new-nvidia-nemo-framework-features-and-nvidia-h200-supercharge-llm-training-performance-and-versatility - :alt: H200-NeMo-performance - :width: 600 +
+ Large Language Models and Multimodal +
+ Accelerate your generative AI journey with NVIDIA NeMo framework on GKE (2024/03/16) -NeMo Framework has been updated with state-of-the-art features, -such as FSDP, Mixture-of-Experts, and RLHF with TensorRT-LLM to provide speedups up to 4.2x for Llama-2 pre-training on H200. -**All of these features will be available in an upcoming release.** + An end-to-end walkthrough to train generative AI models on the Google Kubernetes Engine (GKE) using the NVIDIA NeMo Framework is available at https://github.com/GoogleCloudPlatform/nvidia-nemo-on-gke. The walkthrough includes detailed instructions on how to set up a Google Cloud Project and pre-train a GPT model using the NeMo Framework. +

+
+
+ Bria Builds Responsible Generative AI for Enterprises Using NVIDIA NeMo, Picasso (2024/03/06) + + Bria, a Tel Aviv startup at the forefront of visual generative AI for enterprises now leverages the NVIDIA NeMo Framework. The Bria.ai platform uses reference implementations from the NeMo Multimodal collection, trained on NVIDIA Tensor Core GPUs, to enable high-throughput and low-latency image generation. Bria has also adopted NVIDIA Picasso, a foundry for visual generative AI models, to run inference. +

+
+ +
+ New NVIDIA NeMo Framework Features and NVIDIA H200 (2023/12/06) + + NVIDIA NeMo Framework now includes several optimizations and enhancements, including: 1) Fully Sharded Data Parallelism (FSDP) to improve the efficiency of training large-scale AI models, 2) Mix of Experts (MoE)-based LLM architectures with expert parallelism for efficient LLM training at scale, 3) Reinforcement Learning from Human Feedback (RLHF) with TensorRT-LLM for inference stage acceleration, and 4) up to 4.2x speedups for Llama 2 pre-training on NVIDIA H200 Tensor Core GPUs. +

+ H200-NeMo-performance +

+
+ +
+ NVIDIA now powers training for Amazon Titan Foundation models (2023/11/28) + + NVIDIA NeMo framework now empowers the Amazon Titan foundation models (FM) with efficient training of large language models (LLMs). The Titan FMs form the basis of Amazon’s generative AI service, Amazon Bedrock. The NeMo Framework provides a versatile framework for building, customizing, and running LLMs. +

+
+ +
+ + Introduction From c9c8408fd0901714a8d051b808dd331cf44a69ae Mon Sep 17 00:00:00 2001 From: Shashank Verma Date: Fri, 19 Apr 2024 10:17:06 -0700 Subject: [PATCH 3/6] Fix incorrect link to latest news in README (#8985) Signed-off-by: Shashank Verma --- README.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.rst b/README.rst index f4c2a541960f..41e326cef44c 100644 --- a/README.rst +++ b/README.rst @@ -60,7 +60,7 @@ Latest News
- New NVIDIA NeMo Framework Features and NVIDIA H200 (2023/12/06) + New NVIDIA NeMo Framework Features and NVIDIA H200 (2023/12/06) NVIDIA NeMo Framework now includes several optimizations and enhancements, including: 1) Fully Sharded Data Parallelism (FSDP) to improve the efficiency of training large-scale AI models, 2) Mix of Experts (MoE)-based LLM architectures with expert parallelism for efficient LLM training at scale, 3) Reinforcement Learning from Human Feedback (RLHF) with TensorRT-LLM for inference stage acceleration, and 4) up to 4.2x speedups for Llama 2 pre-training on NVIDIA H200 Tensor Core GPUs.

From 206e84a15f03166a68b3cdd97f4c197d2d9c552a Mon Sep 17 00:00:00 2001 From: Elena Rastorgueva <80532067+erastorgueva-nv@users.noreply.github.com> Date: Fri, 19 Apr 2024 19:47:55 -0700 Subject: [PATCH 4/6] Enable using hybrid asr models in CTC Segmentation tool (#8828) * enable using hybrid asr models in ctc segmentation tool Signed-off-by: Elena Rastorgueva * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Signed-off-by: Elena Rastorgueva Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- .../ctc_segmentation/scripts/prepare_data.py | 16 ++++++- .../scripts/run_ctc_segmentation.py | 44 +++++++++++++------ 2 files changed, 46 insertions(+), 14 deletions(-) diff --git a/tools/ctc_segmentation/scripts/prepare_data.py b/tools/ctc_segmentation/scripts/prepare_data.py index c6ea024273fb..476e719eb51b 100644 --- a/tools/ctc_segmentation/scripts/prepare_data.py +++ b/tools/ctc_segmentation/scripts/prepare_data.py @@ -26,6 +26,8 @@ from tqdm import tqdm from nemo.collections.asr.models import ASRModel +from nemo.collections.asr.models.ctc_models import EncDecCTCModel +from nemo.collections.asr.models.hybrid_rnnt_ctc_models import EncDecHybridRNNTCTCModel from nemo.utils import model_utils try: @@ -354,7 +356,19 @@ def _split(sentences, delimiter): asr_model = ASRModel.from_pretrained(model_name=args.model) # type: ASRModel model_name = args.model - vocabulary = asr_model.cfg.decoder.vocabulary + if not (isinstance(asr_model, EncDecCTCModel) or isinstance(asr_model, EncDecHybridRNNTCTCModel)): + raise NotImplementedError( + f"Model is not an instance of NeMo EncDecCTCModel or ENCDecHybridRNNTCTCModel." + " Currently only instances of these models are supported" + ) + + # get vocabulary list + if hasattr(asr_model, 'tokenizer'): # i.e. tokenization is BPE-based + vocabulary = asr_model.tokenizer.vocab + elif hasattr(asr_model.decoder, "vocabulary"): # i.e. tokenization is character-based + vocabulary = asr_model.cfg.decoder.vocabulary + else: + raise ValueError("Unexpected model type. Vocabulary list not found.") if os.path.isdir(args.in_text): text_files = glob(f"{args.in_text}/*.txt") diff --git a/tools/ctc_segmentation/scripts/run_ctc_segmentation.py b/tools/ctc_segmentation/scripts/run_ctc_segmentation.py index dddeb9a42dc2..c9d9ed2d8731 100644 --- a/tools/ctc_segmentation/scripts/run_ctc_segmentation.py +++ b/tools/ctc_segmentation/scripts/run_ctc_segmentation.py @@ -27,6 +27,8 @@ from utils import get_segments import nemo.collections.asr as nemo_asr +from nemo.collections.asr.models.ctc_models import EncDecCTCModel +from nemo.collections.asr.models.hybrid_rnnt_ctc_models import EncDecHybridRNNTCTCModel parser = argparse.ArgumentParser(description="CTC Segmentation") parser.add_argument("--output_dir", default="output", type=str, help="Path to output directory") @@ -72,18 +74,19 @@ logging.basicConfig(handlers=handlers, level=level) if os.path.exists(args.model): - asr_model = nemo_asr.models.EncDecCTCModel.restore_from(args.model) - elif args.model in nemo_asr.models.EncDecCTCModel.get_available_model_names(): - asr_model = nemo_asr.models.EncDecCTCModel.from_pretrained(args.model, strict=False) + asr_model = nemo_asr.models.ASRModel.restore_from(args.model) else: - try: - asr_model = nemo_asr.models.EncDecCTCModelBPE.from_pretrained(args.model) - except: - raise ValueError( - f"Provide path to the pretrained checkpoint or choose from {nemo_asr.models.EncDecCTCModel.get_available_model_names()}" - ) + asr_model = nemo_asr.models.ASRModel.from_pretrained(args.model, strict=False) + + if not (isinstance(asr_model, EncDecCTCModel) or isinstance(asr_model, EncDecHybridRNNTCTCModel)): + raise NotImplementedError( + f"Model is not an instance of NeMo EncDecCTCModel or ENCDecHybridRNNTCTCModel." + " Currently only instances of these models are supported" + ) - bpe_model = isinstance(asr_model, nemo_asr.models.EncDecCTCModelBPE) + bpe_model = isinstance(asr_model, nemo_asr.models.EncDecCTCModelBPE) or isinstance( + asr_model, nemo_asr.models.EncDecHybridRNNTCTCBPEModel + ) # get tokenizer used during training, None for char based models if bpe_model: @@ -91,8 +94,18 @@ else: tokenizer = None + if isinstance(asr_model, EncDecHybridRNNTCTCModel): + asr_model.change_decoding_strategy(decoder_type="ctc") + # extract ASR vocabulary and add blank symbol - vocabulary = ["ε"] + list(asr_model.cfg.decoder.vocabulary) + if hasattr(asr_model, 'tokenizer'): # i.e. tokenization is BPE-based + vocabulary = asr_model.tokenizer.vocab + elif hasattr(asr_model.decoder, "vocabulary"): # i.e. tokenization is character-based + vocabulary = asr_model.cfg.decoder.vocabulary + else: + raise ValueError("Unexpected model type. Vocabulary list not found.") + + vocabulary = ["ε"] + list(vocabulary) logging.debug(f"ASR Model vocabulary: {vocabulary}") data = Path(args.data) @@ -136,9 +149,14 @@ logging.debug(f"len(signal): {len(signal)}, sr: {sample_rate}") logging.debug(f"Duration: {original_duration}s, file_name: {path_audio}") - log_probs = asr_model.transcribe(audio=[str(path_audio)], batch_size=1, return_hypotheses=True)[ + hypotheses = asr_model.transcribe([str(path_audio)], batch_size=1, return_hypotheses=True) + # if hypotheses form a tuple (from Hybrid model), extract just "best" hypothesis + if type(hypotheses) == tuple and len(hypotheses) == 2: + hypotheses = hypotheses[0] + log_probs = hypotheses[ 0 - ].alignments + ].alignments # note: "[0]" is for batch dimension unpacking (and here batch size=1) + # move blank values to the first column (ctc-package compatibility) blank_col = log_probs[:, -1].reshape((log_probs.shape[0], 1)) log_probs = np.concatenate((blank_col, log_probs[:, :-1]), axis=1) From 6533e484c7e340d0ee0f1530ce064caeafa87301 Mon Sep 17 00:00:00 2001 From: Huiying Date: Sat, 20 Apr 2024 10:06:36 -0700 Subject: [PATCH 5/6] Add safety checks for 'data' key in MegatronGPTModel cfg (#8991) Signed-off-by: HuiyingLi --- .../nlp/models/language_modeling/megatron_gpt_model.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py b/nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py index a660af46f13d..e5e48cdc10da 100644 --- a/nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py +++ b/nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py @@ -367,9 +367,10 @@ def __init__(self, cfg: DictConfig, trainer: Trainer): self.log_train_loss = bool(int(os.getenv("NEMO_LOG_TRAIN_LOSS", 1))) self.log_memory_usage = bool(int(os.getenv("NEMO_LOG_MEMORY_USAGE", 0))) self.loss_broadcast_src_rank = None - self.return_output_tensors = cfg.data.get('return_output_tensors', False) - self.validation_drop_last = cfg.data.get('validation_drop_last', True) - self.sample_weight = cfg.data.get('sample_weight', 'token') + data_cfg = cfg.get('data', {}) + self.return_output_tensors = data_cfg.get('return_output_tensors', False) + self.validation_drop_last = data_cfg.get('validation_drop_last', True) + self.sample_weight = data_cfg.get('sample_weight', 'token') self.validation_param_sync_overlap = self.cfg.get('validation_param_sync_overlap', False) self.inference_params = None From 9bafd37e58b34b905fab3ba19daa010ae7c6271a Mon Sep 17 00:00:00 2001 From: Aleksandr Laptev Date: Sun, 21 Apr 2024 17:39:26 +0700 Subject: [PATCH 6/6] TDT confidence fix (#8982) * tdt confidence fix --------- Signed-off-by: Aleksandr Laptev Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- .../asr/parts/submodules/ctc_decoding.py | 13 +++ .../asr/parts/submodules/rnnt_decoding.py | 101 ++++++++++++++---- .../parts/submodules/rnnt_greedy_decoding.py | 19 +++- .../submodules/tdt_loop_labels_computer.py | 77 +++++++++++-- .../asr_confidence_benchmarking_utils.py | 2 +- .../asr/parts/utils/asr_confidence_utils.py | 11 +- .../collections/asr/parts/utils/rnnt_utils.py | 12 ++- .../confidence/benchmark_asr_confidence.py | 17 ++- .../test_asr_hybrid_rnnt_ctc_model_char.py | 4 +- .../asr/test_asr_rnnt_encdec_model.py | 4 +- tutorials/asr/ASR_Confidence_Estimation.ipynb | 30 ++++-- 11 files changed, 239 insertions(+), 51 deletions(-) diff --git a/nemo/collections/asr/parts/submodules/ctc_decoding.py b/nemo/collections/asr/parts/submodules/ctc_decoding.py index d331a6c86b53..67559eccf6e2 100644 --- a/nemo/collections/asr/parts/submodules/ctc_decoding.py +++ b/nemo/collections/asr/parts/submodules/ctc_decoding.py @@ -98,6 +98,10 @@ class AbstractCTCDecoding(ConfidenceMixin): Which aggregation type to use for collapsing per-token confidence into per-word confidence. Valid options are `mean`, `min`, `max`, `prod`. + tdt_include_duration: Bool flag indicating that the duration confidence scores are to be calculated and + attached to the regular frame confidence, + making TDT frame confidence element a pair: (`prediction_confidence`, `duration_confidence`). + method_cfg: A dict-like object which contains the method name and settings to compute per-frame confidence scores. @@ -911,10 +915,15 @@ class CTCDecoding(AbstractCTCDecoding): exclude_blank: Bool flag indicating that blank token confidence scores are to be excluded from the `token_confidence`. + aggregation: Which aggregation type to use for collapsing per-token confidence into per-word confidence. Valid options are `mean`, `min`, `max`, `prod`. + tdt_include_duration: Bool flag indicating that the duration confidence scores are to be calculated and + attached to the regular frame confidence, + making TDT frame confidence element a pair: (`prediction_confidence`, `duration_confidence`). + method_cfg: A dict-like object which contains the method name and settings to compute per-frame confidence scores. @@ -1122,6 +1131,10 @@ class CTCBPEDecoding(AbstractCTCDecoding): Which aggregation type to use for collapsing per-token confidence into per-word confidence. Valid options are `mean`, `min`, `max`, `prod`. + tdt_include_duration: Bool flag indicating that the duration confidence scores are to be calculated and + attached to the regular frame confidence, + making TDT frame confidence element a pair: (`prediction_confidence`, `duration_confidence`). + method_cfg: A dict-like object which contains the method name and settings to compute per-frame confidence scores. diff --git a/nemo/collections/asr/parts/submodules/rnnt_decoding.py b/nemo/collections/asr/parts/submodules/rnnt_decoding.py index 7a260f3c6c89..71079f4b6382 100644 --- a/nemo/collections/asr/parts/submodules/rnnt_decoding.py +++ b/nemo/collections/asr/parts/submodules/rnnt_decoding.py @@ -96,6 +96,9 @@ class AbstractRNNTDecoding(ConfidenceMixin): from the `token_confidence`. aggregation: Which aggregation type to use for collapsing per-token confidence into per-word confidence. Valid options are `mean`, `min`, `max`, `prod`. + tdt_include_duration: Bool flag indicating that the duration confidence scores are to be calculated and + attached to the regular frame confidence, + making TDT frame confidence element a pair: (`prediction_confidence`, `duration_confidence`). method_cfg: A dict-like object which contains the method name and settings to compute per-frame confidence scores. @@ -209,7 +212,8 @@ def __init__(self, decoding_cfg, decoder, joint, blank_id: int): self.compute_timestamps = self.cfg.get('compute_timestamps', None) self.word_seperator = self.cfg.get('word_seperator', ' ') - if self.durations is not None and self.durations != []: # this means it's a TDT model. + self._is_tdt = self.durations is not None and self.durations != [] # this means it's a TDT model. + if self._is_tdt: if blank_id == 0: raise ValueError("blank_id must equal len(non_blank_vocabs) for TDT models") if self.big_blank_durations is not None and self.big_blank_durations != []: @@ -254,6 +258,12 @@ def __init__(self, decoding_cfg, decoder, joint, blank_id: int): # initialize confidence-related fields self._init_confidence(self.cfg.get('confidence_cfg', None)) + if self._is_tdt: + if self.preserve_frame_confidence is True and self.preserve_alignments is False: + raise ValueError( + "If `preserve_frame_confidence` flag is set, then `preserve_alignments` flag must also be set." + ) + # Confidence estimation is not implemented for these strategies if ( not self.preserve_frame_confidence @@ -264,7 +274,7 @@ def __init__(self, decoding_cfg, decoder, joint, blank_id: int): if self.cfg.strategy == 'greedy': if self.big_blank_durations is None or self.big_blank_durations == []: - if self.durations is None or self.durations == []: + if not self._is_tdt: self.decoding = rnnt_greedy_decoding.GreedyRNNTInfer( decoder_model=decoder, joint_model=joint, @@ -289,6 +299,7 @@ def __init__(self, decoding_cfg, decoder, joint, blank_id: int): ), preserve_alignments=self.preserve_alignments, preserve_frame_confidence=self.preserve_frame_confidence, + include_duration_confidence=self.tdt_include_duration_confidence, confidence_method_cfg=self.confidence_method_cfg, ) else: @@ -307,7 +318,7 @@ def __init__(self, decoding_cfg, decoder, joint, blank_id: int): elif self.cfg.strategy == 'greedy_batch': if self.big_blank_durations is None or self.big_blank_durations == []: - if self.durations is None or self.durations == []: + if not self._is_tdt: self.decoding = rnnt_greedy_decoding.GreedyBatchedRNNTInfer( decoder_model=decoder, joint_model=joint, @@ -334,6 +345,7 @@ def __init__(self, decoding_cfg, decoder, joint, blank_id: int): ), preserve_alignments=self.preserve_alignments, preserve_frame_confidence=self.preserve_frame_confidence, + include_duration_confidence=self.tdt_include_duration_confidence, confidence_method_cfg=self.confidence_method_cfg, use_cuda_graph_decoder=self.cfg.greedy.get('use_cuda_graph_decoder', False), ) @@ -530,7 +542,7 @@ def decode_hypothesis(self, hypotheses_list: List[Hypothesis]) -> List[Union[Hyp if self.big_blank_durations is not None and self.big_blank_durations != []: # multi-blank RNNT num_extra_outputs = len(self.big_blank_durations) prediction = [p for p in prediction if p < self.blank_id - num_extra_outputs] - elif self.durations is not None and self.durations != []: # TDT model. + elif self._is_tdt: # TDT model. prediction = [p for p in prediction if p < self.blank_id] else: # standard RNN-T prediction = [p for p in prediction if p != self.blank_id] @@ -569,28 +581,69 @@ def compute_confidence(self, hypotheses_list: List[Hypothesis]) -> List[Hypothes Returns: A list of hypotheses with high-level confidence scores. """ - if self.exclude_blank_from_confidence: - for hyp in hypotheses_list: - hyp.token_confidence = hyp.non_blank_frame_confidence - else: + if self._is_tdt: + # if self.tdt_include_duration_confidence is True then frame_confidence elements consist of two numbers + maybe_pre_aggregate = ( + (lambda x: self._aggregate_confidence(x)) if self.tdt_include_duration_confidence else (lambda x: x) + ) for hyp in hypotheses_list: - offset = 0 token_confidence = [] - if len(hyp.timestep) > 0: - for ts, te in zip(hyp.timestep, hyp.timestep[1:] + [len(hyp.frame_confidence)]): - if ts != te: - # tokens are considered to belong to the last non-blank token, if any. - token_confidence.append( - self._aggregate_confidence( - [hyp.frame_confidence[ts][offset]] - + [fc[0] for fc in hyp.frame_confidence[ts + 1 : te]] + # trying to recover frame_confidence according to alignments + subsequent_blank_confidence = [] + # going backwards since tokens are considered belonging to the last non-blank token. + for fc, fa in zip(hyp.frame_confidence[::-1], hyp.alignments[::-1]): + # there is only one score per frame most of the time + if len(fa) > 1: + for i, a in reversed(list(enumerate(fa))): + if a[-1] == self.blank_id: + if not self.exclude_blank_from_confidence: + subsequent_blank_confidence.append(maybe_pre_aggregate(fc[i])) + elif not subsequent_blank_confidence: + token_confidence.append(maybe_pre_aggregate(fc[i])) + else: + token_confidence.append( + self._aggregate_confidence( + [maybe_pre_aggregate(fc[i])] + subsequent_blank_confidence + ) ) - ) - offset = 0 + subsequent_blank_confidence = [] + else: + i, a = 0, fa[0] + if a[-1] == self.blank_id: + if not self.exclude_blank_from_confidence: + subsequent_blank_confidence.append(maybe_pre_aggregate(fc[i])) + elif not subsequent_blank_confidence: + token_confidence.append(maybe_pre_aggregate(fc[i])) else: - token_confidence.append(hyp.frame_confidence[ts][offset]) - offset += 1 + token_confidence.append( + self._aggregate_confidence([maybe_pre_aggregate(fc[i])] + subsequent_blank_confidence) + ) + subsequent_blank_confidence = [] + token_confidence = token_confidence[::-1] hyp.token_confidence = token_confidence + else: + if self.exclude_blank_from_confidence: + for hyp in hypotheses_list: + hyp.token_confidence = hyp.non_blank_frame_confidence + else: + for hyp in hypotheses_list: + offset = 0 + token_confidence = [] + if len(hyp.timestep) > 0: + for ts, te in zip(hyp.timestep, hyp.timestep[1:] + [len(hyp.frame_confidence)]): + if ts != te: + # tokens are considered to belong to the last non-blank token, if any. + token_confidence.append( + self._aggregate_confidence( + [hyp.frame_confidence[ts][offset]] + + [fc[0] for fc in hyp.frame_confidence[ts + 1 : te]] + ) + ) + offset = 0 + else: + token_confidence.append(hyp.frame_confidence[ts][offset]) + offset += 1 + hyp.token_confidence = token_confidence if self.preserve_word_confidence: for hyp in hypotheses_list: hyp.word_confidence = self._aggregate_token_confidence(hyp) @@ -1010,6 +1063,9 @@ class RNNTDecoding(AbstractRNNTDecoding): from the `token_confidence`. aggregation: Which aggregation type to use for collapsing per-token confidence into per-word confidence. Valid options are `mean`, `min`, `max`, `prod`. + tdt_include_duration: Bool flag indicating that the duration confidence scores are to be calculated and + attached to the regular frame confidence, + making TDT frame confidence element a pair: (`prediction_confidence`, `duration_confidence`). method_cfg: A dict-like object which contains the method name and settings to compute per-frame confidence scores. @@ -1276,6 +1332,9 @@ class RNNTBPEDecoding(AbstractRNNTDecoding): from the `token_confidence`. aggregation: Which aggregation type to use for collapsing per-token confidence into per-word confidence. Valid options are `mean`, `min`, `max`, `prod`. + tdt_include_duration: Bool flag indicating that the duration confidence scores are to be calculated and + attached to the regular frame confidence, + making TDT frame confidence element a pair: (`prediction_confidence`, `duration_confidence`). method_cfg: A dict-like object which contains the method name and settings to compute per-frame confidence scores. diff --git a/nemo/collections/asr/parts/submodules/rnnt_greedy_decoding.py b/nemo/collections/asr/parts/submodules/rnnt_greedy_decoding.py index 464dc46e358c..e5de99cf0776 100644 --- a/nemo/collections/asr/parts/submodules/rnnt_greedy_decoding.py +++ b/nemo/collections/asr/parts/submodules/rnnt_greedy_decoding.py @@ -2282,6 +2282,7 @@ class GreedyRNNTInferConfig: max_symbols_per_step: Optional[int] = 10 preserve_alignments: bool = False preserve_frame_confidence: bool = False + tdt_include_duration_confidence: bool = False confidence_method_cfg: Optional[ConfidenceMethodConfig] = field(default_factory=lambda: ConfidenceMethodConfig()) def __post_init__(self): @@ -2298,6 +2299,7 @@ class GreedyBatchedRNNTInferConfig: max_symbols_per_step: Optional[int] = 10 preserve_alignments: bool = False preserve_frame_confidence: bool = False + tdt_include_duration_confidence: bool = False confidence_method_cfg: Optional[ConfidenceMethodConfig] = field(default_factory=lambda: ConfidenceMethodConfig()) loop_labels: bool = True use_cuda_graph_decoder: bool = False @@ -2337,6 +2339,9 @@ class GreedyTDTInfer(_GreedyRNNTInfer): The length of the list corresponds to the Acoustic Length (T). Each value in the list (Ti) is a torch.Tensor (U), representing 1 or more confidence scores. U is the number of target tokens for the current timestep Ti. + include_duration_confidence: Bool flag indicating that the duration confidence scores are to be calculated and + attached to the regular frame confidence, + making TDT frame confidence element a pair: (`prediction_confidence`, `duration_confidence`). confidence_method_cfg: A dict-like object which contains the method name and settings to compute per-frame confidence scores. @@ -2380,6 +2385,7 @@ def __init__( max_symbols_per_step: Optional[int] = None, preserve_alignments: bool = False, preserve_frame_confidence: bool = False, + include_duration_confidence: bool = False, confidence_method_cfg: Optional[DictConfig] = None, ): super().__init__( @@ -2392,6 +2398,7 @@ def __init__( confidence_method_cfg=confidence_method_cfg, ) self.durations = durations + self.include_duration_confidence = include_duration_confidence @typecheck() def forward( @@ -2517,7 +2524,11 @@ def _greedy_decode( if self.preserve_frame_confidence: # insert confidence into last timestep - hypothesis.frame_confidence[-1].append(self._get_confidence(logp)) + hypothesis.frame_confidence[-1].append( + (self._get_confidence_tensor(logp), self._get_confidence_tensor(duration_logp)) + if self.include_duration_confidence + else self._get_confidence_tensor(logp) + ) del logp @@ -2593,6 +2604,9 @@ class GreedyBatchedTDTInfer(_GreedyRNNTInfer): The length of the list corresponds to the Acoustic Length (T). Each value in the list (Ti) is a torch.Tensor (U), representing 1 or more confidence scores. U is the number of target tokens for the current timestep Ti. + include_duration_confidence: Bool flag indicating that the duration confidence scores are to be calculated and + attached to the regular frame confidence, + making TDT frame confidence element a pair: (`prediction_confidence`, `duration_confidence`). confidence_method_cfg: A dict-like object which contains the method name and settings to compute per-frame confidence scores. @@ -2636,6 +2650,7 @@ def __init__( max_symbols_per_step: Optional[int] = None, preserve_alignments: bool = False, preserve_frame_confidence: bool = False, + include_duration_confidence: bool = False, confidence_method_cfg: Optional[DictConfig] = None, use_cuda_graph_decoder: bool = False, ): @@ -2649,6 +2664,7 @@ def __init__( confidence_method_cfg=confidence_method_cfg, ) self.durations = durations + self.include_duration_confidence = include_duration_confidence # Depending on availability of `blank_as_pad` support # switch between more efficient batch decoding technique @@ -2663,6 +2679,7 @@ def __init__( max_symbols_per_step=self.max_symbols, preserve_alignments=preserve_alignments, preserve_frame_confidence=preserve_frame_confidence, + include_duration_confidence=include_duration_confidence, confidence_method_cfg=confidence_method_cfg, allow_cuda_graphs=use_cuda_graph_decoder, ) diff --git a/nemo/collections/asr/parts/submodules/tdt_loop_labels_computer.py b/nemo/collections/asr/parts/submodules/tdt_loop_labels_computer.py index c289ce06cdfa..b136446d97fb 100644 --- a/nemo/collections/asr/parts/submodules/tdt_loop_labels_computer.py +++ b/nemo/collections/asr/parts/submodules/tdt_loop_labels_computer.py @@ -92,6 +92,7 @@ def __init__( logits_dim: int, preserve_alignments=False, preserve_frame_confidence=False, + include_duration_confidence: bool = False, ): """ @@ -105,6 +106,7 @@ def __init__( logits_dim: output dimension for Joint preserve_alignments: if alignments are needed preserve_frame_confidence: if frame confidence is needed + include_duration_confidence: if duration confidence is needed to be added to the frame confidence """ self.device = device self.float_dtype = float_dtype @@ -151,6 +153,7 @@ def __init__( float_dtype=self.float_dtype, store_alignments=preserve_alignments, store_frame_confidence=preserve_frame_confidence, + with_duration_confidence=include_duration_confidence, ) else: self.alignments = None @@ -186,6 +189,7 @@ def __init__( max_symbols_per_step: Optional[int] = None, preserve_alignments=False, preserve_frame_confidence=False, + include_duration_confidence: bool = False, confidence_method_cfg: Optional[DictConfig] = None, allow_cuda_graphs: bool = True, ): @@ -199,6 +203,7 @@ def __init__( max_symbols_per_step: max symbols to emit on each step (to avoid infinite looping) preserve_alignments: if alignments are needed preserve_frame_confidence: if frame confidence is needed + include_duration_confidence: if duration confidence is needed to be added to the frame confidence confidence_method_cfg: config for the confidence """ super().__init__() @@ -210,6 +215,7 @@ def __init__( self.max_symbols = max_symbols_per_step self.preserve_alignments = preserve_alignments self.preserve_frame_confidence = preserve_frame_confidence + self.include_duration_confidence = include_duration_confidence self._SOS = self._blank_index self._init_confidence_method(confidence_method_cfg=confidence_method_cfg) assert self._SOS == self._blank_index # "blank as pad" algorithm only @@ -244,6 +250,7 @@ def loop_labels_torch( # do not recalculate joint projection, project only once encoder_output_projected = self.joint.project_encoder(encoder_output) + dtype = encoder_output_projected.dtype # init output structures: BatchedHyps (for results), BatchedAlignments + last decoder state # init empty batched hypotheses @@ -251,7 +258,7 @@ def loop_labels_torch( batch_size=batch_size, init_length=max_time * self.max_symbols if self.max_symbols is not None else max_time, device=device, - float_dtype=encoder_output_projected.dtype, + float_dtype=dtype, ) # sample state, will be replaced further when the decoding for hypothesis is done last_decoder_state = self.decoder.initialize_state(encoder_output_projected) @@ -263,9 +270,10 @@ def loop_labels_torch( logits_dim=self.joint.num_classes_with_blank, init_length=max_time * 2 if use_alignments else 1, # blank for each timestep + text tokens device=device, - float_dtype=encoder_output_projected.dtype, + float_dtype=dtype, store_alignments=self.preserve_alignments, store_frame_confidence=self.preserve_frame_confidence, + with_duration_confidence=self.include_duration_confidence, ) # durations @@ -327,7 +335,19 @@ def loop_labels_torch( time_indices=time_indices_current_labels, logits=logits if self.preserve_alignments else None, labels=labels if self.preserve_alignments else None, - confidence=self._get_confidence_tensor(F.log_softmax(logits[:, :-num_durations], dim=-1)) + confidence=torch.stack( + ( + self._get_confidence_tensor(F.log_softmax(logits[:, :-num_durations], dim=-1)).to( + dtype=dtype + ), + self._get_confidence_tensor(F.log_softmax(logits[:, -num_durations:], dim=-1)).to( + dtype=dtype + ), + ), + dim=-1, + ) + if self.include_duration_confidence + else self._get_confidence_tensor(F.log_softmax(logits[:, :-num_durations], dim=-1)).to(dtype=dtype) if self.preserve_frame_confidence else None, ) @@ -367,7 +387,21 @@ def loop_labels_torch( time_indices=time_indices_current_labels, logits=logits if self.preserve_alignments else None, labels=more_labels if self.preserve_alignments else None, - confidence=self._get_confidence_tensor(F.log_softmax(logits[:, :-num_durations], dim=-1)) + confidence=torch.stack( + ( + self._get_confidence_tensor(F.log_softmax(logits[:, :-num_durations], dim=-1)).to( + dtype=dtype + ), + self._get_confidence_tensor(F.log_softmax(logits[:, -num_durations:], dim=-1)).to( + dtype=dtype + ), + ), + dim=-1, + ) + if self.include_duration_confidence + else self._get_confidence_tensor(F.log_softmax(logits[:, :-num_durations], dim=-1)).to( + dtype=dtype + ) if self.preserve_frame_confidence else None, ) @@ -520,6 +554,7 @@ def _graph_reinitialize( logits_dim=self.joint.num_classes_with_blank, preserve_alignments=self.preserve_alignments, preserve_frame_confidence=self.preserve_frame_confidence, + include_duration_confidence=self.include_duration_confidence, ) self.state.all_durations = self.durations.to(self.state.device) @@ -616,6 +651,7 @@ def _before_inner_loop_get_joint_output(self): # stage 2: get joint output, iteratively seeking for non-blank labels # blank label in `labels` tensor means "end of hypothesis" (for this index) self.state.active_mask_prev.copy_(self.state.active_mask, non_blocking=True) + dtype = self.state.encoder_output_projected.dtype logits = ( self.joint.joint_after_projection( self.state.encoder_output_projected[self.state.batch_indices, self.state.safe_time_indices].unsqueeze( @@ -644,9 +680,21 @@ def _before_inner_loop_get_joint_output(self): time_indices=self.state.time_indices_current_labels, logits=logits if self.preserve_alignments else None, labels=self.state.labels if self.preserve_alignments else None, - confidence=self._get_confidence_tensor( - F.log_softmax(logits[:, : -self.state.all_durations.shape[0]], dim=-1) + confidence=torch.stack( + ( + self._get_confidence_tensor( + F.log_softmax(logits[:, : -self.state.all_durations.shape[0]], dim=-1) + ).to(dtype=dtype), + self._get_confidence_tensor( + F.log_softmax(logits[:, -self.state.all_durations.shape[0] :], dim=-1) + ).to(dtype=dtype), + ), + dim=-1, ) + if self.include_duration_confidence + else self._get_confidence_tensor( + F.log_softmax(logits[:, : -self.state.all_durations.shape[0]], dim=-1) + ).to(dtype=dtype) if self.preserve_frame_confidence else None, ) @@ -672,6 +720,7 @@ def _inner_loop_code(self): self.state.time_indices_current_labels, out=self.state.time_indices_current_labels, ) + dtype = self.state.encoder_output_projected.dtype logits = ( self.joint.joint_after_projection( self.state.encoder_output_projected[self.state.batch_indices, self.state.safe_time_indices].unsqueeze( @@ -698,9 +747,21 @@ def _inner_loop_code(self): time_indices=self.state.time_indices_current_labels, logits=logits if self.preserve_alignments else None, labels=more_labels if self.preserve_alignments else None, - confidence=self._get_confidence_tensor( - F.log_softmax(logits[:, : -self.state.all_durations.shape[0]], dim=-1) + confidence=torch.stack( + ( + self._get_confidence_tensor( + F.log_softmax(logits[:, : -self.state.all_durations.shape[0]], dim=-1) + ).to(dtype=dtype), + self._get_confidence_tensor( + F.log_softmax(logits[:, -self.state.all_durations.shape[0] :], dim=-1) + ).to(dtype=dtype), + ), + dim=-1, ) + if self.include_duration_confidence + else self._get_confidence_tensor( + F.log_softmax(logits[:, : -self.state.all_durations.shape[0]], dim=-1) + ).to(dtype=dtype) if self.preserve_frame_confidence else None, ) diff --git a/nemo/collections/asr/parts/utils/asr_confidence_benchmarking_utils.py b/nemo/collections/asr/parts/utils/asr_confidence_benchmarking_utils.py index 8b15bc22eac6..96f90bee363c 100644 --- a/nemo/collections/asr/parts/utils/asr_confidence_benchmarking_utils.py +++ b/nemo/collections/asr/parts/utils/asr_confidence_benchmarking_utils.py @@ -172,7 +172,7 @@ def apply_confidence_parameters(decoding_cfg, hp): Updated decoding config. """ new_decoding_cfg = copy.deepcopy(decoding_cfg) - confidence_cfg_fields = ("aggregation", "exclude_blank") + confidence_cfg_fields = ("aggregation", "exclude_blank", "tdt_include_duration") confidence_method_cfg_fields = ("name", "alpha", "entropy_type", "entropy_norm") with open_dict(new_decoding_cfg): for p, v in hp.items(): diff --git a/nemo/collections/asr/parts/utils/asr_confidence_utils.py b/nemo/collections/asr/parts/utils/asr_confidence_utils.py index 27ced569b1a9..20f75baf522e 100644 --- a/nemo/collections/asr/parts/utils/asr_confidence_utils.py +++ b/nemo/collections/asr/parts/utils/asr_confidence_utils.py @@ -136,6 +136,9 @@ class ConfidenceConfig: from the `token_confidence`. aggregation: Which aggregation type to use for collapsing per-token confidence into per-word confidence. Valid options are `mean`, `min`, `max`, `prod`. + tdt_include_duration: Bool flag indicating that the duration confidence scores are to be calculated and + attached to the regular frame confidence, + making TDT frame confidence element a pair: (`prediction_confidence`, `duration_confidence`). method_cfg: A dict-like object which contains the method name and settings to compute per-frame confidence scores. @@ -175,6 +178,7 @@ class ConfidenceConfig: preserve_word_confidence: bool = False exclude_blank: bool = True aggregation: str = "min" + tdt_include_duration: bool = False method_cfg: ConfidenceMethodConfig = field(default_factory=lambda: ConfidenceMethodConfig()) def __post_init__(self): @@ -361,6 +365,7 @@ def _init_confidence(self, confidence_cfg: Optional[DictConfig] = None): confidence_cfg.get('preserve_frame_confidence', False) | self.preserve_token_confidence ) self.exclude_blank_from_confidence = confidence_cfg.get('exclude_blank', True) + self.tdt_include_duration_confidence = confidence_cfg.get('tdt_include_duration', False) self.word_confidence_aggregation = confidence_cfg.get('aggregation', "min") # define aggregation functions @@ -368,8 +373,8 @@ def _init_confidence(self, confidence_cfg: Optional[DictConfig] = None): self._aggregate_confidence = self.confidence_aggregation_bank[self.word_confidence_aggregation] # Update preserve frame confidence - if self.preserve_frame_confidence is False: - if self.cfg.strategy in ['greedy', 'greedy_batch']: + if self.cfg.strategy in ['greedy', 'greedy_batch']: + if not self.preserve_frame_confidence: self.preserve_frame_confidence = self.cfg.greedy.get('preserve_frame_confidence', False) # OmegaConf.structured ensures that post_init check is always executed confidence_method_cfg = OmegaConf.structured(self.cfg.greedy).get('confidence_method_cfg', None) @@ -378,6 +383,8 @@ def _init_confidence(self, confidence_cfg: Optional[DictConfig] = None): if confidence_method_cfg is None else OmegaConf.structured(ConfidenceMethodConfig(**confidence_method_cfg)) ) + if not self.tdt_include_duration_confidence: + self.tdt_include_duration_confidence = self.cfg.greedy.get('tdt_include_duration_confidence', False) @abstractmethod def compute_confidence(self, hypotheses_list: List[Hypothesis]) -> List[Hypothesis]: diff --git a/nemo/collections/asr/parts/utils/rnnt_utils.py b/nemo/collections/asr/parts/utils/rnnt_utils.py index 1cd2d2ddc255..158fe3609286 100644 --- a/nemo/collections/asr/parts/utils/rnnt_utils.py +++ b/nemo/collections/asr/parts/utils/rnnt_utils.py @@ -115,7 +115,7 @@ def non_blank_frame_confidence(self) -> List[float]: non_blank_frame_confidence = [] # self.timestep can be a dict for RNNT timestep = self.timestep['timestep'] if isinstance(self.timestep, dict) else self.timestep - if len(self.timestep) != 0 and self.frame_confidence is not None: + if len(timestep) != 0 and self.frame_confidence is not None: if any(isinstance(i, list) for i in self.frame_confidence): # rnnt t_prev = -1 offset = 0 @@ -405,6 +405,7 @@ def __init__( float_dtype: Optional[torch.dtype] = None, store_alignments: bool = True, store_frame_confidence: bool = False, + with_duration_confidence: bool = False, ): """ @@ -422,6 +423,7 @@ def __init__( if batch_size <= 0: raise ValueError(f"batch_size must be > 0, got {batch_size}") self.with_frame_confidence = store_frame_confidence + self.with_duration_confidence = with_duration_confidence self.with_alignments = store_alignments self._max_length = init_length @@ -442,7 +444,11 @@ def __init__( self.frame_confidence = torch.zeros(0, device=device, dtype=float_dtype) if self.with_frame_confidence: # tensor to store frame confidence - self.frame_confidence = torch.zeros((batch_size, self._max_length), device=device, dtype=float_dtype) + self.frame_confidence = torch.zeros( + [batch_size, self._max_length, 2] if self.with_duration_confidence else [batch_size, self._max_length], + device=device, + dtype=float_dtype, + ) self._batch_indices = torch.arange(batch_size, device=device) def clear_(self): @@ -462,7 +468,7 @@ def _allocate_more(self): self.logits = torch.cat((self.logits, torch.zeros_like(self.logits)), dim=1) self.labels = torch.cat((self.labels, torch.zeros_like(self.labels)), dim=-1) if self.with_frame_confidence: - self.frame_confidence = torch.cat((self.frame_confidence, torch.zeros_like(self.frame_confidence)), dim=-1) + self.frame_confidence = torch.cat((self.frame_confidence, torch.zeros_like(self.frame_confidence)), dim=1) self._max_length *= 2 def add_results_( diff --git a/scripts/speech_recognition/confidence/benchmark_asr_confidence.py b/scripts/speech_recognition/confidence/benchmark_asr_confidence.py index 8a3c3f4e47c0..0c119b02ff7b 100644 --- a/scripts/speech_recognition/confidence/benchmark_asr_confidence.py +++ b/scripts/speech_recognition/confidence/benchmark_asr_confidence.py @@ -82,6 +82,7 @@ def get_experiment_params(cfg): String with the experiment name. """ blank = "no_blank" if cfg.exclude_blank else "blank" + duration = "duration" if cfg.tdt_include_duration else "no_duration" aggregation = cfg.aggregation method_name = cfg.method_cfg.name alpha = cfg.method_cfg.alpha @@ -91,15 +92,24 @@ def get_experiment_params(cfg): experiment_param_list = [ aggregation, str(cfg.exclude_blank), + str(cfg.tdt_include_duration), method_name, entropy_type, entropy_norm, str(alpha), ] - experiment_str = "-".join([aggregation, blank, method_name, entropy_type, entropy_norm, str(alpha)]) + experiment_str = "-".join([aggregation, blank, duration, method_name, entropy_type, entropy_norm, str(alpha)]) else: - experiment_param_list = [aggregation, str(cfg.exclude_blank), method_name, "-", "-", str(alpha)] - experiment_str = "-".join([aggregation, blank, method_name, str(alpha)]) + experiment_param_list = [ + aggregation, + str(cfg.exclude_blank), + str(cfg.tdt_include_duration), + method_name, + "-", + "-", + str(alpha), + ] + experiment_str = "-".join([aggregation, blank, duration, method_name, str(alpha)]) return experiment_param_list, experiment_str @@ -214,6 +224,7 @@ def main(cfg: ConfidenceBenchmarkingConfig): "model_type", "aggregation", "blank", + "duration", "method_name", "entropy_type", "entropy_norm", diff --git a/tests/collections/asr/test_asr_hybrid_rnnt_ctc_model_char.py b/tests/collections/asr/test_asr_hybrid_rnnt_ctc_model_char.py index 85156bf9e2c5..018c9bcc4aa2 100644 --- a/tests/collections/asr/test_asr_hybrid_rnnt_ctc_model_char.py +++ b/tests/collections/asr/test_asr_hybrid_rnnt_ctc_model_char.py @@ -265,7 +265,7 @@ def test_decoding_type_change(self, hybrid_asr_model): @pytest.mark.unit def test_GreedyRNNTInferConfig(self): - IGNORE_ARGS = ['decoder_model', 'joint_model', 'blank_index'] + IGNORE_ARGS = ['decoder_model', 'joint_model', 'blank_index', 'tdt_include_duration_confidence'] result = assert_dataclass_signature_match( greedy_decode.GreedyRNNTInfer, greedy_decode.GreedyRNNTInferConfig, ignore_args=IGNORE_ARGS @@ -279,7 +279,7 @@ def test_GreedyRNNTInferConfig(self): @pytest.mark.unit def test_GreedyBatchedRNNTInferConfig(self): - IGNORE_ARGS = ['decoder_model', 'joint_model', 'blank_index'] + IGNORE_ARGS = ['decoder_model', 'joint_model', 'blank_index', 'tdt_include_duration_confidence'] result = assert_dataclass_signature_match( greedy_decode.GreedyBatchedRNNTInfer, greedy_decode.GreedyBatchedRNNTInferConfig, ignore_args=IGNORE_ARGS diff --git a/tests/collections/asr/test_asr_rnnt_encdec_model.py b/tests/collections/asr/test_asr_rnnt_encdec_model.py index d5ab0054ff87..a6e3714f20f5 100644 --- a/tests/collections/asr/test_asr_rnnt_encdec_model.py +++ b/tests/collections/asr/test_asr_rnnt_encdec_model.py @@ -387,7 +387,7 @@ def test_decoding_change(self, asr_model): @pytest.mark.unit def test_GreedyRNNTInferConfig(self): - IGNORE_ARGS = ['decoder_model', 'joint_model', 'blank_index'] + IGNORE_ARGS = ['decoder_model', 'joint_model', 'blank_index', 'tdt_include_duration_confidence'] result = assert_dataclass_signature_match( greedy_decode.GreedyRNNTInfer, greedy_decode.GreedyRNNTInferConfig, ignore_args=IGNORE_ARGS @@ -401,7 +401,7 @@ def test_GreedyRNNTInferConfig(self): @pytest.mark.unit def test_GreedyBatchedRNNTInferConfig(self): - IGNORE_ARGS = ['decoder_model', 'joint_model', 'blank_index'] + IGNORE_ARGS = ['decoder_model', 'joint_model', 'blank_index', 'tdt_include_duration_confidence'] result = assert_dataclass_signature_match( greedy_decode.GreedyBatchedRNNTInfer, greedy_decode.GreedyBatchedRNNTInferConfig, ignore_args=IGNORE_ARGS diff --git a/tutorials/asr/ASR_Confidence_Estimation.ipynb b/tutorials/asr/ASR_Confidence_Estimation.ipynb index ffcec8e16f39..eb8cd7b11688 100644 --- a/tutorials/asr/ASR_Confidence_Estimation.ipynb +++ b/tutorials/asr/ASR_Confidence_Estimation.ipynb @@ -466,6 +466,7 @@ " preserve_word_confidence=True,\n", " aggregation=\"prod\", # How to aggregate frame scores to token scores and token scores to word scores\n", " exclude_blank=False, # If true, only non-blank emissions contribute to confidence scores\n", + " tdt_include_duration=False, # If true, calculate duration confidence for the TDT models\n", " method_cfg=ConfidenceMethodConfig( # Config for per-frame scores calculation (before aggregation)\n", " name=\"max_prob\", # Or \"entropy\" (default), which usually works better\n", " entropy_type=\"gibbs\", # Used only for name == \"entropy\". Recommended: \"tsallis\" (default) or \"renyi\"\n", @@ -506,7 +507,7 @@ "outputs": [], "source": [ "current_test_set = test_sets[\"test_other\"]\n", - "transcriptions = model.transcribe(paths2audio_files=current_test_set.filepaths, batch_size=16, return_hypotheses=True, num_workers=4)\n", + "transcriptions = model.transcribe(audio=current_test_set.filepaths, batch_size=16, return_hypotheses=True, num_workers=4)\n", "if is_rnnt:\n", " transcriptions = transcriptions[0]" ] @@ -530,12 +531,25 @@ }, "outputs": [], "source": [ + "def round_confidence(confidence_number, ndigits=3):\n", + " if isinstance(confidence_number, float):\n", + " return round(confidence_number, ndigits)\n", + " elif len(confidence_number.size()) == 0: # torch.tensor with one element\n", + " return round(confidence_number.item(), ndigits)\n", + " elif len(confidence_number.size()) == 1: # torch.tensor with a list if elements\n", + " return [round(c.item(), ndigits) for c in confidence_number]\n", + " else:\n", + " raise RuntimeError(f\"Unexpected confidence_number: `{confidence_number}`\")\n", + "\n", + "\n", "tran = transcriptions[0]\n", "print(\n", " f\"\"\" Recognized text: `{tran.text}`\\n\n", - " Word confidence: {[round(c, 3) for c in tran.word_confidence]}\\n\n", - " Token confidence: {[round(c, 3) for c in tran.token_confidence]}\\n\n", - " Frame confidence: {[([round(cc, 3) for cc in c] if is_rnnt else round(c, 3)) for c in tran.frame_confidence]}\"\"\"\n", + " Word confidence: {[round_confidence(c) for c in tran.word_confidence]}\\n\n", + " Token confidence: {[round_confidence(c) for c in tran.token_confidence]}\\n\n", + " Frame confidence: {\n", + " [([round_confidence(cc) for cc in c] if is_rnnt else round_confidence(c)) for c in tran.frame_confidence]\n", + " }\"\"\"\n", ")" ] }, @@ -726,7 +740,7 @@ " else CTCDecodingConfig(confidence_cfg=confidence_cfg)\n", ")\n", "\n", - "transcriptions = model.transcribe(paths2audio_files=current_test_set.filepaths, batch_size=16, return_hypotheses=True, num_workers=4)\n", + "transcriptions = model.transcribe(audio=current_test_set.filepaths, batch_size=16, return_hypotheses=True, num_workers=4)\n", "if is_rnnt:\n", " transcriptions = transcriptions[0]" ] @@ -1067,7 +1081,7 @@ " else CTCDecodingConfig(confidence_cfg=confidence_cfg)\n", ")\n", "\n", - "transcriptions = model.transcribe(paths2audio_files=current_test_set.filepaths, batch_size=16, return_hypotheses=True, num_workers=4)\n", + "transcriptions = model.transcribe(audio=current_test_set.filepaths, batch_size=16, return_hypotheses=True, num_workers=4)\n", "if is_rnnt:\n", " transcriptions = transcriptions[0]\n", "\n", @@ -1238,7 +1252,7 @@ ")\n", "\n", "noise_transcriptions = model.transcribe(\n", - " paths2audio_files=noise_data.filepaths, batch_size=4, return_hypotheses=True, num_workers=4\n", + " audio=noise_data.filepaths, batch_size=4, return_hypotheses=True, num_workers=4\n", ")\n", "if is_rnnt:\n", " noise_transcriptions = noise_transcriptions[0]" @@ -1424,7 +1438,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.8.10" + "version": "3.10.12" } }, "nbformat": 4,