Skip to content

Commit

Permalink
move args
Browse files Browse the repository at this point in the history
  • Loading branch information
horheynm committed Jan 15, 2025
1 parent 59d5d63 commit b5f75d5
Show file tree
Hide file tree
Showing 14 changed files with 82 additions and 51 deletions.
23 changes: 7 additions & 16 deletions src/llmcompressor/transformers/calibration/oneshot.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,22 +5,21 @@
from torch.utils.data import DataLoader

from llmcompressor.core.session_functions import active_session
from llmcompressor.transformers.finetune.data.data_args import DataTrainingArguments
from llmcompressor.transformers.finetune.data.data_helpers import (
get_calibration_dataloader,
)
from llmcompressor.transformers.finetune.model_args import ModelArguments
from llmcompressor.transformers.finetune.text_generation import (
initialize_model_from_path,
initialize_processor_from_path,
parse_args,
)
from llmcompressor.transformers.finetune.training_args import DEFAULT_OUTPUT_DIR
from llmcompressor.transformers.sparsification.compressed_tensors_utils import (
modify_save_pretrained,
patch_tied_tensors_bug,
)
from llmcompressor.transformers.utils.recipe_args import RecipeArguments
from llmcompressor.transformers.utils.arg_parser.training_arguments import (
DEFAULT_OUTPUT_DIR,
)

__all__ = ["Oneshot"]

Expand Down Expand Up @@ -49,22 +48,14 @@ class Oneshot:

def __init__(
self,
model_args: Optional["ModelArguments"] = None,
data_args: Optional["DataTrainingArguments"] = None,
recipe_args: Optional["RecipeArguments"] = None,
output_dir: Optional[str] = None,
**kwargs,
):
if any(arg is not None for arg in [model_args, data_args, recipe_args]):
self.model_args = model_args
self.data_args = self.data_args
self.recipe_args = self.recipe_args
else:
self.model_args, self.data_args, self.recipe_args, _, output_dir = (
parse_args(**kwargs)
)
self.model_args, self.data_args, self.recipe_args, _, output_dir_parser = (
parse_args(**kwargs)
)

self.output_dir = output_dir
self.output_dir = output_dir or output_dir_parser

# Preprocess the model and tokenizer/processor
self._pre_process()
Expand Down
4 changes: 1 addition & 3 deletions src/llmcompressor/transformers/finetune/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
# flake8: noqa

from .data import DataTrainingArguments, TextGenerationDataset
from .model_args import ModelArguments
from .data import TextGenerationDataset
from .session_mixin import SessionManagerMixIn
from .text_generation import apply, compress, eval, oneshot, train
from .training_args import TrainingArguments
1 change: 0 additions & 1 deletion src/llmcompressor/transformers/finetune/data/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
from .c4 import C4Dataset
from .cnn_dailymail import CNNDailyMailDataset
from .custom import CustomDataset
from .data_args import DataTrainingArguments
from .evolcodealpaca import EvolCodeAlpacaDataset
from .flickr_30k import Flickr30K
from .gsm8k import GSM8KDataset
Expand Down
4 changes: 2 additions & 2 deletions src/llmcompressor/transformers/finetune/data/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,12 @@
from datasets.formatting.formatting import LazyRow
from loguru import logger

from llmcompressor.transformers.finetune.data.data_args import DataTrainingArguments
from llmcompressor.transformers.finetune.data.data_helpers import (
LABELS_MASK_VALUE,
get_custom_datasets_from_path,
get_raw_dataset,
)
from llmcompressor.transformers.utils.arg_parser import DatasetArguments
from llmcompressor.transformers.utils.preprocessing_functions import (
PreprocessingFunctionRegistry,
)
Expand Down Expand Up @@ -41,7 +41,7 @@ class TextGenerationDataset(RegistryMixin):

def __init__(
self,
data_args: DataTrainingArguments,
data_args: DatasetArguments,
split: str,
processor: Processor,
):
Expand Down
33 changes: 19 additions & 14 deletions src/llmcompressor/transformers/finetune/runner.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import math
import os
import re
from dataclasses import asdict
from typing import List, Optional

import torch
Expand All @@ -16,14 +17,19 @@
from llmcompressor.pytorch.utils import tensors_to_device
from llmcompressor.recipe import Recipe, StageRunType
from llmcompressor.transformers.finetune.data import TextGenerationDataset
from llmcompressor.transformers.finetune.data.data_args import DataTrainingArguments
from llmcompressor.transformers.finetune.data.data_helpers import (
format_calibration_data,
make_dataset_splits,
)
from llmcompressor.transformers.finetune.model_args import ModelArguments
from llmcompressor.transformers.finetune.training_args import TrainingArguments
from llmcompressor.transformers.utils.recipe_args import RecipeArguments
from llmcompressor.transformers.utils.arg_parser import (
DatasetArguments,
ModelArguments,
RecipeArguments,
TrainingArguments,
)
from llmcompressor.transformers.utils.arg_parser.training_arguments import (
DEFAULT_OUTPUT_DIR,
)
from llmcompressor.typing import Processor
from llmcompressor.utils.fsdp.helpers import is_fsdp_model, save_model_and_recipe

Expand All @@ -47,7 +53,7 @@ class StageRunner:

def __init__(
self,
data_args: "DataTrainingArguments",
data_args: "DatasetArguments",
model_args: "ModelArguments",
training_args: "TrainingArguments",
recipe_args: "RecipeArguments",
Expand Down Expand Up @@ -260,22 +266,21 @@ def run_sequential_stages(self, checkpoint: Optional[str] = None):
self._model_args.model = model

oneshot = Oneshot(
# lifecycle=active_session()._lifecycle,
model_args=self._model_args,
data_args=self._data_args,
recipe_args=self._recipe_args,
# training_args=self._training_args,
# model_args=self._model_args,
# data_args=self._data_args,
# recipe_args=self._recipe_args,
# output_dir=self._training_args.output_dir,
**asdict(self._model_args),
**asdict(self._data_args),
**asdict(self._recipe_args),
output_dir=self._training_args.output_dir,
)
oneshot.run(stage_name=stage_name)
elif run_type is StageRunType.TRAIN:
self.train(checkpoint=checkpoint, stage=stage_name)
checkpoint = None

if (
self._training_args.output_dir
!= TrainingArguments.__dataclass_fields__["output_dir"].default
):
if self._training_args.output_dir != DEFAULT_OUTPUT_DIR:
save_model_and_recipe(
model=self.trainer.model,
save_path=self._output_dir,
Expand Down
9 changes: 5 additions & 4 deletions src/llmcompressor/transformers/finetune/session_mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,14 +31,15 @@
DisableHalfPrecisionCallback,
TrainingLoopCallbacks,
)
from llmcompressor.transformers.finetune.model_args import ModelArguments
from llmcompressor.utils.fsdp.context import summon_full_params_context
from llmcompressor.utils.fsdp.helpers import is_fsdp_model, save_pretrained_fsdp
from llmcompressor.utils.pytorch import qat_active

if TYPE_CHECKING:
from llmcompressor.transformers import DataTrainingArguments

from llmcompressor.transformers.utils.arg_parser import (
DatasetArguments,
ModelArguments,
)

__all__ = [
"SessionManagerMixIn",
Expand Down Expand Up @@ -69,7 +70,7 @@ def __init__(
self,
recipe: Optional[str] = None,
recipe_args: Optional[Union[Dict[str, Any], str]] = None,
data_args: Optional["DataTrainingArguments"] = None,
data_args: Optional["DatasetArguments"] = None,
model_args: Optional["ModelArguments"] = None,
teacher: Optional[Union[Module, str]] = None,
**kwargs,
Expand Down
16 changes: 9 additions & 7 deletions src/llmcompressor/transformers/finetune/text_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,11 +40,8 @@
parse_dtype,
)
from llmcompressor.recipe import Recipe, StageRunType
from llmcompressor.transformers.finetune.data.data_args import DataTrainingArguments
from llmcompressor.transformers.finetune.model_args import ModelArguments
from llmcompressor.transformers.finetune.runner import StageRunner
from llmcompressor.transformers.finetune.trainer import Trainer
from llmcompressor.transformers.finetune.training_args import TrainingArguments
from llmcompressor.transformers.sparsification.compressed_tensors_utils import (
modify_fsdp_model_save_pretrained,
modify_save_pretrained,
Expand All @@ -53,8 +50,13 @@
from llmcompressor.transformers.sparsification.sparse_model import (
get_processor_from_model,
)
from llmcompressor.transformers.utils.arg_parser import (
DatasetArguments,
ModelArguments,
RecipeArguments,
TrainingArguments,
)
from llmcompressor.transformers.utils.helpers import detect_last_checkpoint
from llmcompressor.transformers.utils.recipe_args import RecipeArguments
from llmcompressor.typing import Processor
from llmcompressor.utils.fsdp.helpers import is_fsdp_model

Expand Down Expand Up @@ -111,7 +113,7 @@ def compress(**kwargs):

def load_dataset(dataset_name: str, **kwargs):
parser = HfArgumentParser(
(ModelArguments, DataTrainingArguments, RecipeArguments, TrainingArguments)
(ModelArguments, DatasetArguments, RecipeArguments, TrainingArguments)
)
_, data_args, _, _ = parser.parse_dict(kwargs)
data_args["dataset_name"] = dataset_name
Expand All @@ -130,7 +132,7 @@ def parse_args(**kwargs):
"""

parser = HfArgumentParser(
(ModelArguments, DataTrainingArguments, RecipeArguments, TrainingArguments)
(ModelArguments, DatasetArguments, RecipeArguments, TrainingArguments)
)

if not kwargs:
Expand Down Expand Up @@ -297,7 +299,7 @@ def initialize_processor_from_path(

def main(
model_args: ModelArguments,
data_args: DataTrainingArguments,
data_args: DatasetArguments,
recipe_args: RecipeArguments,
training_args: TrainingArguments,
):
Expand Down
6 changes: 6 additions & 0 deletions src/llmcompressor/transformers/utils/arg_parser/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
# flake8: noqa

from .data_arguments import DatasetArguments
from .model_arguments import ModelArguments
from .recipe_arguments import RecipeArguments
from .training_arguments import TrainingArguments
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@


@dataclass
class DVCDatasetTrainingArguments:
class DVCDatasetArguments:
"""
Arguments for training using DVC
"""
Expand All @@ -17,7 +17,7 @@ class DVCDatasetTrainingArguments:


@dataclass
class CustomDataTrainingArguments(DVCDatasetTrainingArguments):
class CustomDatasetArguments(DVCDatasetArguments):
"""
Arguments for training using custom datasets
"""
Expand Down Expand Up @@ -67,7 +67,7 @@ class CustomDataTrainingArguments(DVCDatasetTrainingArguments):


@dataclass
class DataTrainingArguments(CustomDataTrainingArguments):
class DatasetArguments(CustomDatasetArguments):
"""
Arguments pertaining to what data we are going to input our model for
training and eval
Expand Down
26 changes: 26 additions & 0 deletions src/llmcompressor/transformers/utils/arg_parser/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
from dataclasses import fields
from typing import Any, Dict, Union

from .data_arguments import DatasetArguments
from .model_arguments import ModelArguments
from .recipe_arguments import RecipeArguments
from .training_arguments import TrainingArguments


def get_dataclass_as_dict(
dataclass_instance: Union[
"ModelArguments", "RecipeArguments", "DatasetArguments", "TrainingArguments"
],
dataclass_class: Union[
"ModelArguments", "RecipeArguments", "DatasetArguments", "TrainingArguments"
],
) -> Dict[str, Any]:
"""
Get the dataclass instance attributes as a dict, neglicting the inherited class.
Ex. dataclass_class=TrainingArguments will ignore HFTrainignArguments
"""
return {
field.name: getattr(dataclass_instance, field.name)
for field in fields(dataclass_class)
}
5 changes: 4 additions & 1 deletion src/llmcompressor/transformers/utils/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,10 @@
from transformers.trainer_utils import get_last_checkpoint

if TYPE_CHECKING:
from llmcompressor.transformers import ModelArguments, TrainingArguments
from llmcompressor.transformers.utils.arg_parser import (
ModelArguments,
TrainingArguments,
)

__all__ = [
"RECIPE_FILE_NAME",
Expand Down

0 comments on commit b5f75d5

Please sign in to comment.