From 27bc3054a3dd866d892e938c76ccf1ab6ddd9a16 Mon Sep 17 00:00:00 2001 From: "Ma, Guokai" Date: Wed, 26 Jun 2024 18:02:35 +0800 Subject: [PATCH] FIX sft script: only print trainable params if peft (#1888) --- examples/sft/train.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/examples/sft/train.py b/examples/sft/train.py index 88263e6707..baaa4bfc0c 100644 --- a/examples/sft/train.py +++ b/examples/sft/train.py @@ -137,7 +137,8 @@ def main(model_args, data_args, training_args): max_seq_length=data_args.max_seq_length, ) trainer.accelerator.print(f"{trainer.model}") - trainer.model.print_trainable_parameters() + if hasattr(trainer.model, "print_trainable_parameters()"): + trainer.model.print_trainable_parameters() # train checkpoint = None