Skip to content

Commit

Permalink
Ignore jit flag
Browse files Browse the repository at this point in the history
  • Loading branch information
Giuseppe5 authored and nickfraser committed Sep 4, 2024
1 parent 568420f commit 498899f
Show file tree
Hide file tree
Showing 3 changed files with 4 additions and 1 deletion.
2 changes: 1 addition & 1 deletion src/brevitas/function/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,7 +189,7 @@ def min_int(signed: bool, narrow_range: bool, bit_width: Tensor) -> Tensor:
return value


@brevitas.jit.script
@brevitas.jit.ignore
def max_float(exponent_bit_width: Tensor, mantissa_bit_width: Tensor, exponent_bias: Tensor):
max_exponent = (2. ** exponent_bit_width) - 1. - exponent_bias
max_mantissa = torch.sum((
Expand Down
2 changes: 2 additions & 0 deletions src/brevitas/jit.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,14 @@ def _disabled(fn):

script_method = torch.jit.script_method
script = torch.jit.script
ignore = torch.jit.ignore
ScriptModule = torch.jit.ScriptModule
Attribute = torch.jit.Attribute

else:

script_method = _disabled
script = _disabled
ignore = _disabled
ScriptModule = torch.nn.Module
Attribute = lambda val, type: val
1 change: 1 addition & 0 deletions src/brevitas/utils/torch_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,7 @@ def float_internal_scale(
return internal_scale


@brevitas.jit.ignore
def padding(x, group_size, group_dim):
# Given a tensor X, compute the padding aloing group_dim so that groupwise shaping is possible
padding = [0, 0] * len(x.shape)
Expand Down

0 comments on commit 498899f

Please sign in to comment.