Skip to content

Commit

Permalink
Feat (mx): PTQ MX + Float support (#1010)
Browse files Browse the repository at this point in the history
---------

Co-authored-by: Nick Fraser <icanlosh@gmail.com>
  • Loading branch information
Giuseppe5 and nickfraser authored Sep 5, 2024
1 parent d4834bd commit b889bb2
Show file tree
Hide file tree
Showing 5 changed files with 94 additions and 48 deletions.
5 changes: 1 addition & 4 deletions src/brevitas_examples/common/generative/quant_blocks.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,12 @@
# SPDX-License-Identifier: BSD-3-Clause
"""

from typing import Callable, List, Optional, Tuple
from typing import Callable

import torch
from torch import Tensor
import torch.nn as nn

import brevitas
from brevitas.core.function_wrapper.shape import PermuteDims
from brevitas.core.utils import SliceTensor
from brevitas.core.zero_point import _ScaleShiftZeroPoint
from brevitas.function.ops_ste import abs_binary_sign_grad

Expand Down
26 changes: 13 additions & 13 deletions src/brevitas_examples/imagenet_classification/ptq/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,8 @@ usage: ptq_evaluate.py [-h] --calibration-dir CALIBRATION_DIR --validation-dir
[--bias-bit-width {32,16,None}]
[--act-quant-type {sym,asym}]
[--weight-quant-type {sym,asym}]
[--weight-quant-granularity {per_tensor,per_channel}]
[--weight-quant-granularity {per_tensor,per_channel,per_group}]
[--act-quant-granularity {per_tensor,per_group}]
[--weight-quant-calibration-type {stats,mse}]
[--act-equalization {fx,layerwise,None}]
[--act-quant-calibration-type {stats,mse}]
Expand All @@ -90,11 +91,11 @@ usage: ptq_evaluate.py [-h] --calibration-dir CALIBRATION_DIR --validation-dir
[--learned-round-lr LEARNED_ROUND_LR]
[--act-quant-percentile ACT_QUANT_PERCENTILE]
[--export-onnx-qcdq] [--export-torch-qcdq]
[--scaling-per-output-channel | --no-scaling-per-output-channel]
[--bias-corr | --no-bias-corr]
[--graph-eq-merge-bias | --no-graph-eq-merge-bias]
[--weight-narrow-range | --no-weight-narrow-range]
[--gpfq-p GPFQ_P] [--quant-format {int,float}]
[--gpfq-p GPFQ_P]
[--quant-format {int,float,float_ocp}]
[--layerwise-first-last-mantissa-bit-width LAYERWISE_FIRST_LAST_MANTISSA_BIT_WIDTH]
[--layerwise-first-last-exponent-bit-width LAYERWISE_FIRST_LAST_EXPONENT_BIT_WIDTH]
[--weight-mantissa-bit-width WEIGHT_MANTISSA_BIT_WIDTH]
Expand All @@ -104,6 +105,7 @@ usage: ptq_evaluate.py [-h] --calibration-dir CALIBRATION_DIR --validation-dir
[--accumulator-bit-width ACCUMULATOR_BIT_WIDTH]
[--onnx-opset-version ONNX_OPSET_VERSION]
[--channel-splitting-ratio CHANNEL_SPLITTING_RATIO]
[--compression-rate COMPRESSION_RATE]
[--gptq | --no-gptq] [--gpfq | --no-gpfq]
[--gpfa2q | --no-gpfa2q]
[--gpxq-act-order | --no-gpxq-act-order]
Expand All @@ -115,7 +117,7 @@ usage: ptq_evaluate.py [-h] --calibration-dir CALIBRATION_DIR --validation-dir

PyTorch ImageNet PTQ Validation

options:
optional arguments:
-h, --help show this help message and exit
--calibration-dir CALIBRATION_DIR
Path to folder containing Imagenet calibration folder
Expand Down Expand Up @@ -176,7 +178,9 @@ options:
Activation quantization type (default: sym)
--weight-quant-type {sym,asym}
Weight quantization type (default: sym)
--weight-quant-granularity {per_tensor,per_channel}
--weight-quant-granularity {per_tensor,per_channel,per_group}
Weight quantization type (default: per_tensor)
--act-quant-granularity {per_tensor,per_group}
Activation quantization type (default: per_tensor)
--weight-quant-calibration-type {stats,mse}
Weight quantization calibration type (default: stats)
Expand All @@ -201,12 +205,6 @@ options:
(default: 99.999)
--export-onnx-qcdq If true, export the model in onnx qcdq format
--export-torch-qcdq If true, export the model in torch qcdq format
--scaling-per-output-channel
Enable Weight scaling per output channel (default:
enabled)
--no-scaling-per-output-channel
Disable Weight scaling per output channel (default:
enabled)
--bias-corr Enable Bias correction after calibration (default:
enabled)
--no-bias-corr Disable Bias correction after calibration (default:
Expand All @@ -224,7 +222,7 @@ options:
Disable Narrow range for weight quantization (default:
disabled)
--gpfq-p GPFQ_P P parameter for GPFQ (default: 1.0)
--quant-format {int,float}
--quant-format {int,float,float_ocp}
Quantization format to use for weights and activations
(default: int)
--layerwise-first-last-mantissa-bit-width LAYERWISE_FIRST_LAST_MANTISSA_BIT_WIDTH
Expand Down Expand Up @@ -252,6 +250,9 @@ options:
--channel-splitting-ratio CHANNEL_SPLITTING_RATIO
Split Ratio for Channel Splitting. When set to 0.0,
Channel Splitting will not be applied. (default: 0.0)
--compression-rate COMPRESSION_RATE
Specify compression rate < 1.0 for random projection.
Default is 0.0 and does not use RP.
--gptq Enable GPTQ (default: disabled)
--no-gptq Disable GPTQ (default: disabled)
--gpfq Enable GPFQ (default: disabled)
Expand Down Expand Up @@ -280,7 +281,6 @@ options:
--no-uint_sym_act_for_unsigned_values
Disable Use unsigned act quant when possible (default:
enabled)

```

The script requires to specify the calibration folder (`--calibration-dir`), from which the calibration samples will be taken (configurable with the `--calibration-samples` argument), and a validation folder (`--validation-dir`).
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,9 @@ def unique(sequence):
'act_bit_width': [8], # Act bit width
'bias_bit_width': [32], # Bias Bit-Width for Po2 scale
'weight_quant_granularity': ['per_channel'], # Scaling Per Output Channel
'act_quant_granularity': ['per_tensor'], # Scaling Per Output Channel
'act_quant_type': ['sym'], # Act Quant Type
'act_scale_computation_type': ['static'], # Act Quant Type
'act_param_method': ['stats'], # Act Param Method
'weight_param_method': ['mse'], # Weight Quant Type
'bias_corr': [True], # Bias Correction
Expand Down Expand Up @@ -240,7 +242,9 @@ def ptq_torchvision_models(args):
weight_param_method=config_namespace.weight_param_method,
act_param_method=config_namespace.act_param_method,
bias_bit_width=config_namespace.bias_bit_width,
act_scale_computation_type=config_namespace.act_scale_computation_type,
weight_quant_granularity=config_namespace.weight_quant_granularity,
act_quant_granularity=config_namespace.act_quant_granularity,
act_quant_percentile=config_namespace.act_quant_percentile,
act_quant_type=config_namespace.act_quant_type,
scale_factor_type=config_namespace.scale_factor_type,
Expand Down
91 changes: 67 additions & 24 deletions src/brevitas_examples/imagenet_classification/ptq/ptq_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,20 @@
from brevitas.quant.experimental.float import Fp8e4m3WeightPerChannelFloatMSE
from brevitas.quant.experimental.float import Fp8e4m3WeightPerTensorFloat
from brevitas.quant.experimental.float import Fp8e4m3WeightPerTensorFloatMSE
from brevitas.quant.experimental.float_quant_ocp import Fp8e4m3OCPActPerTensorFloat
from brevitas.quant.experimental.float_quant_ocp import Fp8e4m3OCPActPerTensorFloatMSE
from brevitas.quant.experimental.float_quant_ocp import Fp8e4m3OCPWeightPerChannelFloat
from brevitas.quant.experimental.float_quant_ocp import Fp8e4m3OCPWeightPerChannelFloatMSE
from brevitas.quant.experimental.float_quant_ocp import Fp8e4m3OCPWeightPerTensorFloat
from brevitas.quant.experimental.float_quant_ocp import Fp8e4m3OCPWeightPerTensorFloatMSE
from brevitas.quant.experimental.mx_quant_ocp import MXFloat8e4m3Act
from brevitas.quant.experimental.mx_quant_ocp import MXFloat8e4m3Weight
from brevitas.quant.experimental.mx_quant_ocp import MXFloat8e4m3WeightMSE
from brevitas.quant.experimental.mx_quant_ocp import MXInt8Act
from brevitas.quant.experimental.mx_quant_ocp import MXInt8Weight
from brevitas.quant.experimental.mx_quant_ocp import MXInt8WeightMSE
from brevitas.quant.experimental.mx_quant_ocp import ShiftedMXUInt8Weight
from brevitas.quant.experimental.mx_quant_ocp import ShiftedMXUInt8WeightMSE
from brevitas.quant.fixed_point import Int8ActPerTensorFixedPoint
from brevitas.quant.fixed_point import Int8ActPerTensorFixedPointMSE
from brevitas.quant.fixed_point import Int8WeightPerChannelFixedPoint
Expand Down Expand Up @@ -96,12 +110,16 @@ class CNNInt8DynamicActPerTensorFloat(Int8DynamicActPerTensorFloat):
'per_tensor': {
'sym': Int8WeightPerTensorFixedPoint},
'per_channel': {
'sym': Int8WeightPerChannelFixedPoint},},
'sym': Int8WeightPerChannelFixedPoint},
'per_group': {
'sym': MXInt8Weight, 'asym': ShiftedMXUInt8Weight}},
'mse': {
'per_tensor': {
'sym': Int8WeightPerTensorFixedPointMSE},
'per_channel': {
'sym': Int8WeightPerChannelFixedPointMSE}},}},
'sym': Int8WeightPerChannelFixedPointMSE},
'per_group': {
'sym': MXInt8WeightMSE, 'asym': ShiftedMXUInt8WeightMSE}},}},
'float': {
'float_scale': {
'stats': {
Expand All @@ -113,7 +131,26 @@ class CNNInt8DynamicActPerTensorFloat(Int8DynamicActPerTensorFloat):
'per_tensor': {
'sym': Fp8e4m3WeightPerTensorFloatMSE},
'per_channel': {
'sym': Fp8e4m3WeightPerChannelFloatMSE}}}}}
'sym': Fp8e4m3WeightPerChannelFloatMSE}}}},
'float_ocp': {
'float_scale': {
'stats': {
'per_tensor': {
'sym': Fp8e4m3OCPWeightPerTensorFloat},
'per_channel': {
'sym': Fp8e4m3OCPWeightPerChannelFloat}},
'mse': {
'per_tensor': {
'sym': Fp8e4m3OCPWeightPerTensorFloatMSE},
'per_channel': {
'sym': Fp8e4m3OCPWeightPerChannelFloatMSE}}},
'po2_scale': {
'stats': {
'per_group': {
'sym': MXFloat8e4m3Weight}},
'mse': {
'per_group': {
'sym': MXFloat8e4m3WeightMSE}}}}}

INPUT_QUANT_MAP = {
'int': {
Expand All @@ -139,7 +176,10 @@ class CNNInt8DynamicActPerTensorFloat(Int8DynamicActPerTensorFloat):
'stats': {
'per_tensor': {
'sym': CNNInt8DynamicActPerTensorFloat,
'asym': CNNShiftedUint8DynamicActPerTensorFloat}}}}},
'asym': CNNShiftedUint8DynamicActPerTensorFloat}}},
'po2_scale': {
'stats': {
'per_group': MXInt8Act}}}},
'float': {
'static': {
'float_scale': {
Expand All @@ -148,7 +188,21 @@ class CNNInt8DynamicActPerTensorFloat(Int8DynamicActPerTensorFloat):
'sym': Fp8e4m3ActPerTensorFloat}},
'mse': {
'per_tensor': {
'sym': Fp8e4m3ActPerTensorFloatMSE}}}}}}
'sym': Fp8e4m3ActPerTensorFloatMSE}}}}},
'float_ocp': {
'static': {
'float_scale': {
'stats': {
'per_tensor': {
'sym': Fp8e4m3OCPActPerTensorFloat}},
'mse': {
'per_tensor': {
'sym': Fp8e4m3OCPActPerTensorFloatMSE}}}},
'dynamic': {
'po2_scale': {
'stats': {
'per_group': {
'sym': MXFloat8e4m3Act}}}}}}


def quantize_model(
Expand Down Expand Up @@ -252,14 +306,14 @@ def layerwise_bit_width_fn_weight(module):
weight_bit_width_dict['weight_bit_width'] = weight_bit_width
act_bit_width_dict['act_bit_width'] = act_bit_width

if quant_format == 'float' and backend == 'layerwise':
if 'float' in quant_format and backend == 'layerwise':
weight_bit_width_dict['weight_bit_width'] = layerwise_bit_width_fn_weight
act_bit_width_dict['act_bit_width'] = layerwise_bit_width_fn_act
weight_bit_width_dict['weight_mantissa_bit_width'] = layerwise_bit_width_fn_weight_mantissa
weight_bit_width_dict['weight_exponent_bit_width'] = layerwise_bit_width_fn_weight_exponent
act_bit_width_dict['act_mantissa_bit_width'] = layerwise_bit_width_fn_act_mantissa
act_bit_width_dict['act_exponent_bit_width'] = layerwise_bit_width_fn_act_exponent
elif quant_format == 'float' and backend != 'layerwise':
elif 'float' in quant_format and backend != 'layerwise':
weight_bit_width_dict['weight_bit_width'] = weight_bit_width
act_bit_width_dict['act_bit_width'] = act_bit_width
weight_bit_width_dict['weight_mantissa_bit_width'] = weight_mantissa_bit_width
Expand Down Expand Up @@ -334,12 +388,12 @@ def kwargs_prefix(prefix, weight_kwargs):
return {prefix + k: v for k, v in weight_kwargs.items()}

weight_bit_width_dict = {'bit_width': weight_bit_width}
if weight_quant_format == 'float':
if 'float' in weight_quant_format:
weight_bit_width_dict['exponent_bit_width'] = weight_exponent_bit_width
weight_bit_width_dict['mantissa_bit_width'] = weight_mantissa_bit_width

act_bit_width_dict = {'bit_width': act_bit_width}
if act_quant_format == 'float':
if 'float' in act_quant_format:
act_bit_width_dict['exponent_bit_width'] = act_exponent_bit_width
act_bit_width_dict['mantissa_bit_width'] = act_mantissa_bit_width

Expand All @@ -355,16 +409,12 @@ def kwargs_prefix(prefix, weight_kwargs):
# Some activations in MHA should always be symmetric
sym_act_quant = INPUT_QUANT_MAP[act_quant_format][act_scale_computation_type][
act_scale_type][act_param_method][act_quant_granularity]['sym']
# Linear layers with 2d input should always be per tensor
per_tensor_act_quant = INPUT_QUANT_MAP[act_quant_format][act_scale_computation_type][
act_scale_type][act_param_method]['per_tensor'][act_quant_type]

act_quant = act_quant.let(**act_bit_width_dict)
sym_act_quant = sym_act_quant.let(**act_bit_width_dict)
per_tensor_act_quant = per_tensor_act_quant.let(**act_bit_width_dict)
else:
act_quant = None
sym_act_quant = None
per_tensor_act_quant = None

# Modify the weight quantizer based on the arguments passed in
weight_quant = weight_quant.let(
Expand All @@ -383,13 +433,6 @@ def kwargs_prefix(prefix, weight_kwargs):
sym_act_quant = sym_act_quant.let(
**{
'high_percentile_q': act_quant_percentile, 'dtype': dtype, 'device': device})
if per_tensor_act_quant is not None:
per_tensor_act_quant = per_tensor_act_quant.let(
**{
'high_percentile_q': act_quant_percentile, 'dtype': dtype, 'device': device})
if act_quant_type == 'asym' and act_quant_percentile is not None:
per_tensor_act_quant = per_tensor_act_quant.let(
**{'low_percentile_q': 100 - act_quant_percentile})

weight_quant_dict = {'weight_quant': weight_quant}

Expand Down Expand Up @@ -431,9 +474,9 @@ def kwargs_prefix(prefix, weight_kwargs):
unsigned_quant_act_kwargs['signed'] = False

# Layerwise is basic quant kwargs + input_quant
layerwise_quant_wbiol_kwargs = {**quant_wbiol_kwargs, 'input_quant': per_tensor_act_quant}
layerwise_quant_wbiol_kwargs = {**quant_wbiol_kwargs, 'input_quant': act_quant}

layerwise_quant_mha_kwargs = {**quant_mha_kwargs, 'in_proj_input_quant': per_tensor_act_quant}
layerwise_quant_mha_kwargs = {**quant_mha_kwargs, 'in_proj_input_quant': act_quant}

quant_layer_map = {
torch.nn.Linear: (qnn.QuantLinear, quant_wbiol_kwargs),
Expand Down Expand Up @@ -526,7 +569,7 @@ def apply_gptq(calib_loader, model, act_order=False):
dtype = next(model.parameters()).dtype
device = next(model.parameters()).device
with torch.no_grad():
with gptq_mode(model, act_order=act_order, use_quant_activations=False) as gptq:
with gptq_mode(model, act_order=act_order, use_quant_activations=True) as gptq:
gptq_model = gptq.model
for i in tqdm(range(gptq.num_layers)):
for i, (images, target) in enumerate(calib_loader):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,12 @@ def parse_type(v, default_type):
parser.add_argument(
'--weight-quant-granularity',
default='per_tensor',
choices=['per_tensor', 'per_channel'],
choices=['per_tensor', 'per_channel', 'per_group'],
help='Weight quantization type (default: per_tensor)')
parser.add_argument(
'--act-quant-granularity',
default='per_tensor',
choices=['per_tensor', 'per_group'],
help='Activation quantization type (default: per_tensor)')
parser.add_argument(
'--weight-quant-calibration-type',
Expand Down Expand Up @@ -168,11 +173,7 @@ def parse_type(v, default_type):
'--export-torch-qcdq',
action='store_true',
help='If true, export the model in torch qcdq format')
add_bool_arg(
parser,
'scaling-per-output-channel',
default=True,
help='Weight scaling per output channel (default: enabled)')

add_bool_arg(
parser, 'bias-corr', default=True, help='Bias correction after calibration (default: enabled)')
add_bool_arg(
Expand All @@ -189,7 +190,7 @@ def parse_type(v, default_type):
parser.add_argument(
'--quant-format',
default='int',
choices=['int', 'float'],
choices=['int', 'float', 'float_ocp'],
help='Quantization format to use for weights and activations (default: int)')
parser.add_argument(
'--layerwise-first-last-mantissa-bit-width',
Expand Down Expand Up @@ -409,6 +410,7 @@ def main():
weight_narrow_range=args.weight_narrow_range,
weight_param_method=args.weight_quant_calibration_type,
weight_quant_granularity=args.weight_quant_granularity,
act_quant_granularity=args.act_quant_granularity,
weight_quant_type=args.weight_quant_type,
layerwise_first_last_bit_width=args.layerwise_first_last_bit_width,
act_bit_width=args.act_bit_width,
Expand Down

0 comments on commit b889bb2

Please sign in to comment.