From 8a9196b4c48b64dd45d613b41215b1aa6f237a2e Mon Sep 17 00:00:00 2001 From: Giuseppe Franco Date: Thu, 23 May 2024 12:15:06 +0100 Subject: [PATCH] Fix (core/quant/float): use eps to avoid log(0) --- src/brevitas/core/quant/float.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/src/brevitas/core/quant/float.py b/src/brevitas/core/quant/float.py index 371c5551c..12862f3fd 100644 --- a/src/brevitas/core/quant/float.py +++ b/src/brevitas/core/quant/float.py @@ -59,9 +59,12 @@ def __init__( self.scaling_impl = scaling_impl self.float_clamp_impl = float_clamp_impl + # To avoid log(0), we add small a small value based on the used dtype + self.eps = torch.finfo(dtype).tiny + @brevitas.jit.script_method def internal_scale(self, x): - internal_scale = floor_ste(torch.log2(torch.abs(x))) - self.mantissa_bit_width() + internal_scale = floor_ste(torch.log2(torch.abs(x) + self.eps)) - self.mantissa_bit_width() internal_scale = torch.clamp_min(internal_scale, self.fp_internal_scale_min()) internal_scale = torch.exp2(internal_scale) return internal_scale