Skip to content

Commit

Permalink
PTQ via NeMo-Run CLI (#10984)
Browse files Browse the repository at this point in the history
* PTQ support in nemo CLI

Signed-off-by: Jan Lasek <janek.lasek@gmail.com>

* Naming engine vs checkpoint

Signed-off-by: Jan Lasek <janek.lasek@gmail.com>

---------

Signed-off-by: Jan Lasek <janek.lasek@gmail.com>
  • Loading branch information
janekl authored Nov 18, 2024
1 parent bca6b09 commit 956b54d
Show file tree
Hide file tree
Showing 3 changed files with 79 additions and 13 deletions.
3 changes: 2 additions & 1 deletion nemo/collections/llm/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,7 +214,7 @@
try:
import nemo_run as run

from nemo.collections.llm.api import export_ckpt, finetune, generate, import_ckpt, pretrain, train, validate
from nemo.collections.llm.api import export_ckpt, finetune, generate, import_ckpt, pretrain, ptq, train, validate
from nemo.collections.llm.recipes import * # noqa

__all__.extend(
Expand All @@ -226,6 +226,7 @@
"validate",
"finetune",
"generate",
"ptq",
]
)
except ImportError as error:
Expand Down
67 changes: 62 additions & 5 deletions nemo/collections/llm/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from typing_extensions import Annotated

import nemo.lightning as nl
from nemo.collections.llm.quantization import ExportConfig, QuantizationConfig
from nemo.lightning import (
AutoResume,
NeMoLogger,
Expand Down Expand Up @@ -67,7 +68,8 @@ def train(
resume (Optional[Union[AutoResume, Resume]]): Resume training from a checkpoint.
optim (Optional[OptimizerModule]): The optimizer module to be used. If not provided, the default optimizer
from the model will be used.
tokenizer (Optional[TokenizerType]): Tokenizer setting to be applied. Can be 'data' or 'model' or an instance of TokenizerSpec.
tokenizer (Optional[TokenizerType]): Tokenizer setting to be applied. Can be 'data' or 'model'
or an instance of TokenizerSpec.
export (Optional[str]): Filename to save the exported checkpoint after training.
model_transform (Optional[Union[Callable[[nn.Module], nn.Module], PEFT]]): A model transform to be applied.
Expand All @@ -83,7 +85,7 @@ def train(
>>> data = llm.SquadDataModule(seq_length=4096, global_batch_size=16, micro_batch_size=2)
>>> precision = nl.MegatronMixedPrecision(precision="bf16-mixed")
>>> trainer = nl.Trainer(strategy=nl.MegatronStrategy(tensor_model_parallel_size=2), plugins=precision)
>>> train(model, data, trainer, tokenizer="data")
>>> llm.train(model, data, trainer, tokenizer="data")
PosixPath('/path/to/log_dir')
"""
app_state = _setup(
Expand Down Expand Up @@ -185,7 +187,7 @@ def finetune(
>>> data = llm.SquadDataModule(seq_length=4096, global_batch_size=16, micro_batch_size=2)
>>> precision = nl.MegatronMixedPrecision(precision="bf16-mixed")
>>> trainer = nl.Trainer(strategy=nl.MegatronStrategy(tensor_model_parallel_size=2), plugins=precision)
>>> finetune(model, data, trainer, peft=llm.peft.LoRA()])
>>> llm.finetune(model, data, trainer, peft=llm.peft.LoRA()])
PosixPath('/path/to/log_dir')
"""

Expand Down Expand Up @@ -223,7 +225,8 @@ def validate(
resume (Optional[AutoResume]): Resume from a checkpoint for validation.
optim (Optional[OptimizerModule]): The optimizer module to be used. If not provided, the default optimizer
from the model will be used.
tokenizer (Optional[TokenizerType]): Tokenizer setting to be applied. Can be 'data' or 'model' or an instance of TokenizerSpec.
tokenizer (Optional[TokenizerType]): Tokenizer setting to be applied. Can be 'data' or 'model'
or an instance of TokenizerSpec.
model_transform (Optional[Union[Callable[[nn.Module], nn.Module], PEFT]]): A model transform to be applied.
Returns:
Expand All @@ -236,7 +239,7 @@ def validate(
>>> data = llm.SquadDataModule(seq_length=4096, global_batch_size=16, micro_batch_size=2)
>>> precision = nl.MegatronMixedPrecision(precision="bf16-mixed")
>>> trainer = nl.Trainer(strategy=nl.MegatronStrategy(tensor_model_parallel_size=2), plugins=precision)
>>> validate(model, data, trainer, tokenizer="data")
>>> llm.validate(model, data, trainer, tokenizer="data")
PosixPath('/path/to/log_dir')
"""
app_state = _setup(
Expand All @@ -255,6 +258,60 @@ def validate(
return app_state.exp_dir


@run.cli.entrypoint(name="ptq", namespace="llm")
def ptq(
nemo_checkpoint: str,
calib_tp: int = 1,
calib_pp: int = 1,
quantization_config: Annotated[Optional[QuantizationConfig], run.Config[QuantizationConfig]] = None,
export_config: Optional[Union[ExportConfig, run.Config[ExportConfig]]] = None,
) -> Path:
# TODO: Fix "nemo_run.cli.cli_parser.CLIException: An unexpected error occurred (Argument: , Context: {})"
"""
Applies Post-Training Quantization (PTQ) for a model using the specified quantization and export configs. It runs
calibration for a small dataset to collect scaling factors low-precision GEMMs used by desired quantization method.
This function produces TensorRT-LLM checkpoint ready for deployment using nemo.export and nemo.deploy modules
or direcly using TensorRT-LLM library.
The function can be used through the NeMo CLI in the following way:
```bash
# Run calibration using tensor parallel set to 8 and export quantized checkpoint with tensor parallel equal 2
nemo llm ptq nemo_checkpoint=/models/Llama-3-70B \
export_config.path=/models/Llama-3-70B-FP8 \
calib_tp=8 \
export_config.inference_tensor_parallel=2
# Choose different quantization method, for example, INT8 SmoothQuant
nemo llm ptq nemo_checkpoint=/models/Llama-3-8B \
export_config.path=/models/Llama-3-8B-INT8_SQ \
quantization_config.algorithm=int8_sq
```
Args:
nemo_checkpoint (str): The path to model to be quantized.
calib_tp (int): Calibration tensor parallelism.
calib_pp (int): Calibration pipeline parallelism.
quantization_config (QuantizationConfig): Configuration for quantization algorithm.
export_config (ExportConfig): Export configuration for TensorRT-LLM checkpoint.
Returns:
Path: The path where the quantized checkpoint has been saved after calibration.
"""
if export_config.path is None:
raise ValueError("The export_config.path needs to be specified, got None.")

from nemo.collections.llm import quantization

quantizer = quantization.Quantizer(quantization_config, export_config)

model = quantization.load_with_modelopt_layer_spec(nemo_checkpoint, calib_tp, calib_pp)

model = quantizer.quantize(model)

quantizer.export(model, nemo_checkpoint)

console = Console()
console.print(f"[green]✓ PTQ succeded, quantized checkpoint exported to {export_config.path}[/green]")

return export_config.path


@run.cli.entrypoint(namespace="llm")
def deploy(
nemo_checkpoint: Path = None,
Expand Down
22 changes: 15 additions & 7 deletions nemo/collections/llm/quantization/quantizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import os
import shutil
from dataclasses import dataclass
from pathlib import Path
from typing import Optional, Union

import torch
Expand Down Expand Up @@ -75,17 +76,20 @@ class QuantizationConfig:

@dataclass
class ExportConfig:
"""Inference configuration for the quantized TensorRT-LLM engine"""
"""Inference configuration for the quantized TensorRT-LLM checkpoint."""

path: str
path: Union[Path, str]
dtype: Union[str, int] = "bf16"
decoder_type: Optional[str] = None
inference_tensor_parallel: int = 1
inference_pipeline_parallel: int = 1

def __post_init__(self):
self.path = Path(self.path)


def get_modelopt_decoder_type(config: llm.GPTConfig) -> str:
"""Infers the modelopt decoder type from GPTConfig class"""
"""Infers the modelopt decoder type from GPTConfig class."""
mapping = [
(llm.Baichuan2Config, "baichuan"),
(llm.ChatGLMConfig, "chatglm"),
Expand All @@ -109,17 +113,17 @@ def get_modelopt_decoder_type(config: llm.GPTConfig) -> str:


class Quantizer:
"""Post-training quantization (PTQ) and TRT-LLM export of NeMo 2.0 checkpoints.
"""Post-training quantization (PTQ) and TensorRT-LLM export of NeMo 2.0 checkpoints.
PTQ converts selected model layers to low-precision format (e.g., INT4, FP8) for efficient serving.
The process consist of several steps:
1. Loading a Nemo model from disk using appropriate parallelism strategy
2. Calibrating the model to obtain appropriate algorithm-specific scaling factors
3. Producing output directory
3. Producing an output directory with a quantized checkpoint and a tokenizer
The output directory produced is intended to be consumed by TensorRT-LLM toolbox
for efficient inference. This can be achieved using NeMo inference containers.
for efficient inference. This can be achieved using nemo.export.tensorrt_llm module.
"""

def __init__(self, quantization_config: QuantizationConfig, export_config: ExportConfig):
Expand Down Expand Up @@ -231,6 +235,7 @@ def quantize(self, model: llm.GPTModel, forward_loop=None):
def create_megatron_forward_loop(
self, get_dataloader, num_batches, seq_length=None, micro_batch_size=None, decoder_seq_length=None
):
"""Create a forward loop for over a given data iterator."""
from megatron.core.pipeline_parallel.schedules import get_forward_backward_func

forward_backward_func = get_forward_backward_func()
Expand Down Expand Up @@ -262,6 +267,7 @@ def loop(model):
return loop

def export(self, model: llm.GPTModel, model_dir: str) -> None:
"""Export model to a TensorRT-LLM checkpoint."""
assert self.export_config is not None, "Export config is not set"
# TODO: Add sample generate
# TODO: Support megatron_amp_O2
Expand Down Expand Up @@ -294,7 +300,7 @@ def export(self, model: llm.GPTModel, model_dir: str) -> None:
def get_calib_data_iter(
data: str = "cnn_dailymail", batch_size: int = 64, calib_size: int = 512, max_sequence_length: int = 512
):
"""Creates a sample data iterator for calibration"""
"""Creates a sample data iterator for calibration."""
if data == "wikitext":
dataset = load_dataset("wikitext", "wikitext-103-v1", split="train")
text_column = "text"
Expand All @@ -314,6 +320,8 @@ def get_calib_data_iter(


def create_data_iterator_getter(model, dataset, seq_len, batch_size, calibration_size):
"""Create a function that provides iterator over a given dataset."""

def _iterator():
CHARACTERS_PER_TOKEN = 4

Expand Down

0 comments on commit 956b54d

Please sign in to comment.