From df1a137aa089434e2bccb1cb87c9cb74c610d7a8 Mon Sep 17 00:00:00 2001 From: Giuseppe Franco Date: Wed, 19 Jun 2024 11:23:38 +0200 Subject: [PATCH] Fix (core/float): add default for float_scaling_impl (#972) --- src/brevitas/core/quant/float.py | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/src/brevitas/core/quant/float.py b/src/brevitas/core/quant/float.py index 71f518bb5..195d42a96 100644 --- a/src/brevitas/core/quant/float.py +++ b/src/brevitas/core/quant/float.py @@ -65,10 +65,13 @@ def __init__( @brevitas.jit.script_method def quantize(self, x: torch.Tensor): - scaling_impl_value = self.scaling_impl(x) - float_scaling_impl_value = self.float_scaling_impl( - self.exponent_bit_width(), self.mantissa_bit_width(), self.exponent_bias()) - scale = scaling_impl_value / float_scaling_impl_value + scale = self.scaling_impl(x) + + if self.float_scaling_impl is not None: + float_scaling_impl_value = self.float_scaling_impl( + self.exponent_bit_width(), self.mantissa_bit_width(), self.exponent_bias()) + scale = scale / float_scaling_impl_value + scaled_x = x / scale internal_scale = float_internal_scale( scaled_x, self.mantissa_bit_width(), self.fp_internal_scale_min(), self.eps)