Skip to content

Commit

Permalink
Merge branch 'main' into jiemingz/first_val_step
Browse files Browse the repository at this point in the history
  • Loading branch information
ericharper authored Jan 25, 2024
2 parents a5f5d44 + f10d694 commit 2b47365
Show file tree
Hide file tree
Showing 3 changed files with 27 additions and 9 deletions.
7 changes: 7 additions & 0 deletions nemo/collections/asr/models/rnnt_bpe_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -459,6 +459,13 @@ def change_decoding_strategy(self, decoding_cfg: DictConfig):
decoding_cls = OmegaConf.create(OmegaConf.to_container(decoding_cls))
decoding_cfg = OmegaConf.merge(decoding_cls, decoding_cfg)

loss_name, loss_kwargs = self.extract_rnnt_loss_cfg(self.cfg.get("loss", None))

if loss_name == 'tdt':
decoding_cfg.durations = loss_kwargs.durations
elif loss_name == 'multiblank_rnnt':
decoding_cfg.big_blank_durations = loss_kwargs.big_blank_durations

self.decoding = RNNTBPEDecoding(
decoding_cfg=decoding_cfg, decoder=self.decoder, joint=self.joint, tokenizer=self.tokenizer,
)
Expand Down
3 changes: 3 additions & 0 deletions nemo/collections/asr/models/rnnt_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,9 @@ def __init__(self, cfg: DictConfig, trainer: Trainer = None):

if loss_name == 'tdt':
num_classes = num_classes - self.joint.num_extra_outputs
self.cfg.decoding.durations = loss_kwargs.durations
elif loss_name == 'multiblank_rnnt':
self.cfg.decoding.big_blank_durations = loss_kwargs.big_blank_durations

self.loss = RNNTLoss(
num_classes=num_classes,
Expand Down
26 changes: 17 additions & 9 deletions nemo/collections/asr/parts/submodules/rnnt_decoding.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,15 +208,17 @@ def __init__(self, decoding_cfg, decoder, joint, blank_id: int):
self.compute_timestamps = self.cfg.get('compute_timestamps', None)
self.word_seperator = self.cfg.get('word_seperator', ' ')

if self.durations is not None: # this means it's a TDT model.
if self.durations is not None and self.durations != []: # this means it's a TDT model.
if blank_id == 0:
raise ValueError("blank_id must equal len(non_blank_vocabs) for TDT models")
if self.big_blank_durations is not None:
if self.big_blank_durations is not None and self.big_blank_durations != []:
raise ValueError("duration and big_blank_durations can't both be not None")
if self.cfg.strategy not in ['greedy', 'greedy_batch']:
raise ValueError("currently only greedy and greedy_batch inference is supported for TDT models")

if self.big_blank_durations is not None: # this means it's a multi-blank model.
if (
self.big_blank_durations is not None and self.big_blank_durations != []
): # this means it's a multi-blank model.
if blank_id == 0:
raise ValueError("blank_id must equal len(vocabs) for multi-blank RNN-T models")
if self.cfg.strategy not in ['greedy', 'greedy_batch']:
Expand Down Expand Up @@ -260,8 +262,8 @@ def __init__(self, decoding_cfg, decoder, joint, blank_id: int):
raise NotImplementedError(f"Confidence calculation is not supported for strategy `{self.cfg.strategy}`")

if self.cfg.strategy == 'greedy':
if self.big_blank_durations is None:
if self.durations is None:
if self.big_blank_durations is None or self.big_blank_durations == []:
if self.durations is None or self.durations == []:
self.decoding = rnnt_greedy_decoding.GreedyRNNTInfer(
decoder_model=decoder,
joint_model=joint,
Expand Down Expand Up @@ -303,8 +305,8 @@ def __init__(self, decoding_cfg, decoder, joint, blank_id: int):
)

elif self.cfg.strategy == 'greedy_batch':
if self.big_blank_durations is None:
if self.durations is None:
if self.big_blank_durations is None or self.big_blank_durations == []:
if self.durations is None or self.durations == []:
self.decoding = rnnt_greedy_decoding.GreedyBatchedRNNTInfer(
decoder_model=decoder,
joint_model=joint,
Expand Down Expand Up @@ -522,10 +524,10 @@ def decode_hypothesis(self, hypotheses_list: List[Hypothesis]) -> List[Union[Hyp

# RNN-T sample level is already preprocessed by implicit RNNT decoding
# Simply remove any blank and possibly big blank tokens
if self.big_blank_durations is not None: # multi-blank RNNT
if self.big_blank_durations is not None and self.big_blank_durations != []: # multi-blank RNNT
num_extra_outputs = len(self.big_blank_durations)
prediction = [p for p in prediction if p < self.blank_id - num_extra_outputs]
elif self.durations is not None: # TDT model.
elif self.durations is not None and self.durations != []: # TDT model.
prediction = [p for p in prediction if p < self.blank_id]
else: # standard RNN-T
prediction = [p for p in prediction if p != self.blank_id]
Expand Down Expand Up @@ -1508,6 +1510,12 @@ class RNNTDecodingConfig:
# can be used to change temperature for decoding
temperature: float = 1.0

# config for TDT decoding.
durations: Optional[List[int]] = field(default_factory=list)

# config for multiblank decoding.
big_blank_durations: Optional[List[int]] = field(default_factory=list)


@dataclass
class RNNTBPEDecodingConfig(RNNTDecodingConfig):
Expand Down

0 comments on commit 2b47365

Please sign in to comment.