Skip to content

Commit

Permalink
Fix JIT compatibility
Browse files Browse the repository at this point in the history
  • Loading branch information
Giuseppe5 committed Apr 9, 2024
1 parent 839991e commit 02c7985
Show file tree
Hide file tree
Showing 7 changed files with 131 additions and 87 deletions.
25 changes: 15 additions & 10 deletions src/brevitas/core/function_wrapper/clamp.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,16 +95,25 @@ def __init__(
signed: bool,
inf_values: Optional[Tuple[str]] = None,
nan_values: Optional[Tuple[str]] = None,
saturating: bool = True) -> None:
max_available_float: Optional[Tensor] = None,
saturating: bool = True,
device: Optional[str] = None,
dtype: Optional[torch.dtype] = None) -> None:
super(FloatClamp, self).__init__()

self.tensor_clamp_impl = tensor_clamp_impl
self.nan_values = nan_values
self.inf_values = inf_values
self.saturating = saturating
self.inf_values = inf_values
self.nan_values = nan_values
self.signed = signed
self.has_inf_values = bool(inf_values)

if max_available_float:
max_available_float = torch.tensor(max_available_float, device=device, dtype=dtype)
self.max_available_float = StatelessBuffer(max_available_float)
else:
self.max_available_float = None

@brevitas.jit.script_method
def forward(
self,
Expand All @@ -113,13 +122,9 @@ def forward(
mantissa_bit_width: Tensor,
exponent_bias: Tensor):
inf_mask = x.isinf()
max_value = max_float(
exponent_bit_width,
mantissa_bit_width,
exponent_bias,
self.nan_values,
self.inf_values,
self.saturating)
max_value = max_float(exponent_bit_width, mantissa_bit_width, exponent_bias)
max_value = max_value if self.max_available_float is None else torch.min(
max_value, self.max_available_float())
p_max_val_mask = x > max_value
n_max_val_mask = -x > max_value
min_float = torch.tensor(0.) if not self.signed else -max_value
Expand Down
22 changes: 14 additions & 8 deletions src/brevitas/core/scaling/float_scaling.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,22 +15,28 @@ class FloatScaling(brevitas.jit.ScriptModule):

def __init__(
self,
max_available_float: Optional[float] = None,
inf_values: Optional[Tuple[str]] = None,
nan_values: Optional[Tuple[str]] = None,
saturating: bool = True):
saturating: bool = True,
device: Optional[str] = None,
dtype: Optional[torch.dtype] = None):
super(FloatScaling, self).__init__()
self.inf_values = inf_values
self.nan_values = nan_values
self.saturating = saturating

if max_available_float:
max_available_float = torch.tensor(max_available_float, device=device, dtype=dtype)
self.max_available_float = StatelessBuffer(max_available_float)
else:
self.max_available_float = None

@brevitas.jit.script_method
def forward(
self, exponent_bit_width: Tensor, mantissa_bit_width: Tensor,
exponent_bias: Tensor) -> Tensor:
return max_float(
exponent_bit_width,
mantissa_bit_width,
exponent_bias,
self.nan_values,
self.inf_values,
self.saturating)
max_value = max_float(exponent_bit_width, mantissa_bit_width, exponent_bias)
max_value = max_value if self.max_available_float is None else torch.min(
max_value, self.max_available_float())
return max_value
61 changes: 11 additions & 50 deletions src/brevitas/function/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,56 +192,17 @@ def min_int(signed: bool, narrow_range: bool, bit_width: Tensor) -> Tensor:


@brevitas.jit.script
def max_float(
exponent_bit_width: Tensor,
mantissa_bit_width: Tensor,
exponent_bias: Tensor,
nan_values: Tuple[str],
inf_values: Tuple[str],
saturating: bool):
# Idea: take the smallest NaN/inf value, set max_value to the next smaller one
# inf without NaN not possible
exponent_bit_width = int(exponent_bit_width.item())
mantissa_bit_width = int(mantissa_bit_width.item())
if inf_values is None and nan_values is None:
# saturating has to be True if no NaN/inf value are used
assert saturating, 'cannot be non-saturating without NaN/inf values'
# no special cases, max_value is using all bits for exponent and mantissa
exponent = '1' * exponent_bit_width
mantissa = '1' * mantissa_bit_width
elif nan_values is not None:
# we at least have values for NaN, so initiate MaxValInfNaN
special_values = nan_values + inf_values if inf_values is not None else nan_values

# check that NaN/inf values are all mantissa_bit_width long
if any(map(lambda x: len(x) > mantissa_bit_width, special_values)):
raise RuntimeError('NaN/inf codes need to be the same length as the mantissa.')

# get the minimum special case, our max value is the next smaller value
min_special_case = min(map(lambda x: int(x, 2), special_values))

max_value_mantissa = min_special_case - 1

if max_value_mantissa < 0:
# all mantissa values are used, so we need to use decrease exponent values
exponent = '1' * (exponent_bit_width - 1)
# add trailing 0 to reach bit width
exponent += '0'
# since we decreased exponent, we can use full mantissa
mantissa = '1' * mantissa_bit_width
else:
# there is a free mantissa code, so use full exponent
exponent = '1' * exponent_bit_width
# get binary code for max_value_mantissa in the number of mantissa bits
mantissa = format(max_value_mantissa, f'0{mantissa_bit_width}b')
else:
# no NaN values but inf values
raise RuntimeError('Minifloat Error: inf value cannot exist without NaN value.')

# we don't need the sign since we're looking for the max value
max_value = get_minifloat_value(
exponent=exponent, mantissa=mantissa, exponent_bias=exponent_bias)
return max_value
def max_float(exponent_bit_width: Tensor, mantissa_bit_width: Tensor, exponent_bias: Tensor):
max_exponent = (2. ** exponent_bit_width) - 1. - exponent_bias
max_mantissa = torch.sum((
2. ** torch.arange(
0,
-1. * mantissa_bit_width - 1.,
-1.,
dtype=mantissa_bit_width.dtype,
device=mantissa_bit_width.device)))
max_val = max_mantissa * (2 ** max_exponent)
return max_val


def get_upper_bound_on_l1_norm(
Expand Down
27 changes: 27 additions & 0 deletions src/brevitas/quant/experimental/float_quant_ocp.py
Original file line number Diff line number Diff line change
@@ -1,24 +1,51 @@
# Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause

from dependencies import value

from brevitas.quant.base import MSESymmetricScale
from brevitas.quant.experimental.float_base import FloatActBase
from brevitas.quant.experimental.float_base import FloatWeightBase
from brevitas.quant.experimental.float_base import Fp8e4m3Mixin
from brevitas.quant.experimental.float_base import Fp8e5m2Mixin
from brevitas.quant.experimental.float_base import ScaledFloatActBase
from brevitas.quant.experimental.float_base import ScaledFloatWeightBase
from brevitas.utils.float_quant_utils import get_max_available_float


class Fp8e4m3OCPMixin(Fp8e4m3Mixin):
nan_values = (('111',))
inf_values = None

@value
def max_available_float(
exponent_bit_width, mantissa_bit_width, exponent_bias, nan_values, inf_values,
saturating):
return get_max_available_float(
exponent_bit_width,
mantissa_bit_width,
exponent_bias,
nan_values,
inf_values,
saturating)


class Fp8e5m2OCPMixin(Fp8e5m2Mixin):
nan_values = ('01', '11', '10')
inf_values = (('00',))

@value
def max_available_float(
exponent_bit_width, mantissa_bit_width, exponent_bias, nan_values, inf_values,
saturating):
return get_max_available_float(
exponent_bit_width,
mantissa_bit_width,
exponent_bias,
nan_values,
inf_values,
saturating)


class Fp8e4m3OCPWeight(Fp8e4m3OCPMixin, FloatWeightBase):
"""
Expand Down
53 changes: 53 additions & 0 deletions src/brevitas/utils/float_quant_utils.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
# Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
from typing import Tuple

import torch


def mantissa_bits_to_float(bits: str, frexp_compatible: bool = False) -> float:
Expand All @@ -22,3 +25,53 @@ def get_minifloat_value(exponent: str, mantissa: str, exponent_bias: int) -> flo
exponent_value = int(exponent, 2)
mantissa_value = mantissa_bits_to_float(mantissa)
return (2 ** (exponent_value - exponent_bias)) * mantissa_value


def get_max_available_float(
exponent_bit_width: torch.Tensor,
mantissa_bit_width: torch.Tensor,
exponent_bias: torch.Tensor,
nan_values: Tuple[str],
inf_values: Tuple[str],
saturating: bool) -> torch.Tensor:
# Idea: take the smallest NaN/inf value, set max_value to the next smaller one
# inf without NaN not possible
if inf_values is None and nan_values is None:
# saturating has to be True if no NaN/inf value are used
assert saturating, 'cannot be non-saturating without NaN/inf values'
# no special cases, max_value is using all bits for exponent and mantissa
exponent = '1' * exponent_bit_width
mantissa = '1' * mantissa_bit_width
elif nan_values is not None:
# we at least have values for NaN, so initiate MaxValInfNaN
special_values = nan_values + inf_values if inf_values is not None else nan_values

# check that NaN/inf values are all mantissa_bit_width long
if any(map(lambda x: len(x) > mantissa_bit_width, special_values)):
raise RuntimeError('NaN/inf codes need to be the same length as the mantissa.')

# get the minimum special case, our max value is the next smaller value
min_special_case = min(map(lambda x: int(x, 2), special_values))

max_value_mantissa = min_special_case - 1

if max_value_mantissa < 0:
# all mantissa values are used, so we need to use decrease exponent values
exponent = '1' * (exponent_bit_width - 1)
# add trailing 0 to reach bit width
exponent += '0'
# since we decreased exponent, we can use full mantissa
mantissa = '1' * mantissa_bit_width
else:
# there is a free mantissa code, so use full exponent
exponent = '1' * exponent_bit_width
# get binary code for max_value_mantissa in the number of mantissa bits
mantissa = format(max_value_mantissa, f'0{mantissa_bit_width}b')
else:
# no NaN values but inf values
raise RuntimeError('Minifloat Error: inf value cannot exist without NaN value.')

# we don't need the sign since we're looking for the max value
max_value = get_minifloat_value(
exponent=exponent, mantissa=mantissa, exponent_bias=exponent_bias)
return max_value
21 changes: 7 additions & 14 deletions tests/brevitas/core/test_clamp.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,32 +22,25 @@
'minifloat, expected_max_val',
((format, max_val) for format, max_val in FORMAT_MAXVAL_MAP.items()))
def test_max_value(minifloat, expected_max_val):
inf_values = minifloat.float_clamp_impl.inf_values
nan_values = minifloat.float_clamp_impl.nan_values
saturating = minifloat.float_clamp_impl.saturating
max_val = max_float(
torch.tensor(minifloat.exponent_bit_width, dtype=torch.float32),
torch.tensor(minifloat.mantissa_bit_width, dtype=torch.float32),
torch.tensor(minifloat.exponent_bias, dtype=torch.float32),
nan_values,
inf_values,
saturating)
torch.tensor(minifloat.exponent_bias, dtype=torch.float32))
max_available_float = minifloat.float_clamp_impl.max_available_float
max_val = max_val if max_available_float is None else torch.min(max_val, max_available_float())

assert expected_max_val == max_val


@given(inp=float_tensor_random_shape_st())
def test_float_clamp(inp, fp8_clamp):
inf_values = fp8_clamp.float_clamp_impl.inf_values
nan_values = fp8_clamp.float_clamp_impl.nan_values
saturating = fp8_clamp.float_clamp_impl.saturating

max_val = max_float(
torch.tensor(fp8_clamp.exponent_bit_width, dtype=torch.float32),
torch.tensor(fp8_clamp.mantissa_bit_width, dtype=torch.float32),
torch.tensor(fp8_clamp.exponent_bias, dtype=torch.float32),
nan_values,
inf_values,
saturating)
torch.tensor(fp8_clamp.exponent_bias, dtype=torch.float32))
max_available_float = fp8_clamp.float_clamp_impl.max_available_float
max_val = max_val if max_available_float is None else torch.min(max_val, max_available_float())
# get values that exceed max_val
over_limit_mask = inp.abs() > max_val

Expand Down
9 changes: 4 additions & 5 deletions tests/brevitas/core/test_float_quant.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,11 +187,10 @@ def test_inner_scale(inp, minifloat_format, scale):
max_val = max_float(
torch.tensor(exponent_bit_width),
torch.tensor(mantissa_bit_width),
torch.tensor(exponent_bias),
None,
None,
True)

torch.tensor(exponent_bias))
max_available_float = float_clamp.max_available_float
max_value = max_val if max_available_float is None else torch.min(
max_value, max_available_float)
# call internal scale
internal_scale = float_quant.internal_scale(scaled_inp)
val_fp_quant = internal_scale * float_quant.float_to_int_impl(scaled_inp / internal_scale)
Expand Down

0 comments on commit 02c7985

Please sign in to comment.