diff --git a/eole/trainer.py b/eole/trainer.py index 81236b83..16f5cf16 100644 --- a/eole/trainer.py +++ b/eole/trainer.py @@ -437,11 +437,12 @@ def validate(self, valid_iter, moving_average=None): if len(self.valid_scorers) > 0: computed_metrics = {} start = time.time() - preds, texts_ref = self.scoring_preparator.translate( - model=self.model, - gpu_rank=self.gpu_rank, - step=self.optim.training_step, - ) + with get_autocast(enabled=self.optim.amp): + preds, texts_ref = self.scoring_preparator.translate( + model=self.model, + gpu_rank=self.gpu_rank, + step=self.optim.training_step, + ) logger.info( """The translation of the valid dataset for dynamic scoring took : {} s.""".format(