Skip to content

Commit

Permalink
Merge branch 'r1.20.0' into spellmapper_fix_bugs
Browse files Browse the repository at this point in the history
  • Loading branch information
bene-ges committed Jul 8, 2023
2 parents 03e3481 + f08cb21 commit 35daa87
Show file tree
Hide file tree
Showing 12 changed files with 220 additions and 167 deletions.
10 changes: 10 additions & 0 deletions docs/source/asr/models.rst
Original file line number Diff line number Diff line change
Expand Up @@ -215,6 +215,11 @@ It is recommended to train a model in streaming model with limited context for t

You may find FastConformer variants of cache-aware streaming models under ``<NeMo_git_root>/examples/asr/conf/fastconformer/``.

Note cache-aware streaming models are being exported without caching support by default.
To include caching support, `model.set_export_config({'cache_support' : 'True'})` should be called before export.
Or, if ``<NeMo_git_root>/scripts/export.py`` is being used:
`python export.py cache_aware_conformer.nemo cache_aware_conformer.onnx --config cache_support=True`

.. _LSTM-Transducer_model:

LSTM-Transducer
Expand Down Expand Up @@ -291,6 +296,11 @@ Similar example configs for FastConformer variants of Hybrid models can be found
``<NeMo_git_root>/examples/asr/conf/fastconformer/hybrid_transducer_ctc/``
``<NeMo_git_root>/examples/asr/conf/fastconformer/hybrid_cache_aware_streaming/``

Note Hybrid models are being exported as RNNT (encoder and decoder+joint parts) by default.
To export as CTC (single encoder+decoder graph), `model.set_export_config({'decoder_type' : 'ctc'})` should be called before export.
Or, if ``<NeMo_git_root>/scripts/export.py`` is being used:
`python export.py hybrid_transducer.nemo hybrid_transducer.onnx --config decoder_type=ctc`

.. _Conformer-HAT_model:

Conformer-HAT (Hybrid Autoregressive Transducer)
Expand Down
31 changes: 31 additions & 0 deletions docs/source/core/export.rst
Original file line number Diff line number Diff line change
Expand Up @@ -177,6 +177,37 @@ Another common requirement for models that are being exported is to run certain
# call base method for common set of modifications
Exportable._prepare_for_export(self, **kwargs)
Some models that require control flow, need to be exported in multiple parts. Typical examples are RNNT nets.
To facilitate that, the hooks below are provided. To export, for example, 'encoder' and 'decoder' subnets of the model, overload list_export_subnets to return ['encoder', 'decoder'].

.. code-block:: Python
def get_export_subnet(self, subnet=None):
"""
Returns Exportable subnet model/module to export
"""
def list_export_subnets(self):
"""
Returns default set of subnet names exported for this model
First goes the one receiving input (input_example)
"""
Some nertworks may be exported differently according to user-settable options (like ragged batch support for TTS or cache support for ASR). To facilitate that - `set_export_config()` method is provided by Exportable to set key/value pairs to predefined model.export_config dictionary, to be used during the export:

.. code-block:: Python
def set_export_config(self, args):
"""
Sets/updates export_config dictionary
"""
Also, if an action hook on setting config is desired, this method may be overloaded by `Exportable` descendants to include one.
An example can be found in ``<NeMo_git_root>/nemo/collections/asr/models/rnnt_models.py``.

Here is example on now `set_export_config()` call is being tied to command line arguments in ``<NeMo_git_root>/scripts/export.py`` :

.. code-block:: Python
python scripts/export.py hybrid_conformer.nemo hybrid_conformer.onnx --config decoder_type=ctc
Exportable Model Code
~~~~~~~~~~~~~~~~~~~~~
Expand Down
72 changes: 28 additions & 44 deletions nemo/collections/asr/models/asr_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,7 +161,7 @@ def output_module(self):
@property
def output_names(self):
otypes = self.output_module.output_types
if hasattr(self.input_module, 'export_cache_support') and self.input_module.export_cache_support:
if getattr(self.input_module, 'export_cache_support', False):
in_types = self.input_module.output_types
otypes = {n: t for (n, t) in list(otypes.items())[:1]}
for (n, t) in list(in_types.items())[1:]:
Expand All @@ -174,7 +174,6 @@ def forward_for_export(
"""
This forward is used when we need to export the model to ONNX format.
Inputs cache_last_channel and cache_last_time are needed to be passed for exporting streaming models.
When they are passed, it just passes the inputs through the encoder part and currently the ONNX conversion does not fully work for this case.
Args:
input: Tensor that represents a batch of raw audio signals,
of shape [B, T]. T here represents timesteps.
Expand All @@ -187,49 +186,26 @@ def forward_for_export(
Returns:
the output of the model
"""
if hasattr(self.input_module, 'forward_for_export'):
if cache_last_channel is None and cache_last_time is None:
encoder_output = self.input_module.forward_for_export(audio_signal=input, length=length)
else:
encoder_output = self.input_module.forward_for_export(
audio_signal=input,
length=length,
cache_last_channel=cache_last_channel,
cache_last_time=cache_last_time,
cache_last_channel_len=cache_last_channel_len,
)
enc_fun = getattr(self.input_module, 'forward_for_export', self.input_module.forward)
if cache_last_channel is None:
encoder_output = enc_fun(audio_signal=input, length=length)
if isinstance(encoder_output, tuple):
encoder_output = encoder_output[0]
else:
if cache_last_channel is None and cache_last_time is None:
encoder_output = self.input_module(audio_signal=input, length=length)
else:
encoder_output = self.input_module(
audio_signal=input,
length=length,
cache_last_channel=cache_last_channel,
cache_last_time=cache_last_time,
cache_last_channel_len=cache_last_channel_len,
)
if isinstance(encoder_output, tuple):
decoder_input = encoder_output[0]
else:
decoder_input = encoder_output
if hasattr(self.output_module, 'forward_for_export'):
if cache_last_channel is None and cache_last_time is None:
ret = self.output_module.forward_for_export(encoder_output=decoder_input)
else:
ret = self.output_module.forward_for_export(encoder_output=decoder_input)
else:
if cache_last_channel is None and cache_last_time is None:
ret = self.output_module(encoder_output=decoder_input)
else:
ret = self.output_module(encoder_output=decoder_input)
if cache_last_channel is None and cache_last_time is None:
pass
else:
if isinstance(ret, tuple):
ret = (ret[0], encoder_output[1], encoder_output[2], encoder_output[3], encoder_output[4])
else:
ret = (ret, encoder_output[1], encoder_output[2], encoder_output[3], encoder_output[4])
encoder_output, length, cache_last_channel, cache_last_time, cache_last_channel_len = enc_fun(
audio_signal=input,
length=length,
cache_last_channel=cache_last_channel,
cache_last_time=cache_last_time,
cache_last_channel_len=cache_last_channel_len,
)

dec_fun = getattr(self.output_module, 'forward_for_export', self.output_module.forward)
ret = dec_fun(encoder_output=encoder_output)
if isinstance(ret, tuple):
ret = ret[0]
if cache_last_channel is not None:
ret = (ret, length, cache_last_channel, cache_last_time, cache_last_channel_len)
return cast_all(ret, from_dtype=torch.float16, to_dtype=torch.float32)

@property
Expand All @@ -239,3 +215,11 @@ def disabled_deployment_input_names(self):
@property
def disabled_deployment_output_names(self):
return self.encoder.disabled_deployment_output_names

def set_export_config(self, args):
if 'cache_support' in args:
enable = bool(args['cache_support'])
self.encoder.export_cache_support = enable
logging.info(f"Caching support enabled: {enable}")
self.encoder.setup_streaming_params()
super().set_export_config(args)
14 changes: 14 additions & 0 deletions nemo/collections/asr/models/hybrid_rnnt_ctc_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -645,6 +645,20 @@ def multi_test_epoch_end(self, outputs, dataloader_idx: int = 0):
self.finalize_interctc_metrics(metrics, outputs, prefix="test_")
return metrics

# EncDecRNNTModel is exported in 2 parts
def list_export_subnets(self):
if self.cur_decoder == 'rnnt':
return ['encoder', 'decoder_joint']
else:
return ['self']

@property
def output_module(self):
if self.cur_decoder == 'rnnt':
return self.decoder
else:
return self.ctc_decoder

@classmethod
def list_available_models(cls) -> Optional[PretrainedModelInfo]:
"""
Expand Down
12 changes: 10 additions & 2 deletions nemo/collections/asr/models/rnnt_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
from nemo.collections.asr.data.audio_to_text_dali import AudioToCharDALIDataset, DALIOutputs
from nemo.collections.asr.losses.rnnt import RNNTLoss, resolve_rnnt_default_loss_name
from nemo.collections.asr.metrics.rnnt_wer import RNNTWER, RNNTDecoding, RNNTDecodingConfig
from nemo.collections.asr.models.asr_model import ASRModel
from nemo.collections.asr.models.asr_model import ASRModel, ExportableEncDecModel
from nemo.collections.asr.modules.rnnt import RNNTDecoderJoint
from nemo.collections.asr.parts.mixins import ASRModuleMixin
from nemo.collections.asr.parts.utils.audio_utils import ChannelSelectorType
Expand All @@ -39,7 +39,7 @@
from nemo.utils import logging


class EncDecRNNTModel(ASRModel, ASRModuleMixin, Exportable):
class EncDecRNNTModel(ASRModel, ASRModuleMixin, ExportableEncDecModel):
"""Base class for encoder decoder RNNT-based models."""

def __init__(self, cfg: DictConfig, trainer: Trainer = None):
Expand Down Expand Up @@ -960,6 +960,14 @@ def list_export_subnets(self):
def decoder_joint(self):
return RNNTDecoderJoint(self.decoder, self.joint)

def set_export_config(self, args):
if 'decoder_type' in args:
if hasattr(self, 'change_decoding_strategy'):
self.change_decoding_strategy(decoder_type=args['decoder_type'])
else:
raise Exception("Model does not have decoder type option")
super().set_export_config(args)

@classmethod
def list_available_models(cls) -> List[PretrainedModelInfo]:
"""
Expand Down
48 changes: 21 additions & 27 deletions nemo/collections/asr/modules/conformer_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -505,11 +505,6 @@ def forward_internal(
(audio_signal.size(0),), audio_signal.size(-1), dtype=torch.int64, device=audio_signal.device
)

if cache_last_time is not None:
cache_last_time_next = torch.zeros_like(cache_last_time)
else:
cache_last_time_next = None

# select a random att_context_size with the distribution specified by att_context_probs during training
# for non-validation cases like test, validation or inference, it uses the first mode in self.att_context_size
if self.training and len(self.att_context_size_all) > 1:
Expand All @@ -536,7 +531,6 @@ def forward_internal(
if cache_last_channel is not None:
cache_len = self.streaming_cfg.last_channel_cache_size
cache_keep_size = max_audio_length - self.streaming_cfg.cache_drop_size
cache_last_channel_next = torch.zeros_like(cache_last_channel)
max_audio_length = max_audio_length + cache_len
padding_length = length + cache_len
offset = torch.neg(cache_last_channel_len) + cache_len
Expand All @@ -561,19 +555,32 @@ def forward_internal(
pad_mask = pad_mask[:, cache_len:]
if att_mask is not None:
att_mask = att_mask[:, cache_len:]
# Convert caches from the tensor to list
cache_last_time_next = []
cache_last_channel_next = []

for lth, (drop_prob, layer) in enumerate(zip(self.layer_drop_probs, self.layers)):
original_signal = audio_signal
if cache_last_channel is not None:
cache_last_channel_cur = cache_last_channel[lth]
cache_last_time_cur = cache_last_time[lth]
else:
cache_last_channel_cur = None
cache_last_time_cur = None
audio_signal = layer(
x=audio_signal,
att_mask=att_mask,
pos_emb=pos_emb,
pad_mask=pad_mask,
cache_last_channel=cache_last_channel,
cache_last_time=cache_last_time,
cache_last_channel_next=cache_last_channel_next,
cache_last_time_next=cache_last_time_next,
cache_last_channel=cache_last_channel_cur,
cache_last_time=cache_last_time_cur,
)

if cache_last_channel_cur is not None:
(audio_signal, cache_last_channel_cur, cache_last_time_cur) = audio_signal
cache_last_channel_next.append(cache_last_channel_cur)
cache_last_time_next.append(cache_last_time_cur)

# applying stochastic depth logic from https://arxiv.org/abs/2102.03216
if self.training and drop_prob > 0.0:
should_drop = torch.rand(1) < drop_prob
Expand Down Expand Up @@ -626,6 +633,8 @@ def forward_internal(
length = length.to(dtype=torch.int64)

if cache_last_channel is not None:
cache_last_channel_next = torch.stack(cache_last_channel_next, dim=0)
cache_last_time_next = torch.stack(cache_last_time_next, dim=0)
return (
audio_signal,
length,
Expand Down Expand Up @@ -860,20 +869,12 @@ def setup_streaming_params(
else:
streaming_cfg.drop_extra_pre_encoded = streaming_cfg.pre_encode_cache_size // self.subsampling_factor

# counting the number of the layers need caching
streaming_cfg.last_channel_num = 0
streaming_cfg.last_time_num = 0
for m in self.layers.modules():
if hasattr(m, "_max_cache_len"):
if isinstance(m, MultiHeadAttention):
m._cache_id = streaming_cfg.last_channel_num
m.cache_drop_size = streaming_cfg.cache_drop_size
streaming_cfg.last_channel_num += 1

if isinstance(m, CausalConv1D):
m._cache_id = streaming_cfg.last_time_num
m.cache_drop_size = streaming_cfg.cache_drop_size
streaming_cfg.last_time_num += 1

self.streaming_cfg = streaming_cfg

Expand All @@ -886,19 +887,12 @@ def get_initial_cache_state(self, batch_size=1, dtype=torch.float32, device=None
create_tensor = torch.zeros
last_time_cache_size = self.conv_context_size[0]
cache_last_channel = create_tensor(
(
self.streaming_cfg.last_channel_num,
batch_size,
self.streaming_cfg.last_channel_cache_size,
self.d_model,
),
(len(self.layers), batch_size, self.streaming_cfg.last_channel_cache_size, self.d_model,),
device=device,
dtype=dtype,
)
cache_last_time = create_tensor(
(self.streaming_cfg.last_time_num, batch_size, self.d_model, last_time_cache_size),
device=device,
dtype=dtype,
(len(self.layers), batch_size, self.d_model, last_time_cache_size), device=device, dtype=dtype,
)
if max_dim > 0:
cache_last_channel_len = torch.randint(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -147,26 +147,26 @@ def __init__(
# reset parameters for Q to be identity operation
self.reset_parameters()

def forward(self, query, key, value, mask, pos_emb=None, cache=None, cache_next=None):
def forward(self, query, key, value, mask, pos_emb=None, cache=None):
"""Compute 'Scaled Dot Product Attention'.
Args:
query (torch.Tensor): (batch, time1, size)
key (torch.Tensor): (batch, time2, size)
value(torch.Tensor): (batch, time2, size)
mask (torch.Tensor): (batch, time1, time2)
cache (torch.Tensor) : (cache_nums, batch, time_cache, size)
cache_next (torch.Tensor) : (cache_nums, batch, time_cache_next, size)
cache (torch.Tensor) : (batch, time_cache, size)
returns:
output (torch.Tensor): transformed `value` (batch, time1, d_model) weighted by the query dot key attention
cache (torch.Tensor) : (batch, time_cache_next, size)
"""
# Need to perform duplicate computations as at this point the tensors have been
# separated by the adapter forward
query = self.pre_norm(query)
key = self.pre_norm(key)
value = self.pre_norm(value)

return super().forward(query, key, value, mask, pos_emb, cache=cache, cache_next=cache_next)
return super().forward(query, key, value, mask, pos_emb, cache=cache)

def reset_parameters(self):
with torch.no_grad():
Expand Down Expand Up @@ -242,26 +242,26 @@ def __init__(
# reset parameters for Q to be identity operation
self.reset_parameters()

def forward(self, query, key, value, mask, pos_emb, cache=None, cache_next=None):
def forward(self, query, key, value, mask, pos_emb, cache=None):
"""Compute 'Scaled Dot Product Attention' with rel. positional encoding.
Args:
query (torch.Tensor): (batch, time1, size)
key (torch.Tensor): (batch, time2, size)
value(torch.Tensor): (batch, time2, size)
mask (torch.Tensor): (batch, time1, time2)
pos_emb (torch.Tensor) : (batch, time1, size)
cache (torch.Tensor) : (cache_nums, batch, time_cache, size)
cache_next (torch.Tensor) : (cache_nums, batch, time_cache_next, size)
cache (torch.Tensor) : (batch, time_cache, size)
Returns:
output (torch.Tensor): transformed `value` (batch, time1, d_model) weighted by the query dot key attention
cache_next (torch.Tensor) : (batch, time_cache_next, size)
"""
# Need to perform duplicate computations as at this point the tensors have been
# separated by the adapter forward
query = self.pre_norm(query)
key = self.pre_norm(key)
value = self.pre_norm(value)

return super().forward(query, key, value, mask, pos_emb, cache=cache, cache_next=cache_next)
return super().forward(query, key, value, mask, pos_emb, cache=cache)

def reset_parameters(self):
with torch.no_grad():
Expand Down
Loading

0 comments on commit 35daa87

Please sign in to comment.