diff --git a/src/transformers/generation/flax_utils.py b/src/transformers/generation/flax_utils.py index 08480ac983e805..88535b44e9c479 100644 --- a/src/transformers/generation/flax_utils.py +++ b/src/transformers/generation/flax_utils.py @@ -397,6 +397,8 @@ def generate( "(https://huggingface.co/docs/transformers/main/en/main_classes/text_generation)" ) generation_config.max_length = generation_config.max_new_tokens + input_ids_seq_length + else: # by default let's always generate 10 new tokens + generation_config.max_length = generation_config.max_length + input_ids_seq_length if generation_config.min_length is not None and generation_config.min_length > generation_config.max_length: raise ValueError( diff --git a/tests/generation/test_flax_utils.py b/tests/generation/test_flax_utils.py index 647482b88cd83f..bb0c1828763bb6 100644 --- a/tests/generation/test_flax_utils.py +++ b/tests/generation/test_flax_utils.py @@ -101,6 +101,10 @@ def test_greedy_generate_pt_fx(self): pt_model = pt_model_class(config).eval() pt_model = load_flax_weights_in_pytorch_model(pt_model, flax_model.params) + # Generate max 5 tokens only otherwise seems to be numerical error accumulation + pt_model.generation_config.max_length = 5 + flax_model.generation_config.max_length = 5 + flax_generation_outputs = flax_model.generate(input_ids).sequences pt_generation_outputs = pt_model.generate(torch.tensor(input_ids, dtype=torch.long)) diff --git a/tests/test_modeling_common.py b/tests/test_modeling_common.py index 51d51dfcc2825c..d88b0dc5f02f83 100755 --- a/tests/test_modeling_common.py +++ b/tests/test_modeling_common.py @@ -3002,7 +3002,7 @@ def test_inputs_embeds_matches_input_ids(self): def test_inputs_embeds_matches_input_ids_with_generate(self): config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() - for model_class in self.all_model_classes: + for model_class in self.all_generative_model_classes: if model_class.__name__ not in [ *get_values(MODEL_FOR_CAUSAL_LM_MAPPING_NAMES), *get_values(MODEL_FOR_VISION_2_SEQ_MAPPING_NAMES),