From 85d8756a8121bb4bb021c4d2b89bc5f290e7c571 Mon Sep 17 00:00:00 2001 From: yaoyu-33 <54727607+yaoyu-33@users.noreply.github.com> Date: Wed, 31 Jan 2024 12:04:14 -0800 Subject: [PATCH 1/4] Add Bert HF checkpoint converter (#8088) * Add Bert HF checkpoint converter Signed-off-by: yaoyu-33 * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Reformat Signed-off-by: yaoyu-33 * Add BERT ONNX export * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Add NeMo BERT to HF BERT script * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Clean code Signed-off-by: yaoyu-33 * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Update argument names Signed-off-by: yaoyu-33 * Update build_transformer_config in Bert Signed-off-by: yaoyu-33 --------- Signed-off-by: yaoyu-33 Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Bobby Chen --- .../conf/megatron_bert_config.yaml | 8 +- .../language_modeling/megatron/bert_model.py | 10 + .../language_modeling/megatron_bert_model.py | 63 ++++ .../modules/common/megatron/transformer.py | 27 +- .../convert_bert_hf_to_nemo.py | 289 ++++++++++++++++++ .../convert_bert_nemo_to_hf.py | 269 ++++++++++++++++ .../export_nemo_bert_to_onnx.py | 83 +++++ 7 files changed, 745 insertions(+), 4 deletions(-) create mode 100644 scripts/nlp_language_modeling/convert_bert_hf_to_nemo.py create mode 100644 scripts/nlp_language_modeling/convert_bert_nemo_to_hf.py create mode 100644 scripts/nlp_language_modeling/export_nemo_bert_to_onnx.py diff --git a/examples/nlp/language_modeling/conf/megatron_bert_config.yaml b/examples/nlp/language_modeling/conf/megatron_bert_config.yaml index d388fe35b963..b3e3912fffd4 100644 --- a/examples/nlp/language_modeling/conf/megatron_bert_config.yaml +++ b/examples/nlp/language_modeling/conf/megatron_bert_config.yaml @@ -2,7 +2,7 @@ name: megatron_bert restore_from_path: null # used when starting from a .nemo file trainer: - devices: 2 + devices: 1 num_nodes: 1 accelerator: gpu precision: 16 @@ -56,15 +56,19 @@ model: hidden_size: 768 ffn_hidden_size: 3072 # Transformer FFN hidden size. Usually 4 * hidden_size. num_attention_heads: 12 + skip_head: False + transformer_block_type: post_ln init_method_std: 0.02 # Standard deviation of the zero mean normal distribution used for weight initialization.') hidden_dropout: 0.1 # Dropout probability for hidden state transformer. kv_channels: null # Projection weights dimension in multi-head attention. Set to hidden_size // num_attention_heads if null apply_query_key_layer_scaling: True # scale Q * K^T by 1 / layer-number. + normalization: layernorm layernorm_epsilon: 1e-5 make_vocab_size_divisible_by: 128 # Pad the vocab size to be divisible by this value for computation efficiency. pre_process: True # add embedding post_process: True # add pooler bert_binary_head: True # BERT binary head + megatron_legacy: False tokenizer: library: 'megatron' @@ -128,7 +132,7 @@ model: # - /raid/data/pile/my-gpt3_00_text_document # - .5 # - /raid/data/pile/my-gpt3_01_text_document - data_prefix: ??? + data_prefix: [1.0, /path/to/data] index_mapping_dir: null # path to save index mapping .npy files, by default will save in the same location as data_prefix data_impl: mmap splits_string: 900,50,50 diff --git a/nemo/collections/nlp/models/language_modeling/megatron/bert_model.py b/nemo/collections/nlp/models/language_modeling/megatron/bert_model.py index 22cfd7fb8efa..7e928a4e893b 100644 --- a/nemo/collections/nlp/models/language_modeling/megatron/bert_model.py +++ b/nemo/collections/nlp/models/language_modeling/megatron/bert_model.py @@ -65,6 +65,9 @@ def bert_extended_attention_mask(attention_mask): # [b, 1, s, s] extended_attention_mask = attention_mask_bss.unsqueeze(1) + # HF Masking is equivalent to the one below + # extended_attention_mask = (attention_mask.unsqueeze(1) * torch.ones_like(attention_mask).unsqueeze(2)).unsqueeze(1) + # Convert attention mask to binary: extended_attention_mask = extended_attention_mask < 0.5 @@ -182,12 +185,15 @@ def __init__( activations_checkpoint_num_layers=1, activations_checkpoint_layers_per_pipeline=None, layernorm_epsilon=1e-5, + normalization='layernorm', + transformer_block_type='pre_ln', masked_softmax_fusion=False, bias_gelu_fusion=True, bias_dropout_add_fusion=True, openai_gelu=False, onnx_safe=False, add_binary_head=True, + skip_head=False, megatron_legacy=False, sequence_parallel=False, position_embedding_type='learned_absolute', @@ -229,6 +235,8 @@ def __init__( activations_checkpoint_num_layers=activations_checkpoint_num_layers, activations_checkpoint_layers_per_pipeline=activations_checkpoint_layers_per_pipeline, layernorm_epsilon=layernorm_epsilon, + normalization=normalization, + transformer_block_type=transformer_block_type, masked_softmax_fusion=masked_softmax_fusion, bias_activation_fusion=bias_gelu_fusion, bias_dropout_add_fusion=bias_dropout_add_fusion, @@ -242,6 +250,8 @@ def __init__( init_method=init_method_normal(init_method_std), vocab_size=vocab_size, hidden_size=hidden_size ) + if skip_head: + self.post_process = False if self.post_process: self.lm_head = BertLMHead( config, diff --git a/nemo/collections/nlp/models/language_modeling/megatron_bert_model.py b/nemo/collections/nlp/models/language_modeling/megatron_bert_model.py index e4ae0f87d353..bef13367eb10 100644 --- a/nemo/collections/nlp/models/language_modeling/megatron_bert_model.py +++ b/nemo/collections/nlp/models/language_modeling/megatron_bert_model.py @@ -184,10 +184,13 @@ def model_provider_func(self, pre_process, post_process): ), layernorm_epsilon=cfg.get('layernorm_epsilon', 1e-5), masked_softmax_fusion=cfg.get('masked_softmax_fusion', True), + normalization=cfg.get('normalization', 'layernorm'), + transformer_block_type=cfg.get('transformer_block_type', 'pre_ln'), bias_gelu_fusion=cfg.get('bias_gelu_fusion', True), bias_dropout_add_fusion=cfg.get("bias_dropout_add_fusion", True), onnx_safe=cfg.get('onnx_safe', False), add_binary_head=cfg.bert_binary_head, + skip_head=cfg.get('skip_head', False), megatron_legacy=cfg.get('megatron_legacy', False), position_embedding_type=self.cfg.get("position_embedding_type", "learned_absolute"), ) @@ -1034,5 +1037,65 @@ def build_transformer_config(self) -> TransformerConfig: """ activation = self.cfg.get('activation', 'gelu') assert activation == 'gelu', "Only gelu activation is support for BERT at the moment." + + normalization = self.cfg.get('normalization', 'layernorm') + + layernorm_zero_centered_gamma = self.cfg.get('normalization', 'layernorm') == 'layernorm1p' + if normalization == 'layernorm': + normalization = 'LayerNorm' + elif normalization == 'rmsnorm': + normalization = 'RMSNorm' + elif normalization == 'layernorm1p': + normalization = 'LayerNorm' + layernorm_zero_centered_gamma = True + else: + logging.warning( + f"The normalization type: {normalization} might not be supported in megatron core." + f"Supported types are LayerNorm and RMSNorm." + ) + + # any configs that are not in the nemo model config will be added here + model_specific_configs = { + 'layernorm_zero_centered_gamma': layernorm_zero_centered_gamma, + 'normalization': normalization, + } + transformer_config = super().build_transformer_config() + + for key, value in model_specific_configs.items(): + setattr(transformer_config, key, value) + + # pass mcore customization configs directly to mcore + mcore_customization_config_dict = self.cfg.get('mcore_customization_config', {}) + for key, value in mcore_customization_config_dict.items(): + setattr(transformer_config, key, value) + return transformer_config + + +class MegatronBertTextEmbeddingModel(MegatronBertModel): + """ + Megatron Bert Text Embedding. + Model returns [batch, hidden] shape + """ + + def average_pool(self, last_hidden_states, attention_mask): + last_hidden = last_hidden_states.masked_fill(~attention_mask[..., None].bool(), 0.0) + return last_hidden.sum(dim=1) / attention_mask.sum(dim=1)[..., None] + + def forward( + self, + input_ids, + attention_mask, + token_type_ids, + lm_labels=None, + checkpoint_activations_all_layers=None, + model=None, + ): + outputs = super().forward( + input_ids, attention_mask, token_type_ids, lm_labels, checkpoint_activations_all_layers, model + ) + embeddings = self.average_pool(outputs[0], attention_mask) + embeddings = F.normalize(embeddings, p=2, dim=1) + + return embeddings diff --git a/nemo/collections/nlp/modules/common/megatron/transformer.py b/nemo/collections/nlp/modules/common/megatron/transformer.py index ca8c0ecafefd..9e9c7b526782 100644 --- a/nemo/collections/nlp/modules/common/megatron/transformer.py +++ b/nemo/collections/nlp/modules/common/megatron/transformer.py @@ -625,7 +625,6 @@ def forward( ) output = bias_dropout_add_func(mlp_output, mlp_bias, residual, self.hidden_dropout) - # print(f"Layer: {self.layer_number} MLP + Dropout + Residual checksum {output.sum()}") if self.transformer_block_type == 'post_ln': output = self.post_attention_layernorm(output) @@ -1158,6 +1157,27 @@ def build_layer(layer_number): offset = parallel_state.get_pipeline_model_parallel_rank() * self.num_layers self.layers = torch.nn.ModuleList([build_layer(i + 1 + offset) for i in range(self.num_layers)]) + if self.pre_process and self.transformer_block_type == 'post_ln': + # Final layer norm before output. + if normalization == 'layernorm': + self.initial_layernorm = get_layer_norm( + hidden_size, layernorm_epsilon, persist_layer_norm, sequence_parallel=config.sequence_parallel + ) + + elif normalization == 'layernorm1p': + self.initial_layernorm = LayerNorm1P( + hidden_size, layernorm_epsilon, sequence_parallel_enabled=config.sequence_parallel + ) + elif normalization == 'low_precision_layernorm': + self.initial_layernorm = LPLayerNorm(hidden_size, layernorm_epsilon) + else: + self.initial_layernorm = MixedFusedRMSNorm(hidden_size, layernorm_epsilon) + # for architectures such as MPT, there is no bias term even on the layernorms + # this code allows us to remove the bias terms from the layernorm module + # so that we can support MPT. However, certain apex-based LNs don't support + # removing bias, so we also have to check for that + if not bias and normalization not in ['layernorm', 'layernorm1p']: + remove_bias_from_layernorm(self.initial_layernorm) if self.post_process and self.transformer_block_type != 'post_ln': # Final layer norm before output. @@ -1435,7 +1455,10 @@ def forward( 'get_key_value does not work with ' 'activation checkpointing' ) - if not self.pre_process: + if self.pre_process: + if self.transformer_block_type == 'post_ln': + hidden_states = self.initial_layernorm(hidden_states) + else: # See set_input_tensor() hidden_states = self.input_tensor diff --git a/scripts/nlp_language_modeling/convert_bert_hf_to_nemo.py b/scripts/nlp_language_modeling/convert_bert_hf_to_nemo.py new file mode 100644 index 000000000000..cc9483b68c8a --- /dev/null +++ b/scripts/nlp_language_modeling/convert_bert_hf_to_nemo.py @@ -0,0 +1,289 @@ +# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Example to run this conversion script: +``` + python convert_bert_hf_to_nemo.py \ + --input_name_or_path "thenlper/gte-large" \ + --output_path /path/to/output/nemo/file.nemo \ + --precision 32 +``` +""" + +import os +from argparse import ArgumentParser + +import torch +import torch.nn.functional as F +from omegaconf import OmegaConf +from transformers import AutoModel, AutoTokenizer + +from nemo.collections.nlp.models.language_modeling.megatron_bert_model import MegatronBertModel +from nemo.collections.nlp.parts.megatron_trainer_builder import MegatronTrainerBuilder +from nemo.utils import logging + + +def average_pool(last_hidden_states: torch.Tensor, attention_mask: torch.Tensor) -> torch.Tensor: + last_hidden = last_hidden_states.masked_fill(~attention_mask[..., None].bool(), 0.0) + return last_hidden.sum(dim=1) / attention_mask.sum(dim=1)[..., None] + + +def create_rename_keys(num_hidden_layers): + rename_keys = [] + for i in range(num_hidden_layers): + # encoder layers: attention mechanism, 2 feedforward neural networks, and 2 layernorms + rename_keys.extend( + [ + ( + f"encoder.layer.{i}.attention.self.query.weight", + f"model.language_model.encoder.layers.{i}.self_attention.query.weight", + ), + ( + f"encoder.layer.{i}.attention.self.query.bias", + f"model.language_model.encoder.layers.{i}.self_attention.query.bias", + ), + ( + f"encoder.layer.{i}.attention.self.key.weight", + f"model.language_model.encoder.layers.{i}.self_attention.key.weight", + ), + ( + f"encoder.layer.{i}.attention.self.key.bias", + f"model.language_model.encoder.layers.{i}.self_attention.key.bias", + ), + ( + f"encoder.layer.{i}.attention.self.value.weight", + f"model.language_model.encoder.layers.{i}.self_attention.value.weight", + ), + ( + f"encoder.layer.{i}.attention.self.value.bias", + f"model.language_model.encoder.layers.{i}.self_attention.value.bias", + ), + ( + f"encoder.layer.{i}.attention.output.dense.weight", + f"model.language_model.encoder.layers.{i}.self_attention.dense.weight", + ), + ( + f"encoder.layer.{i}.attention.output.dense.bias", + f"model.language_model.encoder.layers.{i}.self_attention.dense.bias", + ), + ( + f"encoder.layer.{i}.attention.output.LayerNorm.weight", + f"model.language_model.encoder.layers.{i}.input_layernorm.weight", + ), + ( + f"encoder.layer.{i}.attention.output.LayerNorm.bias", + f"model.language_model.encoder.layers.{i}.input_layernorm.bias", + ), + ( + f"encoder.layer.{i}.intermediate.dense.weight", + f"model.language_model.encoder.layers.{i}.mlp.dense_h_to_4h.weight", + ), + ( + f"encoder.layer.{i}.intermediate.dense.bias", + f"model.language_model.encoder.layers.{i}.mlp.dense_h_to_4h.bias", + ), + ( + f"encoder.layer.{i}.output.dense.weight", + f"model.language_model.encoder.layers.{i}.mlp.dense_4h_to_h.weight", + ), + ( + f"encoder.layer.{i}.output.dense.bias", + f"model.language_model.encoder.layers.{i}.mlp.dense_4h_to_h.bias", + ), + ( + f"encoder.layer.{i}.output.LayerNorm.weight", + f"model.language_model.encoder.layers.{i}.post_attention_layernorm.weight", + ), + ( + f"encoder.layer.{i}.output.LayerNorm.bias", + f"model.language_model.encoder.layers.{i}.post_attention_layernorm.bias", + ), + ] + ) + + # Non-layer dependent keys + rename_keys.extend( + [ + ("embeddings.word_embeddings.weight", "model.language_model.embedding.word_embeddings.weight"), + ("embeddings.position_embeddings.weight", "model.language_model.embedding.position_embeddings.weight"), + ("embeddings.token_type_embeddings.weight", "model.language_model.embedding.tokentype_embeddings.weight"), + ("embeddings.LayerNorm.weight", "model.language_model.encoder.initial_layernorm.weight"), + ("embeddings.LayerNorm.bias", "model.language_model.encoder.initial_layernorm.bias"), + ("pooler.dense.weight", "model.language_model.pooler.dense.weight"), + ("pooler.dense.bias", "model.language_model.pooler.dense.bias"), + ] + ) + + return rename_keys + + +def rename_model_keys(model_state_dict, rename_keys): + """ + Rename keys in the model's state dictionary based on the provided mappings. + + Parameters: + model_state_dict (dict): The state dictionary of the model. + rename_keys (list): A list of tuples with the mapping (old_key, new_key). + + Returns: + dict: A new state dictionary with updated key names. + """ + + # Create a new state dictionary with updated key names + new_state_dict = {} + + # Track keys from the original state dict to ensure all are processed + remaining_keys = set(model_state_dict.keys()) + + # Iterate over the rename mappings + for old_key, new_key in rename_keys: + if old_key in model_state_dict: + # Rename the key and remove it from the tracking set + new_state_dict[new_key] = model_state_dict[old_key] + remaining_keys.remove(old_key) + else: + print(f"Warning: Key '{old_key}' not found in the model state dictionary.") + + # Check if any keys were not converted from old to new + for old_key in remaining_keys: + print(f"Warning: Key '{old_key}' was not converted.") + + return new_state_dict + + +def adjust_tensor_shapes(model, nemo_state_dict): + """ + Adapt tensor shapes in the state dictionary to ensure compatibility with a different model structure. + + Parameters: + nemo_state_dict (dict): The state dictionary of the model. + + Returns: + dict: The updated state dictionary with modified tensor shapes for compatibility. + """ + + # Note: For 'key' and 'value' weights and biases, NeMo uses a consolidated tensor 'query_key_value'. + for key_ in list(nemo_state_dict.keys()): + if "self_attention.query" in key_: + key_q = key_ + key_k = key_.replace('self_attention.query', 'self_attention.key') + key_v = key_.replace('self_attention.query', 'self_attention.value') + key_new = key_.replace('self_attention.query', 'self_attention.query_key_value') + value_new = torch.concat((nemo_state_dict[key_q], nemo_state_dict[key_k], nemo_state_dict[key_v]), dim=0) + nemo_state_dict[key_new] = value_new + del nemo_state_dict[key_q], nemo_state_dict[key_k], nemo_state_dict[key_v] + + # Padding to new vocab size + original_embedding = nemo_state_dict['model.language_model.embedding.word_embeddings.weight'] + vocab_size = original_embedding.size(0) + if model.padded_vocab_size > vocab_size: + zeros_to_add = torch.zeros( + model.padded_vocab_size - vocab_size, + original_embedding.size(1), + dtype=original_embedding.dtype, + device=original_embedding.device, + ) + # Concatenate the two tensors along rows + padded_embedding = torch.cat([original_embedding, zeros_to_add], dim=0) + nemo_state_dict['model.language_model.embedding.word_embeddings.weight'] = padded_embedding + + return nemo_state_dict + + +def adjust_nemo_config(model_config, ref_config): + model_config.tokenizer["type"] = "intfloat/e5-large-unsupervised" # ref_config["_input_name_or_path"] + model_config["num_layers"] = ref_config["num_hidden_layers"] + model_config["hidden_size"] = ref_config["hidden_size"] + model_config["ffn_hidden_size"] = ref_config["intermediate_size"] + model_config["num_attention_heads"] = ref_config["num_attention_heads"] + model_config["layernorm_epsilon"] = ref_config["layer_norm_eps"] + model_config["normalization"] = "layernorm" + model_config["transformer_block_type"] = "post_ln" + model_config["apply_query_key_layer_scaling"] = False + model_config["skip_head"] = True + model_config["megatron_legacy"] = True + return model_config + + +def get_args(): + parser = ArgumentParser() + parser.add_argument("--input_name_or_path", type=str, default="thenlper/gte-large") + parser.add_argument( + "--hparams_file", + type=str, + default=f"{os.path.dirname(__file__)}/../../examples/nlp/language_modeling/conf/megatron_bert_config.yaml", + required=False, + help="Path config for restoring. It's created during training and may need to be modified during restore if restore environment is different than training. Ex: /raid/nemo_experiments/megatron_gpt/hparams.yaml", + ) + parser.add_argument("--output_path", type=str, default=None, required=True, help="Path to output .nemo file.") + parser.add_argument( + "--precision", type=str, default="32", choices=["bf16", "32"], help="Precision for checkpoint weights saved" + ) + + args = parser.parse_args() + return args + + +def convert(args): + logging.info(f"Loading checkpoint from HF: `{args.input_name_or_path}`") + hf_tokenizer = AutoTokenizer.from_pretrained(args.input_name_or_path) + hf_model = AutoModel.from_pretrained(args.input_name_or_path) + + nemo_config = OmegaConf.load(args.hparams_file) + nemo_config.model = adjust_nemo_config(nemo_config.model, hf_model.config.to_dict()) + + nemo_config.trainer["precision"] = args.precision + trainer = MegatronTrainerBuilder(nemo_config).create_trainer() + model = MegatronBertModel(nemo_config.model, trainer) + + old_state_dict = hf_model.state_dict() + rename_keys = create_rename_keys(nemo_config.model.num_layers) + new_state_dict = rename_model_keys(model_state_dict=old_state_dict, rename_keys=rename_keys) + nemo_state_dict = adjust_tensor_shapes(model, new_state_dict) + model.load_state_dict(nemo_state_dict, strict=True) + + logging.info(f'=' * 50) + # Verifications + input_texts = [ + 'query: how much protein should a female eat', + 'query: summit define', + "passage: As a general guideline, the CDC's average requirement of protein for women ages 19 to 70 is 46 grams per day. But, as you can see from this chart, you'll need to increase that if you're expecting or training for a marathon. Check out the chart below to see how much protein you should be eating each day.", + "passage: Definition of summit for English Language Learners. : 1 the highest point of a mountain : the top of a mountain. : 2 the highest level. : 3 a meeting or series of meetings between the leaders of two or more governments.", + ] + + # Tokenize the input texts + batch_dict = hf_tokenizer(input_texts, max_length=512, padding=True, truncation=True, return_tensors='pt') + batch_dict_cuda = {k: v.cuda() for k, v in batch_dict.items()} + hf_model = hf_model.cuda().eval() + model = model.eval() + with torch.no_grad(): + hf_outputs = hf_model(**batch_dict_cuda) + embeddings_hf = average_pool(hf_outputs.last_hidden_state, batch_dict_cuda['attention_mask']) + embeddings_hf = F.normalize(embeddings_hf, p=2, dim=1) + + outputs = model(**batch_dict_cuda) + embeddings = average_pool(outputs[0], batch_dict_cuda['attention_mask']) + embeddings = F.normalize(embeddings, p=2, dim=1) + # Print difference between two embeddings + print("Difference between reference embedding and converted embedding results:") + print(embeddings - embeddings_hf) + + model.save_to(args.output_path) + logging.info(f'NeMo model saved to: {args.output_path}') + + +if __name__ == '__main__': + args = get_args() + convert(args) diff --git a/scripts/nlp_language_modeling/convert_bert_nemo_to_hf.py b/scripts/nlp_language_modeling/convert_bert_nemo_to_hf.py new file mode 100644 index 000000000000..e970ea29fca2 --- /dev/null +++ b/scripts/nlp_language_modeling/convert_bert_nemo_to_hf.py @@ -0,0 +1,269 @@ +# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Example to run this conversion script: +``` + python convert_bert_hf_to_nemo.py \ + --input_name_or_path /path/to/input/nemo/file.nemo \ + --output_path /path/to/output/huggingface/file \ + --precision 32 +``` +""" + +from argparse import ArgumentParser + +import torch +import torch.nn.functional as F +from pytorch_lightning import Trainer +from transformers import AutoTokenizer, BertConfig, BertModel + +from nemo.collections.nlp.models.language_modeling.megatron_bert_model import MegatronBertModel +from nemo.collections.nlp.parts.nlp_overrides import NLPDDPStrategy +from nemo.utils import logging + + +def average_pool(last_hidden_states: torch.Tensor, attention_mask: torch.Tensor) -> torch.Tensor: + last_hidden = last_hidden_states.masked_fill(~attention_mask[..., None].bool(), 0.0) + return last_hidden.sum(dim=1) / attention_mask.sum(dim=1)[..., None] + + +def create_rename_keys(num_hidden_layers): + rename_keys = [] + for i in range(num_hidden_layers): + # encoder layers: attention mechanism, 2 feedforward neural networks, and 2 layernorms + rename_keys.extend( + [ + ( + f"encoder.layer.{i}.attention.self.query.weight", + f"model.language_model.encoder.layers.{i}.self_attention.query.weight", + ), + ( + f"encoder.layer.{i}.attention.self.query.bias", + f"model.language_model.encoder.layers.{i}.self_attention.query.bias", + ), + ( + f"encoder.layer.{i}.attention.self.key.weight", + f"model.language_model.encoder.layers.{i}.self_attention.key.weight", + ), + ( + f"encoder.layer.{i}.attention.self.key.bias", + f"model.language_model.encoder.layers.{i}.self_attention.key.bias", + ), + ( + f"encoder.layer.{i}.attention.self.value.weight", + f"model.language_model.encoder.layers.{i}.self_attention.value.weight", + ), + ( + f"encoder.layer.{i}.attention.self.value.bias", + f"model.language_model.encoder.layers.{i}.self_attention.value.bias", + ), + ( + f"encoder.layer.{i}.attention.output.dense.weight", + f"model.language_model.encoder.layers.{i}.self_attention.dense.weight", + ), + ( + f"encoder.layer.{i}.attention.output.dense.bias", + f"model.language_model.encoder.layers.{i}.self_attention.dense.bias", + ), + ( + f"encoder.layer.{i}.attention.output.LayerNorm.weight", + f"model.language_model.encoder.layers.{i}.input_layernorm.weight", + ), + ( + f"encoder.layer.{i}.attention.output.LayerNorm.bias", + f"model.language_model.encoder.layers.{i}.input_layernorm.bias", + ), + ( + f"encoder.layer.{i}.intermediate.dense.weight", + f"model.language_model.encoder.layers.{i}.mlp.dense_h_to_4h.weight", + ), + ( + f"encoder.layer.{i}.intermediate.dense.bias", + f"model.language_model.encoder.layers.{i}.mlp.dense_h_to_4h.bias", + ), + ( + f"encoder.layer.{i}.output.dense.weight", + f"model.language_model.encoder.layers.{i}.mlp.dense_4h_to_h.weight", + ), + ( + f"encoder.layer.{i}.output.dense.bias", + f"model.language_model.encoder.layers.{i}.mlp.dense_4h_to_h.bias", + ), + ( + f"encoder.layer.{i}.output.LayerNorm.weight", + f"model.language_model.encoder.layers.{i}.post_attention_layernorm.weight", + ), + ( + f"encoder.layer.{i}.output.LayerNorm.bias", + f"model.language_model.encoder.layers.{i}.post_attention_layernorm.bias", + ), + ] + ) + + # Non-layer dependent keys + rename_keys.extend( + [ + ("embeddings.word_embeddings.weight", "model.language_model.embedding.word_embeddings.weight"), + ("embeddings.position_embeddings.weight", "model.language_model.embedding.position_embeddings.weight"), + ("embeddings.token_type_embeddings.weight", "model.language_model.embedding.tokentype_embeddings.weight"), + ("embeddings.LayerNorm.weight", "model.language_model.encoder.initial_layernorm.weight"), + ("embeddings.LayerNorm.bias", "model.language_model.encoder.initial_layernorm.bias"), + ("pooler.dense.weight", "model.language_model.pooler.dense.weight"), + ("pooler.dense.bias", "model.language_model.pooler.dense.bias"), + ] + ) + + return rename_keys + + +def rename_model_keys(model_state_dict, rename_keys): + """ + Rename keys in the model's state dictionary based on the provided mappings. + + Parameters: + model_state_dict (dict): The state dictionary of the model. + rename_keys (list): A list of tuples with the mapping (new_key, old_key). + + Returns: + dict: A new state dictionary with updated key names. + """ + + # Create a new state dictionary with updated key names + new_state_dict = {} + + # Track keys from the original state dict to ensure all are processed + remaining_keys = set(model_state_dict.keys()) + + # Iterate over the rename mappings + for new_key, old_key in rename_keys: + if old_key in model_state_dict: + # Rename the key and remove it from the tracking set + new_state_dict[new_key] = model_state_dict[old_key] + remaining_keys.remove(old_key) + else: + print(f"Warning: Key '{old_key}' not found in the model state dictionary.") + + # Check if any keys were not converted from old to new + for old_key in remaining_keys: + print(f"Warning: Key '{old_key}' was not converted.") + + return new_state_dict + + +def adjust_tensor_shapes(model_state_dict): + """ + Adapt tensor shapes in the state dictionary to ensure compatibility with a different model structure. + + Parameters: + nemo_state_dict (dict): The state dictionary of the model. + + Returns: + dict: The updated state dictionary with modified tensor shapes for compatibility. + """ + + # Note: For 'key' and 'value' weights and biases, NeMo uses a consolidated tensor 'query_key_value'. + for key_ in list(model_state_dict.keys()): + if "self_attention.query_key_value" in key_: + key_q = key_.replace('self_attention.query_key_value', 'self_attention.query') + key_k = key_.replace('self_attention.query_key_value', 'self_attention.key') + key_v = key_.replace('self_attention.query_key_value', 'self_attention.value') + local_dim = model_state_dict[key_].shape[0] // 3 + q, k, v = model_state_dict[key_].split(local_dim) + model_state_dict[key_q] = q + model_state_dict[key_k] = k + model_state_dict[key_v] = v + del model_state_dict[key_] + + return model_state_dict + + +def convert_config(ref_config, hf_state_dict): + vocab_size = hf_state_dict['embeddings.word_embeddings.weight'].shape[0] + new_config = { + "vocab_size": vocab_size, + "num_hidden_layers": ref_config["num_layers"], + "hidden_size": ref_config["hidden_size"], + "intermediate_size": ref_config["ffn_hidden_size"], + "num_attention_heads": ref_config["num_attention_heads"], + "layer_norm_eps": ref_config["layernorm_epsilon"], + "max_position_embeddings": ref_config["max_position_embeddings"], + } + hf_config = BertConfig(**new_config) + return hf_config + + +def get_args(): + parser = ArgumentParser() + parser.add_argument( + "--input_name_or_path", type=str, required=True, help="Path to .nemo file", + ) + parser.add_argument( + "--output_path", type=str, required=True, help="Output HF model path", + ) + + args = parser.parse_args() + return args + + +def convert(args): + logging.info(f"Loading checkpoint from: `{args.input_name_or_path}`") + dummy_trainer = Trainer(devices=1, accelerator='cpu', strategy=NLPDDPStrategy()) + nemo_model = MegatronBertModel.restore_from(args.input_name_or_path, trainer=dummy_trainer) + nemo_config = nemo_model.cfg + + old_state_dict = nemo_model.state_dict() + rename_keys = create_rename_keys(nemo_config.num_layers) + new_state_dict = adjust_tensor_shapes(old_state_dict) + hf_state_dict = rename_model_keys(model_state_dict=new_state_dict, rename_keys=rename_keys) + + hf_config = convert_config(nemo_config, hf_state_dict) + hf_model = BertModel(hf_config) + + hf_model.load_state_dict(hf_state_dict, strict=True) + + logging.info(f'=' * 50) + # Verifications + input_texts = [ + 'query: how much protein should a female eat', + 'query: summit define', + "passage: As a general guideline, the CDC's average requirement of protein for women ages 19 to 70 is 46 grams per day. But, as you can see from this chart, you'll need to increase that if you're expecting or training for a marathon. Check out the chart below to see how much protein you should be eating each day.", + "passage: Definition of summit for English Language Learners. : 1 the highest point of a mountain : the top of a mountain. : 2 the highest level. : 3 a meeting or series of meetings between the leaders of two or more governments.", + ] + + # Tokenize the input texts + hf_tokenizer = AutoTokenizer.from_pretrained(nemo_config.tokenizer["type"]) + batch_dict = hf_tokenizer(input_texts, max_length=512, padding=True, truncation=True, return_tensors='pt') + batch_dict_cuda = {k: v.cuda() for k, v in batch_dict.items()} + hf_model = hf_model.cuda().eval() + nemo_model = nemo_model.eval() + with torch.no_grad(): + hf_outputs = hf_model(**batch_dict_cuda) + embeddings_hf = average_pool(hf_outputs.last_hidden_state, batch_dict_cuda['attention_mask']) + embeddings_hf = F.normalize(embeddings_hf, p=2, dim=1) + + outputs = nemo_model(**batch_dict_cuda) + embeddings = average_pool(outputs[0], batch_dict_cuda['attention_mask']) + embeddings = F.normalize(embeddings, p=2, dim=1) + # Print difference between two embeddings + print("Difference between reference embedding and converted embedding results:") + print(embeddings - embeddings_hf) + + hf_model.save_pretrained(args.output_path) + logging.info(f'Full HF model model saved to: {args.output_path}') + + +if __name__ == '__main__': + args = get_args() + convert(args) diff --git a/scripts/nlp_language_modeling/export_nemo_bert_to_onnx.py b/scripts/nlp_language_modeling/export_nemo_bert_to_onnx.py new file mode 100644 index 000000000000..c6b3f351cc07 --- /dev/null +++ b/scripts/nlp_language_modeling/export_nemo_bert_to_onnx.py @@ -0,0 +1,83 @@ +# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import os +from argparse import ArgumentParser + +import torch +from omegaconf import OmegaConf + +from nemo.collections.nlp.models.language_modeling.megatron_bert_model import MegatronBertTextEmbeddingModel +from nemo.collections.nlp.parts.megatron_trainer_builder import MegatronTrainerBuilder +from nemo.utils import logging + + +def get_args(): + parser = ArgumentParser() + parser.add_argument("--nemo_path", type=str, required=True) + parser.add_argument( + "--hparams_file", + type=str, + default=f"{os.path.dirname(__file__)}/../../examples/nlp/language_modeling/conf/megatron_bert_config.yaml", + required=False, + help="Path config for restoring. It's created during training and may need to be modified during restore if restore environment is different than training. Ex: /raid/nemo_experiments/megatron_gpt/hparams.yaml", + ) + parser.add_argument( + "--onnx_path", type=str, default="bert.onnx", required=False, help="Path to output .nemo file." + ) + parser.add_argument( + "--precision", type=str, default="32", choices=["bf16", "32"], help="Precision for checkpoint weights saved" + ) + + args = parser.parse_args() + return args + + +def export(args): + nemo_config = OmegaConf.load(args.hparams_file) + nemo_config.trainer["precision"] = args.precision + + trainer = MegatronTrainerBuilder(nemo_config).create_trainer() + model = MegatronBertTextEmbeddingModel.restore_from(args.nemo_path, trainer=trainer) + + hf_tokenizer = model.tokenizer.tokenizer + + logging.info(f'=' * 50) + # Verifications + input_texts = [ + 'query: how much protein should a female eat', + 'query: summit define', + "passage: As a general guideline, the CDC's average requirement of protein for women ages 19 to 70 is 46 grams per day. But, as you can see from this chart, you'll need to increase that if you're expecting or training for a marathon. Check out the chart below to see how much protein you should be eating each day.", + "passage: Definition of summit for English Language Learners. : 1 the highest point of a mountain : the top of a mountain. : 2 the highest level. : 3 a meeting or series of meetings between the leaders of two or more governments.", + ] + + # Tokenize the input texts + batch_dict = hf_tokenizer(input_texts, max_length=512, padding=True, truncation=True, return_tensors='pt') + batch_dict_cuda = {k: v.cuda() for k, v in batch_dict.items()} + model = model.eval() + + input_names = ["input_ids", "attention_mask", "token_type_ids"] + output_names = ["outputs"] + export_input = tuple([batch_dict_cuda[name] for name in input_names]) + + torch.onnx.export( + model, export_input, args.onnx_path, verbose=False, input_names=input_names, output_names=output_names, + ) + logging.info(f'NeMo model saved to: {args.onnx_path}') + + +if __name__ == '__main__': + args = get_args() + export(args) From f6e64859174920e6a7fb26c4308cbe4c1d44206e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Piotr=20=C5=BBelasko?= Date: Wed, 31 Jan 2024 18:33:47 -0500 Subject: [PATCH 2/4] Pin lhotse version to 1.19.2 (#8291) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Piotr Żelasko --- requirements/requirements_asr.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements/requirements_asr.txt b/requirements/requirements_asr.txt index cb6681a08243..671f06f3dcca 100644 --- a/requirements/requirements_asr.txt +++ b/requirements/requirements_asr.txt @@ -5,7 +5,7 @@ ipywidgets jiwer kaldi-python-io kaldiio -lhotse>=1.19.2 +lhotse==1.19.2 librosa>=0.10.0 marshmallow matplotlib From a4f1f1cbc80e7e37f743200af9ddcf4153b2ca28 Mon Sep 17 00:00:00 2001 From: Vladimir Bataev Date: Fri, 2 Feb 2024 01:07:50 +0400 Subject: [PATCH 3/4] Fix documentation build (#8308) * Fix docs build Signed-off-by: Vladimir Bataev * Clean up Signed-off-by: Vladimir Bataev * Fix mock imports Signed-off-by: Vladimir Bataev * Add comment Signed-off-by: Vladimir Bataev --------- Signed-off-by: Vladimir Bataev --- docs/source/conf.py | 2 ++ docs/update_docs_docker.sh | 2 +- nemo/core/neural_types/elements.py | 4 ++++ nemo/core/neural_types/neural_type.py | 18 +++++++++++++++--- 4 files changed, 22 insertions(+), 4 deletions(-) diff --git a/docs/source/conf.py b/docs/source/conf.py index 586f6cf47675..0596b15e3de5 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -64,6 +64,8 @@ 'PIL', 'boto3', 'taming', + 'cytoolz', # for adapters + 'megatron', # for nlp ] _skipped_autodoc_mock_imports = ['wrapt', 'numpy'] diff --git a/docs/update_docs_docker.sh b/docs/update_docs_docker.sh index 19f9b46d3ef1..653894630a3c 100755 --- a/docs/update_docs_docker.sh +++ b/docs/update_docs_docker.sh @@ -1,5 +1,5 @@ cd ../ -docker run --rm -v $PWD:/workspace python:3.8 /bin/bash -c "cd /workspace && \ +docker run --rm -v $PWD:/workspace python:3.10 /bin/bash -c "cd /workspace && \ pip install -r requirements/requirements_docs.txt && cd docs/ && rm -rf build && make clean && make html && make html" echo "To start web server just run in docs directory:" echo "python3 -m http.server 8000 --directory ./build/html/" diff --git a/nemo/core/neural_types/elements.py b/nemo/core/neural_types/elements.py index 5d5697d80e46..7e95acebd91f 100644 --- a/nemo/core/neural_types/elements.py +++ b/nemo/core/neural_types/elements.py @@ -96,6 +96,10 @@ def fields(self): return None def compare(self, second) -> NeuralTypeComparisonResult: + if torch.jit.is_scripting(): + # Neural types for TorchScript are suppressed + # This is a stub to make TorchScript happy + return NeuralTypeComparisonResult.SAME # First, check general compatibility first_t = type(self) second_t = type(second) diff --git a/nemo/core/neural_types/neural_type.py b/nemo/core/neural_types/neural_type.py index 49345d6d234f..d00ba72df043 100644 --- a/nemo/core/neural_types/neural_type.py +++ b/nemo/core/neural_types/neural_type.py @@ -51,8 +51,21 @@ def __str__(self): else: return f"axes: None; elements_type: {self.elements_type.__class__.__name__}" + def __init__(self, axes: Optional[Any] = None, elements_type: Optional[Any] = None, optional: bool = False): + """ + Args: + axes: a tuple of AxisTypes objects representing the semantics of what varying each axis means + elements_type: None or ElementType; we need Any annotation here to avoid problems with TorchScript (it is checked in _init_internal) + optional: If input to the port of this type can be optional (False by default). + """ + if not torch.jit.is_scripting(): + self._init_internal(axes=axes, elements_type=elements_type, optional=optional) + @torch.jit.unused - def __init__(self, axes: Optional[Any] = None, elements_type: Optional[ElementType] = None, optional=False): + def _init_internal( + self, axes: Optional[Any] = None, elements_type: Optional[ElementType] = None, optional: bool = False + ): + """Internals of __init__, separated to make TorchScript and autodoc work""" if elements_type is None: elements_type = VoidType() if not isinstance(elements_type, ElementType): @@ -62,8 +75,7 @@ def __init__(self, axes: Optional[Any] = None, elements_type: Optional[ElementTy ) self.elements_type = elements_type if axes is not None: - if not torch.jit.is_scripting(): - NeuralType.__check_sanity(axes) + NeuralType.__check_sanity(axes) axes_list = [] for axis in axes: if isinstance(axis, str): From 5fdd12e9a9711b241023eb4ed0922733d69ded5e Mon Sep 17 00:00:00 2001 From: "github-actions[bot]" <41898282+github-actions[bot]@users.noreply.github.com> Date: Thu, 1 Feb 2024 16:07:06 -0800 Subject: [PATCH 4/4] Cache Aware Streaming tutorial notebook (#8296) (#8311) * add notebook * rename old notebook to Buffered_Streaming * call setup_streaming_params in set_default_att_context_size method * update links in docs * update links to tutorials in docs * remove hard-coding * rename var --------- Signed-off-by: Elena Rastorgueva Co-authored-by: Elena Rastorgueva <80532067+erastorgueva-nv@users.noreply.github.com> --- README.rst | 2 +- docs/source/asr/intro.rst | 4 +- docs/source/asr/models.rst | 2 + docs/source/starthere/tutorials.rst | 7 +- .../asr/modules/conformer_encoder.py | 2 + ..._Microphone_Demo_Buffered_Streaming.ipynb} | 0 ...icrophone_Demo_Cache_Aware_Streaming.ipynb | 433 ++++++++++++++++++ 7 files changed, 444 insertions(+), 6 deletions(-) rename tutorials/asr/{Online_ASR_Microphone_Demo.ipynb => Online_ASR_Microphone_Demo_Buffered_Streaming.ipynb} (100%) create mode 100644 tutorials/asr/Online_ASR_Microphone_Demo_Cache_Aware_Streaming.ipynb diff --git a/README.rst b/README.rst index 05cf7d8d7124..78396d80dc45 100644 --- a/README.rst +++ b/README.rst @@ -101,7 +101,7 @@ Key Features * Hybrid Transducer/CTC * NeMo Original `Multi-blank Transducers `_ and `Token-and-Duration Transducers (TDT) `_ * Streaming/Buffered ASR (CTC/Transducer) - `Chunked Inference Examples `_ - * `Cache-aware Streaming Conformer `_ with multiple lookaheads. + * `Cache-aware Streaming Conformer `_ with multiple lookaheads (including microphone streaming `tutorial `_). * Beam Search decoding * `Language Modelling for ASR (CTC and RNNT) `_: N-gram LM in fusion with Beam Search decoding, Neural Rescoring with Transformer * `Support of long audios for Conformer with memory efficient local attention `_ diff --git a/docs/source/asr/intro.rst b/docs/source/asr/intro.rst index 2ac27c4312dc..b0c923a85c77 100644 --- a/docs/source/asr/intro.rst +++ b/docs/source/asr/intro.rst @@ -108,9 +108,7 @@ See more information about LM decoding :doc:`here <./asr_language_modeling>`. Use real-time transcription --------------------------- -It is possible to use NeMo to transcribe speech in real-time. You can find an example of how to do -this in the following `notebook tutorial `_. - +It is possible to use NeMo to transcribe speech in real-time. We provide tutorial notebooks for `Cache Aware Streaming `_ and `Buffered Streaming `_. Try different ASR models ------------------------ diff --git a/docs/source/asr/models.rst b/docs/source/asr/models.rst index 02dff9c598e4..4752ff931af7 100644 --- a/docs/source/asr/models.rst +++ b/docs/source/asr/models.rst @@ -159,6 +159,8 @@ You may find more examples under ``/examples/asr/conf/fastconform Cache-aware Streaming Conformer ------------------------------- +Try real-time ASR with the Cache-aware Streaming Conformer `tutorial notebook `_. + Buffered streaming uses overlapping chunks to make an offline ASR model to be used for streaming with reasonable accuracy. However, it uses significant amount of duplication in computations due to the overlapping chunks. Also there is a accuracy gap between the offline model and the streaming one as there is inconsistency between how we train the model and how we perform inference for streaming. The Cache-aware Streaming Conformer models would tackle and address these disadvantages. These streaming Conformers are trained with limited right context that it would make it possible to match how the model is being used in both the training and inference. diff --git a/docs/source/starthere/tutorials.rst b/docs/source/starthere/tutorials.rst index 29ba2f300b74..f48662de937f 100644 --- a/docs/source/starthere/tutorials.rst +++ b/docs/source/starthere/tutorials.rst @@ -47,8 +47,11 @@ To run a tutorial: - Offline ASR Inference with Beam Search and External Language Model Rescoring - `Offline ASR `_ * - ASR - - Online ASR inference with Microphone - - `Online ASR Microphone `_ + - Online ASR inference with Microphone (Cache-Aware Streaming) + - `Online ASR Microphone Cache Aware Streaming `_ + * - ASR + - Online ASR inference with Microphone (Buffered Streaming) + - `Online ASR Microphone Buffered Streaming `_ * - ASR - Fine-tuning CTC Models on New Languages - `ASR CTC Language Fine-Tuning `_ diff --git a/nemo/collections/asr/modules/conformer_encoder.py b/nemo/collections/asr/modules/conformer_encoder.py index 6d2879e02a11..8488cf1b3812 100644 --- a/nemo/collections/asr/modules/conformer_encoder.py +++ b/nemo/collections/asr/modules/conformer_encoder.py @@ -786,6 +786,8 @@ def set_default_att_context_size(self, att_context_size): if att_context_size is not None: self.att_context_size = att_context_size + self.setup_streaming_params() + def setup_streaming_params( self, chunk_size: int = None, diff --git a/tutorials/asr/Online_ASR_Microphone_Demo.ipynb b/tutorials/asr/Online_ASR_Microphone_Demo_Buffered_Streaming.ipynb similarity index 100% rename from tutorials/asr/Online_ASR_Microphone_Demo.ipynb rename to tutorials/asr/Online_ASR_Microphone_Demo_Buffered_Streaming.ipynb diff --git a/tutorials/asr/Online_ASR_Microphone_Demo_Cache_Aware_Streaming.ipynb b/tutorials/asr/Online_ASR_Microphone_Demo_Cache_Aware_Streaming.ipynb new file mode 100644 index 000000000000..fb676af7dbb7 --- /dev/null +++ b/tutorials/asr/Online_ASR_Microphone_Demo_Cache_Aware_Streaming.ipynb @@ -0,0 +1,433 @@ +{ + "cells": [ + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Introduction\n", + "\n", + "This notebook allows you to do real-time (\"streaming\") speech recognition using audio recorded from your microphone. This notebook shows how to use a NeMo chunk-aware FastConformer model with caching enabled.\n", + "\n", + "## Installation\n", + "\n", + "The notebook requires PyAudio library, which is used to capture an audio stream from your machine. This means that you need to run this notebook locally. This notebook will not be able to record your audio if you run it in Google Colab or in a Docker container.\n", + "\n", + "For Ubuntu, please run the following commands to install it:\n", + "\n", + "```\n", + "sudo apt install python3-pyaudio\n", + "pip install pyaudio\n", + "```\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "## Install dependencies\n", + "!pip install wget\n", + "!apt-get install sox libsndfile1 ffmpeg portaudio19-dev\n", + "!pip install text-unidecode\n", + "!pip install pyaudio" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# ## Uncomment this cell to install NeMo if it has not been installed\n", + "# BRANCH = 'main'\n", + "# !python -m pip install git+https://github.com/NVIDIA/NeMo.git@$BRANCH#egg=nemo_toolkit[asr]" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# import dependencies\n", + "import copy\n", + "import time\n", + "import pyaudio as pa\n", + "import numpy as np\n", + "import torch\n", + "\n", + "from omegaconf import OmegaConf, open_dict\n", + "\n", + "import nemo.collections.asr as nemo_asr\n", + "from nemo.collections.asr.models.ctc_bpe_models import EncDecCTCModelBPE\n", + "from nemo.collections.asr.parts.utils.streaming_utils import CacheAwareStreamingAudioBuffer\n", + "from nemo.collections.asr.parts.utils.rnnt_utils import Hypothesis\n", + "\n", + "# specify sample rate we will use for recording audio\n", + "SAMPLE_RATE = 16000 # Hz" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Cache-aware streaming Fastconformer\n", + "In this tutorial, we will do streaming transcription using NeMo models that were specially trained for use in streaming applications. These models are described in the paper released by the NeMo team: [*Noroozi et al.* \"Stateful FastConformer with Cache-based Inference for Streaming Automatic Speech Recognition](https://arxiv.org/abs/2312.17279)\" (accepted to ICASSP 2024).\n", + "\n", + "These models have the following features:\n", + "* They were trained such that at each timestep, the decoder (either RNNT or CTC) would receive a limited amount of context on the left and (most importantly) the right side. Keeping the right side context small means that in a real time streaming scenario, we do not need to keep recording for very long before we are able to compute the output token at that timestep - thus we are able to get transcriptions with a low latency.\n", + "* The model implementation has **caching** enabled, meaning we do not need to recalculate activations that were obtained in previous timesteps, thus reducing latency further.\n", + "\n", + "\n", + "## Model checkpoints\n", + "The following checkpoints of these models are currently available, and are compatible with this notebook. The meaning of \"lookahead\" and \"chunk size\" is described in the following section.\n", + "\n", + "1) [`stt_en_fastconformer_hybrid_large_streaming_80ms`](https://ngc.nvidia.com/catalog/models/nvidia:nemo:stt_en_fastconformer_hybrid_large_streaming_80ms) - 80ms lookahead / 160ms chunk size\n", + "\n", + "2) [`stt_en_fastconformer_hybrid_large_streaming_480ms`](https://ngc.nvidia.com/catalog/models/nvidia:nemo:stt_en_fastconformer_hybrid_large_streaming_480ms) - 480ms lookahead / 540ms chunk size\n", + "\n", + "3) [`stt_en_fastconformer_hybrid_large_streaming_1040ms`](https://ngc.nvidia.com/catalog/models/nvidia:nemo:stt_en_fastconformer_hybrid_large_streaming_1040ms) - 1040ms lookahead / 1120ms chunk size\n", + "\n", + "4) [`stt_en_fastconformer_hybrid_large_streaming_multi`](https://ngc.nvidia.com/catalog/models/nvidia:nemo:stt_en_fastconformer_hybrid_large_streaming_multi) - 0ms, 80ms, 480ms, 1040ms lookahead / 80ms, 160ms, 540ms, 1120ms chunk size\n", + "\n", + "## Model inference explanation\n", + "We run inference by continuously recording our audio in chunks, and feeding the chunks into the chosen ASR model. In this notebook we use `pyaudio` to open an audio input stream, and pass the audio to a `stream_callback` function every \"chunk-sized\" number of seconds. In the `stream_callback` function, we pass the audio signal to a `transcribe` function (which we will specify in this notebook), and print the resulting transcription.\n", + "\n", + "As mentioned, the \"chunk size\" is the duration of audio that we feed into the ASR model at a time (and we keep doing this continuously, to allow for real-time, streaming speech recognition).\n", + "\n", + "\"Lookahead\" size is the \"chunk size\" minus the duration of a single output timestep from the decoder. For FastConformer models, the duration of an output timestep is always 80ms, hence in this notebook always `lookahead size = chunk size - 80 ms`." + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Model selection\n", + "In the next cell, you can select which pretrained `model_name` and `lookahead_size` you would like to try.\n", + "\n", + "Additionally, note that all of the available models are [Hybrid RNNT-CTC models](https://docs.nvidia.com/deeplearning/nemo/user-guide/docs/en/main/asr/models.html#hybrid-transducer-ctc). Inference is by default done using the RNNT decoder (which tends to produce a higher transcription accuracy), but you may choose to use the CTC decoder instead. For this, we also provide a `decoder_type` variable in the cell below." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# You may wish to try different values of model_name and lookahead_size\n", + "\n", + "# Choose a the name of a model to use.\n", + "# Currently available options:\n", + "# 1) \"stt_en_fastconformer_hybrid_large_streaming_multi\"\n", + "# 2) \"stt_en_fastconformer_hybrid_large_streaming_80ms\"\n", + "# 3) \"stt_en_fastconformer_hybrid_large_streaming_480ms\"\n", + "# 4) \"stt_en_fastconformer_hybrid_large_streaming_1040ms\"\n", + "\n", + "model_name = \"stt_en_fastconformer_hybrid_large_streaming_multi\"\n", + "\n", + "# Specify the lookahead_size.\n", + "# If model_name == \"stt_en_fastconformer_hybrid_large_streaming_multi\" then\n", + "# lookahead_size can be 0, 80, 480 or 1040 (ms)\n", + "# Else, lookahead_size should be whatever is written in the model_name:\n", + "# \"stt_en_fastconformer_hybrid_large_streaming_ms\"\n", + "\n", + "lookahead_size = 80 # in milliseconds\n", + "\n", + "# Specify the decoder to use.\n", + "# Can be \"rnnt\" or \"ctc\"\n", + "decoder_type = \"rnnt\"" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Model set-up\n", + "Next we:\n", + "* set up the `asr_model` according to the chosen `model_name` and `lookahead_size`\n", + "* make sure we use the specified `decoder_type`\n", + "* make sure the model's decoding strategy has suitable parameters\n", + "* instantiate a `CacheAwareStreamingAudioBuffer`\n", + "* get some parameters to use as the initial cache state" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "scrolled": true + }, + "outputs": [], + "source": [ + "# setting up model and validating the choice of model_name and lookahead size\n", + "asr_model = nemo_asr.models.ASRModel.from_pretrained(model_name=model_name)\n", + "\n", + "\n", + "# specify ENCODER_STEP_LENGTH (which is 80 ms for FastConformer models)\n", + "ENCODER_STEP_LENGTH = 80 # ms\n", + "\n", + "# update att_context_size if using multi-lookahead model\n", + "# (for single-lookahead models, the default context size will be used and the\n", + "# `lookahead_size` variable will be ignored)\n", + "if model_name == \"stt_en_fastconformer_hybrid_large_streaming_multi\":\n", + " # check that lookahead_size is one of the valid ones\n", + " if lookahead_size not in [0, 80, 480, 1040]:\n", + " raise ValueError(\n", + " f\"specified lookahead_size {lookahead_size} is not one of the \"\n", + " \"allowed lookaheads (can select 0, 80, 480 or 1040 ms)\"\n", + " )\n", + "\n", + " # update att_context_size\n", + " left_context_size = asr_model.encoder.att_context_size[0]\n", + " asr_model.encoder.set_default_att_context_size([left_context_size, int(lookahead_size / ENCODER_STEP_LENGTH)])\n", + "\n", + "\n", + "# make sure we use the specified decoder_type\n", + "asr_model.change_decoding_strategy(decoder_type=decoder_type)\n", + "\n", + "# make sure the model's decoding strategy is optimal\n", + "decoding_cfg = asr_model.cfg.decoding\n", + "with open_dict(decoding_cfg):\n", + " # save time by doing greedy decoding and not trying to record the alignments\n", + " decoding_cfg.strategy = \"greedy\"\n", + " decoding_cfg.preserve_alignments = False\n", + " if hasattr(asr_model, 'joint'): # if an RNNT model\n", + " # restrict max_symbols to make sure not stuck in infinite loop\n", + " decoding_cfg.greedy.max_symbols = 10\n", + " # sensible default parameter, but not necessary since batch size is 1\n", + " decoding_cfg.fused_batch_size = -1\n", + " asr_model.change_decoding_strategy(decoding_cfg)\n", + "\n", + "\n", + "# set model to eval mode\n", + "asr_model.eval()\n", + "\n", + "\n", + "# get parameters to use as the initial cache state\n", + "cache_last_channel, cache_last_time, cache_last_channel_len = asr_model.encoder.get_initial_cache_state(\n", + " batch_size=1\n", + ")" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Transcribing a single chunk\n", + "In the following code block we specify the `transcribe_chunk` function that transcribes a single chunk." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# init params we will use for streaming\n", + "previous_hypotheses = None\n", + "pred_out_stream = None\n", + "step_num = 0\n", + "pre_encode_cache_size = asr_model.encoder.streaming_cfg.pre_encode_cache_size[1]\n", + "# cache-aware models require some small section of the previous processed_signal to\n", + "# be fed in at each timestep - we initialize this to a tensor filled with zeros\n", + "# so that we will do zero-padding for the very first chunk(s)\n", + "num_channels = asr_model.cfg.preprocessor.features\n", + "cache_pre_encode = torch.zeros((1, num_channels, pre_encode_cache_size), device=asr_model.device)\n", + "\n", + "\n", + "# helper function for extracting transcriptions\n", + "def extract_transcriptions(hyps):\n", + " \"\"\"\n", + " The transcribed_texts returned by CTC and RNNT models are different.\n", + " This method would extract and return the text section of the hypothesis.\n", + " \"\"\"\n", + " if isinstance(hyps[0], Hypothesis):\n", + " transcriptions = []\n", + " for hyp in hyps:\n", + " transcriptions.append(hyp.text)\n", + " else:\n", + " transcriptions = hyps\n", + " return transcriptions\n", + "\n", + "# define functions to init audio preprocessor and to\n", + "# preprocess the audio (ie obtain the mel-spectrogram)\n", + "def init_preprocessor(asr_model):\n", + " cfg = copy.deepcopy(asr_model._cfg)\n", + " OmegaConf.set_struct(cfg.preprocessor, False)\n", + "\n", + " # some changes for streaming scenario\n", + " cfg.preprocessor.dither = 0.0\n", + " cfg.preprocessor.pad_to = 0\n", + " cfg.preprocessor.normalize = \"None\"\n", + " \n", + " preprocessor = EncDecCTCModelBPE.from_config_dict(cfg.preprocessor)\n", + " preprocessor.to(asr_model.device)\n", + " \n", + " return preprocessor\n", + "\n", + "preprocessor = init_preprocessor(asr_model)\n", + "\n", + "def preprocess_audio(audio, asr_model):\n", + " device = asr_model.device\n", + "\n", + " # doing audio preprocessing\n", + " audio_signal = torch.from_numpy(audio).unsqueeze_(0).to(device)\n", + " audio_signal_len = torch.Tensor([audio.shape[0]]).to(device)\n", + " processed_signal, processed_signal_length = preprocessor(\n", + " input_signal=audio_signal, length=audio_signal_len\n", + " )\n", + " return processed_signal, processed_signal_length\n", + "\n", + "\n", + "def transcribe_chunk(new_chunk):\n", + " \n", + " global cache_last_channel, cache_last_time, cache_last_channel_len\n", + " global previous_hypotheses, pred_out_stream, step_num\n", + " global cache_pre_encode\n", + " \n", + " # new_chunk is provided as np.int16, so we convert it to np.float32\n", + " # as that is what our ASR models expect\n", + " audio_data = new_chunk.astype(np.float32)\n", + " audio_data = audio_data / 32768.0\n", + "\n", + " # get mel-spectrogram signal & length\n", + " processed_signal, processed_signal_length = preprocess_audio(audio_data, asr_model)\n", + " \n", + " # prepend with cache_pre_encode\n", + " processed_signal = torch.cat([cache_pre_encode, processed_signal], dim=-1)\n", + " processed_signal_length += cache_pre_encode.shape[1]\n", + " \n", + " # save cache for next time\n", + " cache_pre_encode = processed_signal[:, :, -pre_encode_cache_size:]\n", + " \n", + " with torch.no_grad():\n", + " (\n", + " pred_out_stream,\n", + " transcribed_texts,\n", + " cache_last_channel,\n", + " cache_last_time,\n", + " cache_last_channel_len,\n", + " previous_hypotheses,\n", + " ) = asr_model.conformer_stream_step(\n", + " processed_signal=processed_signal,\n", + " processed_signal_length=processed_signal_length,\n", + " cache_last_channel=cache_last_channel,\n", + " cache_last_time=cache_last_time,\n", + " cache_last_channel_len=cache_last_channel_len,\n", + " keep_all_outputs=False,\n", + " previous_hypotheses=previous_hypotheses,\n", + " previous_pred_out=pred_out_stream,\n", + " drop_extra_pre_encoded=None,\n", + " return_transcription=True,\n", + " )\n", + " \n", + " final_streaming_tran = extract_transcriptions(transcribed_texts)\n", + " step_num += 1\n", + " \n", + " return final_streaming_tran[0]" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Simple streaming with microphone\n", + "We use `pyaudio` to record audio from an input audio device on your local machine. We use a `stream_callback` which will be called every `frames_per_buffer` number of frames, and conduct the transcription, which will be printed in the output of the cell below." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# calculate chunk_size in milliseconds\n", + "chunk_size = lookahead_size + ENCODER_STEP_LENGTH\n", + "\n", + "p = pa.PyAudio()\n", + "print('Available audio input devices:')\n", + "input_devices = []\n", + "for i in range(p.get_device_count()):\n", + " dev = p.get_device_info_by_index(i)\n", + " if dev.get('maxInputChannels'):\n", + " input_devices.append(i)\n", + " print(i, dev.get('name'))\n", + "\n", + "if len(input_devices):\n", + " dev_idx = -2\n", + " while dev_idx not in input_devices:\n", + " print('Please type input device ID:')\n", + " dev_idx = int(input())\n", + "\n", + " def callback(in_data, frame_count, time_info, status):\n", + " signal = np.frombuffer(in_data, dtype=np.int16)\n", + " text = transcribe_chunk(signal)\n", + " print(text, end='\\r')\n", + " return (in_data, pa.paContinue)\n", + "\n", + " stream = p.open(format=pa.paInt16,\n", + " channels=1,\n", + " rate=SAMPLE_RATE,\n", + " input=True,\n", + " input_device_index=dev_idx,\n", + " stream_callback=callback,\n", + " frames_per_buffer=int(SAMPLE_RATE * chunk_size / 1000) - 1\n", + " )\n", + "\n", + " print('Listening...')\n", + "\n", + " stream.start_stream()\n", + " \n", + " # Interrupt kernel and then speak for a few more words to exit the pyaudio loop !\n", + " try:\n", + " while stream.is_active():\n", + " time.sleep(0.1)\n", + " finally: \n", + " stream.stop_stream()\n", + " stream.close()\n", + " p.terminate()\n", + "\n", + " print()\n", + " print(\"PyAudio stopped\")\n", + " \n", + "else:\n", + " print('ERROR: No audio input device found.')" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.13" + } + }, + "nbformat": 4, + "nbformat_minor": 4 +}