diff --git a/src/brevitas/function/ops.py b/src/brevitas/function/ops.py index 6751ab69c..74da08e19 100644 --- a/src/brevitas/function/ops.py +++ b/src/brevitas/function/ops.py @@ -189,7 +189,7 @@ def min_int(signed: bool, narrow_range: bool, bit_width: Tensor) -> Tensor: return value -@brevitas.jit.script +@brevitas.jit.ignore def max_float(exponent_bit_width: Tensor, mantissa_bit_width: Tensor, exponent_bias: Tensor): max_exponent = (2. ** exponent_bit_width) - 1. - exponent_bias max_mantissa = torch.sum(( diff --git a/src/brevitas/jit.py b/src/brevitas/jit.py index 0719e1017..6acf43728 100644 --- a/src/brevitas/jit.py +++ b/src/brevitas/jit.py @@ -14,6 +14,7 @@ def _disabled(fn): script_method = torch.jit.script_method script = torch.jit.script + ignore = torch.jit.ignore ScriptModule = torch.jit.ScriptModule Attribute = torch.jit.Attribute @@ -21,5 +22,6 @@ def _disabled(fn): script_method = _disabled script = _disabled + ignore = _disabled ScriptModule = torch.nn.Module Attribute = lambda val, type: val diff --git a/src/brevitas/utils/torch_utils.py b/src/brevitas/utils/torch_utils.py index a2aed48d9..225cacaac 100644 --- a/src/brevitas/utils/torch_utils.py +++ b/src/brevitas/utils/torch_utils.py @@ -104,6 +104,7 @@ def float_internal_scale( return internal_scale +@brevitas.jit.ignore def padding(x, group_size, group_dim): # Given a tensor X, compute the padding aloing group_dim so that groupwise shaping is possible padding = [0, 0] * len(x.shape)