From b5cb7831ddac10c6a6be57d21cfd09dbc3b10c36 Mon Sep 17 00:00:00 2001 From: Boris Fomitchev Date: Fri, 7 Jul 2023 14:31:31 -0700 Subject: [PATCH 1/2] Hybrid conformer export (#6983) * Implemented generic kv-pair setting of export_config from args Signed-off-by: Boris Fomitchev * Hybrid conformer export Signed-off-by: Boris Fomitchev * Hybrid decoder export Signed-off-by: Boris Fomitchev * Cleanup Signed-off-by: Boris Fomitchev * Changed from **kwargs Signed-off-by: Boris Fomitchev * Docstring Signed-off-by: Boris Fomitchev * Docs added Signed-off-by: Boris Fomitchev * Stringify args Signed-off-by: Boris Fomitchev * Added docs for ASR export configs Signed-off-by: Boris Fomitchev * lowercase ctc Signed-off-by: Boris Fomitchev --------- Signed-off-by: Boris Fomitchev --- docs/source/asr/models.rst | 10 ++++++ docs/source/core/export.rst | 31 +++++++++++++++++++ nemo/collections/asr/models/asr_model.py | 8 +++++ .../asr/models/hybrid_rnnt_ctc_models.py | 14 +++++++++ nemo/collections/asr/models/rnnt_models.py | 12 +++++-- nemo/core/classes/exportable.py | 14 +++++++++ scripts/export.py | 19 +++++++++--- 7 files changed, 102 insertions(+), 6 deletions(-) diff --git a/docs/source/asr/models.rst b/docs/source/asr/models.rst index 80a0fd90f0fb..697a89827145 100644 --- a/docs/source/asr/models.rst +++ b/docs/source/asr/models.rst @@ -215,6 +215,11 @@ It is recommended to train a model in streaming model with limited context for t You may find FastConformer variants of cache-aware streaming models under ``/examples/asr/conf/fastconformer/``. +Note cache-aware streaming models are being exported without caching support by default. +To include caching support, `model.set_export_config({'cache_support' : 'True'})` should be called before export. +Or, if ``/scripts/export.py`` is being used: +`python export.py cache_aware_conformer.nemo cache_aware_conformer.onnx --config cache_support=True` + .. _LSTM-Transducer_model: LSTM-Transducer @@ -291,6 +296,11 @@ Similar example configs for FastConformer variants of Hybrid models can be found ``/examples/asr/conf/fastconformer/hybrid_transducer_ctc/`` ``/examples/asr/conf/fastconformer/hybrid_cache_aware_streaming/`` +Note Hybrid models are being exported as RNNT (encoder and decoder+joint parts) by default. +To export as CTC (single encoder+decoder graph), `model.set_export_config({'decoder_type' : 'ctc'})` should be called before export. +Or, if ``/scripts/export.py`` is being used: +`python export.py hybrid_transducer.nemo hybrid_transducer.onnx --config decoder_type=ctc` + .. _Conformer-HAT_model: Conformer-HAT (Hybrid Autoregressive Transducer) diff --git a/docs/source/core/export.rst b/docs/source/core/export.rst index 0e598e215dbf..f54daffe9c9c 100644 --- a/docs/source/core/export.rst +++ b/docs/source/core/export.rst @@ -177,6 +177,37 @@ Another common requirement for models that are being exported is to run certain # call base method for common set of modifications Exportable._prepare_for_export(self, **kwargs) +Some models that require control flow, need to be exported in multiple parts. Typical examples are RNNT nets. +To facilitate that, the hooks below are provided. To export, for example, 'encoder' and 'decoder' subnets of the model, overload list_export_subnets to return ['encoder', 'decoder']. + +.. code-block:: Python + + def get_export_subnet(self, subnet=None): + """ + Returns Exportable subnet model/module to export + """ + + + def list_export_subnets(self): + """ + Returns default set of subnet names exported for this model + First goes the one receiving input (input_example) + """ + +Some nertworks may be exported differently according to user-settable options (like ragged batch support for TTS or cache support for ASR). To facilitate that - `set_export_config()` method is provided by Exportable to set key/value pairs to predefined model.export_config dictionary, to be used during the export: + +.. code-block:: Python + def set_export_config(self, args): + """ + Sets/updates export_config dictionary + """ +Also, if an action hook on setting config is desired, this method may be overloaded by `Exportable` descendants to include one. +An example can be found in ``/nemo/collections/asr/models/rnnt_models.py``. + +Here is example on now `set_export_config()` call is being tied to command line arguments in ``/scripts/export.py`` : + +.. code-block:: Python + python scripts/export.py hybrid_conformer.nemo hybrid_conformer.onnx --config decoder_type=ctc Exportable Model Code ~~~~~~~~~~~~~~~~~~~~~ diff --git a/nemo/collections/asr/models/asr_model.py b/nemo/collections/asr/models/asr_model.py index c0f4c1cd0a70..20be6cc16203 100644 --- a/nemo/collections/asr/models/asr_model.py +++ b/nemo/collections/asr/models/asr_model.py @@ -239,3 +239,11 @@ def disabled_deployment_input_names(self): @property def disabled_deployment_output_names(self): return self.encoder.disabled_deployment_output_names + + def set_export_config(self, args): + if 'cache_support' in args: + enable = bool(args['cache_support']) + self.encoder.export_cache_support = enable + logging.info(f"Caching support enabled: {enable}") + self.encoder.setup_streaming_params() + super().set_export_config(args) diff --git a/nemo/collections/asr/models/hybrid_rnnt_ctc_models.py b/nemo/collections/asr/models/hybrid_rnnt_ctc_models.py index 5ca6124ecfd7..11c616b1257f 100644 --- a/nemo/collections/asr/models/hybrid_rnnt_ctc_models.py +++ b/nemo/collections/asr/models/hybrid_rnnt_ctc_models.py @@ -645,6 +645,20 @@ def multi_test_epoch_end(self, outputs, dataloader_idx: int = 0): self.finalize_interctc_metrics(metrics, outputs, prefix="test_") return metrics + # EncDecRNNTModel is exported in 2 parts + def list_export_subnets(self): + if self.cur_decoder == 'rnnt': + return ['encoder', 'decoder_joint'] + else: + return ['self'] + + @property + def output_module(self): + if self.cur_decoder == 'rnnt': + return self.decoder + else: + return self.ctc_decoder + @classmethod def list_available_models(cls) -> Optional[PretrainedModelInfo]: """ diff --git a/nemo/collections/asr/models/rnnt_models.py b/nemo/collections/asr/models/rnnt_models.py index 92bb04fd2a3e..0c1da97c5012 100644 --- a/nemo/collections/asr/models/rnnt_models.py +++ b/nemo/collections/asr/models/rnnt_models.py @@ -28,7 +28,7 @@ from nemo.collections.asr.data.audio_to_text_dali import AudioToCharDALIDataset, DALIOutputs from nemo.collections.asr.losses.rnnt import RNNTLoss, resolve_rnnt_default_loss_name from nemo.collections.asr.metrics.rnnt_wer import RNNTWER, RNNTDecoding, RNNTDecodingConfig -from nemo.collections.asr.models.asr_model import ASRModel +from nemo.collections.asr.models.asr_model import ASRModel, ExportableEncDecModel from nemo.collections.asr.modules.rnnt import RNNTDecoderJoint from nemo.collections.asr.parts.mixins import ASRModuleMixin from nemo.collections.asr.parts.utils.audio_utils import ChannelSelectorType @@ -39,7 +39,7 @@ from nemo.utils import logging -class EncDecRNNTModel(ASRModel, ASRModuleMixin, Exportable): +class EncDecRNNTModel(ASRModel, ASRModuleMixin, ExportableEncDecModel): """Base class for encoder decoder RNNT-based models.""" def __init__(self, cfg: DictConfig, trainer: Trainer = None): @@ -960,6 +960,14 @@ def list_export_subnets(self): def decoder_joint(self): return RNNTDecoderJoint(self.decoder, self.joint) + def set_export_config(self, args): + if 'decoder_type' in args: + if hasattr(self, 'change_decoding_strategy'): + self.change_decoding_strategy(decoder_type=args['decoder_type']) + else: + raise Exception("Model does not have decoder type option") + super().set_export_config(args) + @classmethod def list_available_models(cls) -> List[PretrainedModelInfo]: """ diff --git a/nemo/core/classes/exportable.py b/nemo/core/classes/exportable.py index 3d2682f2304e..8469e80219d6 100644 --- a/nemo/core/classes/exportable.py +++ b/nemo/core/classes/exportable.py @@ -302,3 +302,17 @@ def list_export_subnets(self): First goes the one receiving input (input_example) """ return ['self'] + + def get_export_config(self): + """ + Returns export_config dictionary + """ + return getattr(self, 'export_config', {}) + + def set_export_config(self, args): + """ + Sets/updates export_config dictionary + """ + ex_config = self.get_export_config() + ex_config.update(args) + self.export_config = ex_config diff --git a/scripts/export.py b/scripts/export.py index fe3b79ebdf28..4b21bc4ffd73 100644 --- a/scripts/export.py +++ b/scripts/export.py @@ -62,6 +62,15 @@ def get_args(argv): ) parser.add_argument("--device", default="cuda", help="Device to export for") parser.add_argument("--check-tolerance", type=float, default=0.01, help="tolerance for verification") + parser.add_argument( + "--config", + metavar="KEY=VALUE", + nargs='+', + help="Set a number of key-value pairs to model.export_config dictionary " + "(do not put spaces before or after the = sign). " + "Note that values are always treated as strings.", + ) + args = parser.parse_args(argv) return args @@ -130,10 +139,12 @@ def nemo_export(argv): in_args["max_dim"] = args.max_dim max_dim = args.max_dim - if args.cache_support and hasattr(model, "encoder") and hasattr(model.encoder, "export_cache_support"): - model.encoder.export_cache_support = True - logging.info("Caching support is enabled.") - model.encoder.setup_streaming_params() + if args.cache_support: + model.set_export_config({"cache_support": "True"}) + + if args.config: + kv = dict(map(lambda s: s.split('='), args.config)) + model.set_export_config(kv) autocast = nullcontext if args.autocast: From f08cb214199f0171e1cc56255b0636e2d1a7ce1e Mon Sep 17 00:00:00 2001 From: Boris Fomitchev Date: Fri, 7 Jul 2023 14:34:55 -0700 Subject: [PATCH 2/2] Cache handling without input tensors mutation (#6980) * Cache handling without input tensors mutation Signed-off-by: Boris Fomitchev * Cleanup Signed-off-by: Boris Fomitchev * Cleanup#2 Signed-off-by: Boris Fomitchev * Cleanup#3 Signed-off-by: Boris Fomitchev --------- Signed-off-by: Boris Fomitchev Co-authored-by: Somshubra Majumdar --- nemo/collections/asr/models/asr_model.py | 64 ++++++----------- .../asr/modules/conformer_encoder.py | 48 ++++++------- .../multi_head_attention_adapter_module.py | 16 ++--- .../asr/parts/submodules/causal_convs.py | 28 ++++---- .../asr/parts/submodules/conformer_modules.py | 70 +++++++------------ .../parts/submodules/multi_head_attention.py | 53 ++++++++------ 6 files changed, 118 insertions(+), 161 deletions(-) diff --git a/nemo/collections/asr/models/asr_model.py b/nemo/collections/asr/models/asr_model.py index 20be6cc16203..7e03d587139f 100644 --- a/nemo/collections/asr/models/asr_model.py +++ b/nemo/collections/asr/models/asr_model.py @@ -161,7 +161,7 @@ def output_module(self): @property def output_names(self): otypes = self.output_module.output_types - if hasattr(self.input_module, 'export_cache_support') and self.input_module.export_cache_support: + if getattr(self.input_module, 'export_cache_support', False): in_types = self.input_module.output_types otypes = {n: t for (n, t) in list(otypes.items())[:1]} for (n, t) in list(in_types.items())[1:]: @@ -174,7 +174,6 @@ def forward_for_export( """ This forward is used when we need to export the model to ONNX format. Inputs cache_last_channel and cache_last_time are needed to be passed for exporting streaming models. - When they are passed, it just passes the inputs through the encoder part and currently the ONNX conversion does not fully work for this case. Args: input: Tensor that represents a batch of raw audio signals, of shape [B, T]. T here represents timesteps. @@ -187,49 +186,26 @@ def forward_for_export( Returns: the output of the model """ - if hasattr(self.input_module, 'forward_for_export'): - if cache_last_channel is None and cache_last_time is None: - encoder_output = self.input_module.forward_for_export(audio_signal=input, length=length) - else: - encoder_output = self.input_module.forward_for_export( - audio_signal=input, - length=length, - cache_last_channel=cache_last_channel, - cache_last_time=cache_last_time, - cache_last_channel_len=cache_last_channel_len, - ) + enc_fun = getattr(self.input_module, 'forward_for_export', self.input_module.forward) + if cache_last_channel is None: + encoder_output = enc_fun(audio_signal=input, length=length) + if isinstance(encoder_output, tuple): + encoder_output = encoder_output[0] else: - if cache_last_channel is None and cache_last_time is None: - encoder_output = self.input_module(audio_signal=input, length=length) - else: - encoder_output = self.input_module( - audio_signal=input, - length=length, - cache_last_channel=cache_last_channel, - cache_last_time=cache_last_time, - cache_last_channel_len=cache_last_channel_len, - ) - if isinstance(encoder_output, tuple): - decoder_input = encoder_output[0] - else: - decoder_input = encoder_output - if hasattr(self.output_module, 'forward_for_export'): - if cache_last_channel is None and cache_last_time is None: - ret = self.output_module.forward_for_export(encoder_output=decoder_input) - else: - ret = self.output_module.forward_for_export(encoder_output=decoder_input) - else: - if cache_last_channel is None and cache_last_time is None: - ret = self.output_module(encoder_output=decoder_input) - else: - ret = self.output_module(encoder_output=decoder_input) - if cache_last_channel is None and cache_last_time is None: - pass - else: - if isinstance(ret, tuple): - ret = (ret[0], encoder_output[1], encoder_output[2], encoder_output[3], encoder_output[4]) - else: - ret = (ret, encoder_output[1], encoder_output[2], encoder_output[3], encoder_output[4]) + encoder_output, length, cache_last_channel, cache_last_time, cache_last_channel_len = enc_fun( + audio_signal=input, + length=length, + cache_last_channel=cache_last_channel, + cache_last_time=cache_last_time, + cache_last_channel_len=cache_last_channel_len, + ) + + dec_fun = getattr(self.output_module, 'forward_for_export', self.output_module.forward) + ret = dec_fun(encoder_output=encoder_output) + if isinstance(ret, tuple): + ret = ret[0] + if cache_last_channel is not None: + ret = (ret, length, cache_last_channel, cache_last_time, cache_last_channel_len) return cast_all(ret, from_dtype=torch.float16, to_dtype=torch.float32) @property diff --git a/nemo/collections/asr/modules/conformer_encoder.py b/nemo/collections/asr/modules/conformer_encoder.py index 74c255741039..8f429c25806d 100644 --- a/nemo/collections/asr/modules/conformer_encoder.py +++ b/nemo/collections/asr/modules/conformer_encoder.py @@ -505,11 +505,6 @@ def forward_internal( (audio_signal.size(0),), audio_signal.size(-1), dtype=torch.int64, device=audio_signal.device ) - if cache_last_time is not None: - cache_last_time_next = torch.zeros_like(cache_last_time) - else: - cache_last_time_next = None - # select a random att_context_size with the distribution specified by att_context_probs during training # for non-validation cases like test, validation or inference, it uses the first mode in self.att_context_size if self.training and len(self.att_context_size_all) > 1: @@ -536,7 +531,6 @@ def forward_internal( if cache_last_channel is not None: cache_len = self.streaming_cfg.last_channel_cache_size cache_keep_size = max_audio_length - self.streaming_cfg.cache_drop_size - cache_last_channel_next = torch.zeros_like(cache_last_channel) max_audio_length = max_audio_length + cache_len padding_length = length + cache_len offset = torch.neg(cache_last_channel_len) + cache_len @@ -561,19 +555,32 @@ def forward_internal( pad_mask = pad_mask[:, cache_len:] if att_mask is not None: att_mask = att_mask[:, cache_len:] + # Convert caches from the tensor to list + cache_last_time_next = [] + cache_last_channel_next = [] for lth, (drop_prob, layer) in enumerate(zip(self.layer_drop_probs, self.layers)): original_signal = audio_signal + if cache_last_channel is not None: + cache_last_channel_cur = cache_last_channel[lth] + cache_last_time_cur = cache_last_time[lth] + else: + cache_last_channel_cur = None + cache_last_time_cur = None audio_signal = layer( x=audio_signal, att_mask=att_mask, pos_emb=pos_emb, pad_mask=pad_mask, - cache_last_channel=cache_last_channel, - cache_last_time=cache_last_time, - cache_last_channel_next=cache_last_channel_next, - cache_last_time_next=cache_last_time_next, + cache_last_channel=cache_last_channel_cur, + cache_last_time=cache_last_time_cur, ) + + if cache_last_channel_cur is not None: + (audio_signal, cache_last_channel_cur, cache_last_time_cur) = audio_signal + cache_last_channel_next.append(cache_last_channel_cur) + cache_last_time_next.append(cache_last_time_cur) + # applying stochastic depth logic from https://arxiv.org/abs/2102.03216 if self.training and drop_prob > 0.0: should_drop = torch.rand(1) < drop_prob @@ -626,6 +633,8 @@ def forward_internal( length = length.to(dtype=torch.int64) if cache_last_channel is not None: + cache_last_channel_next = torch.stack(cache_last_channel_next, dim=0) + cache_last_time_next = torch.stack(cache_last_time_next, dim=0) return ( audio_signal, length, @@ -860,20 +869,12 @@ def setup_streaming_params( else: streaming_cfg.drop_extra_pre_encoded = streaming_cfg.pre_encode_cache_size // self.subsampling_factor - # counting the number of the layers need caching - streaming_cfg.last_channel_num = 0 - streaming_cfg.last_time_num = 0 for m in self.layers.modules(): if hasattr(m, "_max_cache_len"): if isinstance(m, MultiHeadAttention): - m._cache_id = streaming_cfg.last_channel_num m.cache_drop_size = streaming_cfg.cache_drop_size - streaming_cfg.last_channel_num += 1 - if isinstance(m, CausalConv1D): - m._cache_id = streaming_cfg.last_time_num m.cache_drop_size = streaming_cfg.cache_drop_size - streaming_cfg.last_time_num += 1 self.streaming_cfg = streaming_cfg @@ -886,19 +887,12 @@ def get_initial_cache_state(self, batch_size=1, dtype=torch.float32, device=None create_tensor = torch.zeros last_time_cache_size = self.conv_context_size[0] cache_last_channel = create_tensor( - ( - self.streaming_cfg.last_channel_num, - batch_size, - self.streaming_cfg.last_channel_cache_size, - self.d_model, - ), + (len(self.layers), batch_size, self.streaming_cfg.last_channel_cache_size, self.d_model,), device=device, dtype=dtype, ) cache_last_time = create_tensor( - (self.streaming_cfg.last_time_num, batch_size, self.d_model, last_time_cache_size), - device=device, - dtype=dtype, + (len(self.layers), batch_size, self.d_model, last_time_cache_size), device=device, dtype=dtype, ) if max_dim > 0: cache_last_channel_len = torch.randint( diff --git a/nemo/collections/asr/parts/submodules/adapters/multi_head_attention_adapter_module.py b/nemo/collections/asr/parts/submodules/adapters/multi_head_attention_adapter_module.py index 169dde48602f..563d4219baa7 100644 --- a/nemo/collections/asr/parts/submodules/adapters/multi_head_attention_adapter_module.py +++ b/nemo/collections/asr/parts/submodules/adapters/multi_head_attention_adapter_module.py @@ -147,18 +147,18 @@ def __init__( # reset parameters for Q to be identity operation self.reset_parameters() - def forward(self, query, key, value, mask, pos_emb=None, cache=None, cache_next=None): + def forward(self, query, key, value, mask, pos_emb=None, cache=None): """Compute 'Scaled Dot Product Attention'. Args: query (torch.Tensor): (batch, time1, size) key (torch.Tensor): (batch, time2, size) value(torch.Tensor): (batch, time2, size) mask (torch.Tensor): (batch, time1, time2) - cache (torch.Tensor) : (cache_nums, batch, time_cache, size) - cache_next (torch.Tensor) : (cache_nums, batch, time_cache_next, size) + cache (torch.Tensor) : (batch, time_cache, size) returns: output (torch.Tensor): transformed `value` (batch, time1, d_model) weighted by the query dot key attention + cache (torch.Tensor) : (batch, time_cache_next, size) """ # Need to perform duplicate computations as at this point the tensors have been # separated by the adapter forward @@ -166,7 +166,7 @@ def forward(self, query, key, value, mask, pos_emb=None, cache=None, cache_next= key = self.pre_norm(key) value = self.pre_norm(value) - return super().forward(query, key, value, mask, pos_emb, cache=cache, cache_next=cache_next) + return super().forward(query, key, value, mask, pos_emb, cache=cache) def reset_parameters(self): with torch.no_grad(): @@ -242,7 +242,7 @@ def __init__( # reset parameters for Q to be identity operation self.reset_parameters() - def forward(self, query, key, value, mask, pos_emb, cache=None, cache_next=None): + def forward(self, query, key, value, mask, pos_emb, cache=None): """Compute 'Scaled Dot Product Attention' with rel. positional encoding. Args: query (torch.Tensor): (batch, time1, size) @@ -250,10 +250,10 @@ def forward(self, query, key, value, mask, pos_emb, cache=None, cache_next=None) value(torch.Tensor): (batch, time2, size) mask (torch.Tensor): (batch, time1, time2) pos_emb (torch.Tensor) : (batch, time1, size) - cache (torch.Tensor) : (cache_nums, batch, time_cache, size) - cache_next (torch.Tensor) : (cache_nums, batch, time_cache_next, size) + cache (torch.Tensor) : (batch, time_cache, size) Returns: output (torch.Tensor): transformed `value` (batch, time1, d_model) weighted by the query dot key attention + cache_next (torch.Tensor) : (batch, time_cache_next, size) """ # Need to perform duplicate computations as at this point the tensors have been # separated by the adapter forward @@ -261,7 +261,7 @@ def forward(self, query, key, value, mask, pos_emb, cache=None, cache_next=None) key = self.pre_norm(key) value = self.pre_norm(value) - return super().forward(query, key, value, mask, pos_emb, cache=cache, cache_next=cache_next) + return super().forward(query, key, value, mask, pos_emb, cache=cache) def reset_parameters(self): with torch.no_grad(): diff --git a/nemo/collections/asr/parts/submodules/causal_convs.py b/nemo/collections/asr/parts/submodules/causal_convs.py index 25f841802154..c6251690b1b1 100644 --- a/nemo/collections/asr/parts/submodules/causal_convs.py +++ b/nemo/collections/asr/parts/submodules/causal_convs.py @@ -45,7 +45,6 @@ def __init__( raise ValueError("Argument padding should be set to None for CausalConv2D.") self._left_padding = kernel_size - 1 self._right_padding = stride - 1 - self._cache_id = None padding = 0 super(CausalConv2D, self).__init__( @@ -113,7 +112,6 @@ def __init__( raise ValueError(f"Invalid padding param: {padding}!") self._max_cache_len = self._left_padding - self._cache_id = None super(CausalConv1D, self).__init__( in_channels=in_channels, @@ -129,21 +127,21 @@ def __init__( dtype=dtype, ) - def update_cache(self, x, cache=None, cache_next=None): + def update_cache(self, x, cache=None): if cache is None: new_x = F.pad(x, pad=(self._left_padding, self._right_padding)) else: new_x = F.pad(x, pad=(0, self._right_padding)) - new_x = torch.cat([cache[self._cache_id], new_x], dim=-1) - # todo: we should know input_x.size(-1) at config time - if cache_next is not None: - cache_keep_size = torch.tensor(x.size(-1) - self.cache_drop_size, dtype=torch.int64, device=x.device) - cache_keep_size = torch.clip(cache_keep_size, min=1, max=cache_next.size(-1)) - cache_next[self._cache_id, :, :, :-cache_keep_size] = cache[self._cache_id, :, :, cache_keep_size:] - cache_next[self._cache_id, :, :, -cache_keep_size:] = x[:, :, :cache_keep_size] - return new_x - - def forward(self, x, cache=None, cache_next=None): - x = self.update_cache(x, cache=cache, cache_next=cache_next) + new_x = torch.cat([cache, new_x], dim=-1) + if self.cache_drop_size > 0: + x = x[:, :, : -self.cache_drop_size] + cache = torch.cat([cache[:, :, x.size(-1) :], x], dim=-1) + return new_x, cache + + def forward(self, x, cache=None): + x, cache = self.update_cache(x, cache=cache) x = super().forward(x) - return x + if cache is None: + return x + else: + return x, cache diff --git a/nemo/collections/asr/parts/submodules/conformer_modules.py b/nemo/collections/asr/parts/submodules/conformer_modules.py index 579b78a8f5a8..677d2acd9f2e 100644 --- a/nemo/collections/asr/parts/submodules/conformer_modules.py +++ b/nemo/collections/asr/parts/submodules/conformer_modules.py @@ -138,29 +138,19 @@ def __init__( self.dropout = nn.Dropout(dropout) self.norm_out = LayerNorm(d_model) - def forward( - self, - x, - att_mask=None, - pos_emb=None, - pad_mask=None, - cache_last_channel=None, - cache_last_time=None, - cache_last_channel_next=None, - cache_last_time_next=None, - ): + def forward(self, x, att_mask=None, pos_emb=None, pad_mask=None, cache_last_channel=None, cache_last_time=None): """ Args: x (torch.Tensor): input signals (B, T, d_model) att_mask (torch.Tensor): attention masks(B, T, T) pos_emb (torch.Tensor): (L, 1, d_model) pad_mask (torch.tensor): padding mask - cache_last_channel (torch.tensor) : cache for MHA layers (N, B, T_cache, d_model) - cache_last_time (torch.tensor) : cache for convolutional layers (N, B, d_model, T_cache) - cache_last_channel_next (torch.tensor) : next cache for MHA layers (N, B, T_cache, d_model) - cache_last_time_next (torch.tensor) : next cache for convolutional layers (N, B, d_model, T_cache) + cache_last_channel (torch.tensor) : cache for MHA layers (B, T_cache, d_model) + cache_last_time (torch.tensor) : cache for convolutional layers (B, d_model, T_cache) Returns: x (torch.Tensor): (B, T, d_model) + cache_last_channel (torch.tensor) : next cache for MHA layers (B, T_cache, d_model) + cache_last_time (torch.tensor) : next cache for convolutional layers (B, d_model, T_cache) """ residual = x x = self.norm_feed_forward1(x) @@ -169,31 +159,17 @@ def forward( x = self.norm_self_att(residual) if self.self_attention_model == 'rel_pos': - x = self.self_attn( - query=x, - key=x, - value=x, - mask=att_mask, - pos_emb=pos_emb, - cache=cache_last_channel, - cache_next=cache_last_channel_next, - ) + x = self.self_attn(query=x, key=x, value=x, mask=att_mask, pos_emb=pos_emb, cache=cache_last_channel) elif self.self_attention_model == 'rel_pos_local_attn': - x = self.self_attn( - query=x, - key=x, - value=x, - pad_mask=pad_mask, - pos_emb=pos_emb, - cache=cache_last_channel, - cache_next=cache_last_channel_next, - ) + x = self.self_attn(query=x, key=x, value=x, pad_mask=pad_mask, pos_emb=pos_emb, cache=cache_last_channel) elif self.self_attention_model == 'abs_pos': - x = self.self_attn( - query=x, key=x, value=x, mask=att_mask, cache=cache_last_channel, cache_next=cache_last_channel_next - ) + x = self.self_attn(query=x, key=x, value=x, mask=att_mask, cache=cache_last_channel) else: x = None + + if x is not None and cache_last_channel is not None: + (x, cache_last_channel) = x + residual = residual + self.dropout(x) if self.is_adapter_available(): @@ -208,7 +184,9 @@ def forward( residual = pack_ip['x'] x = self.norm_conv(residual) - x = self.conv(x, pad_mask=pad_mask, cache=cache_last_time, cache_next=cache_last_time_next) + x = self.conv(x, pad_mask=pad_mask, cache=cache_last_time) + if cache_last_time is not None: + (x, cache_last_time) = x residual = residual + self.dropout(x) x = self.norm_feed_forward2(residual) @@ -228,8 +206,10 @@ def forward( if self.is_access_enabled() and self.access_cfg.get('save_encoder_tensors', False): self.register_accessible_tensor(name='encoder', tensor=x) - - return x + if cache_last_channel is None: + return x + else: + return x, cache_last_channel, cache_last_time def forward_single_enabled_adapter_( self, @@ -355,7 +335,7 @@ def __init__( in_channels=dw_conv_input_dim, out_channels=d_model, kernel_size=1, stride=1, padding=0, bias=True ) - def forward(self, x, pad_mask=None, cache=None, cache_next=None): + def forward(self, x, pad_mask=None, cache=None): x = x.transpose(1, 2) x = self.pointwise_conv1(x) @@ -368,10 +348,9 @@ def forward(self, x, pad_mask=None, cache=None, cache_next=None): if pad_mask is not None: x = x.float().masked_fill(pad_mask.unsqueeze(1), 0.0) + x = self.depthwise_conv(x, cache=cache) if cache is not None: - x = self.depthwise_conv(x, cache=cache, cache_next=cache_next) - else: - x = self.depthwise_conv(x) + x, cache = x if self.norm_type == "layer_norm": x = x.transpose(1, 2) @@ -383,7 +362,10 @@ def forward(self, x, pad_mask=None, cache=None, cache_next=None): x = self.activation(x) x = self.pointwise_conv2(x) x = x.transpose(1, 2) - return x + if cache is None: + return x + else: + return x, cache def reset_parameters_conv(self): pw1_max = pw2_max = self.d_model ** -0.5 diff --git a/nemo/collections/asr/parts/submodules/multi_head_attention.py b/nemo/collections/asr/parts/submodules/multi_head_attention.py index b7356ffe87e4..a0253524419e 100644 --- a/nemo/collections/asr/parts/submodules/multi_head_attention.py +++ b/nemo/collections/asr/parts/submodules/multi_head_attention.py @@ -73,7 +73,6 @@ def __init__(self, n_head, n_feat, dropout_rate, max_cache_len=0): self.dropout = nn.Dropout(p=dropout_rate) self._max_cache_len = max_cache_len - self._cache_id = None def forward_qkv(self, query, key, value): """Transforms query, key and value. @@ -119,20 +118,20 @@ def forward_attention(self, value, scores, mask): return self.linear_out(x) # (batch, time1, d_model) - def forward(self, query, key, value, mask, pos_emb=None, cache=None, cache_next=None): + def forward(self, query, key, value, mask, pos_emb=None, cache=None): """Compute 'Scaled Dot Product Attention'. Args: query (torch.Tensor): (batch, time1, size) key (torch.Tensor): (batch, time2, size) value(torch.Tensor): (batch, time2, size) mask (torch.Tensor): (batch, time1, time2) - cache (torch.Tensor) : (cache_nums, batch, time_cache, size) - cache_next (torch.Tensor) : (cache_nums, batch, time_cache_next, size) + cache (torch.Tensor) : (batch, time_cache, size) returns: output (torch.Tensor): transformed `value` (batch, time1, d_model) weighted by the query dot key attention + cache (torch.Tensor) : (batch, time_cache_next, size) """ - key, value, query = self.update_cache(key=key, value=value, query=query, cache=cache, cache_next=cache_next) + key, value, query, cache = self.update_cache(key=key, value=value, query=query, cache=cache) if torch.is_autocast_enabled(): query, key, value = query.to(torch.float32), key.to(torch.float32), value.to(torch.float32) @@ -142,17 +141,17 @@ def forward(self, query, key, value, mask, pos_emb=None, cache=None, cache_next= q, k, v = self.forward_qkv(query, key, value) scores = torch.matmul(q, k.transpose(-2, -1)) / self.s_d_k out = self.forward_attention(v, scores, mask) + if cache is None: + return out + else: + return out, cache - return out - - def update_cache(self, key, value, query, cache, cache_next): + def update_cache(self, key, value, query, cache): if cache is not None: - key = value = torch.cat([cache[self._cache_id], key], dim=1) + key = value = torch.cat([cache, key], dim=1) q_keep_size = query.shape[1] - self.cache_drop_size - if cache_next is not None: - cache_next[self._cache_id, :, :-q_keep_size, :] = cache[self._cache_id, :, q_keep_size:, :] - cache_next[self._cache_id, :, -q_keep_size:, :] = query[:, :q_keep_size, :] - return key, value, query + cache = torch.cat([cache[:, q_keep_size:, :], query[:, :q_keep_size, :]], dim=1) + return key, value, query, cache class RelPositionMultiHeadAttention(MultiHeadAttention): @@ -195,7 +194,7 @@ def rel_shift(self, x): x = x[:, :, 1:].view(b, h, qlen, pos_len) # (b, h, t1, t2) return x - def forward(self, query, key, value, mask, pos_emb, cache=None, cache_next=None): + def forward(self, query, key, value, mask, pos_emb, cache=None): """Compute 'Scaled Dot Product Attention' with rel. positional encoding. Args: query (torch.Tensor): (batch, time1, size) @@ -203,12 +202,13 @@ def forward(self, query, key, value, mask, pos_emb, cache=None, cache_next=None) value(torch.Tensor): (batch, time2, size) mask (torch.Tensor): (batch, time1, time2) pos_emb (torch.Tensor) : (batch, time1, size) - cache (torch.Tensor) : (cache_nums, batch, time_cache, size) - cache_next (torch.Tensor) : (cache_nums, batch, time_cache_next, size) + cache (torch.Tensor) : (batch, time_cache, size) + Returns: output (torch.Tensor): transformed `value` (batch, time1, d_model) weighted by the query dot key attention + cache (torch.Tensor) : (batch, time_cache_next, size) """ - key, value, query = self.update_cache(key=key, value=value, query=query, cache=cache, cache_next=cache_next) + key, value, query, cache = self.update_cache(key=key, value=value, query=query, cache=cache) if torch.is_autocast_enabled(): query, key, value = query.to(torch.float32), key.to(torch.float32), value.to(torch.float32) @@ -244,7 +244,10 @@ def forward(self, query, key, value, mask, pos_emb, cache=None, cache_next=None) out = self.forward_attention(v, scores, mask) - return out + if cache is None: + return out + else: + return out, cache class RelPositionMultiHeadAttentionLongformer(RelPositionMultiHeadAttention): @@ -298,7 +301,7 @@ def __init__( self.global_k = nn.Linear(n_feat, n_feat) self.global_v = nn.Linear(n_feat, n_feat) - def forward(self, query, key, value, pad_mask, pos_emb, cache=None, cache_next=None): + def forward(self, query, key, value, pad_mask, pos_emb, cache=None): """Compute Scaled Dot Product Local Attention with rel. positional encoding. using overlapping chunks Args: query (torch.Tensor): (batch, time, size) @@ -306,13 +309,13 @@ def forward(self, query, key, value, pad_mask, pos_emb, cache=None, cache_next=N value(torch.Tensor): (batch, time, size) pad_mask (torch.Tensor): (batch, time) pos_emb (torch.Tensor) : (batch, 2w + 1, size) - cache (torch.Tensor) : (cache_nums, batch, time_cache, size) - cache_next (torch.Tensor) : (cache_nums, batch, time_cache_next, size) + cache (torch.Tensor) : (batch, time_cache, size) Returns: output (torch.Tensor): transformed `value` (batch, time1, d_model) weighted by the query dot key attention + cache (torch.Tensor) : (batch, time_cache_next, size) """ - key, value, query = self.update_cache(key=key, value=value, query=query, cache=cache, cache_next=cache_next) + key, value, query, cache = self.update_cache(key=key, value=value, query=query, cache=cache) if torch.is_autocast_enabled(): query, key, value = query.to(torch.float32), key.to(torch.float32), value.to(torch.float32) @@ -453,7 +456,11 @@ def forward(self, query, key, value, pad_mask, pos_emb, cache=None, cache_next=N out[is_index_global_attn_nonzero] += out_global_to_all - return self.linear_out(out.reshape(n_batch, -1, self.h * self.d_k)[:, :T]) + ret = self.linear_out(out.reshape(n_batch, -1, self.h * self.d_k)[:, :T]) + if cache is None: + return ret + else: + return ret, cache def _get_global_attn_indices(self, is_index_global_attn: torch.Tensor) -> Tuple: """