From 14e20802b78fe322906d875c5f5eb5ab72fa8dfc Mon Sep 17 00:00:00 2001 From: Giuseppe Franco Date: Tue, 28 May 2024 09:52:50 +0200 Subject: [PATCH] Fix (core/quant/float): use eps to avoid log(0) (#957) --- src/brevitas/core/quant/float.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/src/brevitas/core/quant/float.py b/src/brevitas/core/quant/float.py index 371c5551c..65af11e61 100644 --- a/src/brevitas/core/quant/float.py +++ b/src/brevitas/core/quant/float.py @@ -59,9 +59,14 @@ def __init__( self.scaling_impl = scaling_impl self.float_clamp_impl = float_clamp_impl + # To avoid log(0), we add small a small value based on the used dtype + if dtype is None: + dtype = torch.get_default_dtype() + self.eps = torch.finfo(dtype).tiny + @brevitas.jit.script_method def internal_scale(self, x): - internal_scale = floor_ste(torch.log2(torch.abs(x))) - self.mantissa_bit_width() + internal_scale = floor_ste(torch.log2(torch.abs(x) + self.eps)) - self.mantissa_bit_width() internal_scale = torch.clamp_min(internal_scale, self.fp_internal_scale_min()) internal_scale = torch.exp2(internal_scale) return internal_scale