Skip to content

Commit

Permalink
Fix (examples/imagenet): Assert all bit widths are positive integers (#…
Browse files Browse the repository at this point in the history
  • Loading branch information
OscarSavolainenDR authored Apr 10, 2024
1 parent 3014537 commit a8159f9
Showing 1 changed file with 24 additions and 0 deletions.
24 changes: 24 additions & 0 deletions src/brevitas_examples/imagenet_classification/ptq/ptq_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
# SPDX-License-Identifier: BSD-3-Clause

from copy import deepcopy
import math

import torch
import torch.backends.cudnn as cudnn
Expand Down Expand Up @@ -153,6 +154,20 @@ def quantize_model(
weight_scale_type = scale_factor_type
act_scale_type = scale_factor_type

# We check all of the provided values are positive integers
check_positive_int(
weight_bit_width,
act_bit_width,
bias_bit_width,
layerwise_first_last_bit_width,
layerwise_first_last_mantissa_bit_width,
layerwise_first_last_exponent_bit_width,
weight_mantissa_bit_width,
weight_exponent_bit_width,
act_mantissa_bit_width,
act_exponent_bit_width,
)

weight_quant_format = quant_format
act_quant_format = quant_format

Expand Down Expand Up @@ -535,3 +550,12 @@ def apply_learned_round_learning(
pbar.set_description(
"loss = {:.4f}, rec_loss = {:.4f}, round_loss = {:.4f}, b = {:.4f}".format(
loss, rec_loss, round_loss, b))


def check_positive_int(*args):
"""
We check that every inputted value is positive, and an integer.
"""
for arg in args:
assert arg > 0.0
assert math.isclose(arg % 1, 0.0)

0 comments on commit a8159f9

Please sign in to comment.