Skip to content

Commit

Permalink
Update
Browse files Browse the repository at this point in the history
  • Loading branch information
Giuseppe5 committed Sep 28, 2023
1 parent a46568c commit 58d7a2a
Show file tree
Hide file tree
Showing 3 changed files with 94 additions and 81 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -55,33 +55,10 @@ def unique(sequence):
'mobilenet_v2': 71.898,
'vit_b_32': 75.912,}

OPTIONS = {
'model_name': TORCHVISION_TOP1_MAP.keys(),
'target_backend': ['fx', 'layerwise', 'flexml'], # Target backend
'scale_factor_type': ['float', 'po2'], # Scale factor type
'weight_bit_width': [8, 4], # Weight Bit Width
'act_bit_width': [8, 4], # Act bit width
'bias_bit_width': [None, 32, 16], # Bias Bit-Width for Po2 scale
'weight_quant_granularity': ['per_tensor', 'per_channel'], # Scaling Per Output Channel
'act_quant_type': ['asym', 'sym'], # Act Quant Type
'weight_param_method': ['stats', 'mse'], # Weight Quant Type
'act_param_method': ['stats', 'mse'], # Act Param Method
'bias_corr': [True], # Bias Correction
'graph_eq_iterations': [0, 20], # Graph Equalization
'graph_eq_merge_bias': [False, True], # Merge bias for Graph Equalization
'act_equalization': ['fx', 'layerwise', None], # Perform Activation Equalization (Smoothquant)
'learned_round': [False, True], # Enable/Disable Learned Round
'gptq': [False, True], # Enable/Disable GPTQ
'gptq_act_order': [False, True], # Use act_order euristics for GPTQ
'gpfq': [False, True], # Enable/Disable GPFQ
'gpfq_p': [0.25, 0.75], # GPFQ P
'act_quant_percentile': [99.9, 99.99, 99.999], # Activation Quantization Percentile
}

OPTIONS_DEFAULT = {
'model_name': list(TORCHVISION_TOP1_MAP.keys()),
'quant_format': ['int'], # Quantization type (INT vs Float)
'target_backend': ['fx'], # Target backend
'target_backend': ['layerwise'], # Target backend
'scale_factor_type': ['float_scale'], # Scale factor type
'weight_mantissa_bit_width': [4],
'weight_exponent_bit_width': [3],
Expand Down Expand Up @@ -164,6 +141,7 @@ def ptq_torchvision_models(args):
return

config_namespace = SimpleNamespace(**configs[args.idx])
print(config_namespace)

fp_accuracy = TORCHVISION_TOP1_MAP[config_namespace.model_name]
# Get model-specific configurations about input shapes and normalization
Expand Down Expand Up @@ -220,6 +198,10 @@ def ptq_torchvision_models(args):
quant_format=config_namespace.quant_format,
backend=config_namespace.target_backend,
act_bit_width=config_namespace.act_bit_width,
weight_mantissa_bit_width=config_namespace.weight_mantissa_bit_width,
weight_exponent_bit_width=config_namespace.weight_exponent_bit_width,
act_mantissa_bit_width=config_namespace.act_mantissa_bit_width,
act_exponent_bit_width=config_namespace.act_exponent_bit_width,
weight_bit_width=config_namespace.weight_bit_width,
weight_param_method=config_namespace.weight_param_method,
act_param_method=config_namespace.act_param_method,
Expand Down Expand Up @@ -328,6 +310,12 @@ def validate_config(config_namespace):
config_namespace.act_quant_type = 'sym'
config_namespace.weight_quant_type = 'sym'

if config_namespace.quant_format == 'float':
if config_namespace.weight_exponent_bit_width + config_namespace.weight_mantissa_bit_width != config_namespace.weight_bit_width - 1:
is_valid = False
if config_namespace.act_exponent_bit_width + config_namespace.act_mantissa_bit_width != config_namespace.act_bit_width - 1:
is_valid = False

config_namespace.is_valid = is_valid
return config_namespace

Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
python ptq_benchmark_torchvision.py $1 --calibration-dir /scratch/datasets/imagenet_symlink/calibration --validation-dir /scratch/datasets/imagenet_symlink/val \
--quant_format float \
--target_backend layerwise \
--graph_eq_iterations 50 \
--act_param_method stats mse \
--act_quant_percentile 99.9 99.99
137 changes: 80 additions & 57 deletions src/brevitas_examples/imagenet_classification/ptq/ptq_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from brevitas.graph.quantize import layerwise_quantize
from brevitas.graph.quantize import quantize
from brevitas.graph.target.flexml import quantize_flexml
from brevitas.inject import value
import brevitas.nn as qnn
from brevitas.quant.experimental.float import Fp8e4m3Act
from brevitas.quant.experimental.float import Fp8e4m3ActPerTensorFloat
Expand Down Expand Up @@ -142,51 +143,86 @@ def quantize_model(
weight_quant_format = quant_format
act_quant_format = quant_format

weight_quant_granularity = weight_quant_granularity

def bit_width_fn(module, first_last_bit_width, other_bit_width):
def layerwise_bit_width_fn(module, base_bit_width, first_last_bit_width):
if isinstance(module, torch.nn.Conv2d) and module.in_channels == 3:
return first_last_bit_width
elif isinstance(module, torch.nn.Linear) and module.out_features == 1000:
return first_last_bit_width
else:
return other_bit_width

weight_bit_width_or_lambda = weight_bit_width if backend != 'layerwise' else lambda module: bit_width_fn(
module, layerwise_first_last_bit_width, weight_bit_width)
act_bit_width_or_lambda = act_bit_width if backend != 'layerwise' else lambda module: bit_width_fn(
module, layerwise_first_last_bit_width, act_bit_width)
return base_bit_width

@value
def layerwise_bit_width_fn_act_exponent(module):
return layerwise_bit_width_fn(
module, act_exponent_bit_width, layerwise_first_last_exponent_bit_width)

@value
def layerwise_bit_width_fn_act_mantissa(module):
return layerwise_bit_width_fn(
module, act_mantissa_bit_width, layerwise_first_last_mantissa_bit_width)

@value
def layerwise_bit_width_fn_weight_exponent(module):
return layerwise_bit_width_fn(
module, weight_exponent_bit_width, layerwise_first_last_exponent_bit_width)

@value
def layerwise_bit_width_fn_weight_mantissa(module):
return layerwise_bit_width_fn(
module, weight_mantissa_bit_width, layerwise_first_last_mantissa_bit_width)

@value
def layerwise_bit_width_fn_act(module):
return layerwise_bit_width_fn(module, act_bit_width, layerwise_first_last_bit_width)

@value
def layerwise_bit_width_fn_weight(module):
return layerwise_bit_width_fn(module, weight_bit_width, layerwise_first_last_bit_width)

# Missing fix for backend =! layerwise
# Missing fix for name_shadowing for all variables
weight_bit_width_dict = {}
act_bit_width_dict = {}
if weight_quant_format == 'int' and backend == 'layerwise':
weight_bit_width_dict['weight_bit_width'] = layerwise_bit_width_fn_weight
act_bit_width_dict['act_bit_width'] = layerwise_bit_width_fn_act

weight_mantissa_bit_width_or_lambda = weight_bit_width if backend != 'layerwise' else lambda module: bit_width_fn(
module, layerwise_first_last_mantissa_bit_width, weight_mantissa_bit_width)
weight_bit_exponent_width_or_lambda = weight_bit_width if backend != 'layerwise' else lambda module: bit_width_fn(
module, layerwise_first_last_exponent_bit_width, weight_exponent_bit_width)
else:
weight_bit_width_dict['weight_bit_width'] = weight_bit_width
act_bit_width_dict['act_bit_width'] = act_bit_width

if weight_quant_format == 'float' and backend == 'layerwise':
weight_bit_width_dict['weight_bit_width'] = layerwise_bit_width_fn_weight
act_bit_width_dict['act_bit_width'] = layerwise_bit_width_fn_act
weight_bit_width_dict['weight_mantissa_bit_width'] = layerwise_bit_width_fn_weight_mantissa
weight_bit_width_dict['weight_exponent_bit_width'] = layerwise_bit_width_fn_weight_exponent
act_bit_width_dict['act_mantissa_bit_width'] = layerwise_bit_width_fn_act_mantissa
act_bit_width_dict['act_exponent_bit_width'] = layerwise_bit_width_fn_act_exponent
else:
weight_bit_width_dict['weight_bit_width'] = weight_bit_width
act_bit_width_dict['act_bit_width'] = act_bit_width
weight_bit_width_dict['weight_mantissa_bit_width'] = weight_mantissa_bit_width
weight_bit_width_dict['weight_exponent_bit_width'] = weight_exponent_bit_width
act_bit_width_dict['act_mantissa_bit_width'] = act_mantissa_bit_width
act_bit_width_dict['act_exponent_bit_width'] = act_exponent_bit_width

act_bit_mantissa_width_or_lambda = act_bit_width if backend != 'layerwise' else lambda module: bit_width_fn(
module, layerwise_first_last_mantissa_bit_width, act_mantissa_bit_width)
act_bit_exponent_width_or_lambda = act_bit_width if backend != 'layerwise' else lambda module: bit_width_fn(
module, layerwise_first_last_exponent_bit_width, act_exponent_bit_width)

quant_layer_map, quant_layerwise_layer_map, quant_act_map, quant_identity_map = create_quant_maps(dtype=dtype,
bias_bit_width=bias_bit_width,
weight_bit_width=weight_bit_width_or_lambda,
weight_param_method=weight_param_method,
weight_scale_type=weight_scale_type,
weight_quant_type=weight_quant_type,
weight_quant_granularity=weight_quant_granularity,
weight_narrow_range=weight_narrow_range,
weight_quant_format=weight_quant_format,
weight_mantissa_bit_width=weight_mantissa_bit_width_or_lambda,
weight_exponent_bit_width=weight_bit_exponent_width_or_lambda,
act_mantissa_bit_width=act_bit_mantissa_width_or_lambda,
act_exponent_bit_width=act_bit_exponent_width_or_lambda,
act_quant_format=act_quant_format,
act_bit_width=act_bit_width_or_lambda,
act_scale_type=act_scale_type,
act_param_method=act_param_method,
act_quant_type=act_quant_type,
act_quant_granularity=act_quant_granularity,
act_quant_percentile=act_quant_percentile)
act_quant_percentile=act_quant_percentile,
**weight_bit_width_dict,
**act_bit_width_dict)

if backend != 'layerwise':
# Fx and flexml backend requires three mappings for quantization
Expand All @@ -212,11 +248,11 @@ def create_quant_maps(
weight_quant_granularity,
weight_narrow_range,
weight_quant_format,
weight_mantissa_bit_width,
weight_exponent_bit_width,
act_mantissa_bit_width,
act_exponent_bit_width,
act_quant_format,
weight_mantissa_bit_width=None,
weight_exponent_bit_width=None,
act_mantissa_bit_width=None,
act_exponent_bit_width=None,
act_bit_width=None,
act_scale_type=None,
act_param_method=None,
Expand All @@ -230,25 +266,24 @@ def create_quant_maps(
def kwargs_prefix(prefix, weight_kwargs):
return {prefix + k: v for k, v in weight_kwargs.items()}

weight_bit_width_dict = {'bit_width': weight_bit_width}
if weight_quant_format == 'float':
weight_float_format = {
weight_bit_width_dict = {
'exponent_bit_width': weight_exponent_bit_width,
'mantissa_bit_width': weight_mantissa_bit_width}
else:
weight_float_format = {}

act_bit_width_dict = {'bit_width': act_bit_width}
if act_quant_format == 'float':
act_float_format = {
act_bit_width_dict = {
'exponent_bit_width': act_exponent_bit_width,
'mantissa_bit_width': act_mantissa_bit_width}
else:
act_float_format = {}

# Retrieve base input, weight, and bias quantizers
bias_quant = BIAS_BIT_WIDTH_MAP[bias_bit_width]
weight_quant = WEIGHT_QUANT_MAP[weight_quant_format][weight_scale_type][weight_param_method][
weight_quant_granularity][weight_quant_type]
weight_quant = weight_quant.let(**weight_float_format)
weight_quant = weight_quant.let(**weight_bit_width_dict)

if act_bit_width is not None:
act_quant = INPUT_QUANT_MAP[act_quant_format][act_scale_type][act_param_method][
act_quant_granularity][act_quant_type]
Expand All @@ -258,10 +293,9 @@ def kwargs_prefix(prefix, weight_kwargs):
# Linear layers with 2d input should always be per tensor
per_tensor_act_quant = INPUT_QUANT_MAP[act_quant_format][act_scale_type][act_param_method][
'per_tensor'][act_quant_type]
act_quant = act_quant.let(**act_float_format)
sym_act_quant = sym_act_quant.let(**act_float_format)
per_tensor_act_quant = per_tensor_act_quant.let(**act_float_format)

act_quant = act_quant.let(**act_bit_width_dict)
sym_act_quant = sym_act_quant.let(**act_bit_width_dict)
per_tensor_act_quant = per_tensor_act_quant.let(**act_bit_width_dict)
else:
act_quant = None
sym_act_quant = None
Expand Down Expand Up @@ -290,32 +324,23 @@ def kwargs_prefix(prefix, weight_kwargs):
per_tensor_act_quant = per_tensor_act_quant.let(
**{'low_percentile_q': 100 - act_quant_percentile})

weight_quant_and_bit_width = {
'weight_quant': weight_quant, 'weight_bit_width': weight_bit_width}
weight_quant_dict = {'weight_quant': weight_quant}

quant_wbiol_kwargs = {
**weight_quant_and_bit_width,
'dtype': dtype,
'return_quant_tensor': False,
'bias_quant': bias_quant}
**weight_quant_dict, 'dtype': dtype, 'return_quant_tensor': False, 'bias_quant': bias_quant}

# yapf: disable
quant_mha_kwargs = {
**kwargs_prefix('in_proj_', weight_quant_and_bit_width),
**kwargs_prefix('out_proj_', weight_quant_and_bit_width),
**kwargs_prefix('in_proj_', weight_quant_dict),
**kwargs_prefix('out_proj_', weight_quant_dict),
'in_proj_bias_quant': bias_quant,
'softmax_input_quant': None,
'attn_output_weights_quant': sym_act_quant,
'attn_output_weights_bit_width': act_bit_width,
'attn_output_weights_signed': False,
'q_scaled_quant': sym_act_quant,
'q_scaled_bit_width': act_bit_width,
'k_transposed_quant': sym_act_quant,
'k_transposed_bit_width': act_bit_width,
'v_quant': sym_act_quant,
'v_bit_width': act_bit_width,
'out_proj_input_quant': act_quant,
'out_proj_input_bit_width': act_bit_width,
'out_proj_bias_quant': bias_quant,
'out_proj_output_quant': None,
# activation equalization requires packed_in_proj
Expand All @@ -327,12 +352,10 @@ def kwargs_prefix(prefix, weight_kwargs):

# Layerwise is basic quant kwargs + input_quant
layerwise_quant_wbiol_kwargs = {
**quant_wbiol_kwargs, 'input_quant': per_tensor_act_quant, 'input_bit_width': act_bit_width}
**quant_wbiol_kwargs, 'input_quant': per_tensor_act_quant
} #, **kwargs_prefix('input_', act_float_format)}#'input_mantissa_bit_width': act_float_format['mantissa_bit_width'], 'input_exponent_bit_width': act_float_format['exponent_bit_width']}

layerwise_quant_mha_kwargs = {
**quant_mha_kwargs,
'in_proj_input_quant': per_tensor_act_quant,
'in_proj_input_bit_width': act_bit_width}
layerwise_quant_mha_kwargs = {**quant_mha_kwargs, 'in_proj_input_quant': per_tensor_act_quant}

quant_layer_map = {
torch.nn.Linear: (qnn.QuantLinear, quant_wbiol_kwargs),
Expand Down

0 comments on commit 58d7a2a

Please sign in to comment.