Skip to content

Commit

Permalink
clean up **kwargs of Oneshot
Browse files Browse the repository at this point in the history
  • Loading branch information
horheynm committed Jan 16, 2025
1 parent bd1385e commit d52dbf3
Show file tree
Hide file tree
Showing 3 changed files with 9 additions and 10 deletions.
6 changes: 2 additions & 4 deletions src/llmcompressor/transformers/calibration/oneshot.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ class Oneshot:
"""
Class responsible for carrying out oneshot calibration.
Usage:
```python
Expand All @@ -48,15 +49,12 @@ class Oneshot:

def __init__(
self,
output_dir: Optional[str] = None,
**kwargs,
):
self.model_args, self.data_args, self.recipe_args, _, output_dir_parser = (
self.model_args, self.data_args, self.recipe_args, _, self.output_dir = (
parse_args(**kwargs)
)

self.output_dir = output_dir or output_dir_parser

# Preprocess the model and tokenizer/processor
self._pre_process()

Expand Down
11 changes: 6 additions & 5 deletions src/llmcompressor/transformers/finetune/text_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@
get_processor_from_model,
)
from llmcompressor.transformers.utils.arg_parser import (
DEFAULT_OUTPUT_DIR,
DatasetArguments,
ModelArguments,
RecipeArguments,
Expand All @@ -65,7 +66,7 @@ def train(**kwargs):
"""
CLI entrypoint for running training
"""
model_args, data_args, recipe_args, training_args = parse_args(
model_args, data_args, recipe_args, training_args, _ = parse_args(
include_training_args=True, **kwargs
)
training_args.do_train = True
Expand All @@ -76,7 +77,7 @@ def eval(**kwargs):
"""
CLI entrypoint for running evaluation
"""
model_args, data_args, recipe_args, training_args = parse_args(
model_args, data_args, recipe_args, training_args, _ = parse_args(
include_training_args=True, **kwargs
)
training_args.do_eval = True
Expand Down Expand Up @@ -142,6 +143,7 @@ def parse_args(include_training_args: bool = False, **kwargs):
conflict with accelerate library's accelerator.
"""
output_dir = kwargs.pop("output_dir", DEFAULT_OUTPUT_DIR)

if include_training_args:
parser = HfArgumentParser(
Expand All @@ -158,6 +160,8 @@ def parse_args(include_training_args: bool = False, **kwargs):
# Unpack parsed arguments based on the presence of training arguments
if include_training_args:
model_args, data_args, recipe_args, training_args = parsed_args
if output_dir is not None:
training_args.output_dir = output_dir
else:
model_args, data_args, recipe_args = parsed_args
training_args = None
Expand Down Expand Up @@ -185,9 +189,6 @@ def parse_args(include_training_args: bool = False, **kwargs):
model_args.processor = model_args.tokenizer
model_args.tokenizer = None

# Handle output_dir only if training arguments are included
output_dir = training_args.output_dir if training_args else None

return model_args, data_args, recipe_args, training_args, output_dir


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,4 +3,4 @@
from .data_arguments import DatasetArguments
from .model_arguments import ModelArguments
from .recipe_arguments import RecipeArguments
from .training_arguments import TrainingArguments
from .training_arguments import DEFAULT_OUTPUT_DIR, TrainingArguments

0 comments on commit d52dbf3

Please sign in to comment.