diff --git a/llm_rl/src/llmtuner/hparams/finetuning_args.py b/llm_rl/src/llmtuner/hparams/finetuning_args.py index 48d94177..982a15a8 100644 --- a/llm_rl/src/llmtuner/hparams/finetuning_args.py +++ b/llm_rl/src/llmtuner/hparams/finetuning_args.py @@ -104,6 +104,10 @@ class FinetuningArguments: default=None, metadata={"help": "The path to the checkpoint saved queue file."} ) + improve_step: Optional[int] = field( + default=0, + metadata={"help": "The n-th improve step in ReST."} + ) def __post_init__(self): if isinstance(self.lora_target, str): # support custom target modules/layers of LoRA diff --git a/llm_rl/src/llmtuner/tuner/sft/custom_callback.py b/llm_rl/src/llmtuner/tuner/sft/custom_callback.py index eff18c1e..a852f570 100644 --- a/llm_rl/src/llmtuner/tuner/sft/custom_callback.py +++ b/llm_rl/src/llmtuner/tuner/sft/custom_callback.py @@ -3,10 +3,11 @@ from llmtuner.tuner.core.utils import is_first_node class SaveModelCallback(TrainerCallback): - def __init__(self, save_epochs, output_dir, checkpoint_saved_queue): + def __init__(self, save_epochs, output_dir, checkpoint_saved_queue, improve_step): self.save_epochs = save_epochs self.output_dir = output_dir self.checkpoint_saved_queue = checkpoint_saved_queue + self.curr_improve_step = improve_step self.curr_epoch = 0 def on_epoch_end(self, args, state, control, model=None, **kwargs): @@ -22,7 +23,7 @@ def on_save(self, args, state, control, model, **kwargs): return # Customized checkpoint name - custom_checkpoint_name = f"checkpoint_epoch_{int(self.curr_epoch)}" + custom_checkpoint_name = f"checkpoint_improve-{self.curr_improve_step}_epoch-{int(self.curr_epoch)}" # Original auto-saved checkpoint directory auto_checkpoint_dir = os.path.join(args.output_dir, f"checkpoint-{state.global_step}") diff --git a/llm_rl/src/llmtuner/tuner/sft/workflow.py b/llm_rl/src/llmtuner/tuner/sft/workflow.py index 546c18bb..17ed740e 100644 --- a/llm_rl/src/llmtuner/tuner/sft/workflow.py +++ b/llm_rl/src/llmtuner/tuner/sft/workflow.py @@ -56,7 +56,7 @@ def run_sft( training_args.report_to = ["wandb"] if model_args.use_custom_callback: - callbacks.append(SaveModelCallback(model_args.call_back_save_epochs, training_args.output_dir, finetuning_args.checkpoint_saved_queue)) + callbacks.append(SaveModelCallback(model_args.call_back_save_epochs, training_args.output_dir, finetuning_args.checkpoint_saved_queue, finetuning_args.improve_step)) # Initialize our Trainer trainer = CustomSeq2SeqTrainer( diff --git a/llm_self_train/pipelines/run_train.py b/llm_self_train/pipelines/run_train.py index 0deaed15..1b6b301b 100644 --- a/llm_self_train/pipelines/run_train.py +++ b/llm_self_train/pipelines/run_train.py @@ -6,7 +6,7 @@ with open('./resources/train_args.yml', 'r') as file: train_args = yaml.safe_load(file) -def run_sft(output_dir): +def run_sft(output_dir, improve_step, ): args = ["deepspeed", f"--num_gpus={config['num_gpus']}", "../llm_rl/src/train_bash.py"] for key, value in train_args.items(): if key in config: @@ -20,5 +20,6 @@ def run_sft(output_dir): else: args.append(f"--{key}") args.append(str(value)) + args.append(f"--improve_step={improve_step}") subprocess.run(args) \ No newline at end of file diff --git a/llm_self_train/train.py b/llm_self_train/train.py index 9ab55b80..d2f3e55e 100644 --- a/llm_self_train/train.py +++ b/llm_self_train/train.py @@ -19,10 +19,10 @@ def main(): if not os.path.exists("../llm_rl/data/sotopia_custom_training_sft.json"): preprocess_episodes_with_tag() - for i in range(config["num_improve_steps"]): + for improve_step in range(config["num_improve_steps"]): run_sft_completed = multiprocessing.Value('b', False) - curr_improve_dir = os.path.join(config['data_dir'], config["experiment_name"]) - sft_process = multiprocessing.Process(target=run_sft, args=(curr_improve_step, )) + output_dir = os.path.join(config['data_dir'], config["experiment_name"]) + sft_process = multiprocessing.Process(target=run_sft, args=(output_dir, improve_step, )) sft_process.start() sft_process.join()