diff --git a/src/brevitas/core/quant/float.py b/src/brevitas/core/quant/float.py index 71f518bb5..195d42a96 100644 --- a/src/brevitas/core/quant/float.py +++ b/src/brevitas/core/quant/float.py @@ -65,10 +65,13 @@ def __init__( @brevitas.jit.script_method def quantize(self, x: torch.Tensor): - scaling_impl_value = self.scaling_impl(x) - float_scaling_impl_value = self.float_scaling_impl( - self.exponent_bit_width(), self.mantissa_bit_width(), self.exponent_bias()) - scale = scaling_impl_value / float_scaling_impl_value + scale = self.scaling_impl(x) + + if self.float_scaling_impl is not None: + float_scaling_impl_value = self.float_scaling_impl( + self.exponent_bit_width(), self.mantissa_bit_width(), self.exponent_bias()) + scale = scale / float_scaling_impl_value + scaled_x = x / scale internal_scale = float_internal_scale( scaled_x, self.mantissa_bit_width(), self.fp_internal_scale_min(), self.eps)