diff --git a/src/brevitas_examples/imagenet_classification/ptq/ptq_common.py b/src/brevitas_examples/imagenet_classification/ptq/ptq_common.py index 62752f1d0..9f00f0a62 100644 --- a/src/brevitas_examples/imagenet_classification/ptq/ptq_common.py +++ b/src/brevitas_examples/imagenet_classification/ptq/ptq_common.py @@ -2,6 +2,7 @@ # SPDX-License-Identifier: BSD-3-Clause from copy import deepcopy +import math import torch import torch.backends.cudnn as cudnn @@ -153,6 +154,20 @@ def quantize_model( weight_scale_type = scale_factor_type act_scale_type = scale_factor_type + # We check all of the provided values are positive integers + check_positive_int( + weight_bit_width, + act_bit_width, + bias_bit_width, + layerwise_first_last_bit_width, + layerwise_first_last_mantissa_bit_width, + layerwise_first_last_exponent_bit_width, + weight_mantissa_bit_width, + weight_exponent_bit_width, + act_mantissa_bit_width, + act_exponent_bit_width, + ) + weight_quant_format = quant_format act_quant_format = quant_format @@ -535,3 +550,12 @@ def apply_learned_round_learning( pbar.set_description( "loss = {:.4f}, rec_loss = {:.4f}, round_loss = {:.4f}, b = {:.4f}".format( loss, rec_loss, round_loss, b)) + + +def check_positive_int(*args): + """ + We check that every inputted value is positive, and an integer. + """ + for arg in args: + assert arg > 0.0 + assert math.isclose(arg % 1, 0.0)