Skip to content

Commit

Permalink
Fix
Browse files Browse the repository at this point in the history
  • Loading branch information
Giuseppe5 committed Jun 18, 2024
1 parent df44966 commit 6f2f0b0
Show file tree
Hide file tree
Showing 4 changed files with 12 additions and 12 deletions.
6 changes: 2 additions & 4 deletions src/brevitas/proxy/float_parameter_quant.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,11 +65,9 @@ def nan_values(self):
@property
def is_ocp(self):
is_e4m3 = self.mantissa_bit_width() == 3 and self.exponent_bit_width() == 4

is_ocp_e4m3 = is_e4m3 and self.inf_values() is None and self.nan_values() == (('111',))

is_e5m2 = self.mantissa_bit_width() == 5 and self.exponent_bit_width() == 2

is_ocp_e5m2 = is_e5m2 and self.inf_values() == (
('00',)) and self.nan_values() == ('01', '11', '10')

Expand All @@ -79,11 +77,11 @@ def is_ocp(self):
def is_fnuz(self):
is_e4m3 = self.mantissa_bit_width() == 3 and self.exponent_bit_width() == 4
is_fnuz_e4m3 = is_e4m3 and self.inf_values() is None and self.nan_values(
) is None and self.exponent_bias == 8
) is None and self.exponent_bias() == 8

is_e5m2 = self.mantissa_bit_width() == 5 and self.exponent_bit_width() == 2
is_fnuz_e5m2 = is_e5m2 and self.inf_values() is None and self.nan_values(
) is None and self.exponent_bias == 16
) is None and self.exponent_bias() == 16
return is_fnuz_e4m3 or is_fnuz_e5m2

def forward(self, x: torch.Tensor) -> Union[Tensor, FloatQuantTensor]:
Expand Down
6 changes: 2 additions & 4 deletions src/brevitas/proxy/float_runtime_quant.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,11 +40,9 @@ def nan_values(self, force_eval=True):
@property
def is_ocp(self):
is_e4m3 = self.mantissa_bit_width() == 3 and self.exponent_bit_width() == 4

is_ocp_e4m3 = is_e4m3 and self.inf_values() is None and self.nan_values() == (('111',))

is_e5m2 = self.mantissa_bit_width() == 5 and self.exponent_bit_width() == 2

is_ocp_e5m2 = is_e5m2 and self.inf_values() == (
('00',)) and self.nan_values() == ('01', '11', '10')

Expand All @@ -54,11 +52,11 @@ def is_ocp(self):
def is_fnuz(self):
is_e4m3 = self.mantissa_bit_width() == 3 and self.exponent_bit_width() == 4
is_fnuz_e4m3 = is_e4m3 and self.inf_values() is None and self.nan_values(
) is None and self.exponent_bias == 8
) is None and self.exponent_bias() == 8

is_e5m2 = self.mantissa_bit_width() == 5 and self.exponent_bit_width() == 2
is_fnuz_e5m2 = is_e5m2 and self.inf_values() is None and self.nan_values(
) is None and self.exponent_bias == 16
) is None and self.exponent_bias() == 16
return is_fnuz_e4m3 or is_fnuz_e5m2

def forward(self, x: Union[Tensor, FloatQuantTensor]) -> Union[Tensor, FloatQuantTensor]:
Expand Down
5 changes: 4 additions & 1 deletion tests/brevitas_ort/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@
from brevitas.nn import QuantConvTranspose2d
from brevitas.nn import QuantConvTranspose3d
from brevitas.nn import QuantLinear
from brevitas.quant.experimental.float_quant_fnuz import Fp8e4m3FNUZActPerTensorFloat
from brevitas.quant.experimental.float_quant_fnuz import Fp8e4m3FNUZWeightPerTensorFloat
from brevitas.quant.experimental.float_quant_ocp import Fp8e4m3OCPActPerTensorFloat
from brevitas.quant.experimental.float_quant_ocp import Fp8e4m3OCPWeightPerTensorFloat
from brevitas.quant.fixed_point import Int8ActPerTensorFixedPoint
Expand Down Expand Up @@ -62,7 +64,8 @@ class A2QWeightQuantizerForTests(Int8AccumulatorAwareWeightQuant):
(Int8WeightPerChannelFixedPoint, Int8ActPerTensorFixedPoint),
'weight_symmetric_activation_dynamic_asymmetric_per_tensor_float':
(Int8WeightPerTensorFloat, ShiftedUint8DynamicActPerTensorFloat),
'fp8_per_tensor_float': (Fp8e4m3OCPWeightPerTensorFloat, Fp8e4m3OCPActPerTensorFloat)}
'fp8_ocp_per_tensor_float': (Fp8e4m3OCPWeightPerTensorFloat, Fp8e4m3OCPActPerTensorFloat),
'fp8_fnuz_per_tensor_float': (Fp8e4m3FNUZWeightPerTensorFloat, Fp8e4m3FNUZActPerTensorFloat)}
LSTM_QUANTIZERS = {
'asymmetric_per_tensor_float':
(ShiftedUint8WeightPerTensorFloat, ShiftedUint8ActPerTensorFloat),
Expand Down
7 changes: 4 additions & 3 deletions tests/brevitas_ort/quant_module_cases.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,8 @@ def case_quant_wbiol(
set_case_id(request.node.callspec.id, QuantWBIOLCases.case_quant_wbiol)

weight_quant, io_quant = quantizers
if weight_quant == Fp8e4m3OCPWeightPerTensorFloat:
is_fp8 = weight_quant == Fp8e4m3OCPWeightPerTensorFloat or weight_quant == Fp8e4m3FNUZWeightPerTensorFloat
if is_fp8:
if weight_bit_width < 8 or input_bit_width < 8 or output_bit_width < 8:
pytest.skip('FP8 export requires total bitwidth equal to 8')
torch.use_deterministic_algorithms(False)
Expand All @@ -40,9 +41,9 @@ def case_quant_wbiol(
layer_kwargs = {
'in_channels': IN_CH, 'out_channels': OUT_CH, 'kernel_size': KERNEL_SIZE}

bias_quantizer = None if weight_quant == Fp8e4m3OCPWeightPerTensorFloat else Int32Bias
bias_quantizer = None if is_fp8 else Int32Bias
# Required because of numpy error with FP8 data type. Export iself works fine.
return_quant_tensor = False if weight_quant == Fp8e4m3OCPWeightPerTensorFloat else True
return_quant_tensor = False if is_fp8 else True

class Model(nn.Module):

Expand Down

0 comments on commit 6f2f0b0

Please sign in to comment.