Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

When I used galore, the learning rate was set to 8e-6, but the training rate was 0.001 #31707

Closed
Minami-su opened this issue Jun 29, 2024 · 7 comments · Fixed by #31710
Closed

Comments

@Minami-su
Copy link

import os
import sys
from typing import List

import fire
import torch
import transformers
from datasets import load_dataset
import os

# os.environ["NCCL_P2P_DISABLE"] = "1"
# os.environ["NCCL_IB_DISABLE"] = "1"

"""
Unused imports:
import torch.nn as nn
import bitsandbytes as bnb
"""


from transformers import AutoModelForCausalLM, AutoTokenizer
from transformers import BitsAndBytesConfig
from peft import (
    LoraConfig,
    get_peft_model,
    prepare_model_for_kbit_training,
    set_peft_model_state_dict,
)
#from utils.prompter import Prompter
import signal
import sys
import os
os.environ["WANDB_DISABLED"] = "true"
def train(
    # model/data params
    base_model: str = "",  # the only required argument
    data_path: str = "yahma/alpaca-cleaned",
    output_dir: str = "./lora-alpaca",
    # training hyperparams
    batch_size: int = 128,
    micro_batch_size: int = 4,
    num_epochs: int = 3,
    learning_rate: float = 3e-4,
    cutoff_len: int = 256,
    val_set_size: int = 2000,
    # llm hyperparams
    train_on_inputs: bool = True,  # if False, masks out inputs in loss
    add_eos_token: bool = False,
    group_by_length: bool = False,  # faster, but produces an odd training loss curve
    # wandb params
    wandb_project: str = "",
    wandb_run_name: str = "",
    wandb_watch: str = "",  # options: false | gradients | all
    wandb_log_model: str = "",  # options: false | true
    resume_from_checkpoint: str = None,  # either training checkpoint or final adapter
    prompt_template_name: str = "alpaca2",  # The prompt template to use, will default to alpaca.
):
    if int(os.environ.get("LOCAL_RANK", 0)) == 0:
        print(
            f"Training Alpaca-LoRA model with params:\n"
            f"base_model: {base_model}\n"
            f"data_path: {data_path}\n"
            f"output_dir: {output_dir}\n"
            f"batch_size: {batch_size}\n"
            f"micro_batch_size: {micro_batch_size}\n"
            f"num_epochs: {num_epochs}\n"
            f"learning_rate: {learning_rate}\n"
            f"cutoff_len: {cutoff_len}\n"
            f"val_set_size: {val_set_size}\n"
            f"lora_r: {lora_r}\n"
            f"lora_alpha: {lora_alpha}\n"
            f"lora_dropout: {lora_dropout}\n"
            f"train_on_inputs: {train_on_inputs}\n"
            f"add_eos_token: {add_eos_token}\n"
            f"group_by_length: {group_by_length}\n"
            f"wandb_project: {wandb_project}\n"
            f"wandb_run_name: {wandb_run_name}\n"
            f"wandb_watch: {wandb_watch}\n"
            f"wandb_log_model: {wandb_log_model}\n"
            f"resume_from_checkpoint: {resume_from_checkpoint or False}\n"
            f"prompt template: {prompt_template_name}\n"
        )
    assert (
        base_model
    ), "Please specify a --base_model, e.g. --base_model='huggyllama/llama-7b'"
    gradient_accumulation_steps = batch_size // micro_batch_size

    #prompter = Prompter(prompt_template_name)

    device_map = "auto"
    world_size = int(os.environ.get("WORLD_SIZE", 1))
    ddp = world_size != 1
    if ddp:
        device_map = {"": int(os.environ.get("LOCAL_RANK") or 0)}
        gradient_accumulation_steps = gradient_accumulation_steps // world_size

    # Check if parameter passed or if set within environ
    use_wandb = len(wandb_project) > 0 or (
        "WANDB_PROJECT" in os.environ and len(os.environ["WANDB_PROJECT"]) > 0
    )
    # Only overwrite environ if wandb param passed
    if len(wandb_project) > 0:
        os.environ["WANDB_PROJECT"] = wandb_project
    if len(wandb_watch) > 0:
        os.environ["WANDB_WATCH"] = wandb_watch
    if len(wandb_log_model) > 0:
        os.environ["WANDB_LOG_MODEL"] = wandb_log_model

    tokenizer = AutoTokenizer.from_pretrained(base_model,trust_remote_code=True)
    
    if base_model.find("qwen") != -1 or base_model.find("Qwen") != -1:
        tokenizer.add_special_tokens({"bos_token": "<|im_start|>"})
        tokenizer.add_special_tokens({"eos_token": "<|im_end|>"})
        tokenizer.add_special_tokens({"pad_token": "<|endoftext|>"})
    tokenizer.padding_side = "left"  # Allow batched inference
    def save_model(signal, frame):
        print("\nSaving the model...")
        model.save_pretrained(output_dir)
        sys.exit(0)
    def tokenize(prompt, add_eos_token=True):
        # there's probably a way to do this with the tokenizer settings
        # but again, gotta move fast
        result = tokenizer(
            prompt,
            truncation=True,
            max_length=cutoff_len,
            padding=False,
            return_tensors=None,
        )
        if (
            result["input_ids"][-1] != tokenizer.eos_token_id
            and len(result["input_ids"]) < cutoff_len
            and add_eos_token
        ):
            result["input_ids"].append(tokenizer.eos_token_id)
            result["attention_mask"].append(1)

        result["labels"] = result["input_ids"].copy()

        return result

    def generate_and_tokenize_prompt(data_point):
        full_prompt =data_point["instruction"] + data_point["input"] + data_point["output"]
        tokenized_full_prompt = tokenize(full_prompt)

        return tokenized_full_prompt
    
    print(tokenizer.pad_token_id)
    print(tokenizer.pad_token)
    print(tokenizer.bos_token_id)
    print(tokenizer.bos_token)
    print(tokenizer.eos_token_id)
    print(tokenizer.eos_token)
    if data_path.endswith(".json") or data_path.endswith(".jsonl"):
        data = load_dataset("json", data_files=data_path)
    else:
        data = load_dataset(data_path)
    if val_set_size > 0:
        train_val = data["train"].train_test_split(
            test_size=val_set_size, shuffle=True, seed=42
        )
        train_data = (
            train_val["train"].shuffle().map(generate_and_tokenize_prompt)
        )
        val_data = (
            train_val["test"].shuffle().map(generate_and_tokenize_prompt)
        )
    else:
        train_data = data["train"].shuffle().map(generate_and_tokenize_prompt)
        val_data = None

    model = AutoModelForCausalLM.from_pretrained(base_model,
                                             trust_remote_code=True,
                                             attn_implementation="flash_attention_2",
                                             torch_dtype=torch.bfloat16,
                                             #device_map=device_map,
                                             )


    if resume_from_checkpoint:
        # Check the available weights and load them
        checkpoint_name = os.path.join(
            resume_from_checkpoint, "pytorch_model.bin"
        )  # Full checkpoint
        if not os.path.exists(checkpoint_name):
            checkpoint_name = os.path.join(
                resume_from_checkpoint, "adapter_model.bin"
            )  # only LoRA model - LoRA config above has to fit
            resume_from_checkpoint = (
                False  # So the trainer won't try loading its state
            )
        # The two files above have a different name depending on how they were saved, but are actually the same.
        if os.path.exists(checkpoint_name):
            print(f"Restarting from {checkpoint_name}")
            adapters_weights = torch.load(checkpoint_name)
            set_peft_model_state_dict(model, adapters_weights)
        else:
            print(f"Checkpoint {checkpoint_name} not found")


    if not ddp and torch.cuda.device_count() > 1:
        # keeps Trainer from trying its own DataParallelism when more than 1 gpu is available
        model.is_parallelizable = True
        model.model_parallel = True

    trainer = transformers.Trainer(
        model=model,
        train_dataset=train_data,
        eval_dataset=val_data,
        args=transformers.TrainingArguments(
            per_device_train_batch_size=micro_batch_size,
            gradient_accumulation_steps=gradient_accumulation_steps,
            warmup_steps=0,
            num_train_epochs=num_epochs,
            learning_rate=learning_rate,
            bf16=True,
            logging_steps=50,
            lr_scheduler_type="cosine",
            #optim="adamw_torch",
            optim = "galore_adamw_8bit_layerwise",
            optim_target_modules=[r".*attn.*", r".*mlp.*"],
            
            optim_args="rank=1024, update_proj_gap=500, scale=0.25",
            evaluation_strategy="steps" if val_set_size > 0 else "no",
            save_strategy="steps",
            eval_steps=100 if val_set_size > 0 else None,
            save_steps=200,
            output_dir=output_dir,
            save_total_limit=2,
            load_best_model_at_end=True if val_set_size > 0 else False,
            ddp_find_unused_parameters=False if ddp else None,
            group_by_length=group_by_length,
            report_to="wandb" if use_wandb else None,
            run_name=wandb_run_name if use_wandb else None,
        ),
        data_collator=transformers.DataCollatorForSeq2Seq(
            tokenizer, pad_to_multiple_of=8, return_tensors="pt", padding=True
        ),
    )
    
    
    signal.signal(signal.SIGINT, save_model)
    
    trainer.train()
    model.save_pretrained(output_dir)

    print(
        "\n If there's a warning about missing keys above, please disregard :)"
    )


if __name__ == "__main__":
    fire.Fire(train)

result:

Map: 100%|█████████████████████████████████████████████████████████████████████████████████████| 2760/2760 [00:20<00:00, 133.85 examples/s]
You are attempting to use Flash Attention 2.0 with a model not initialized on GPU. Make sure to move the model to GPU after initializing it on CPU with `model.to('cuda')`.
Loading checkpoint shards: 100%|█████████████████████████████████████████████████████████████████████████████| 4/4 [00:01<00:00,  3.29it/s]
Using the `WANDB_DISABLED` environment variable is deprecated and will be removed in v5. Use the --report_to flag to control the integrations used for logging result (for instance --report_to none).
/work/jcxy/anaconda3/envs/haolu/lib/python3.10/site-packages/accelerate/accelerator.py:444: FutureWarning: Passing the following arguments to `Accelerator` is deprecated and will be removed in version 1.0 of Accelerate: dict_keys(['dispatch_batches', 'split_batches', 'even_batches', 'use_seedable_sampler']). Please pass an `accelerate.DataLoaderConfiguration` instead: 
dataloader_config = DataLoaderConfiguration(dispatch_batches=None, split_batches=False, even_batches=True, use_seedable_sampler=True)
  warnings.warn(
Activated GaLoRE fine-tuning, depending on your model size and hardware, the training might take a while before starting. Please be patient !
{'loss': 0.8829, 'grad_norm': 0.0, 'learning_rate': 0.001, 'epoch': 0.02}                                                                  
{'loss': 0.7547, 'grad_norm': 0.0, 'learning_rate': 0.001, 'epoch': 0.04}                                                                  
{'loss': 0.7595, 'grad_norm': 0.0, 'learning_rate': 0.001, 'epoch': 0.05}                                                                  
{'loss': 0.7547, 'grad_norm': 0.0, 'learning_rate': 0.001, 'epoch': 0.07}       
@Minami-su Minami-su changed the title When I used galore on orpo, the learning rate was set to 8e-6, but the training rate was 0.01 When I used galore on orpo, the learning rate was set to 8e-6, but the training rate was 0.001 Jun 29, 2024
@Minami-su Minami-su changed the title When I used galore on orpo, the learning rate was set to 8e-6, but the training rate was 0.001 When I used galore, the learning rate was set to 8e-6, but the training rate was 0.001 Jun 29, 2024
@vasqu
Copy link
Contributor

vasqu commented Jun 29, 2024

Hey @Minami-su, on which version of transformers are you? It reminds me of an older issue #30082 very similar to this which should have been fixed by #30085 (>= v4.40.0). Still pretty sure that it is more of a display issue.

@Minami-su
Copy link
Author

Hey @Minami-su, on which version of transformers are you? It reminds me of an older issue #30082 very similar to this which should have been fixed by #30085 (>= v4.40.0). Still pretty sure that it is more of a display issue.

4.42.3.However, I found that lr and grad in print had problems. In fact, there were changes.

@vasqu
Copy link
Contributor

vasqu commented Jun 29, 2024

If you refer to changes, do you mean the actual display of lr/grad changed? I might look into it when I have time. Galore currently uses a lot of dummies to display things which might cause an issue here again (just my first intuition).

@Minami-su
Copy link
Author

Minami-su commented Jun 29, 2024

The lr shown is not changing,but the actual training lr is changing when I set lr to 1e-5 and 1e-2

lr = 1e-5
{'loss': 1.7991, 'grad_norm': 0.0, 'learning_rate': 0.001, 'epoch': 0.02}                                                                  
{'loss': 1.3706, 'grad_norm': 0.0, 'learning_rate': 0.001, 'epoch': 0.04}                                                                  
{'loss': 0.9335, 'grad_norm': 0.0, 'learning_rate': 0.001, 'epoch': 0.05}                                                                  
{'loss': 0.7765, 'grad_norm': 0.0, 'learning_rate': 0.001, 'epoch': 0.07}                                                                  
{'loss': 0.7417, 'grad_norm': 0.0, 'learning_rate': 0.001, 'epoch': 0.09}                                                                  
{'loss': 0.7413, 'grad_norm': 0.0, 'learning_rate': 0.001, 'epoch': 0.11}                                                                  
{'loss': 0.7234, 'grad_norm': 0.0, 'learning_rate': 0.001, 'epoch': 0.13}                                                                  
{'loss': 0.7438, 'grad_norm': 0.0, 'learning_rate': 0.001, 'epoch': 0.14}                                                                  
{'loss': 0.7257, 'grad_norm': 0.0, 'learning_rate': 0.001, 'epoch': 0.16}                                                                  
{'loss': 0.7101, 'grad_norm': 0.0, 'learning_rate': 0.001, 'epoch': 0.18}     
lr = 1e-2
{'loss': 743335.52, 'grad_norm': 0.0, 'learning_rate': 0.001, 'epoch': 0.02}                                                               
{'loss': 151415.77, 'grad_norm': 0.0, 'learning_rate': 0.001, 'epoch': 0.04}                                                               
{'loss': 202281.64, 'grad_norm': 0.0, 'learning_rate': 0.001, 'epoch': 0.05}                                                               
{'loss': 34941.99, 'grad_norm': 0.0, 'learning_rate': 0.001, 'epoch': 0.07}                                                                
{'loss': 111826.13, 'grad_norm': 0.0, 'learning_rate': 0.001, 'epoch': 0.09}                                                               
{'loss': 167749.48, 'grad_norm': 0.0, 'learning_rate': 0.001, 'epoch': 0.11}                                                               
{'loss': 125194.02, 'grad_norm': 0.0, 'learning_rate': 0.001, 'epoch': 0.13}                                                               
{'loss': 161781.74, 'grad_norm': 0.0, 'learning_rate': 0.001, 'epoch': 0.14}                                                               
{'loss': 128028.8, 'grad_norm': 0.0, 'learning_rate': 0.001, 'epoch': 0.16}                                                                
{'loss': 79324.51, 'grad_norm': 0.0, 'learning_rate': 0.001, 'epoch': 0.18}     

@vasqu
Copy link
Contributor

vasqu commented Jun 29, 2024

@Minami-su small update on my side. I could reproduce the issue with a somewhat shrinked variant of yours:

import os

import torch
import transformers
from datasets import load_dataset

from transformers import AutoModelForCausalLM, AutoTokenizer, logging

logging.set_verbosity(logging.DEBUG)
os.environ["WANDB_DISABLED"] = "true"


# model/data params
base_model: str = "gpt2"  # the only required argument
data_path: str = "yahma/alpaca-cleaned"
output_dir: str = "./lora-alpaca"
# training hyperparams
batch_size: int = 32
num_epochs: int = 3
learning_rate: float = 3e-4
cutoff_len: int = 256
val_set_size: int = 2000
# llm hyperparams
train_on_inputs: bool = True  # if False, masks out inputs in loss
add_eos_token: bool = False
group_by_length: bool = False # faster, but produces an odd training loss curve
# wandb params
wandb_project: str = ""
wandb_run_name: str = ""
wandb_watch: str = ""  # options: false | gradients | all
wandb_log_model: str = ""  # options: false | true
resume_from_checkpoint: str = None  # either training checkpoint or final adapter
prompt_template_name: str = "alpaca2"  # The prompt template to use, will default to alpaca.


if int(os.environ.get("LOCAL_RANK", 0)) == 0:
    print(
        f"Training Alpaca-LoRA model with params:\n"
        f"base_model: {base_model}\n"
        f"data_path: {data_path}\n"
        f"output_dir: {output_dir}\n"
        f"batch_size: {batch_size}\n"
        f"num_epochs: {num_epochs}\n"
        f"learning_rate: {learning_rate}\n"
        f"cutoff_len: {cutoff_len}\n"
        f"val_set_size: {val_set_size}\n"
        f"train_on_inputs: {train_on_inputs}\n"
        f"add_eos_token: {add_eos_token}\n"
        f"group_by_length: {group_by_length}\n"
        f"wandb_project: {wandb_project}\n"
        f"wandb_run_name: {wandb_run_name}\n"
        f"wandb_watch: {wandb_watch}\n"
        f"wandb_log_model: {wandb_log_model}\n"
        f"resume_from_checkpoint: {resume_from_checkpoint or False}\n"
        f"prompt template: {prompt_template_name}\n"
    )
assert (
    base_model
), "Please specify a --base_model, e.g. --base_model='huggyllama/llama-7b'"

# Check if parameter passed or if set within environ
use_wandb = len(wandb_project) > 0 or (
        "WANDB_PROJECT" in os.environ and len(os.environ["WANDB_PROJECT"]) > 0
)
# Only overwrite environ if wandb param passed
if len(wandb_project) > 0:
    os.environ["WANDB_PROJECT"] = wandb_project
if len(wandb_watch) > 0:
    os.environ["WANDB_WATCH"] = wandb_watch
if len(wandb_log_model) > 0:
    os.environ["WANDB_LOG_MODEL"] = wandb_log_model

tokenizer = AutoTokenizer.from_pretrained(base_model, trust_remote_code=True)
# needed for models like gpt2
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token

if base_model.find("qwen") != -1 or base_model.find("Qwen") != -1:
    tokenizer.add_special_tokens({"bos_token": "<|im_start|>"})
    tokenizer.add_special_tokens({"eos_token": "<|im_end|>"})
    tokenizer.add_special_tokens({"pad_token": "<|endoftext|>"})
tokenizer.padding_side = "left"  # Allow batched inference

def tokenize(prompt, add_eos_token=True):
    # there's probably a way to do this with the tokenizer settings
    # but again, gotta move fast
    result = tokenizer(
        prompt,
        truncation=True,
        max_length=cutoff_len,
        padding=False,
        return_tensors=None,
    )
    if (
            result["input_ids"][-1] != tokenizer.eos_token_id
            and len(result["input_ids"]) < cutoff_len
            and add_eos_token
    ):
        result["input_ids"].append(tokenizer.eos_token_id)
        result["attention_mask"].append(1)

    result["labels"] = result["input_ids"].copy()

    return result

def generate_and_tokenize_prompt(data_point):
    full_prompt =data_point["instruction"] + data_point["input"] + data_point["output"]
    tokenized_full_prompt = tokenize(full_prompt)

    return tokenized_full_prompt

print(tokenizer.pad_token_id)
print(tokenizer.pad_token)
print(tokenizer.bos_token_id)
print(tokenizer.bos_token)
print(tokenizer.eos_token_id)
print(tokenizer.eos_token)
if data_path.endswith(".json") or data_path.endswith(".jsonl"):
    data = load_dataset("json", data_files=data_path)
else:
    data = load_dataset(data_path)
if val_set_size > 0:
    train_val = data["train"].train_test_split(
        test_size=val_set_size, shuffle=True, seed=42
    )
    train_data = (
        train_val["train"].shuffle().map(generate_and_tokenize_prompt)
    )
    val_data = (
        train_val["test"].shuffle().map(generate_and_tokenize_prompt)
    )
else:
    train_data = data["train"].shuffle().map(generate_and_tokenize_prompt)
    val_data = None

model = AutoModelForCausalLM.from_pretrained(base_model,
                                             trust_remote_code=True,
                                             torch_dtype=torch.bfloat16,
                                             )


trainer = transformers.Trainer(
    model=model,
    train_dataset=train_data,
    eval_dataset=val_data,
    args=transformers.TrainingArguments(
        per_device_train_batch_size=batch_size,
        gradient_accumulation_steps=1,
        warmup_steps=0,
        num_train_epochs=num_epochs,
        learning_rate=learning_rate,
        bf16=True,
        lr_scheduler_type="cosine",
        optim = "galore_adamw_8bit_layerwise",
        optim_target_modules=[
            'q_proj', 'k_proj', 'down_proj', 'up_proj',
            'gate_proj', 'v_proj', 'o_proj', 'lm_head'
        ],
        optim_args="rank=1024, update_proj_gap=500, scale=0.25",
        eval_strategy="steps" if val_set_size > 0 else "no",
        save_strategy="steps",
        logging_strategy="steps",
        logging_steps=10,
        eval_steps=100 if val_set_size > 0 else None,
        save_steps=200,
        output_dir=output_dir,
        save_total_limit=2,
        load_best_model_at_end=True if val_set_size > 0 else False,
        group_by_length=group_by_length,
        report_to="wandb" if use_wandb else None,
        run_name=wandb_run_name if use_wandb else None,
    ),
    data_collator=transformers.DataCollatorForSeq2Seq(
        tokenizer, pad_to_multiple_of=8, return_tensors="pt", padding=True
    ),
)

trainer.train()
model.save_pretrained(output_dir)

It is a display issue due to the (cosine) lr scheduler. Working on a fix that I'll submit in a PR.

If you're interested why this happens: In short, galore work param-wise on each individually and to conform to this without interrupting it, dummy schedulers and optims are used as a global overhead. This is so that they don't interfer with the param-wise updates. In this case, the scheduler was the problem as it did not follow the scheduling as well as the param-wise learning rates were discarded in the process. To be clear tho, it's entirely a display issue.

@Minami-su
Copy link
Author

@vasqu Thank you for your explanation,I figure out.

@vasqu
Copy link
Contributor

vasqu commented Jun 29, 2024

@Minami-su PR is up, and no problem!

Small edit: You should also see the changes in the lr when using warmup steps.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants