From 4320b1c5bbd1944fc3ffb7e7f8a0920cb1099431 Mon Sep 17 00:00:00 2001 From: Niels Date: Thu, 28 Dec 2023 16:26:50 +0100 Subject: [PATCH] Fix test --- tests/test_tokenization_common.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tests/test_tokenization_common.py b/tests/test_tokenization_common.py index 9b60b2f186738e..aa32cb34437afc 100644 --- a/tests/test_tokenization_common.py +++ b/tests/test_tokenization_common.py @@ -2624,7 +2624,8 @@ def test_prepare_seq2seq_batch(self): src_texts=src_text, max_length=3, max_target_length=10, return_tensors="pt" ) self.assertEqual(batch_encoder_only.input_ids.shape[1], 3) - self.assertEqual(batch_encoder_only.attention_mask.shape[1], 3) + if "attention_mask" in tokenizer.model_input_names: + self.assertEqual(batch_encoder_only.attention_mask.shape[1], 3) self.assertNotIn("decoder_input_ids", batch_encoder_only) def test_is_fast(self):