Skip to content

Commit

Permalink
Fix (quant/float): restore fix to avoid log(0) (#968)
Browse files Browse the repository at this point in the history
  • Loading branch information
Giuseppe5 authored Jun 11, 2024
1 parent 02f5b6b commit 6bdb1f8
Show file tree
Hide file tree
Showing 4 changed files with 17 additions and 8 deletions.
4 changes: 2 additions & 2 deletions src/brevitas/core/quant/float.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@


class FloatQuant(brevitas.jit.ScriptModule):
__constants__ = ['signed']
__constants__ = ['signed', 'eps']

def __init__(
self,
Expand Down Expand Up @@ -71,7 +71,7 @@ def quantize(self, x: torch.Tensor):
scale = scaling_impl_value / float_scaling_impl_value
scaled_x = x / scale
internal_scale = float_internal_scale(
scaled_x, self.mantissa_bit_width(), self.fp_internal_scale_min())
scaled_x, self.mantissa_bit_width(), self.fp_internal_scale_min(), self.eps)
val_fp_quant = internal_scale * self.float_to_int_impl(scaled_x / internal_scale)
return val_fp_quant, scale

Expand Down
10 changes: 8 additions & 2 deletions src/brevitas/quant_tensor/float_quant_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,10 @@ def training(self):
def saturating(self):
return self.saturating_t.item()

@property
def eps(self):
return torch.finfo(self.scale.dtype).tiny

def __torch_function__(self, func, types, args=(), kwargs=None):
if kwargs is None:
kwargs = {}
Expand All @@ -99,7 +103,8 @@ def _pre_round_float_value(self):
scale = self.scale.type(torch.float32)
minifloat_value = value / scale
fp_internal_scale = 1. - self.exponent_bias - self.mantissa_bit_width
int_scale = float_internal_scale(self.value, self.mantissa_bit_width, fp_internal_scale)
int_scale = float_internal_scale(
self.value, self.mantissa_bit_width, fp_internal_scale, self.eps)
minifloat_value = minifloat_value / int_scale
return minifloat_value

Expand Down Expand Up @@ -135,7 +140,8 @@ def minifloat(self, float_datatype=True):

if self.is_valid:
fp_internal_scale = 1. - self.exponent_bias - self.mantissa_bit_width
int_scale = float_internal_scale(self.value, self.mantissa_bit_width, fp_internal_scale)
int_scale = float_internal_scale(
self.value, self.mantissa_bit_width, fp_internal_scale, self.eps)
float_value = torch.round(self._pre_round_float_value) * int_scale
return float_value.type(self.scale.dtype)
else:
Expand Down
8 changes: 5 additions & 3 deletions src/brevitas/utils/torch_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,10 +93,12 @@ def compute_channel_view_shape(tensor: torch.Tensor, channel_dim: int):

@brevitas.jit.script
def float_internal_scale(
x: torch.Tensor, mantissa_bit_width: torch.Tensor,
fp_internal_scale_min: torch.Tensor) -> torch.Tensor:
x: torch.Tensor,
mantissa_bit_width: torch.Tensor,
fp_internal_scale_min: torch.Tensor,
eps: float) -> torch.Tensor:

internal_scale = floor_ste(torch.log2(torch.abs(x))) - mantissa_bit_width
internal_scale = floor_ste(torch.log2(torch.abs(x) + eps)) - mantissa_bit_width
internal_scale = torch.clamp_min(internal_scale, fp_internal_scale_min)
internal_scale = torch.exp2(internal_scale)
return internal_scale
3 changes: 2 additions & 1 deletion tests/brevitas/core/test_float_quant.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,8 +193,9 @@ def test_inner_scale(inp, minifloat_format, scale):
max_value = max_val if max_available_float is None else torch.min(
max_value, max_available_float)
# call internal scale
eps = torch.finfo(inp.dtype).tiny
internal_scale = float_internal_scale(
scaled_inp, float_quant.mantissa_bit_width(), float_quant.fp_internal_scale_min())
scaled_inp, float_quant.mantissa_bit_width(), float_quant.fp_internal_scale_min(), eps)
val_fp_quant = internal_scale * float_quant.float_to_int_impl(scaled_inp / internal_scale)
if signed:
val_fp_quant = torch.clip(val_fp_quant, -1. * max_val, max_val)
Expand Down

0 comments on commit 6bdb1f8

Please sign in to comment.