diff --git a/src/brevitas_examples/imagenet_classification/ptq/benchmark/ptq_benchmark_torchvision.py b/src/brevitas_examples/imagenet_classification/ptq/benchmark/ptq_benchmark_torchvision.py index 8cafef2a4..58af032ef 100644 --- a/src/brevitas_examples/imagenet_classification/ptq/benchmark/ptq_benchmark_torchvision.py +++ b/src/brevitas_examples/imagenet_classification/ptq/benchmark/ptq_benchmark_torchvision.py @@ -50,7 +50,7 @@ 'scale_factor_type': ['float', 'po2'], # Scale factor type 'weight_bit_width': [8, 4], # Weight Bit Width 'act_bit_width': [8, 4], # Act bit width - 'bias_bit_width': ['float', 32, 16], # Bias Bit-Width for Po2 scale + 'bias_bit_width': [None, 32, 16], # Bias Bit-Width for Po2 scale 'weight_quant_granularity': ['per_tensor', 'per_channel'], # Scaling Per Output Channel 'act_quant_type': ['asym', 'sym'], # Act Quant Type 'weight_param_method': ['stats', 'mse'], # Weight Quant Type diff --git a/src/brevitas_examples/imagenet_classification/ptq/ptq_common.py b/src/brevitas_examples/imagenet_classification/ptq/ptq_common.py index 935c8f741..42e976134 100644 --- a/src/brevitas_examples/imagenet_classification/ptq/ptq_common.py +++ b/src/brevitas_examples/imagenet_classification/ptq/ptq_common.py @@ -46,7 +46,7 @@ QUANTIZE_MAP = {'layerwise': layerwise_quantize, 'fx': quantize, 'flexml': quantize_flexml} -BIAS_BIT_WIDTH_MAP = {32: Int32Bias, 16: Int16Bias, 'float': None} +BIAS_BIT_WIDTH_MAP = {32: Int32Bias, 16: Int16Bias, None: None} WEIGHT_QUANT_MAP = { 'float': { diff --git a/src/brevitas_examples/imagenet_classification/ptq/ptq_evaluate.py b/src/brevitas_examples/imagenet_classification/ptq/ptq_evaluate.py index 7a83c8410..319b429b7 100644 --- a/src/brevitas_examples/imagenet_classification/ptq/ptq_evaluate.py +++ b/src/brevitas_examples/imagenet_classification/ptq/ptq_evaluate.py @@ -92,7 +92,7 @@ type=int, help='Input and weights bit width for first and last layer w/ layerwise backend (default: 8)') parser.add_argument( - '--bias-bit-width', default=32, choices=[32, 16, 'float'], help='Bias bit width (default: 32)') + '--bias-bit-width', default=32, choices=[32, 16, None], help='Bias bit width (default: 32)') parser.add_argument( '--act-quant-type', default='sym',