diff --git a/examples/sft/train.py b/examples/sft/train.py index 88263e6707..baaa4bfc0c 100644 --- a/examples/sft/train.py +++ b/examples/sft/train.py @@ -137,7 +137,8 @@ def main(model_args, data_args, training_args): max_seq_length=data_args.max_seq_length, ) trainer.accelerator.print(f"{trainer.model}") - trainer.model.print_trainable_parameters() + if hasattr(trainer.model, "print_trainable_parameters()"): + trainer.model.print_trainable_parameters() # train checkpoint = None