Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Xlora cannot reload model from last checkpoint by using trainer.train(resume_from_checkpoint="checkpp") #2185

Open
3 of 4 tasks
SongHanKen opened this issue Oct 29, 2024 · 2 comments

Comments

@SongHanKen
Copy link

System Info

Peft v0.13.2
Transformers v4.44.0
Accelerate v0.33.0

Who can help?

Since this relates to an interaction with PEFT and Xlora maybe @BenjaminBossan @EricLBuehler

Information

  • The official example scripts
  • My own modified scripts

Tasks

  • An officially supported task in the examples folder
  • My own task or dataset (give details below)

Reproduction

Hi there,
I try to use trainer.train(resume_from_checkpoint=checkpoint_directory) or trainer.train(resume_from_checkpoint="YES") to reload(retrain) the model with Xlora from the last checkpoint, here is my training code and base model is chatglm4-9b:

import os
import jieba
import dataclasses as dc
import functools
from collections.abc import Callable, Mapping, Sequence
from pathlib import Path
from typing import Annotated, Any, Union
import numpy as np
import ruamel.yaml as yaml
import torch
import typer
from datasets import Dataset, Split
from nltk.translate.bleu_score import sentence_bleu, SmoothingFunction
from peft import PeftConfig, get_peft_config, get_peft_model
from rouge_chinese import Rouge
from torch import nn
from transformers import (
    AutoModelForCausalLM,
    AutoTokenizer,
    EvalPrediction,
    GenerationConfig,
    PreTrainedTokenizer,
    Seq2SeqTrainingArguments,
)
from transformers import DataCollatorForSeq2Seq as _DataCollatorForSeq2Seq
from transformers import Seq2SeqTrainer as _Seq2SeqTrainer
from datasets import load_dataset, DatasetDict, NamedSplit
from typing import Optional

app = typer.Typer(pretty_exceptions_show_locals=False)
class DataCollatorForSeq2Seq(_DataCollatorForSeq2Seq):
    def __call__(self, features, return_tensors=None):
        output_ids = ([feature['output_ids'] for feature in features] if 'output_ids' in features[0].keys() else None)
        if output_ids is not None:
            max_output_length = max(len(out) for out in output_ids)
            if self.pad_to_multiple_of is not None:
                max_output_length = (
                        (
                                max_output_length + self.pad_to_multiple_of - 1) //
                        self.pad_to_multiple_of * self.pad_to_multiple_of
                )
            for feature in features:
                remainder = [self.tokenizer.pad_token_id] * (
                        max_output_length - len(feature['output_ids'])
                )
                if isinstance(feature['output_ids'], list):
                    feature['output_ids'] = feature['output_ids'] + remainder
                else:
                    feature['output_ids'] = np.concatenate(
                        [feature['output_ids'], remainder]
                    ).astype(np.int64)
        return super().__call__(features, return_tensors)

class Seq2SeqTrainer(_Seq2SeqTrainer):
    # Not Support for apex
    def training_step(self, model: nn.Module, inputs: dict[str, Any]) -> torch.Tensor:
        model.train()
        inputs = self._prepare_inputs(inputs)
        with self.compute_loss_context_manager():
            loss = self.compute_loss(model, inputs)
        if self.args.n_gpu > 1:
            loss = loss.mean()
        self.accelerator.backward(loss)
        detached_loss = loss.detach() / self.args.gradient_accumulation_steps
        del inputs
        torch.cuda.empty_cache()
        return detached_loss

    def prediction_step(
            self,
            model: nn.Module,
            inputs: dict[str, Any],
            prediction_loss_only: bool,
            ignore_keys=None,
            **gen_kwargs,
    ) -> tuple[Optional[float], Optional[torch.Tensor], Optional[torch.Tensor]]:
        with torch.no_grad():  # Ensure no gradient computation
            if self.args.predict_with_generate:
                output_ids = inputs.pop('output_ids')
            input_ids = inputs['input_ids']
            loss, generated_tokens, labels = super().prediction_step(
                model, inputs, prediction_loss_only, ignore_keys, **gen_kwargs
            )
            generated_tokens = generated_tokens[:, input_ids.size()[1]:]
            labels = output_ids
            del inputs, input_ids, output_ids
            torch.cuda.empty_cache()
        return loss, generated_tokens, labels

@dc.dataclass
class DataConfig(object):
    train_file: Optional[str] = None
    val_file: Optional[str] = None
    test_file: Optional[str] = None
    num_proc: Optional[int] = None

    @property
    def data_format(self) -> str:
        return Path(self.train_file).suffix

    @property
    def data_files(self) -> dict[NamedSplit, str]:
        return {
            split: data_file
            for split, data_file in zip(
                [Split.TRAIN, Split.VALIDATION, Split.TEST],
                [self.train_file, self.val_file, self.test_file],
            )
            if data_file is not None
        }


@dc.dataclass
class FinetuningConfig(object):
    data_config: DataConfig
    max_input_length: int
    max_output_length: int
    combine: bool
    freezeV: bool
    training_args: Seq2SeqTrainingArguments = dc.field(
        default_factory=lambda: Seq2SeqTrainingArguments(output_dir='./output')
    )
    peft_config: Optional[PeftConfig] = None
    def __post_init__(self):
        if not self.training_args.do_eval or self.data_config.val_file is None:
            self.training_args.do_eval = False
            self.training_args.evaluation_strategy = 'no'
            self.data_config.val_file = None
        else:
            self.training_args.per_device_eval_batch_size = (
                    self.training_args.per_device_eval_batch_size
                    or self.training_args.per_device_train_batch_size
            )
    @classmethod
    def from_dict(cls, **kwargs) -> 'FinetuningConfig':
        training_args = kwargs.get('training_args', None)
        if training_args is not None and not isinstance(
                training_args, Seq2SeqTrainingArguments
        ):
            gen_config = training_args.get('generation_config')
            if not isinstance(gen_config, GenerationConfig):
                training_args['generation_config'] = GenerationConfig(
                    **gen_config
                )
            kwargs['training_args'] = Seq2SeqTrainingArguments(**training_args)

        data_config = kwargs.get('data_config')
        if not isinstance(data_config, DataConfig):
            kwargs['data_config'] = DataConfig(**data_config)

        peft_config = kwargs.get('peft_config', None)
        if peft_config is not None and not isinstance(peft_config, PeftConfig):
            kwargs['peft_config'] = get_peft_config(config_dict=peft_config)
        return cls(**kwargs)
    @classmethod
    def from_file(cls, path: Union[str, Path]) -> 'FinetuningConfig':
        path = Path(path)
        parser = yaml.YAML(typ='safe', pure=True)
        parser.indent(mapping=2, offset=2, sequence=4)
        parser.default_flow_style = False
        kwargs = parser.load(path)
        return cls.from_dict(**kwargs)

def _load_datasets(
        data_dir: str,
        data_format: str,
        data_files: dict[NamedSplit, str],
        num_proc: Optional[int],
) -> DatasetDict:
    if data_format == '.jsonl':
        dataset_dct = load_dataset(
            data_dir,
            data_files=data_files,
            split=None,
            num_proc=num_proc,
        )
    else:
        raise NotImplementedError(f"Cannot load dataset in the '{data_format}' format.")
    return dataset_dct

class DataManager(object):
    def __init__(self, data_dir: str, data_config: DataConfig):
        self._num_proc = data_config.num_proc

        self._dataset_dct = _load_datasets(
            data_dir,
            data_config.data_format,
            data_config.data_files,
            self._num_proc,
        )
    def _get_dataset(self, split: NamedSplit) -> Optional[Dataset]:
        return self._dataset_dct.get(split, None)
    def get_dataset(
            self,
            split: NamedSplit,
            process_fn: Callable[[dict[str, Any]], dict[str, Any]],
            batched: bool = True,
            remove_orig_columns: bool = True,
    ) -> Optional[Dataset]:
        orig_dataset = self._get_dataset(split)
        if orig_dataset is None:
            return
        if remove_orig_columns:
            remove_columns = orig_dataset.column_names
        else:
            remove_columns = None
        return orig_dataset.map(
            process_fn,
            batched=batched,
            remove_columns=remove_columns,
            num_proc=self._num_proc,
        )

def process_message(message):
    if 'tools' in message and message['role'] == 'system':
        for tool in message['tools']:
            parameters = tool['function']['parameters']['properties']
            tool['function']['parameters']['properties'] = \
                {k: v for k, v in parameters.items() if
                 v is not None}
    elif 'tools' in message:
        del message['tools']
    return message

def process_batch(
        batch: Mapping[str, Sequence],
        tokenizer: PreTrainedTokenizer,
        max_input_length: int,
        max_output_length: int,
        combine: bool,
) -> dict[str, list]:
    batched_conv = batch['messages']
    batched_input_ids = []
    batched_labels = []
    for conv in batched_conv:
        input_ids = [151331, 151333]
        loss_masks = [False, False]
        if combine:
            new_input_ids = tokenizer.apply_chat_template(conv, tokenize=True, return_dict=False)
            input_ids = new_input_ids
            loss_masks = [False] * len(input_ids)
            last_assistant_index = len(input_ids) - input_ids[::-1].index(151337) - 1
            for j in range(last_assistant_index + 1, len(input_ids)):
                loss_masks[j] = True
        else:
            for message in conv:
                message = process_message(message)
                loss_mask_val = False if message['role'] in ('system', 'user', 'observation') else True
                new_input_ids = tokenizer.apply_chat_template([message], tokenize=True, return_dict=False)[2:]
                input_ids += new_input_ids
                loss_masks += [loss_mask_val] * len(new_input_ids)

        input_ids.append(151336)  # EOS for chat
        loss_masks = [False, *loss_masks]
        labels = []
        for input_id, mask in zip(input_ids, loss_masks):
            if mask:
                labels.append(input_id)
            else:
                labels.append(-100)
        max_length = max_input_length + max_output_length + 1
        batched_input_ids.append(input_ids[:max_length])
        batched_labels.append(labels[:max_length])

    del batched_conv, conv, input_ids, loss_masks, new_input_ids, labels
    torch.cuda.empty_cache()

    return {'input_ids': batched_input_ids, 'labels': batched_labels}


def process_batch_eval(
        batch: Mapping[str, Sequence],
        tokenizer: PreTrainedTokenizer,
        max_input_length: int,
        max_output_length: int,
        combine: bool,
) -> dict[str, list]:
    batched_conv = batch['messages']
    batched_input_ids = []
    batched_output_ids = []
    for conv in batched_conv:
        if combine:
            new_input_ids = tokenizer.apply_chat_template(conv, tokenize=True, return_dict=False)
            input_ids = new_input_ids
            last_assistant_index = len(input_ids) - input_ids[::-1].index(151337) - 1
            output_prompt, output_ids = (
                input_ids[:1],
                input_ids[last_assistant_index:],
            )
            output_ids.append(151336)
            batched_input_ids.append(
                input_ids[:max_input_length] + output_prompt[:1]
            )
            batched_output_ids.append(output_ids[:max_output_length])
        else:
            input_ids = [151331, 151333]
            for message in conv:
                if len(input_ids) >= max_input_length:
                    break
                else:
                    message = process_message(message)
                    new_input_ids = tokenizer.apply_chat_template([message], tokenize=True, return_dict=False)[2:]
                    if message['role'] == 'assistant':
                        output_prompt, output_ids = (
                            new_input_ids[:1],
                            new_input_ids[1:],
                        )
                        output_ids.append(151336)
                        batched_input_ids.append(
                            input_ids[:max_input_length] + output_prompt[:1]
                        )
                        batched_output_ids.append(output_ids[:max_output_length])
                    input_ids += new_input_ids
    del batched_conv, conv, input_ids, new_input_ids, output_prompt, output_ids
    torch.cuda.empty_cache()
    return {'input_ids': batched_input_ids, 'output_ids': batched_output_ids}

def load_tokenizer_and_model(
        model_dir: str,
        peft_config: Optional[PeftConfig] = None,
):
    tokenizer = AutoTokenizer.from_pretrained(model_dir, trust_remote_code=True)
    if peft_config is not None:
        model = AutoModelForCausalLM.from_pretrained(
            model_dir,
            trust_remote_code=True,
            empty_init=False,
            use_cache=False,
            torch_dtype=torch.bfloat16  # Must use BFloat 16
        )
        model = get_peft_model(model, peft_config)
        model.print_trainable_parameters()
    else:
        model = AutoModelForCausalLM.from_pretrained(
            model_dir,
            trust_remote_code=True,
            empty_init=False,
            use_cache=False,
            torch_dtype=torch.bfloat16
        )
    return tokenizer, model

def compute_metrics(eval_preds: EvalPrediction, tokenizer):
    batched_pred_ids, batched_label_ids = eval_preds
    metrics_dct = {'rouge-1': [], 'rouge-2': [], 'rouge-l': [], 'bleu-4': []}
    for pred_ids, label_ids in zip(batched_pred_ids, batched_label_ids):
        pred_txt = tokenizer.decode(pred_ids).strip()
        label_txt = tokenizer.decode(label_ids).strip()
        pred_tokens = list(jieba.cut(pred_txt))
        label_tokens = list(jieba.cut(label_txt))
        rouge = Rouge()
        scores = rouge.get_scores(' '.join(pred_tokens), ' '.join(label_tokens))
        for k, v in scores[0].items():
            metrics_dct[k].append(round(v['f'] * 100, 4))
        metrics_dct['bleu-4'].append(
            sentence_bleu([label_tokens], pred_tokens, smoothing_function=SmoothingFunction().method3))
    return {k: np.mean(v) for k, v in metrics_dct.items()}

@app.command()
def main(
        data_dir: Annotated[str, typer.Argument(help='')],
        model_dir: Annotated[
            str,
            typer.Argument(
                help='A string that specifies the model id of a pretrained model configuration hosted on huggingface.co, or a path to a directory containing a model configuration file.'
            ),
        ],
        config_file: Annotated[str, typer.Argument(help='')],
        auto_resume_from_checkpoint: str = typer.Argument(
            default='',
            help='If entered as yes, automatically use the latest save checkpoint. If it is a numerical example 12 15, use the corresponding save checkpoint. If the input is no, restart training'
        ),
):
    ft_config = FinetuningConfig.from_file(config_file)
    tokenizer, model = load_tokenizer_and_model(model_dir, peft_config=ft_config.peft_config)
    data_manager = DataManager(data_dir, ft_config.data_config)
    train_dataset = data_manager.get_dataset(
        Split.TRAIN,
        functools.partial(
            process_batch,
            tokenizer=tokenizer,
            combine=ft_config.combine,
            max_input_length=ft_config.max_input_length,
            max_output_length=ft_config.max_output_length,
        ),
        batched=True,
    )
    print('train_dataset:', train_dataset)
    val_dataset = data_manager.get_dataset(
        Split.VALIDATION,
        functools.partial(
            process_batch_eval,
            tokenizer=tokenizer,
            combine=ft_config.combine,
            max_input_length=ft_config.max_input_length,
            max_output_length=ft_config.max_output_length,
        ),
        batched=True,
    )
    if val_dataset is not None:
        print('val_dataset:', val_dataset)
    test_dataset = data_manager.get_dataset(
        Split.TEST,
        functools.partial(
            process_batch_eval,
            tokenizer=tokenizer,
            combine=ft_config.combine,
            max_input_length=ft_config.max_input_length,
            max_output_length=ft_config.max_output_length,
        ),
        batched=True,
    )
    if test_dataset is not None:
        print('test_dataset:', test_dataset)
    # model.gradient_checkpointing_enable()
    model.enable_input_require_grads()
    ft_config.training_args.generation_config.pad_token_id = (
        151329
    )
    ft_config.training_args.generation_config.eos_token_id = [
        151329, 151336, 151338
    ]
    trainer = Seq2SeqTrainer(
        model=model,
        args=ft_config.training_args,
        data_collator=DataCollatorForSeq2Seq(
            tokenizer=tokenizer,
            padding='longest',
            return_tensors='pt',
        ),
        train_dataset=train_dataset,
        eval_dataset=val_dataset,
        compute_metrics=functools.partial(compute_metrics, tokenizer=tokenizer),
    )
    # trainer.train(resume_from_checkpoint="/home/zhangjunyi/hs_test/finetune_demo/output_1026/checkpoint-20")

    if auto_resume_from_checkpoint.upper() == "" or auto_resume_from_checkpoint is None:
        trainer.train()
    else:
        output_dir = ft_config.training_args.output_dir
        dirlist = os.listdir(output_dir)
        checkpoint_sn = 0
        for checkpoint_str in dirlist:
            if checkpoint_str.find("eckpoint") > 0 and checkpoint_str.find("tmp") == -1:
                checkpoint = int(checkpoint_str.replace("checkpoint-", ""))
                if checkpoint > checkpoint_sn:
                    checkpoint_sn = checkpoint
        if auto_resume_from_checkpoint.upper() == "YES":
            if checkpoint_sn > 0:
                model.gradient_checkpointing_enable()
                model.enable_input_require_grads()
                checkpoint_directory = os.path.join(output_dir, "checkpoint-" + str(checkpoint_sn))
                print("resume checkpoint from checkpoint-" + str(checkpoint_sn))
                trainer.train(resume_from_checkpoint=checkpoint_directory)
            else:
                trainer.train()
        else:
            if auto_resume_from_checkpoint.isdigit():
                if int(auto_resume_from_checkpoint) > 0:
                    checkpoint_sn = int(auto_resume_from_checkpoint)
                    model.gradient_checkpointing_enable()
                    model.enable_input_require_grads()
                    checkpoint_directory = os.path.join(output_dir, "checkpoint-" + str(checkpoint_sn))
                    print("resume checkpoint from checkpoint-" + str(checkpoint_sn))
                    trainer.train(resume_from_checkpoint=checkpoint_directory)
            else:
                print(auto_resume_from_checkpoint,
                      "The specified checkpoint sn(" + auto_resume_from_checkpoint + ") has not been saved. Please search for the correct checkpoint in the model output directory")
    if test_dataset is not None:
        trainer.predict(test_dataset)
if __name__ == '__main__':
    app()

Expected behavior

I expect the training to resume from the last checkpoint, continuing from the saved state without reinitializing the model weights, optimizer states, or any training progress by using resume_from_checkpoint. When training with lora, I successfully saved several checkpoints and could resume training from these checkpoints without any issues. However, with xlora, although I am able to save checkpoints during training, I encounter issues when trying to resume training from these checkpoints. The model fails to load properly, preventing the continuation of training from the saved state. Here is bug information:

Loading checkpoint shards:   0%|          | 0/10 [00:00<?, ?it/s]
Loading checkpoint shards:  10%|█         | 1/10 [00:00<00:01,  6.86it/s]
Loading checkpoint shards:  20%|██        | 2/10 [00:00<00:01,  6.89it/s]
Loading checkpoint shards:  30%|███       | 3/10 [00:00<00:01,  6.90it/s]
Loading checkpoint shards:  40%|████      | 4/10 [00:00<00:00,  6.43it/s]
Loading checkpoint shards:  50%|█████     | 5/10 [00:00<00:00,  6.61it/s]
Loading checkpoint shards:  60%|██████    | 6/10 [00:00<00:00,  6.72it/s]
Loading checkpoint shards:  70%|███████   | 7/10 [00:01<00:00,  6.80it/s]
Loading checkpoint shards:  80%|████████  | 8/10 [00:01<00:00,  6.85it/s]
Loading checkpoint shards:  90%|█████████ | 9/10 [00:01<00:00,  6.88it/s]
Loading checkpoint shards: 100%|██████████| 10/10 [00:01<00:00,  6.96it/s]
Loading checkpoint shards: 100%|██████████| 10/10 [00:01<00:00,  6.82it/s]

  0%|          | 0/2 [00:00<?, ?it/s]
 50%|█████     | 1/2 [00:06<00:06,  6.33s/it]
100%|██████████| 2/2 [00:06<00:00,  3.22s/it]
Froze 160 adapters.
LoRA -> xLoRA complete: Swapped 40 LoRA layers (out of 971 modules).
trainable params: 67,145,732 || all params: 9,472,667,652 || trainable%: 0.7088

Map:   0%|          | 0/14803 [00:00<?, ? examples/s]
Map:   7%|▋         | 1000/14803 [00:03<00:42, 327.62 examples/s]
Map:  14%|█▎        | 2000/14803 [00:05<00:36, 347.77 examples/s]
Map:  20%|██        | 3000/14803 [00:08<00:33, 356.42 examples/s]
Map:  27%|██▋       | 4000/14803 [00:11<00:29, 361.64 examples/s]
Map:  34%|███▍      | 5000/14803 [00:13<00:26, 363.21 examples/s]
Map:  41%|████      | 6000/14803 [00:16<00:24, 363.04 examples/s]
Map:  47%|████▋     | 7000/14803 [00:19<00:21, 363.56 examples/s]
Map:  54%|█████▍    | 8000/14803 [00:21<00:16, 413.31 examples/s]
Map:  61%|██████    | 9000/14803 [00:22<00:11, 504.23 examples/s]
Map:  68%|██████▊   | 10000/14803 [00:23<00:08, 597.65 examples/s]
Map:  74%|███████▍  | 11000/14803 [00:24<00:05, 681.07 examples/s]
Map:  81%|████████  | 12000/14803 [00:25<00:03, 754.44 examples/s]
Map:  88%|████████▊ | 13000/14803 [00:26<00:02, 809.81 examples/s]
Map:  95%|█████████▍| 14000/14803 [00:27<00:00, 856.43 examples/s]
Map: 100%|██████████| 14803/14803 [00:28<00:00, 890.57 examples/s]
Map: 100%|██████████| 14803/14803 [00:28<00:00, 528.34 examples/s]
train_dataset: Dataset({
    features: ['input_ids', 'labels'],
    num_rows: 14803
})

Map:   0%|          | 0/2 [00:00<?, ? examples/s]
Map: 100%|██████████| 2/2 [00:00<00:00, 187.78 examples/s]
val_dataset: Dataset({
    features: ['input_ids', 'output_ids'],
    num_rows: 2
})

Map:   0%|          | 0/2 [00:00<?, ? examples/s]
Map: 100%|██████████| 2/2 [00:00<00:00, 189.77 examples/s]
test_dataset: Dataset({
    features: ['input_ids', 'output_ids'],
    num_rows: 2
})
Detected kernel version 5.4.0, which is below the recommended minimum of 5.5.0; this can cause the process to hang. It is recommended to upgrade the kernel to the minimum version or higher.
max_steps is given, it will override any value given in num_train_epochs
resume checkpoint from checkpoint-20
Loading model from ./output_new/checkpoint-20.
Multiple active adapters detected will only consider the first adapter
[2024-10-15 18:47:30,968] [INFO] [real_accelerator.py:203:get_accelerator] Setting ds_accelerator to cuda (auto detect)
/home/zhangjunyi/anaconda3/lib/python3.11/site-packages/transformers/trainer.py:3098: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.
  torch.load(os.path.join(checkpoint, OPTIMIZER_NAME), map_location=map_location)
[rank0]: ╭───────────────────── Traceback (most recent call last) ──────────────────────╮
[rank0]: │ /home/zhangjunyi/hs_test/finetune_demo/finetune.py:615 in main               │
[rank0]: │                                                                              │
[rank0]: │   612 │   │   │   │   model.enable_input_require_grads()                     │
[rank0]: │   613 │   │   │   │   checkpoint_directory = os.path.join(output_dir, "check │
[rank0]: │   614 │   │   │   │   print("resume checkpoint from checkpoint-" + str(check │
[rank0]: │ ❱ 615 │   │   │   │   trainer.train(resume_from_checkpoint=checkpoint_direct │
[rank0]: │   616 │   │   │   else:                                                      │
[rank0]: │   617 │   │   │   │   trainer.train()                                        │
[rank0]: │   618 │   │   else:                                                          │
[rank0]: │                                                                              │
[rank0]: │ /home/zhangjunyi/anaconda3/lib/python3.11/site-packages/transformers/trainer │
[rank0]: │ .py:1938 in train                                                            │
[rank0]: │                                                                              │
[rank0]: │   1935 │   │   │   finally:                                                  │
[rank0]: │   1936 │   │   │   │   hf_hub_utils.enable_progress_bars()                   │
[rank0]: │   1937 │   │   else:                                                         │
[rank0]: │ ❱ 1938 │   │   │   return inner_training_loop(                               │
[rank0]: │   1939 │   │   │   │   args=args,                                            │
[rank0]: │   1940 │   │   │   │   resume_from_checkpoint=resume_from_checkpoint,        │
[rank0]: │   1941 │   │   │   │   trial=trial,                                          │
[rank0]: │                                                                              │
[rank0]: │ /home/zhangjunyi/anaconda3/lib/python3.11/site-packages/transformers/trainer │
[rank0]: │ .py:2126 in _inner_training_loop                                             │
[rank0]: │                                                                              │
[rank0]: │   2123 │   │   │   │   self._load_from_checkpoint(resume_from_checkpoint, se │
[rank0]: │   2124 │   │                                                                 │
[rank0]: │   2125 │   │   # Check if saved optimizer or scheduler states exist          │
[rank0]: │ ❱ 2126 │   │   self._load_optimizer_and_scheduler(resume_from_checkpoint)    │
[rank0]: │   2127 │   │                                                                 │
[rank0]: │   2128 │   │   # important: at this point:                                   │
[rank0]: │   2129 │   │   # self.model         is the Transformers Model                │
[rank0]: │                                                                              │
[rank0]: │ /home/zhangjunyi/anaconda3/lib/python3.11/site-packages/transformers/trainer │
[rank0]: │ .py:3097 in _load_optimizer_and_scheduler                                    │
[rank0]: │                                                                              │
[rank0]: │   3094 │   │   │   │   │   │   │   **_get_fsdp_ckpt_kwargs(),                │
[rank0]: │   3095 │   │   │   │   │   │   )                                             │
[rank0]: │   3096 │   │   │   │   │   else:                                             │
[rank0]: │ ❱ 3097 │   │   │   │   │   │   self.optimizer.load_state_dict(               │
[rank0]: │   3098 │   │   │   │   │   │   │   torch.load(os.path.join(checkpoint, OPTIM │
[rank0]: │   3099 │   │   │   │   │   │   )                                             │
[rank0]: │   3100 │   │   │   │   with warnings.catch_warnings(record=True) as caught_w │
[rank0]: │                                                                              │
[rank0]: │ /home/zhangjunyi/anaconda3/lib/python3.11/site-packages/accelerate/optimizer │
[rank0]: │ .py:107 in load_state_dict                                                   │
[rank0]: │                                                                              │
[rank0]: │   104 │   def load_state_dict(self, state_dict):                             │
[rank0]: │   105 │   │   if self.accelerator_state.distributed_type == DistributedType. │
[rank0]: │   106 │   │   │   xm.send_cpu_data_to_device(state_dict, self.accelerator_st │
[rank0]: │ ❱ 107 │   │   self.optimizer.load_state_dict(state_dict)                     │
[rank0]: │   108 │                                                                      │
[rank0]: │   109 │   def state_dict(self):                                              │
[rank0]: │   110 │   │   return self.optimizer.state_dict()                             │
[rank0]: │                                                                              │
[rank0]: │ /home/zhangjunyi/anaconda3/lib/python3.11/site-packages/torch/_compile.py:31 │
[rank0]: │ in inner                                                                     │
[rank0]: │                                                                              │
[rank0]: │   28 │   │   │   │   disable_fn = torch._dynamo.disable(fn, recursive)       │
[rank0]: │   29 │   │   │   │   fn.__dynamo_disable = disable_fn                        │
[rank0]: │   30 │   │   │                                                               │
[rank0]: │ ❱ 31 │   │   │   return disable_fn(*args, **kwargs)                          │
[rank0]: │   32 │   │                                                                   │
[rank0]: │   33 │   │   return inner                                                    │
[rank0]: │   34 │   else:                                                               │
[rank0]: │                                                                              │
[rank0]: │ /home/zhangjunyi/anaconda3/lib/python3.11/site-packages/torch/_dynamo/eval_f │
[rank0]: │ rame.py:600 in _fn                                                           │
[rank0]: │                                                                              │
[rank0]: │    597 │   │   def _fn(*args, **kwargs):                                     │
[rank0]: │    598 │   │   │   prior = set_eval_frame(callback)                          │
[rank0]: │    599 │   │   │   try:                                                      │
[rank0]: │ ❱  600 │   │   │   │   return fn(*args, **kwargs)                            │
[rank0]: │    601 │   │   │   finally:                                                  │
[rank0]: │    602 │   │   │   │   set_eval_frame(prior)                                 │
[rank0]: │    603                                                                       │
[rank0]: │                                                                              │
[rank0]: │ /home/zhangjunyi/anaconda3/lib/python3.11/site-packages/torch/optim/optimize │
[rank0]: │ r.py:854 in load_state_dict                                                  │
[rank0]: │                                                                              │
[rank0]: │    851 │   │   param_lens = (len(g["params"]) for g in groups)               │
[rank0]: │    852 │   │   saved_lens = (len(g["params"]) for g in saved_groups)         │
[rank0]: │    853 │   │   if any(p_len != s_len for p_len, s_len in zip(param_lens, sav │
[rank0]: │ ❱  854 │   │   │   raise ValueError(                                         │
[rank0]: │    855 │   │   │   │   "loaded state dict contains a parameter group "       │
[rank0]: │    856 │   │   │   │   "that doesn't match the size of optimizer's group"    │
[rank0]: │    857 │   │   │   )                                                         │
[rank0]: ╰──────────────────────────────────────────────────────────────────────────────╯
[rank0]: ValueError: loaded state dict contains a parameter group that doesn't match the 
[rank0]: size of optimizer's group
E1015 18:47:35.719000 139827737793152 torch/distributed/elastic/multiprocessing/api.py:833] failed (exitcode: 1) local_rank: 0 (pid: 4075482) of binary: /home/zhangjunyi/anaconda3/bin/python
Traceback (most recent call last):
  File "/home/zhangjunyi/anaconda3/bin/torchrun", line 8, in <module>
    sys.exit(main())
             ^^^^^^
  File "/home/zhangjunyi/anaconda3/lib/python3.11/site-packages/torch/distributed/elastic/multiprocessing/errors/__init__.py", line 348, in wrapper
    return f(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^
  File "/home/zhangjunyi/anaconda3/lib/python3.11/site-packages/torch/distributed/run.py", line 901, in main
    run(args)
  File "/home/zhangjunyi/anaconda3/lib/python3.11/site-packages/torch/distributed/run.py", line 892, in run
    elastic_launch(
  File "/home/zhangjunyi/anaconda3/lib/python3.11/site-packages/torch/distributed/launcher/api.py", line 133, in __call__
    return launch_agent(self._config, self._entrypoint, list(args))
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/zhangjunyi/anaconda3/lib/python3.11/site-packages/torch/distributed/launcher/api.py", line 264, in launch_agent
    raise ChildFailedError(
torch.distributed.elastic.multiprocessing.errors.ChildFailedError: 
============================================================
finetune.py FAILED

I would greatly appreciate any guidance on resolving this issue with xlora checkpoint restoration. If anyone has encountered a similar problem or has insights into specific settings or steps to enable successful checkpoint recovery for xlora, your advice would be invaluable. Additionally, if any maintainers or community members familiar with xlora could offer support, that would be extremely helpful. Many thanks

Copy link

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

@BenjaminBossan
Copy link
Member

not stale

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants