From 3c7981d58a9a8164f627bb0589e3ed0f56f727d3 Mon Sep 17 00:00:00 2001 From: Alexandra Antonova Date: Tue, 21 Nov 2023 13:03:51 +0300 Subject: [PATCH] migrate to PTL 2.0 Signed-off-by: Alexandra Antonova --- .../spellchecking_asr_customization_train.py | 4 ++++ .../spellchecking_model.py | 9 +++++---- 2 files changed, 9 insertions(+), 4 deletions(-) diff --git a/examples/nlp/spellchecking_asr_customization/spellchecking_asr_customization_train.py b/examples/nlp/spellchecking_asr_customization/spellchecking_asr_customization_train.py index 7ea9314d196d..ac50b4121f15 100644 --- a/examples/nlp/spellchecking_asr_customization/spellchecking_asr_customization_train.py +++ b/examples/nlp/spellchecking_asr_customization/spellchecking_asr_customization_train.py @@ -47,6 +47,10 @@ @hydra_runner(config_path="conf", config_name="spellchecking_asr_customization_config") def main(cfg: DictConfig) -> None: + # PTL 2.0 has find_unused_parameters as False by default, so its required to set it to True + # when there are unused parameters like here + if cfg.trainer.strategy == 'ddp': + cfg.trainer.strategy = "ddp_find_unused_parameters_true" logging.info(f'Config Params: {OmegaConf.to_yaml(cfg)}') # Train the model diff --git a/nemo/collections/nlp/models/spellchecking_asr_customization/spellchecking_model.py b/nemo/collections/nlp/models/spellchecking_asr_customization/spellchecking_model.py index 15ffb2dd1bcd..1df35cb6c2f6 100644 --- a/nemo/collections/nlp/models/spellchecking_asr_customization/spellchecking_model.py +++ b/nemo/collections/nlp/models/spellchecking_asr_customization/spellchecking_model.py @@ -272,14 +272,15 @@ def validation_step(self, batch, batch_idx): ) val_loss = self.loss_fn(logits=logits, labels=labels, loss_mask=labels_mask) + self.validation_step_outputs.append({'val_loss': val_loss}) return {'val_loss': val_loss} - def validation_epoch_end(self, outputs): + def on_validation_epoch_end(self): """ Called at the end of validation to aggregate outputs. :param outputs: list of individual outputs of each validation step. """ - avg_loss = torch.stack([x['val_loss'] for x in outputs]).mean() + avg_loss = torch.stack([x['val_loss'] for x in self.validation_step_outputs]).mean() # Calculate metrics and classification report # Note that in our task recall = accuracy, and the recall column is the per class accuracy @@ -300,12 +301,12 @@ def test_step(self, batch, batch_idx): """ return self.validation_step(batch, batch_idx) - def test_epoch_end(self, outputs): + def on_test_epoch_end(self): """ Called at the end of test to aggregate outputs. :param outputs: list of individual outputs of each test step. """ - return self.validation_epoch_end(outputs) + return self.on_validation_epoch_end() # Functions for inference