Skip to content

Commit

Permalink
accomodate deploy
Browse files Browse the repository at this point in the history
  • Loading branch information
Jasonqi146 committed Dec 27, 2023
1 parent 68b01d4 commit d1c11b0
Show file tree
Hide file tree
Showing 5 changed files with 13 additions and 7 deletions.
4 changes: 4 additions & 0 deletions llm_rl/src/llmtuner/hparams/finetuning_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,10 @@ class FinetuningArguments:
default=None,
metadata={"help": "The path to the checkpoint saved queue file."}
)
improve_step: Optional[int] = field(
default=0,
metadata={"help": "The n-th improve step in ReST."}
)

def __post_init__(self):
if isinstance(self.lora_target, str): # support custom target modules/layers of LoRA
Expand Down
5 changes: 3 additions & 2 deletions llm_rl/src/llmtuner/tuner/sft/custom_callback.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,11 @@
from llmtuner.tuner.core.utils import is_first_node

class SaveModelCallback(TrainerCallback):
def __init__(self, save_epochs, output_dir, checkpoint_saved_queue):
def __init__(self, save_epochs, output_dir, checkpoint_saved_queue, improve_step):
self.save_epochs = save_epochs
self.output_dir = output_dir
self.checkpoint_saved_queue = checkpoint_saved_queue
self.curr_improve_step = improve_step
self.curr_epoch = 0

def on_epoch_end(self, args, state, control, model=None, **kwargs):
Expand All @@ -22,7 +23,7 @@ def on_save(self, args, state, control, model, **kwargs):
return

# Customized checkpoint name
custom_checkpoint_name = f"checkpoint_epoch_{int(self.curr_epoch)}"
custom_checkpoint_name = f"checkpoint_improve-{self.curr_improve_step}_epoch-{int(self.curr_epoch)}"

# Original auto-saved checkpoint directory
auto_checkpoint_dir = os.path.join(args.output_dir, f"checkpoint-{state.global_step}")
Expand Down
2 changes: 1 addition & 1 deletion llm_rl/src/llmtuner/tuner/sft/workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ def run_sft(
training_args.report_to = ["wandb"]

if model_args.use_custom_callback:
callbacks.append(SaveModelCallback(model_args.call_back_save_epochs, training_args.output_dir, finetuning_args.checkpoint_saved_queue))
callbacks.append(SaveModelCallback(model_args.call_back_save_epochs, training_args.output_dir, finetuning_args.checkpoint_saved_queue, finetuning_args.improve_step))

# Initialize our Trainer
trainer = CustomSeq2SeqTrainer(
Expand Down
3 changes: 2 additions & 1 deletion llm_self_train/pipelines/run_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
with open('./resources/train_args.yml', 'r') as file:
train_args = yaml.safe_load(file)

def run_sft(output_dir):
def run_sft(output_dir, improve_step, ):
args = ["deepspeed", f"--num_gpus={config['num_gpus']}", "../llm_rl/src/train_bash.py"]
for key, value in train_args.items():
if key in config:
Expand All @@ -20,5 +20,6 @@ def run_sft(output_dir):
else:
args.append(f"--{key}")
args.append(str(value))
args.append(f"--improve_step={improve_step}")

subprocess.run(args)
6 changes: 3 additions & 3 deletions llm_self_train/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,10 @@ def main():
if not os.path.exists("../llm_rl/data/sotopia_custom_training_sft.json"):
preprocess_episodes_with_tag()

for i in range(config["num_improve_steps"]):
for improve_step in range(config["num_improve_steps"]):
run_sft_completed = multiprocessing.Value('b', False)
curr_improve_dir = os.path.join(config['data_dir'], config["experiment_name"])
sft_process = multiprocessing.Process(target=run_sft, args=(curr_improve_step, ))
output_dir = os.path.join(config['data_dir'], config["experiment_name"])
sft_process = multiprocessing.Process(target=run_sft, args=(output_dir, improve_step, ))
sft_process.start()
sft_process.join()

Expand Down

0 comments on commit d1c11b0

Please sign in to comment.