diff --git a/eole/utils/loss.py b/eole/utils/loss.py index 7683bd39..7b2b79d2 100644 --- a/eole/utils/loss.py +++ b/eole/utils/loss.py @@ -251,9 +251,8 @@ def ignore_prompt(self, batch): batch: The current batch. """ # Create a mask with zeros at prompt positions and ones at answer postions. - mask = batch["src"].squeeze(dim=2) == self.padding_idx + mask = batch["src"].squeeze(dim=-1) == self.padding_idx mask = torch.cumsum(mask.int(), 1) - mask = mask.unsqueeze(-1) # Apply the mask on the target side. batch["tgt"] *= mask.int() # Put the padding token index at the prompt positions.