diff --git a/src/brevitas_examples/imagenet_classification/ptq/benchmark/ptq_benchmark_torchvision.py b/src/brevitas_examples/imagenet_classification/ptq/benchmark/ptq_benchmark_torchvision.py index b6ddffa56..8c829ae8b 100644 --- a/src/brevitas_examples/imagenet_classification/ptq/benchmark/ptq_benchmark_torchvision.py +++ b/src/brevitas_examples/imagenet_classification/ptq/benchmark/ptq_benchmark_torchvision.py @@ -55,33 +55,10 @@ def unique(sequence): 'mobilenet_v2': 71.898, 'vit_b_32': 75.912,} -OPTIONS = { - 'model_name': TORCHVISION_TOP1_MAP.keys(), - 'target_backend': ['fx', 'layerwise', 'flexml'], # Target backend - 'scale_factor_type': ['float', 'po2'], # Scale factor type - 'weight_bit_width': [8, 4], # Weight Bit Width - 'act_bit_width': [8, 4], # Act bit width - 'bias_bit_width': [None, 32, 16], # Bias Bit-Width for Po2 scale - 'weight_quant_granularity': ['per_tensor', 'per_channel'], # Scaling Per Output Channel - 'act_quant_type': ['asym', 'sym'], # Act Quant Type - 'weight_param_method': ['stats', 'mse'], # Weight Quant Type - 'act_param_method': ['stats', 'mse'], # Act Param Method - 'bias_corr': [True], # Bias Correction - 'graph_eq_iterations': [0, 20], # Graph Equalization - 'graph_eq_merge_bias': [False, True], # Merge bias for Graph Equalization - 'act_equalization': ['fx', 'layerwise', None], # Perform Activation Equalization (Smoothquant) - 'learned_round': [False, True], # Enable/Disable Learned Round - 'gptq': [False, True], # Enable/Disable GPTQ - 'gptq_act_order': [False, True], # Use act_order euristics for GPTQ - 'gpfq': [False, True], # Enable/Disable GPFQ - 'gpfq_p': [0.25, 0.75], # GPFQ P - 'act_quant_percentile': [99.9, 99.99, 99.999], # Activation Quantization Percentile -} - OPTIONS_DEFAULT = { 'model_name': list(TORCHVISION_TOP1_MAP.keys()), 'quant_format': ['int'], # Quantization type (INT vs Float) - 'target_backend': ['fx'], # Target backend + 'target_backend': ['layerwise'], # Target backend 'scale_factor_type': ['float_scale'], # Scale factor type 'weight_mantissa_bit_width': [4], 'weight_exponent_bit_width': [3], @@ -164,6 +141,7 @@ def ptq_torchvision_models(args): return config_namespace = SimpleNamespace(**configs[args.idx]) + print(config_namespace) fp_accuracy = TORCHVISION_TOP1_MAP[config_namespace.model_name] # Get model-specific configurations about input shapes and normalization @@ -220,6 +198,10 @@ def ptq_torchvision_models(args): quant_format=config_namespace.quant_format, backend=config_namespace.target_backend, act_bit_width=config_namespace.act_bit_width, + weight_mantissa_bit_width=config_namespace.weight_mantissa_bit_width, + weight_exponent_bit_width=config_namespace.weight_exponent_bit_width, + act_mantissa_bit_width=config_namespace.act_mantissa_bit_width, + act_exponent_bit_width=config_namespace.act_exponent_bit_width, weight_bit_width=config_namespace.weight_bit_width, weight_param_method=config_namespace.weight_param_method, act_param_method=config_namespace.act_param_method, @@ -328,6 +310,12 @@ def validate_config(config_namespace): config_namespace.act_quant_type = 'sym' config_namespace.weight_quant_type = 'sym' + if config_namespace.quant_format == 'float': + if config_namespace.weight_exponent_bit_width + config_namespace.weight_mantissa_bit_width != config_namespace.weight_bit_width - 1: + is_valid = False + if config_namespace.act_exponent_bit_width + config_namespace.act_mantissa_bit_width != config_namespace.act_bit_width - 1: + is_valid = False + config_namespace.is_valid = is_valid return config_namespace diff --git a/src/brevitas_examples/imagenet_classification/ptq/benchmark/single_command.sh b/src/brevitas_examples/imagenet_classification/ptq/benchmark/single_command.sh index f662008a8..96daa49fb 100644 --- a/src/brevitas_examples/imagenet_classification/ptq/benchmark/single_command.sh +++ b/src/brevitas_examples/imagenet_classification/ptq/benchmark/single_command.sh @@ -1,4 +1,6 @@ python ptq_benchmark_torchvision.py $1 --calibration-dir /scratch/datasets/imagenet_symlink/calibration --validation-dir /scratch/datasets/imagenet_symlink/val \ +--quant_format float \ +--target_backend layerwise \ --graph_eq_iterations 50 \ --act_param_method stats mse \ --act_quant_percentile 99.9 99.99 diff --git a/src/brevitas_examples/imagenet_classification/ptq/ptq_common.py b/src/brevitas_examples/imagenet_classification/ptq/ptq_common.py index 1fc4508ab..d442f788c 100644 --- a/src/brevitas_examples/imagenet_classification/ptq/ptq_common.py +++ b/src/brevitas_examples/imagenet_classification/ptq/ptq_common.py @@ -18,6 +18,7 @@ from brevitas.graph.quantize import layerwise_quantize from brevitas.graph.quantize import quantize from brevitas.graph.target.flexml import quantize_flexml +from brevitas.inject import value import brevitas.nn as qnn from brevitas.quant.experimental.float import Fp8e4m3Act from brevitas.quant.experimental.float import Fp8e4m3ActPerTensorFloat @@ -142,51 +143,86 @@ def quantize_model( weight_quant_format = quant_format act_quant_format = quant_format - weight_quant_granularity = weight_quant_granularity - - def bit_width_fn(module, first_last_bit_width, other_bit_width): + def layerwise_bit_width_fn(module, base_bit_width, first_last_bit_width): if isinstance(module, torch.nn.Conv2d) and module.in_channels == 3: return first_last_bit_width elif isinstance(module, torch.nn.Linear) and module.out_features == 1000: return first_last_bit_width else: - return other_bit_width - - weight_bit_width_or_lambda = weight_bit_width if backend != 'layerwise' else lambda module: bit_width_fn( - module, layerwise_first_last_bit_width, weight_bit_width) - act_bit_width_or_lambda = act_bit_width if backend != 'layerwise' else lambda module: bit_width_fn( - module, layerwise_first_last_bit_width, act_bit_width) + return base_bit_width + + @value + def layerwise_bit_width_fn_act_exponent(module): + return layerwise_bit_width_fn( + module, act_exponent_bit_width, layerwise_first_last_exponent_bit_width) + + @value + def layerwise_bit_width_fn_act_mantissa(module): + return layerwise_bit_width_fn( + module, act_mantissa_bit_width, layerwise_first_last_mantissa_bit_width) + + @value + def layerwise_bit_width_fn_weight_exponent(module): + return layerwise_bit_width_fn( + module, weight_exponent_bit_width, layerwise_first_last_exponent_bit_width) + + @value + def layerwise_bit_width_fn_weight_mantissa(module): + return layerwise_bit_width_fn( + module, weight_mantissa_bit_width, layerwise_first_last_mantissa_bit_width) + + @value + def layerwise_bit_width_fn_act(module): + return layerwise_bit_width_fn(module, act_bit_width, layerwise_first_last_bit_width) + + @value + def layerwise_bit_width_fn_weight(module): + return layerwise_bit_width_fn(module, weight_bit_width, layerwise_first_last_bit_width) + + # Missing fix for backend =! layerwise + # Missing fix for name_shadowing for all variables + weight_bit_width_dict = {} + act_bit_width_dict = {} + if weight_quant_format == 'int' and backend == 'layerwise': + weight_bit_width_dict['weight_bit_width'] = layerwise_bit_width_fn_weight + act_bit_width_dict['act_bit_width'] = layerwise_bit_width_fn_act - weight_mantissa_bit_width_or_lambda = weight_bit_width if backend != 'layerwise' else lambda module: bit_width_fn( - module, layerwise_first_last_mantissa_bit_width, weight_mantissa_bit_width) - weight_bit_exponent_width_or_lambda = weight_bit_width if backend != 'layerwise' else lambda module: bit_width_fn( - module, layerwise_first_last_exponent_bit_width, weight_exponent_bit_width) + else: + weight_bit_width_dict['weight_bit_width'] = weight_bit_width + act_bit_width_dict['act_bit_width'] = act_bit_width + + if weight_quant_format == 'float' and backend == 'layerwise': + weight_bit_width_dict['weight_bit_width'] = layerwise_bit_width_fn_weight + act_bit_width_dict['act_bit_width'] = layerwise_bit_width_fn_act + weight_bit_width_dict['weight_mantissa_bit_width'] = layerwise_bit_width_fn_weight_mantissa + weight_bit_width_dict['weight_exponent_bit_width'] = layerwise_bit_width_fn_weight_exponent + act_bit_width_dict['act_mantissa_bit_width'] = layerwise_bit_width_fn_act_mantissa + act_bit_width_dict['act_exponent_bit_width'] = layerwise_bit_width_fn_act_exponent + else: + weight_bit_width_dict['weight_bit_width'] = weight_bit_width + act_bit_width_dict['act_bit_width'] = act_bit_width + weight_bit_width_dict['weight_mantissa_bit_width'] = weight_mantissa_bit_width + weight_bit_width_dict['weight_exponent_bit_width'] = weight_exponent_bit_width + act_bit_width_dict['act_mantissa_bit_width'] = act_mantissa_bit_width + act_bit_width_dict['act_exponent_bit_width'] = act_exponent_bit_width - act_bit_mantissa_width_or_lambda = act_bit_width if backend != 'layerwise' else lambda module: bit_width_fn( - module, layerwise_first_last_mantissa_bit_width, act_mantissa_bit_width) - act_bit_exponent_width_or_lambda = act_bit_width if backend != 'layerwise' else lambda module: bit_width_fn( - module, layerwise_first_last_exponent_bit_width, act_exponent_bit_width) quant_layer_map, quant_layerwise_layer_map, quant_act_map, quant_identity_map = create_quant_maps(dtype=dtype, bias_bit_width=bias_bit_width, - weight_bit_width=weight_bit_width_or_lambda, weight_param_method=weight_param_method, weight_scale_type=weight_scale_type, weight_quant_type=weight_quant_type, weight_quant_granularity=weight_quant_granularity, weight_narrow_range=weight_narrow_range, weight_quant_format=weight_quant_format, - weight_mantissa_bit_width=weight_mantissa_bit_width_or_lambda, - weight_exponent_bit_width=weight_bit_exponent_width_or_lambda, - act_mantissa_bit_width=act_bit_mantissa_width_or_lambda, - act_exponent_bit_width=act_bit_exponent_width_or_lambda, act_quant_format=act_quant_format, - act_bit_width=act_bit_width_or_lambda, act_scale_type=act_scale_type, act_param_method=act_param_method, act_quant_type=act_quant_type, act_quant_granularity=act_quant_granularity, - act_quant_percentile=act_quant_percentile) + act_quant_percentile=act_quant_percentile, + **weight_bit_width_dict, + **act_bit_width_dict) if backend != 'layerwise': # Fx and flexml backend requires three mappings for quantization @@ -212,11 +248,11 @@ def create_quant_maps( weight_quant_granularity, weight_narrow_range, weight_quant_format, - weight_mantissa_bit_width, - weight_exponent_bit_width, - act_mantissa_bit_width, - act_exponent_bit_width, act_quant_format, + weight_mantissa_bit_width=None, + weight_exponent_bit_width=None, + act_mantissa_bit_width=None, + act_exponent_bit_width=None, act_bit_width=None, act_scale_type=None, act_param_method=None, @@ -230,25 +266,24 @@ def create_quant_maps( def kwargs_prefix(prefix, weight_kwargs): return {prefix + k: v for k, v in weight_kwargs.items()} + weight_bit_width_dict = {'bit_width': weight_bit_width} if weight_quant_format == 'float': - weight_float_format = { + weight_bit_width_dict = { 'exponent_bit_width': weight_exponent_bit_width, 'mantissa_bit_width': weight_mantissa_bit_width} - else: - weight_float_format = {} + act_bit_width_dict = {'bit_width': act_bit_width} if act_quant_format == 'float': - act_float_format = { + act_bit_width_dict = { 'exponent_bit_width': act_exponent_bit_width, 'mantissa_bit_width': act_mantissa_bit_width} - else: - act_float_format = {} # Retrieve base input, weight, and bias quantizers bias_quant = BIAS_BIT_WIDTH_MAP[bias_bit_width] weight_quant = WEIGHT_QUANT_MAP[weight_quant_format][weight_scale_type][weight_param_method][ weight_quant_granularity][weight_quant_type] - weight_quant = weight_quant.let(**weight_float_format) + weight_quant = weight_quant.let(**weight_bit_width_dict) + if act_bit_width is not None: act_quant = INPUT_QUANT_MAP[act_quant_format][act_scale_type][act_param_method][ act_quant_granularity][act_quant_type] @@ -258,10 +293,9 @@ def kwargs_prefix(prefix, weight_kwargs): # Linear layers with 2d input should always be per tensor per_tensor_act_quant = INPUT_QUANT_MAP[act_quant_format][act_scale_type][act_param_method][ 'per_tensor'][act_quant_type] - act_quant = act_quant.let(**act_float_format) - sym_act_quant = sym_act_quant.let(**act_float_format) - per_tensor_act_quant = per_tensor_act_quant.let(**act_float_format) - + act_quant = act_quant.let(**act_bit_width_dict) + sym_act_quant = sym_act_quant.let(**act_bit_width_dict) + per_tensor_act_quant = per_tensor_act_quant.let(**act_bit_width_dict) else: act_quant = None sym_act_quant = None @@ -290,32 +324,23 @@ def kwargs_prefix(prefix, weight_kwargs): per_tensor_act_quant = per_tensor_act_quant.let( **{'low_percentile_q': 100 - act_quant_percentile}) - weight_quant_and_bit_width = { - 'weight_quant': weight_quant, 'weight_bit_width': weight_bit_width} + weight_quant_dict = {'weight_quant': weight_quant} quant_wbiol_kwargs = { - **weight_quant_and_bit_width, - 'dtype': dtype, - 'return_quant_tensor': False, - 'bias_quant': bias_quant} + **weight_quant_dict, 'dtype': dtype, 'return_quant_tensor': False, 'bias_quant': bias_quant} # yapf: disable quant_mha_kwargs = { - **kwargs_prefix('in_proj_', weight_quant_and_bit_width), - **kwargs_prefix('out_proj_', weight_quant_and_bit_width), + **kwargs_prefix('in_proj_', weight_quant_dict), + **kwargs_prefix('out_proj_', weight_quant_dict), 'in_proj_bias_quant': bias_quant, 'softmax_input_quant': None, 'attn_output_weights_quant': sym_act_quant, - 'attn_output_weights_bit_width': act_bit_width, 'attn_output_weights_signed': False, 'q_scaled_quant': sym_act_quant, - 'q_scaled_bit_width': act_bit_width, 'k_transposed_quant': sym_act_quant, - 'k_transposed_bit_width': act_bit_width, 'v_quant': sym_act_quant, - 'v_bit_width': act_bit_width, 'out_proj_input_quant': act_quant, - 'out_proj_input_bit_width': act_bit_width, 'out_proj_bias_quant': bias_quant, 'out_proj_output_quant': None, # activation equalization requires packed_in_proj @@ -327,12 +352,10 @@ def kwargs_prefix(prefix, weight_kwargs): # Layerwise is basic quant kwargs + input_quant layerwise_quant_wbiol_kwargs = { - **quant_wbiol_kwargs, 'input_quant': per_tensor_act_quant, 'input_bit_width': act_bit_width} + **quant_wbiol_kwargs, 'input_quant': per_tensor_act_quant + } #, **kwargs_prefix('input_', act_float_format)}#'input_mantissa_bit_width': act_float_format['mantissa_bit_width'], 'input_exponent_bit_width': act_float_format['exponent_bit_width']} - layerwise_quant_mha_kwargs = { - **quant_mha_kwargs, - 'in_proj_input_quant': per_tensor_act_quant, - 'in_proj_input_bit_width': act_bit_width} + layerwise_quant_mha_kwargs = {**quant_mha_kwargs, 'in_proj_input_quant': per_tensor_act_quant} quant_layer_map = { torch.nn.Linear: (qnn.QuantLinear, quant_wbiol_kwargs),