diff --git a/src/brevitas/core/quant/float.py b/src/brevitas/core/quant/float.py index 65f56a134..09dcc248a 100644 --- a/src/brevitas/core/quant/float.py +++ b/src/brevitas/core/quant/float.py @@ -68,10 +68,6 @@ def __init__( @brevitas.jit.script_method def quantize(self, x: torch.Tensor, scale: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: - if self.float_scaling_impl is not None: - float_scaling_impl_value = self.float_scaling_impl( - self.exponent_bit_width(), self.mantissa_bit_width(), self.exponent_bias()) - scale = scale / float_scaling_impl_value x = self.input_view_impl(x) scaled_x = x / scale internal_scale = float_internal_scale( @@ -85,7 +81,12 @@ def dequantize(self, y, scale): @brevitas.jit.script_method def forward(self, x): - scale = self.scaling_impl(x) + if self.float_scaling_impl is not None: + float_scaling_impl_value = self.float_scaling_impl( + self.exponent_bit_width(), self.mantissa_bit_width(), self.exponent_bias()) + else: + float_scaling_impl_value = None + scale = self.scaling_impl(x, float_scaling_impl_value) if self.observer_only: y = x saturating, inf_values, nan_values = self.float_clamp_impl.saturating, self.float_clamp_impl.inf_values, self.float_clamp_impl.nan_values diff --git a/src/brevitas/core/quant/int.py b/src/brevitas/core/quant/int.py index e1cc271d8..328ad63b3 100644 --- a/src/brevitas/core/quant/int.py +++ b/src/brevitas/core/quant/int.py @@ -150,9 +150,8 @@ def __init__( @brevitas.jit.script_method def forward(self, x: Tensor) -> Tuple[Tensor, Tensor, Tensor, Tensor]: bit_width = self.msb_clamp_bit_width_impl() - threshold = self.scaling_impl(x) int_threshold = self.int_scaling_impl(bit_width) - scale = threshold / int_threshold + scale = self.scaling_impl(x, int_threshold) zero_point = self.zero_point_impl(x, scale, bit_width) if self.observer_only: y = x @@ -189,8 +188,7 @@ def forward(self, x: Tensor) -> Tuple[Tensor, Tensor, Tensor, Tensor, Tensor, Te pre_threshold = self.pre_scaling_impl(x) pre_scale = pre_threshold / int_threshold pre_zero_point = self.pre_zero_point_impl(x, pre_scale, bit_width) - threshold = self.scaling_impl(x) - scale = threshold / int_threshold + scale = self.scaling_impl(x, int_threshold) zero_point = self.zero_point_impl(x, scale, bit_width) if self.observer_only: y = x @@ -258,8 +256,7 @@ def forward(self, x: Tensor, input_bit_width: Tensor, pre_threshold = self.pre_scaling_impl(x, input_bit_width, input_is_signed) pre_scale = pre_threshold / int_threshold pre_zero_point = self.pre_zero_point_impl(x, pre_scale, bit_width) - threshold = self.scaling_impl(x) - scale = threshold / int_threshold + scale = self.scaling_impl(x, int_threshold) zero_point = self.zero_point_impl(x, scale, bit_width) if self.observer_only: y = x diff --git a/src/brevitas/core/restrict_val.py b/src/brevitas/core/restrict_val.py index 449318765..59b3fe8ec 100644 --- a/src/brevitas/core/restrict_val.py +++ b/src/brevitas/core/restrict_val.py @@ -36,7 +36,7 @@ def __init__(self, scaling_min_val: Optional[float], restrict_value_impl: Option self.restrict_value_impl = Identity() @brevitas.jit.script_method - def forward(self, x: torch.Tensor): + def forward(self, x: Tensor): x = self.restrict_value_impl(x) x = self.clamp_min_ste(x) return x @@ -52,7 +52,7 @@ def __init__(self, restrict_value_impl: Optional[Module]): self.restrict_value_impl = Identity() @brevitas.jit.script_method - def forward(self, x: torch.Tensor): + def forward(self, x: Tensor): x = self.restrict_value_impl(x) return x @@ -68,7 +68,7 @@ def __init__(self, scaling_min_val: Optional[float]): self.min_val = scaling_min_val @brevitas.jit.script_method - def forward(self, x: torch.Tensor): + def forward(self, x: Tensor): x = self.clamp_min_ste(x) return x @@ -90,8 +90,11 @@ def restrict_init_module(self): def restrict_init_inplace_module(self): return Identity() + def combine_scale_threshold(self, x: Tensor, threshold: Tensor) -> Tensor: + return x / threshold + @brevitas.jit.script_method - def forward(self, x: torch.Tensor) -> Tensor: + def forward(self, x: Tensor) -> Tensor: return x @@ -104,7 +107,7 @@ def __init__(self): def restrict_init_float(self, x: float): return math.log2(x) - def restrict_init_tensor(self, x: torch.Tensor): + def restrict_init_tensor(self, x: Tensor): return torch.log2(x) def restrict_init_module(self): @@ -113,8 +116,11 @@ def restrict_init_module(self): def restrict_init_inplace_module(self): return InplaceLogTwo() + def combine_scale_threshold(self, x: Tensor, threshold: Tensor) -> Tensor: + return x - threshold + @brevitas.jit.script_method - def forward(self, x: torch.Tensor): + def forward(self, x: Tensor): x = self.power_of_two(x) return x @@ -128,7 +134,7 @@ def __init__(self, restrict_value_float_to_int_impl: Module = RoundSte()): def restrict_init_float(self, x: float): return x - def restrict_init_tensor(self, x: torch.Tensor): + def restrict_init_tensor(self, x: Tensor): return x def restrict_init_module(self): @@ -137,8 +143,11 @@ def restrict_init_module(self): def restrict_init_inplace_module(self): return Identity() + def combine_scale_threshold(self, x: Tensor, threshold: Tensor) -> Tensor: + return x / threshold + @brevitas.jit.script_method - def forward(self, x: torch.Tensor): + def forward(self, x: Tensor): x = self.float_to_int_impl(x) return x @@ -153,7 +162,7 @@ def __init__(self, restrict_value_float_to_int_impl: Module = RoundSte()): def restrict_init_float(self, x: float): return math.log2(x) - def restrict_init_tensor(self, x: torch.Tensor): + def restrict_init_tensor(self, x: Tensor): return torch.log2(x) def restrict_init_module(self): @@ -162,8 +171,11 @@ def restrict_init_module(self): def restrict_init_inplace_module(self): return InplaceLogTwo() + def combine_scale_threshold(self, x: Tensor, threshold: Tensor) -> Tensor: + return x - threshold + @brevitas.jit.script_method - def forward(self, x: torch.Tensor): + def forward(self, x: Tensor): x = self.float_to_int_impl(x) x = self.power_of_two(x) return x diff --git a/src/brevitas/core/scaling/runtime.py b/src/brevitas/core/scaling/runtime.py index e4333186d..f11eb1f2a 100644 --- a/src/brevitas/core/scaling/runtime.py +++ b/src/brevitas/core/scaling/runtime.py @@ -11,6 +11,7 @@ import brevitas.config as config from brevitas.core.function_wrapper import Identity from brevitas.core.restrict_val import _RestrictClampValue +from brevitas.core.restrict_val import FloatRestrictValue from brevitas.core.stats import _ParameterListStats from brevitas.core.stats import _RuntimeStats from brevitas.core.stats import DEFAULT_MOMENTUM @@ -27,8 +28,8 @@ def __init__( scaling_stats_input_view_shape_impl: Module, scaling_stats_input_concat_dim: int, tracked_parameter_list: List[torch.nn.Parameter], - restrict_scaling_impl: Module, scaling_shape: Tuple[int, ...], + restrict_scaling_impl: Module = FloatRestrictValue(), affine_rescaling: bool = False, affine_shift_scale: bool = False, scaling_min_val: Optional[float] = None, @@ -51,9 +52,12 @@ def __init__( device) @brevitas.jit.script_method - def forward(self, ignored: torch.Tensor) -> torch.Tensor: + def forward( + self, ignored: torch.Tensor, threshold: Optional[torch.Tensor] = None) -> torch.Tensor: stats = self.parameter_list_stats() - return self.stats_scaling_impl(stats) + if threshold is None: + threshold = torch.ones(1).type_as(stats) + return self.stats_scaling_impl(stats, threshold) class _StatsScaling(brevitas.jit.ScriptModule): @@ -78,10 +82,16 @@ def __init__( self.affine_rescaling = Identity() self.restrict_clamp_scaling = _RestrictClampValue(scaling_min_val, restrict_scaling_impl) self.restrict_scaling_pre = restrict_scaling_impl.restrict_init_module() + self.restrict_scaling_impl = restrict_scaling_impl @brevitas.jit.script_method - def forward(self, stats: torch.Tensor) -> torch.Tensor: + def forward( + self, stats: torch.Tensor, threshold: Optional[torch.Tensor] = None) -> torch.Tensor: + if threshold is None: + threshold = torch.ones(1).type_as(stats) + threshold = self.restrict_scaling_pre(threshold) stats = self.restrict_scaling_pre(stats) + stats = self.restrict_scaling_impl.combine_scale_threshold(stats, threshold) stats = self.affine_rescaling(stats) stats = self.restrict_clamp_scaling(stats) return stats @@ -93,10 +103,10 @@ def __init__( self, scaling_stats_impl: Module, scaling_stats_input_view_shape_impl: Module, - restrict_scaling_impl: Module, scaling_shape: Tuple[int, ...], affine_rescaling: bool = False, affine_shift_scale: bool = False, + restrict_scaling_impl: Module = FloatRestrictValue(), scaling_stats_momentum: float = DEFAULT_MOMENTUM, scaling_min_val: Optional[float] = None, dtype: Optional[torch.dtype] = None, @@ -120,9 +130,9 @@ def __init__( device) @brevitas.jit.script_method - def forward(self, x: torch.Tensor): + def forward(self, x: torch.Tensor, threshold: Optional[torch.Tensor] = None) -> torch.Tensor: stats = self.runtime_stats(x) - return self.stats_scaling_impl(stats) + return self.stats_scaling_impl(stats, threshold) class _AffineRescaling(brevitas.jit.ScriptModule): @@ -163,13 +173,13 @@ def _load_from_state_dict( class RuntimeDynamicGroupStatsScaling(brevitas.jit.ScriptModule): def __init__( - self, - group_size: int, - group_dim: int, - input_view_impl: torch.nn.Module, - scaling_stats_impl: torch.nn.Module, - scaling_min_val: Optional[float], - restrict_scaling_impl: Optional[torch.nn.Module]) -> None: + self, + group_size: int, + group_dim: int, + input_view_impl: Module, + scaling_stats_impl: Module, + scaling_min_val: Optional[float], + restrict_scaling_impl: Module = FloatRestrictValue()) -> None: super(RuntimeDynamicGroupStatsScaling, self).__init__() self.group_size = group_size self.group_dim = group_dim @@ -179,9 +189,14 @@ def __init__( self.restrict_clamp_scaling = _RestrictClampValue(scaling_min_val, restrict_scaling_impl) @brevitas.jit.script_method - def forward(self, stats_input) -> torch.Tensor: + def forward( + self, + stats_input: torch.Tensor, + threshold: Optional[torch.Tensor] = None) -> torch.Tensor: + if threshold is None: + threshold = torch.ones(1).type_as(stats_input) stats_input_reshaped = self.input_view_impl(stats_input) - out = self.scaling_stats_impl(stats_input_reshaped) + out = self.scaling_stats_impl(stats_input_reshaped) / threshold # Scaling min val out = self.restrict_clamp_scaling(out) return out diff --git a/src/brevitas/core/scaling/standalone.py b/src/brevitas/core/scaling/standalone.py index 53f389331..5198444b1 100644 --- a/src/brevitas/core/scaling/standalone.py +++ b/src/brevitas/core/scaling/standalone.py @@ -15,6 +15,7 @@ from brevitas.core.restrict_val import _ClampValue from brevitas.core.restrict_val import _RestrictClampValue from brevitas.core.restrict_val import _RestrictValue +from brevitas.core.restrict_val import FloatRestrictValue from brevitas.core.scaling.runtime import _StatsScaling from brevitas.core.stats import _ParameterListStats from brevitas.core.stats import _Stats @@ -60,7 +61,7 @@ class ConstScaling(brevitas.jit.ScriptModule): def __init__( self, scaling_init: Union[float, Tensor], - restrict_scaling_impl: Optional[Module] = None, + restrict_scaling_impl: Module = FloatRestrictValue(), scaling_min_val: Optional[float] = None, dtype: Optional[torch.dtype] = None, device: Optional[torch.device] = None) -> None: @@ -68,18 +69,23 @@ def __init__( self.restrict_clamp_scaling = _RestrictClampValue(scaling_min_val, restrict_scaling_impl) if isinstance(scaling_init, Tensor): scaling_init = scaling_init.to(device=device, dtype=dtype) - if restrict_scaling_impl is not None: - scaling_init = restrict_scaling_impl.restrict_init_tensor(scaling_init) + scaling_init = restrict_scaling_impl.restrict_init_tensor(scaling_init) + self.restrict_init_module = restrict_scaling_impl.restrict_init_module() self.value = StatelessBuffer(scaling_init.detach()) else: - if restrict_scaling_impl is not None: - scaling_init = restrict_scaling_impl.restrict_init_float(scaling_init) + scaling_init = restrict_scaling_impl.restrict_init_float(scaling_init) + self.restrict_init_module = restrict_scaling_impl.restrict_init_module() self.value = StatelessBuffer(torch.tensor(scaling_init, dtype=dtype, device=device)) @brevitas.jit.script_method - def forward(self, placeholder: Tensor) -> Tensor: - value = self.value() - restricted_value = self.restrict_clamp_scaling(value) + def forward(self, placeholder: Tensor, threshold: Optional[Tensor] = None) -> Tensor: + if threshold is None: + threshold = torch.ones(1).type_as(placeholder) + # We first apply any restriction to scaling + # For IntQuant, this is no-op, retrocompatible. + threshold = self.restrict_clamp_scaling(self.restrict_init_module(threshold)) + restricted_value = self.restrict_clamp_scaling(self.value()) + restricted_value = restricted_value / threshold return restricted_value @@ -126,7 +132,7 @@ def __init__( self, scaling_init: Union[float, Tensor], scaling_shape: Optional[Tuple[int, ...]] = None, - restrict_scaling_impl: Optional[Module] = None, + restrict_scaling_impl: Module = FloatRestrictValue(), scaling_min_val: Optional[float] = None, dtype: Optional[torch.dtype] = None, device: Optional[torch.device] = None) -> None: @@ -141,17 +147,24 @@ def __init__( scaling_init = scaling_init.detach() else: scaling_init = torch.tensor(scaling_init, dtype=dtype, device=device) - if restrict_scaling_impl is not None: - scaling_init = restrict_scaling_impl.restrict_init_tensor(scaling_init) + + scaling_init = restrict_scaling_impl.restrict_init_tensor(scaling_init) + self.restrict_init_module = restrict_scaling_impl.restrict_init_module() + if scaling_init.shape == SCALAR_SHAPE and scaling_shape is not None: scaling_init = torch.full(scaling_shape, scaling_init, dtype=dtype, device=device) self.value = Parameter(scaling_init) self.restrict_clamp_scaling = _RestrictClampValue(scaling_min_val, restrict_scaling_impl) @brevitas.jit.script_method - def forward(self, placeholder: Tensor) -> Tensor: + def forward(self, placeholder: Tensor, threshold: Optional[Tensor] = None) -> Tensor: + if threshold is None: + threshold = torch.ones(1).type_as(placeholder) + # We first apply any restriction to scaling + # For IntQuant, this is no-op, retrocompatible. + threshold = self.restrict_clamp_scaling(self.restrict_init_module(threshold)) value = abs_binary_sign_grad(self.restrict_clamp_scaling(self.value)) - return value + return value / threshold def _load_from_state_dict( self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, @@ -178,8 +191,8 @@ def __init__( scaling_stats_input_view_shape_impl: Module, scaling_stats_input_concat_dim: int, tracked_parameter_list: List[torch.nn.Parameter], - restrict_scaling_impl: Module, scaling_shape: Tuple[int, ...], + restrict_scaling_impl: Module = FloatRestrictValue(), scaling_min_val: Optional[float] = None, dtype: Optional[torch.dtype] = None, device: Optional[torch.device] = None) -> None: @@ -190,20 +203,26 @@ def __init__( scaling_stats_input_view_shape_impl, scaling_stats_input_concat_dim, tracked_parameter_list) + self.restrict_scaling_impl = restrict_scaling_impl self.stats_scaling_impl = _StatsScaling( restrict_scaling_impl, scaling_shape, scaling_min_val, False, False, dtype, device) self.init_done: bool = brevitas.jit.Attribute(False, bool) self.local_loss_mode: bool = brevitas.jit.Attribute(False, bool) - if restrict_scaling_impl is not None: - self.restrict_inplace_preprocess = restrict_scaling_impl.restrict_init_inplace_module() - else: - self.restrict_inplace_preprocess = Identity() + self.restrict_inplace_preprocess = restrict_scaling_impl.restrict_init_inplace_module() + self.restrict_preprocess = restrict_scaling_impl.restrict_init_module() self.value = Parameter(torch.full(scaling_shape, 1.0, dtype=dtype, device=device)) @brevitas.jit.script_method - def forward(self, ignored: torch.Tensor) -> torch.Tensor: + def forward(self, ignored: Tensor, threshold: Optional[Tensor] = None) -> Tensor: + if threshold is None: + threshold = torch.ones(1).type_as(ignored) + # Threshold division must happen after we update self.value, but before we apply restrict_preproces + # This is because we don't want to store a parameter dependant on a runtime value (threshold) + # And because restrict needs to happen after we divide by threshold if self.init_done: - value = abs_binary_sign_grad(self.stats_scaling_impl.restrict_clamp_scaling(self.value)) + threshold = self.restrict_inplace_preprocess(threshold) + value = self.restrict_scaling_impl.combine_scale_threshold(self.value, threshold) + value = abs_binary_sign_grad(self.stats_scaling_impl.restrict_clamp_scaling(value)) return value else: stats = self.parameter_list_stats() @@ -212,8 +231,10 @@ def forward(self, ignored: torch.Tensor) -> torch.Tensor: if self.local_loss_mode: return self.stats_scaling_impl(stats) stats = self.restrict_inplace_preprocess(stats) + threshold = self.restrict_inplace_preprocess(threshold) inplace_tensor_mul(self.value.detach(), stats) - value = abs_binary_sign_grad(self.stats_scaling_impl.restrict_clamp_scaling(self.value)) + value = self.restrict_scaling_impl.combine_scale_threshold(self.value, threshold) + value = abs_binary_sign_grad(self.stats_scaling_impl.restrict_clamp_scaling(value)) self.init_done = True return value @@ -290,7 +311,7 @@ def __init__( scaling_stats_impl: Module, scaling_stats_input_view_shape_impl: Module = OverBatchOverTensorView(), scaling_shape: Tuple[int, ...] = SCALAR_SHAPE, - restrict_scaling_impl: Optional[Module] = None, + restrict_scaling_impl: Module = FloatRestrictValue(), scaling_stats_momentum: Optional[float] = DEFAULT_MOMENTUM, scaling_min_val: Optional[float] = None, dtype: Optional[torch.dtype] = None, @@ -305,19 +326,19 @@ def __init__( scaling_stats_momentum, Optional[float]) self.register_buffer('buffer', torch.full(scaling_shape, 1.0, dtype=dtype, device=device)) self.value = Parameter(torch.full(scaling_shape, 1.0, dtype=dtype, device=device)) + self.restrict_scaling_impl = restrict_scaling_impl self.restrict_scaling = _RestrictValue(restrict_scaling_impl) self.clamp_scaling = _ClampValue(scaling_min_val) self.local_loss_mode: bool = brevitas.jit.Attribute( False, bool) # required to support MSE eval or variants - if restrict_scaling_impl is not None: - self.restrict_inplace_preprocess = restrict_scaling_impl.restrict_init_inplace_module() - self.restrict_preprocess = restrict_scaling_impl.restrict_init_module() - else: - self.restrict_inplace_preprocess = Identity() - self.restrict_preprocess = Identity() + self.restrict_inplace_preprocess = restrict_scaling_impl.restrict_init_inplace_module() + self.restrict_preprocess = restrict_scaling_impl.restrict_init_module() @brevitas.jit.script_method - def training_forward(self, stats_input: Tensor) -> Tensor: + def training_forward(self, stats_input: Tensor, threshold: Tensor) -> Tensor: + # Threshold division must happen after we update self.value, but before we apply restrict_preproces + # This is because we don't want to store a parameter dependent on a runtime value (threshold) + # And because restrict needs to happen after we divide by threshold if self.counter < self.collect_stats_steps: stats_input = self.stats_input_view_shape_impl(stats_input) stats = self.stats(stats_input) @@ -327,32 +348,41 @@ def training_forward(self, stats_input: Tensor) -> Tensor: new_counter = self.counter + 1 # Whenever we are in local loss mode, we don't update the counter nor the buffer if self.local_loss_mode: - return abs_binary_sign_grad(clamped_stats) + # Local loss mode, we early exit and divide by threshold + return abs_binary_sign_grad(clamped_stats / threshold) if self.counter == 0: inplace_tensor_mul(self.buffer, clamped_stats.detach()) else: inplace_momentum_update( self.buffer, clamped_stats.detach(), self.momentum, self.counter, new_counter) self.counter = new_counter - return abs_binary_sign_grad(clamped_stats) + return abs_binary_sign_grad(clamped_stats / threshold) elif self.counter == self.collect_stats_steps: self.restrict_inplace_preprocess(self.buffer) inplace_tensor_mul(self.value.detach(), self.buffer) + threshold = self.restrict_preprocess(threshold) + value = self.restrict_scaling_impl.combine_scale_threshold(self.value, threshold) self.counter = self.counter + 1 - return abs_binary_sign_grad(self.clamp_scaling(self.restrict_scaling(self.value))) + return abs_binary_sign_grad(self.clamp_scaling(self.restrict_scaling(value))) else: - return abs_binary_sign_grad(self.clamp_scaling(self.restrict_scaling(self.value))) + threshold = self.restrict_preprocess(threshold) + value = self.restrict_scaling_impl.combine_scale_threshold(self.value, threshold) + return abs_binary_sign_grad(self.clamp_scaling(self.restrict_scaling(value))) @brevitas.jit.script_method - def forward(self, stats_input: Tensor) -> Tensor: + def forward(self, stats_input: Tensor, threshold: Optional[Tensor] = None) -> Tensor: + if threshold is None: + threshold = torch.ones(1).type_as(stats_input) if self.training: - return self.training_forward(stats_input) + # Threshold division handled inside the training_forward + return self.training_forward(stats_input, threshold) else: if self.counter <= self.collect_stats_steps: - out = self.buffer + out = self.buffer / threshold out = self.restrict_preprocess(out) else: - out = self.value + threshold = self.restrict_preprocess(threshold) + out = self.restrict_scaling_impl.combine_scale_threshold(self.value, threshold) out = abs_binary_sign_grad(self.clamp_scaling(self.restrict_scaling(out))) return out diff --git a/src/brevitas_examples/common/generative/quant_blocks.py b/src/brevitas_examples/common/generative/quant_blocks.py index 93cc235e2..776f1f6b2 100644 --- a/src/brevitas_examples/common/generative/quant_blocks.py +++ b/src/brevitas_examples/common/generative/quant_blocks.py @@ -25,10 +25,10 @@ def __init__( self.stats_impl = scaling_stats_impl self.dynamic_scaling_broadcastable_fn = dynamic_scaling_broadcastable_fn - def forward(self, x) -> Tensor: + def forward(self, x, threshold) -> Tensor: shape = x.shape x = self.scaling_stats_input_view_shape_impl(x) - x = self.stats_impl(x) + x = self.stats_impl(x) / threshold x = self.dynamic_scaling_broadcastable_fn(x, shape) return x diff --git a/tests/brevitas/core/test_float_quant.py b/tests/brevitas/core/test_float_quant.py index 552472717..6c7e26f31 100644 --- a/tests/brevitas/core/test_float_quant.py +++ b/tests/brevitas/core/test_float_quant.py @@ -109,9 +109,10 @@ def test_float_to_quant_float(inp, minifloat_format): @given(inp=float_tensor_random_shape_st(), minifloat_format=random_minifloat_format()) @jit_disabled_for_mock() def test_scaling_impls_called_once(inp, minifloat_format): + float_scaling_impl_return = 1. bit_width, exponent_bit_width, mantissa_bit_width, signed, exponent_bias = minifloat_format - scaling_impl = mock.Mock(side_effect=lambda x: 1.) - float_scaling_impl = mock.Mock(side_effect=lambda x, y, z: 1.) + scaling_impl = mock.Mock(side_effect=lambda x, y: 1.) + float_scaling_impl = mock.Mock(side_effect=lambda x, y, z: float_scaling_impl_return) if exponent_bit_width == 0 or mantissa_bit_width == 0: with pytest.raises(RuntimeError): float_quant = FloatQuant( @@ -142,14 +143,15 @@ def test_scaling_impls_called_once(inp, minifloat_format): scaling_impl=scaling_impl, float_scaling_impl=float_scaling_impl, float_clamp_impl=float_clamp) - scale = float_quant.scaling_impl(inp) + float_scaling = float_scaling_impl(exponent_bit_width, mantissa_bit_width, exponent_bias) + scale = float_quant.scaling_impl(inp, float_scaling) _ = float_quant.quantize(inp, scale) # scaling implementations should be called exaclty once on the input float_scaling_impl.assert_called_once_with( torch.tensor(exponent_bit_width), torch.tensor(mantissa_bit_width), torch.tensor(exponent_bias)) - scaling_impl.assert_called_once_with(inp) + scaling_impl.assert_called_once_with(inp, float_scaling_impl_return) @given( @@ -161,7 +163,7 @@ def test_inner_scale(inp, minifloat_format, scale): bit_width, exponent_bit_width, mantissa_bit_width, signed, exponent_bias = minifloat_format # set scaling_impl to scale and float_scaling_impl to 1 to use the same scale as we are here float_scaling_impl = mock.Mock(side_effect=lambda x, y, z: 1.) - scaling_impl = mock.Mock(side_effect=lambda x: scale) + scaling_impl = mock.Mock(side_effect=lambda x, y: scale) if exponent_bit_width == 0 or mantissa_bit_width == 0: with pytest.raises(RuntimeError): float_quant = FloatQuant( diff --git a/tests/brevitas/graph/test_calibration.py b/tests/brevitas/graph/test_calibration.py index b22994275..10d8f7e7c 100644 --- a/tests/brevitas/graph/test_calibration.py +++ b/tests/brevitas/graph/test_calibration.py @@ -12,6 +12,7 @@ from brevitas.graph.calibrate import bias_correction_mode from brevitas.graph.calibrate import calibration_mode from brevitas.graph.calibrate import load_quant_model_mode +from brevitas.inject.enum import RestrictValueType import brevitas.nn as qnn from brevitas.quant import Int8ActPerTensorFixedPoint from brevitas.quant.experimental.float import Fp8e4m3ActPerTensorFloat @@ -27,7 +28,9 @@ BATCH = 1 REFERENCE_SCALES = { 'int_quant': (0.00935234408825635910, 0.01362917013466358185), - 'fp_quant': (0.00249395845457911491, 0.00363444536924362183)} + 'fp_quant': (0.00249395845457911491, 0.00363444536924362183), + 'int_po2_quant': (0.015625, 0.015625), + 'fp_po2_quant': (0.001953125, 0.00390625),} REFERENCE_INP = torch.tensor([[-1.8645, -0.4071, 1.1971]]) REFERENCE_WEIGHTS = torch.tensor([[1.0023, 0.0205, 1.4604], [-0.2918, -1.8218, -0.7010], [1.4573, -0.9074, -0.2708]]) @@ -44,9 +47,9 @@ def reference_implementation_scale_factors_po2( quant = compute_quantile(x, q) quant = torch.max(min_val, quant) quant_float_to_int = torch.ceil( - torch.log2(quant)) # Float to Int Implementation for PowerOfTwo scale + torch.log2(quant / int_scale)) # Float to Int Implementation for PowerOfTwo scale - scale = torch.pow(torch.tensor(2.), quant_float_to_int) / int_scale + scale = torch.pow(torch.tensor(2.), quant_float_to_int) return scale @@ -75,7 +78,15 @@ def forward(self, x): assert torch.allclose(expected_scale, scale) -QUANTS = {'int_quant': Int8ActPerTensorFloat, 'fp_quant': Fp8e4m3ActPerTensorFloat} +class Fp8e4m3ActPerTensorFixedPoint(Fp8e4m3ActPerTensorFloat): + restrict_scaling_type = RestrictValueType.POWER_OF_TWO + + +QUANTS = { + 'int_quant': Int8ActPerTensorFloat, + 'fp_quant': Fp8e4m3ActPerTensorFloat, + 'int_po2_quant': Int8ActPerTensorFixedPoint, + 'fp_po2_quant': Fp8e4m3ActPerTensorFixedPoint} @pytest_cases.parametrize("act_quant", QUANTS.items(), ids=QUANTS.keys())