diff --git a/llm_rl/requirements.txt b/llm_rl/requirements.txt index 61b38b0b..03ac2855 100644 --- a/llm_rl/requirements.txt +++ b/llm_rl/requirements.txt @@ -17,6 +17,7 @@ uvicorn pydantic fastapi sse-starlette +packaging matplotlib py-cpuinfo deepspeed diff --git a/llm_rl/reward_model.sh b/llm_rl/reward_model.sh index 3068fb43..fa5424df 100644 --- a/llm_rl/reward_model.sh +++ b/llm_rl/reward_model.sh @@ -1,16 +1,15 @@ -python src/train_bash.py \ +CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \ --stage rm \ - --model_name_or_path meta-llama/Llama-2-13b \ + --model_name_or_path meta-llama/Llama-2-13b-hf \ --do_train \ --dataset comparison_gpt4_en \ --template default \ --finetuning_type lora \ --lora_target q_proj,v_proj \ --resume_lora_training False \ - --checkpoint_dir ./llama-2-13b-rm \ --output_dir ./llama-2-13b-rm \ - --per_device_train_batch_size 2 \ - --gradient_accumulation_steps 4 \ + --per_device_train_batch_size 8 \ + --gradient_accumulation_steps 8 \ --lr_scheduler_type cosine \ --logging_steps 10 \ --save_steps 1000 \ @@ -18,4 +17,7 @@ python src/train_bash.py \ --num_train_epochs 1.0 \ --plot_loss \ --fp16 \ - --hf_auth_token "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" \ No newline at end of file + --use_auth_token True \ + --wandb_token "99caa13ec9552adf0e92e5c30021307ce3cf7fa4" \ + --hf_auth_token "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" \ + --deepspeed ./deepspeed_config_s2.json diff --git a/llm_rl/src/llmtuner/hparams/finetuning_args.py b/llm_rl/src/llmtuner/hparams/finetuning_args.py index 0915b701..d8f2d299 100644 --- a/llm_rl/src/llmtuner/hparams/finetuning_args.py +++ b/llm_rl/src/llmtuner/hparams/finetuning_args.py @@ -1,3 +1,4 @@ +from typing import List import json from typing import Literal, Optional from dataclasses import asdict, dataclass, field @@ -83,15 +84,15 @@ class FinetuningArguments: default=0, metadata={"help": "The alpha parameter to control the noise magnitude in NEFTune."} ) - wandb_token: Optional[str|None] = field( + wandb_token: Optional[str] = field( default=None, metadata={"help": "The login api token for wandb."} ) - wandb_project: Optional[str|None] = field( + wandb_project: Optional[str] = field( default=None, metadata={"help": "The project name for the current wandb log."} ) - wandb_tags: Optional[list[str]|None] = field( + wandb_tags: Optional[List[str]] = field( default=None, metadata={"help": "The tag for the current wandb run."} ) diff --git a/llm_rl/src/llmtuner/tuner/core/utils.py b/llm_rl/src/llmtuner/tuner/core/utils.py index 77bdde94..03043e20 100644 --- a/llm_rl/src/llmtuner/tuner/core/utils.py +++ b/llm_rl/src/llmtuner/tuner/core/utils.py @@ -15,8 +15,8 @@ logger = get_logger(__name__) def is_first_node(): - world_rank = dist.get_rank() - local_rank = int(os.environ['LOCAL_RANK']) + world_rank = dist.get_rank() if torch.distributed.is_initialized() else 0 + local_rank = int(os.environ['LOCAL_RANK']) if 'LOCAL_RANK' in os.environ else 0 return world_rank == local_rank == 0 def find_all_linear_modules( diff --git a/llm_rl/src/llmtuner/tuner/rm/trainer.py b/llm_rl/src/llmtuner/tuner/rm/trainer.py index 80502937..94549f18 100644 --- a/llm_rl/src/llmtuner/tuner/rm/trainer.py +++ b/llm_rl/src/llmtuner/tuner/rm/trainer.py @@ -38,7 +38,7 @@ def compute_loss( See: https://github.com/huggingface/transformers/blob/v4.30.2/src/transformers/trainer.py#L3509 """ # Compute rewards - _, _, values = model(**inputs, output_hidden_states=True, return_dict=True) + _, _, values = model(**inputs, output_hidden_states=True, return_dict=True) # (lm_logits, loss, value) if values.size(0) != inputs["input_ids"].size(0): # adapt to chatglm2 values = torch.transpose(values, 0, 1) diff --git a/llm_rl/src/llmtuner/tuner/tune.py b/llm_rl/src/llmtuner/tuner/tune.py index 2687a686..054a6b1c 100644 --- a/llm_rl/src/llmtuner/tuner/tune.py +++ b/llm_rl/src/llmtuner/tuner/tune.py @@ -24,7 +24,7 @@ def run_exp(args: Optional[Dict[str, Any]] = None, callbacks: Optional[List["Tra callbacks = [LogCallback()] if callbacks is None else callbacks if is_first_node(): wandb.login(key=finetuning_args.wandb_token) - wandb.init(project=finetuning_args.wandb_project, tags=[*finetuning_args.wandb_tags]) + wandb.init(project=finetuning_args.wandb_project, tags=[*finetuning_args.wandb_tags] if finetuning_args.wandb_tags else None) if finetuning_args.stage == "pt": run_pt(model_args, data_args, training_args, finetuning_args, callbacks)