diff --git a/src/brevitas/function/ops.py b/src/brevitas/function/ops.py index 67b57df6a..dddb60b83 100644 --- a/src/brevitas/function/ops.py +++ b/src/brevitas/function/ops.py @@ -193,12 +193,16 @@ def max_mantissa_func(val): return torch.sum((2. ** torch.arange(0, -1. * val - 1., -1.))) -MAX_MANTISSA_DICT = {x: max_mantissa_func(x) for x in range(0, 16)} - - +@brevitas.jit.ignore def max_float(exponent_bit_width: Tensor, mantissa_bit_width: Tensor, exponent_bias: Tensor): max_exponent = (2. ** exponent_bit_width) - 1. - exponent_bias - max_mantissa = MAX_MANTISSA_DICT[mantissa_bit_width.item()] + max_mantissa = torch.sum(( + 2. ** torch.arange( + 0, + -1. * mantissa_bit_width - 1., + -1., + dtype=mantissa_bit_width.dtype, + device=mantissa_bit_width.device))) max_val = max_mantissa * (2 ** max_exponent) return max_val