Skip to content

Commit

Permalink
adding fix for wandb tags and distributed ranks
Browse files Browse the repository at this point in the history
  • Loading branch information
Jasonqi146 committed Nov 8, 2023
1 parent cd4e0c2 commit 7b73354
Show file tree
Hide file tree
Showing 6 changed files with 17 additions and 13 deletions.
1 change: 1 addition & 0 deletions llm_rl/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ uvicorn
pydantic
fastapi
sse-starlette
packaging
matplotlib
py-cpuinfo
deepspeed
Expand Down
14 changes: 8 additions & 6 deletions llm_rl/reward_model.sh
Original file line number Diff line number Diff line change
@@ -1,21 +1,23 @@
python src/train_bash.py \
CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \
--stage rm \
--model_name_or_path meta-llama/Llama-2-13b \
--model_name_or_path meta-llama/Llama-2-13b-hf \
--do_train \
--dataset comparison_gpt4_en \
--template default \
--finetuning_type lora \
--lora_target q_proj,v_proj \
--resume_lora_training False \
--checkpoint_dir ./llama-2-13b-rm \
--output_dir ./llama-2-13b-rm \
--per_device_train_batch_size 2 \
--gradient_accumulation_steps 4 \
--per_device_train_batch_size 8 \
--gradient_accumulation_steps 8 \
--lr_scheduler_type cosine \
--logging_steps 10 \
--save_steps 1000 \
--learning_rate 1e-6 \
--num_train_epochs 1.0 \
--plot_loss \
--fp16 \
--hf_auth_token "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"
--use_auth_token True \
--wandb_token "99caa13ec9552adf0e92e5c30021307ce3cf7fa4" \
--hf_auth_token "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" \
--deepspeed ./deepspeed_config_s2.json
7 changes: 4 additions & 3 deletions llm_rl/src/llmtuner/hparams/finetuning_args.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from typing import List
import json
from typing import Literal, Optional
from dataclasses import asdict, dataclass, field
Expand Down Expand Up @@ -83,15 +84,15 @@ class FinetuningArguments:
default=0,
metadata={"help": "The alpha parameter to control the noise magnitude in NEFTune."}
)
wandb_token: Optional[str|None] = field(
wandb_token: Optional[str] = field(
default=None,
metadata={"help": "The login api token for wandb."}
)
wandb_project: Optional[str|None] = field(
wandb_project: Optional[str] = field(
default=None,
metadata={"help": "The project name for the current wandb log."}
)
wandb_tags: Optional[list[str]|None] = field(
wandb_tags: Optional[List[str]] = field(
default=None,
metadata={"help": "The tag for the current wandb run."}
)
Expand Down
4 changes: 2 additions & 2 deletions llm_rl/src/llmtuner/tuner/core/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,8 @@
logger = get_logger(__name__)

def is_first_node():
world_rank = dist.get_rank()
local_rank = int(os.environ['LOCAL_RANK'])
world_rank = dist.get_rank() if torch.distributed.is_initialized() else 0
local_rank = int(os.environ['LOCAL_RANK']) if 'LOCAL_RANK' in os.environ else 0
return world_rank == local_rank == 0

def find_all_linear_modules(
Expand Down
2 changes: 1 addition & 1 deletion llm_rl/src/llmtuner/tuner/rm/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ def compute_loss(
See: https://github.com/huggingface/transformers/blob/v4.30.2/src/transformers/trainer.py#L3509
"""
# Compute rewards
_, _, values = model(**inputs, output_hidden_states=True, return_dict=True)
_, _, values = model(**inputs, output_hidden_states=True, return_dict=True) # (lm_logits, loss, value)
if values.size(0) != inputs["input_ids"].size(0): # adapt to chatglm2
values = torch.transpose(values, 0, 1)

Expand Down
2 changes: 1 addition & 1 deletion llm_rl/src/llmtuner/tuner/tune.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ def run_exp(args: Optional[Dict[str, Any]] = None, callbacks: Optional[List["Tra
callbacks = [LogCallback()] if callbacks is None else callbacks
if is_first_node():
wandb.login(key=finetuning_args.wandb_token)
wandb.init(project=finetuning_args.wandb_project, tags=[*finetuning_args.wandb_tags])
wandb.init(project=finetuning_args.wandb_project, tags=[*finetuning_args.wandb_tags] if finetuning_args.wandb_tags else None)

if finetuning_args.stage == "pt":
run_pt(model_args, data_args, training_args, finetuning_args, callbacks)
Expand Down

0 comments on commit 7b73354

Please sign in to comment.