diff --git a/src/brevitas/core/scaling/runtime.py b/src/brevitas/core/scaling/runtime.py index f1d108068..23707344f 100644 --- a/src/brevitas/core/scaling/runtime.py +++ b/src/brevitas/core/scaling/runtime.py @@ -158,3 +158,38 @@ def _load_from_state_dict( missing_keys.remove(affine_weight_key) if config.IGNORE_MISSING_KEYS and affine_bias_key in missing_keys: missing_keys.remove(affine_bias_key) + + +class RuntimeDynamicGroupStatsScaling(brevitas.jit.ScriptModule): + + def __init__( + self, + group_size: int, + group_dim: int, + scaling_stats_impl: torch.nn.Module, + scaling_min_val: Optional[float], + restrict_scaling_impl: Optional[torch.nn.Module]) -> None: + super(RuntimeDynamicGroupStatsScaling, self).__init__() + self.group_size = group_size + self.group_dim = group_dim + self.scaling_stats_impl = scaling_stats_impl + self.scaling_min_val = scaling_min_val + self.restrict_clamp_scaling = _RestrictClampValue(scaling_min_val, restrict_scaling_impl) + + @brevitas.jit.script_method + def group_scaling_reshape(self, stats_input): + tensor_shape = stats_input.shape + tensor_shape_list = list(tensor_shape) + tensor_shape_list[self.group_dim] = int(tensor_shape_list[self.group_dim] / self.group_size) + block_dim = self.group_dim + 1 if self.group_dim != -1 else -1 + tensor_shape_list.insert(block_dim, self.group_size) + stats_input = stats_input.view(tensor_shape_list) + return stats_input + + @brevitas.jit.script_method + def forward(self, stats_input) -> torch.Tensor: + stats_input_reshaped = self.group_scaling_reshape(stats_input) + out = self.scaling_stats_impl(stats_input_reshaped) + # Scaling min val + out = self.restrict_clamp_scaling(out) + return out diff --git a/src/brevitas/proxy/groupwise_float_runtime_quant.py b/src/brevitas/proxy/groupwise_float_runtime_quant.py index 735b30163..b81504c40 100644 --- a/src/brevitas/proxy/groupwise_float_runtime_quant.py +++ b/src/brevitas/proxy/groupwise_float_runtime_quant.py @@ -37,7 +37,21 @@ def forward( # otherwise return a simple Tensor # We exclude the last two values (inf_values and nan_values) if isinstance(y, tuple) and not any(map(lambda f: f is None, y[:-2])): - out = GroupwiseFloatQuantTensor(*y, signed=self.is_signed, training=self.training) + value, scale, zero_point, exponent_bit_width, mantissa_bit_width, exponent_bias, saturating, inf_values, nan_values = y + out = GroupwiseFloatQuantTensor( + value, + scale, + zero_point, + self.group_size, + self.group_dim, + exponent_bit_width, + mantissa_bit_width, + exponent_bias, + saturating, + inf_values, + nan_values, + signed=self.is_signed, + training=self.training) elif self.is_passthrough_act: # preserve scale/zp/bit/sign even without output quant if isinstance(y, tuple): y = y[0] diff --git a/src/brevitas/quant/experimental/float_base.py b/src/brevitas/quant/experimental/float_base.py index 1b7191037..e17791841 100644 --- a/src/brevitas/quant/experimental/float_base.py +++ b/src/brevitas/quant/experimental/float_base.py @@ -62,3 +62,24 @@ class Fp8e5m2Mixin(ExtendedInjector): exponent_bit_width = 5 mantissa_bit_width = 2 saturating = True + + +class Fp6e3m2Mixin(ExtendedInjector): + bit_width = 6 + exponent_bit_width = 3 + mantissa_bit_width = 2 + saturating = True + + +class Fp6e2m3Mixin(ExtendedInjector): + bit_width = 6 + exponent_bit_width = 2 + mantissa_bit_width = 3 + saturating = True + + +class Fp4e2m1Mixin(ExtendedInjector): + bit_width = 4 + exponent_bit_width = 2 + mantissa_bit_width = 1 + saturating = True diff --git a/src/brevitas/quant/experimental/float_quant_fnuz.py b/src/brevitas/quant/experimental/float_quant_fnuz.py index 7d7035cb6..a1b48ff02 100644 --- a/src/brevitas/quant/experimental/float_quant_fnuz.py +++ b/src/brevitas/quant/experimental/float_quant_fnuz.py @@ -3,160 +3,88 @@ from dependencies import value +from brevitas.inject import ExtendedInjector from brevitas.quant.base import MSESymmetricScale from brevitas.quant.experimental.float_base import FloatActBase from brevitas.quant.experimental.float_base import FloatWeightBase -from brevitas.quant.experimental.float_base import Fp8e4m3Mixin -from brevitas.quant.experimental.float_base import Fp8e5m2Mixin from brevitas.quant.experimental.float_base import ScaledFloatActBase from brevitas.quant.experimental.float_base import ScaledFloatWeightBase -class Fp8e4m3FNUZMixin(Fp8e4m3Mixin): - nan_values = None - inf_values = None +class FpFNUZMixin(ExtendedInjector): + saturating = True @value def exponent_bias(exponent_bit_width): return 2 ** (exponent_bit_width - 1) -class Fp8e5m2FNUZMixin(Fp8e5m2Mixin): - nan_values = None - inf_values = None - - @value - def exponent_bias(exponent_bit_width): - return 2 ** (exponent_bit_width - 1) - - -class Fp8e4m3FNUZWeight(Fp8e4m3FNUZMixin, FloatWeightBase): - """ - FP8 signed E3M4 weight quantizer. - """ - pass - - -class Fp8e5m2FNUZWeight(Fp8e5m2FNUZMixin, FloatWeightBase): - """ - FP8 signed E5M2 weight quantizer. +class FpFNUZWeight(FpFNUZMixin, FloatWeightBase): """ - pass - - -class Fp8e4m3FNUZAct(Fp8e4m3FNUZMixin, FloatActBase): - """ - FP8 signed E4M3 activation quantizer. + FNUZ FP8 signed weight quantizer. """ pass -class Fp8e5m2FNUZAct(Fp8e5m2FNUZMixin, FloatActBase): +class FpFNUZAct(FpFNUZMixin, FloatActBase): """ - FP8 signed E5M2 activation quantizer. + FP8 signed activation quantizer. """ pass -class Fp8e4m3FNUZWeightPerTensorFloat(Fp8e4m3FNUZMixin, ScaledFloatWeightBase): +class FpFNUZWeightPerTensorFloat(FpFNUZMixin, ScaledFloatWeightBase): """ FP8 signed E3M4 weight quantizer with per-tensor absmax-based scaling. """ scaling_per_output_channel = False -class Fp8e5m2FNUZWeightPerTensorFloat(Fp8e5m2FNUZMixin, ScaledFloatWeightBase): - """ - FP8 signed E5M2 weight quantizer with per-tensor absmax-based scaling. - """ - scaling_per_output_channel = False - - -class Fp8e4m3FNUZActPerTensorFloat(Fp8e4m3FNUZMixin, ScaledFloatActBase): - """ - FP8 signed E4M3 activation quantizer with per-tensor static percentile-based scaling. - """ - scaling_per_output_channel = False - - -class Fp8e5m2FNUZActPerTensorFloat(Fp8e5m2FNUZMixin, ScaledFloatActBase): +class FpFNUZActPerTensorFloat(FpFNUZMixin, ScaledFloatActBase): """ - FP8 signed E5M2 activation quantizer with per-tensor static percentile-based scaling. + FP8 signed activation quantizer with per-tensor static percentile-based scaling. """ scaling_per_output_channel = False -class Fp8e4m3FNUZWeightPerChannelFloat(Fp8e4m3FNUZMixin, ScaledFloatWeightBase): +class FpFNUZWeightPerChannelFloat(FpFNUZMixin, ScaledFloatWeightBase): """ FP8 signed E3M4 weight quantizer with per-channel absmax-based scaling. """ scaling_per_output_channel = True -class Fp8e5m2FNUZWeightPerChannelFloat(Fp8e5m2FNUZMixin, ScaledFloatWeightBase): - """ - FP8 signed E5M2 weight quantizer with per-channel absmax-based scaling. - """ - scaling_per_output_channel = True - - -class Fp8e4m3FNUZActPerChannelFloat2d(Fp8e4m3FNUZMixin, ScaledFloatActBase): - """ - FP8 signed E4M3 activation quantizer with per-channel static percentile-based scaling. - """ - scaling_per_output_channel = True - scaling_stats_permute_dims = (1, 0, 2, 3) - - -class Fp8e5m2FNUZActPerChannelFloat2d(Fp8e5m2FNUZMixin, ScaledFloatActBase): +class FpFNUZActPerChannelFloat2d(FpFNUZMixin, ScaledFloatActBase): """ - FP8 signed E5M2 activation quantizer with per-channel static percentile-based scaling. + FP8 signed activation quantizer with per-channel static percentile-based scaling. """ scaling_per_output_channel = True scaling_stats_permute_dims = (1, 0, 2, 3) -class Fp8e4m3FNUZActPerTensorFloatMSE(Fp8e4m3FNUZMixin, MSESymmetricScale, ScaledFloatActBase): - """ - FP8 signed E4M3 activation quantizer with per-tensor static MSE-based scaling. - """ - scaling_per_output_channel = False - - -class Fp8e5m2FNUZActPerTensorFloatMSE(Fp8e5m2FNUZMixin, MSESymmetricScale, ScaledFloatActBase): +class FpFNUZActPerTensorFloatMSE(FpFNUZMixin, MSESymmetricScale, ScaledFloatActBase): """ - FP8 signed E5M2 activation quantizer with per-tensor static MSE-based scaling. + FP8 signed activation quantizer with per-tensor static MSE-based scaling. """ scaling_per_output_channel = False -class Fp8e4m3FNUZActPerChannelFloat2dMSE(Fp8e4m3FNUZMixin, MSESymmetricScale, ScaledFloatActBase): - """ - FP8 signed E4M3 activation quantizer with per-channel static MSE-based scaling. - """ - scaling_per_output_channel = True - scaling_stats_permute_dims = (1, 0, 2, 3) - - -class Fp8e5m2FNUZActPerChannelFloat2dMSE(Fp8e5m2FNUZMixin, MSESymmetricScale, ScaledFloatActBase): +class FpFNUZActPerChannelFloat2dMSE(FpFNUZMixin, MSESymmetricScale, ScaledFloatActBase): """ - FP8 signed E5M2 activation quantizer with per-channel static MSE-based scaling. + FP8 signed activation quantizer with per-channel static MSE-based scaling. """ scaling_per_output_channel = True scaling_stats_permute_dims = (1, 0, 2, 3) -class Fp8e4m3FNUZWeightPerChannelFloatMSE(Fp8e4m3FNUZMixin, - MSESymmetricScale, - ScaledFloatWeightBase): +class FpFNUZWeightPerChannelFloatMSE(FpFNUZMixin, MSESymmetricScale, ScaledFloatWeightBase): """ FP8 signed E3M4 weight quantizer with per-channel MSE-based scaling. """ scaling_per_output_channel = True -class Fp8e4m3FNUZWeightPerTensorFloatMSE(Fp8e4m3FNUZMixin, MSESymmetricScale, - ScaledFloatWeightBase): +class FpFNUZWeightPerTensorFloatMSE(FpFNUZMixin, MSESymmetricScale, ScaledFloatWeightBase): """ FP8 signed E3M4 weight quantizer with per-tensor MSE-based scaling. """ diff --git a/src/brevitas/quant/experimental/float_quant_ocp.py b/src/brevitas/quant/experimental/float_quant_ocp.py index f2b148482..f22e08d67 100644 --- a/src/brevitas/quant/experimental/float_quant_ocp.py +++ b/src/brevitas/quant/experimental/float_quant_ocp.py @@ -3,9 +3,13 @@ from dependencies import value +from brevitas.inject import ExtendedInjector from brevitas.quant.base import MSESymmetricScale from brevitas.quant.experimental.float_base import FloatActBase from brevitas.quant.experimental.float_base import FloatWeightBase +from brevitas.quant.experimental.float_base import Fp4e2m1Mixin +from brevitas.quant.experimental.float_base import Fp6e2m3Mixin +from brevitas.quant.experimental.float_base import Fp6e3m2Mixin from brevitas.quant.experimental.float_base import Fp8e4m3Mixin from brevitas.quant.experimental.float_base import Fp8e5m2Mixin from brevitas.quant.experimental.float_base import ScaledFloatActBase @@ -13,26 +17,28 @@ from brevitas.utils.float_quant_utils import get_max_available_float -class Fp8e4m3OCPMixin(Fp8e4m3Mixin): - nan_values = (('111',)) - inf_values = None +class FpOCPMixin(ExtendedInjector): + saturating = True @value - def max_available_float( - exponent_bit_width, mantissa_bit_width, exponent_bias, nan_values, inf_values, - saturating): - return get_max_available_float( - exponent_bit_width, - mantissa_bit_width, - exponent_bias, - nan_values, - inf_values, - saturating) - + def inf_values(bit_width, mantissa_bit_width, exponent_bit_width): + if bit_width == 8: + if mantissa_bit_width == 3 and exponent_bit_width == 4: + return None + if mantissa_bit_width == 2 and exponent_bit_width == 5: + return (('00',)) + else: + return None -class Fp8e5m2OCPMixin(Fp8e5m2Mixin): - nan_values = ('01', '11', '10') - inf_values = (('00',)) + @value + def nan_values(bit_width, mantissa_bit_width, exponent_bit_width): + if bit_width == 8: + if mantissa_bit_width == 3 and exponent_bit_width == 4: + return (('111',)) + if mantissa_bit_width == 2 and exponent_bit_width == 5: + return ('01', '11', '10') + else: + return None @value def max_available_float( @@ -47,130 +53,72 @@ def max_available_float( saturating) -class Fp8e4m3OCPWeight(Fp8e4m3OCPMixin, FloatWeightBase): +class FpOCPWeight(FpOCPMixin, FloatWeightBase): """ - FP8 signed E3M4 weight quantizer. + OCP FP8 signed weight quantizer. """ pass -class Fp8e5m2OCPWeight(Fp8e5m2OCPMixin, FloatWeightBase): +class FpOCPAct(FpOCPMixin, FloatActBase): """ - FP8 signed E5M2 weight quantizer. + FP8 signed activation quantizer. """ pass -class Fp8e4m3OCPAct(Fp8e4m3OCPMixin, FloatActBase): - """ - FP8 signed E4M3 activation quantizer. - """ - pass - - -class Fp8e5m2OCPAct(Fp8e5m2OCPMixin, FloatActBase): - """ - FP8 signed E5M2 activation quantizer. - """ - pass - - -class Fp8e4m3OCPWeightPerTensorFloat(Fp8e4m3OCPMixin, ScaledFloatWeightBase): +class FpOCPWeightPerTensorFloat(FpOCPMixin, ScaledFloatWeightBase): """ FP8 signed E3M4 weight quantizer with per-tensor absmax-based scaling. """ scaling_per_output_channel = False -class Fp8e5m2OCPWeightPerTensorFloat(Fp8e5m2OCPMixin, ScaledFloatWeightBase): - """ - FP8 signed E5M2 weight quantizer with per-tensor absmax-based scaling. - """ - scaling_per_output_channel = False - - -class Fp8e4m3OCPActPerTensorFloat(Fp8e4m3OCPMixin, ScaledFloatActBase): +class FpOCPActPerTensorFloat(FpOCPMixin, ScaledFloatActBase): """ - FP8 signed E4M3 activation quantizer with per-tensor static percentile-based scaling. + FP8 signed activation quantizer with per-tensor static percentile-based scaling. """ scaling_per_output_channel = False -class Fp8e5m2OCPActPerTensorFloat(Fp8e5m2OCPMixin, ScaledFloatActBase): - """ - FP8 signed E5M2 activation quantizer with per-tensor static percentile-based scaling. - """ - scaling_per_output_channel = False - - -class Fp8e4m3OCPWeightPerChannelFloat(Fp8e4m3OCPMixin, ScaledFloatWeightBase): +class FpOCPWeightPerChannelFloat(FpOCPMixin, ScaledFloatWeightBase): """ FP8 signed E3M4 weight quantizer with per-channel absmax-based scaling. """ scaling_per_output_channel = True -class Fp8e5m2OCPWeightPerChannelFloat(Fp8e5m2OCPMixin, ScaledFloatWeightBase): - """ - FP8 signed E5M2 weight quantizer with per-channel absmax-based scaling. - """ - scaling_per_output_channel = True - - -class Fp8e4m3OCPActPerChannelFloat2d(Fp8e4m3OCPMixin, ScaledFloatActBase): +class FpOCPActPerChannelFloat2d(FpOCPMixin, ScaledFloatActBase): """ - FP8 signed E4M3 activation quantizer with per-channel static percentile-based scaling. + FP8 signed activation quantizer with per-channel static percentile-based scaling. """ scaling_per_output_channel = True scaling_stats_permute_dims = (1, 0, 2, 3) -class Fp8e5m2OCPActPerChannelFloat2d(Fp8e5m2OCPMixin, ScaledFloatActBase): +class FpOCPActPerTensorFloatMSE(FpOCPMixin, MSESymmetricScale, ScaledFloatActBase): """ - FP8 signed E5M2 activation quantizer with per-channel static percentile-based scaling. - """ - scaling_per_output_channel = True - scaling_stats_permute_dims = (1, 0, 2, 3) - - -class Fp8e4m3OCPActPerTensorFloatMSE(Fp8e4m3OCPMixin, MSESymmetricScale, ScaledFloatActBase): - """ - FP8 signed E4M3 activation quantizer with per-tensor static MSE-based scaling. - """ - scaling_per_output_channel = False - - -class Fp8e5m2OCPActPerTensorFloatMSE(Fp8e5m2OCPMixin, MSESymmetricScale, ScaledFloatActBase): - """ - FP8 signed E5M2 activation quantizer with per-tensor static MSE-based scaling. + FP8 signed activation quantizer with per-tensor static MSE-based scaling. """ scaling_per_output_channel = False -class Fp8e4m3OCPActPerChannelFloat2dMSE(Fp8e4m3OCPMixin, MSESymmetricScale, ScaledFloatActBase): - """ - FP8 signed E4M3 activation quantizer with per-channel static MSE-based scaling. - """ - scaling_per_output_channel = True - scaling_stats_permute_dims = (1, 0, 2, 3) - - -class Fp8e5m2OCPActPerChannelFloat2dMSE(Fp8e5m2OCPMixin, MSESymmetricScale, ScaledFloatActBase): +class FpOCPActPerChannelFloat2dMSE(FpOCPMixin, MSESymmetricScale, ScaledFloatActBase): """ - FP8 signed E5M2 activation quantizer with per-channel static MSE-based scaling. + FP8 signed activation quantizer with per-channel static MSE-based scaling. """ scaling_per_output_channel = True scaling_stats_permute_dims = (1, 0, 2, 3) -class Fp8e4m3OCPWeightPerChannelFloatMSE(Fp8e4m3OCPMixin, MSESymmetricScale, ScaledFloatWeightBase): +class FpOCPWeightPerChannelFloatMSE(FpOCPMixin, MSESymmetricScale, ScaledFloatWeightBase): """ FP8 signed E3M4 weight quantizer with per-channel MSE-based scaling. """ scaling_per_output_channel = True -class Fp8e4m3OCPWeightPerTensorFloatMSE(Fp8e4m3OCPMixin, MSESymmetricScale, ScaledFloatWeightBase): +class FpOCPWeightPerTensorFloatMSE(FpOCPMixin, MSESymmetricScale, ScaledFloatWeightBase): """ FP8 signed E3M4 weight quantizer with per-tensor MSE-based scaling. """ diff --git a/src/brevitas/quant/experimental/mx_quant_ocp.py b/src/brevitas/quant/experimental/mx_quant_ocp.py index 16286ecb5..9fbdcf0de 100644 --- a/src/brevitas/quant/experimental/mx_quant_ocp.py +++ b/src/brevitas/quant/experimental/mx_quant_ocp.py @@ -1,59 +1,73 @@ # Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. # SPDX-License-Identifier: BSD-3-Clause +from dependencies import value + from brevitas.core.function_wrapper.ops_ste import CeilSte +from brevitas.core.scaling.runtime import RuntimeDynamicGroupStatsScaling from brevitas.inject import ExtendedInjector from brevitas.inject.enum import RestrictValueType +from brevitas.proxy.groupwise_float_parameter_quant import \ + GroupwiseWeightFloatQuantProxyFromInjector +from brevitas.proxy.groupwise_float_runtime_quant import GroupwiseActFloatQuantProxyFromInjector +from brevitas.quant.base import MSESymmetricScale +from brevitas.quant.experimental.float_base import ScaledFloatActBase from brevitas.quant.experimental.float_base import ScaledFloatWeightBase -from brevitas.quant.experimental.float_quant_ocp import Fp8e4m3OCPWeightPerTensorFloat -from brevitas.quant.experimental.float_quant_ocp import Fp8e5m2OCPWeightPerTensorFloat - - -class Fp6e3m2OCPMixin(ExtendedInjector): - bit_width = 6 - exponent_bit_width = 3 - mantissa_bit_width = 2 - nan_values = None - inf_values = None - - -class Fp6e2m3OCPMixin(ExtendedInjector): - bit_width = 6 - exponent_bit_width = 2 - mantissa_bit_width = 3 - nan_values = None - inf_values = None +from brevitas.quant.experimental.float_quant_ocp import FpOCPAct +from brevitas.quant.experimental.float_quant_ocp import FpOCPActPerTensorFloat +from brevitas.quant.experimental.float_quant_ocp import FpOCPWeight +from brevitas.quant.experimental.float_quant_ocp import FpOCPWeightPerChannelFloat +from brevitas.quant.experimental.float_quant_ocp import FpOCPWeightPerChannelFloatMSE +from brevitas.quant.experimental.float_quant_ocp import FpOCPWeightPerTensorFloat +from brevitas.quant.experimental.float_quant_ocp import FpOCPWeightPerTensorFloatMSE -class Fp4e2m1OCPMixin(ExtendedInjector): - bit_width = 4 - exponent_bit_width = 2 - mantissa_bit_width = 1 - nan_values = None - inf_values = None +class MXFloatWeightMixin(ExtendedInjector): + proxy_class = GroupwiseWeightFloatQuantProxyFromInjector + group_size = 32 + restrict_scaling_type = RestrictValueType.POWER_OF_TWO + restrict_value_float_to_int_impl = CeilSte -class MXWeightMixIn(ExtendedInjector): +class MXFloatActMixin(ExtendedInjector): + proxy_class = GroupwiseActFloatQuantProxyFromInjector group_size = 32 restrict_scaling_type = RestrictValueType.POWER_OF_TWO restrict_value_float_to_int_impl = CeilSte + scaling_impl = RuntimeDynamicGroupStatsScaling - -class MXFp8e4m3OCPWeightPerTensorFloat(Fp8e4m3OCPWeightPerTensorFloat, MXWeightMixIn): - pass + @value + def stats_reduce_dim(group_dim): + # If group_dim = -1, we need a workaround to avoid selecting wrong dim + if group_dim == -1: + return -1 + else: + return group_dim + 1 -class MXFp8e5m2OCPWeightPerTensorFloat(Fp8e5m2OCPWeightPerTensorFloat, MXWeightMixIn): +class MXFloatWeight(MXFloatWeightMixin, FpOCPWeight, ScaledFloatWeightBase): + """ + MX Float signed weight quantizer. + """ pass -class MXFp6e3m2OCPWeightPerTensorFloat(Fp6e3m2OCPMixin, ScaledFloatWeightBase, MXWeightMixIn): +class MXFloatAct(MXFloatActMixin, FpOCPAct, ScaledFloatActBase): + """ + MX Float signed activation quantizer. + """ pass -class MXFp6e2m3OCPWeightPerTensorFloat(Fp6e2m3OCPMixin, ScaledFloatWeightBase, MXWeightMixIn): +class MXFloatActMSE(MXFloatAct, MSESymmetricScale): + """ + MX Float signed activation quantizer with per-tensor static percentile-based scaling. + """ pass -class MXFp4e2m1OCPWeightPerTensorFloat(Fp4e2m1OCPMixin, ScaledFloatWeightBase, MXWeightMixIn): +class MXFloatWeightFloatMSE(MXFloatWeight, MSESymmetricScale): + """ + MX Float signed weight quantizer with per-channel MSE-based scaling. + """ pass diff --git a/src/brevitas/quant/solver/common.py b/src/brevitas/quant/solver/common.py index 67c18f68d..0c1cb7554 100644 --- a/src/brevitas/quant/solver/common.py +++ b/src/brevitas/quant/solver/common.py @@ -171,8 +171,11 @@ def int_scaling_impl(restrict_scaling_type): class SolveStatsReduceDimFromEnum(ExtendedInjector): @value - def stats_reduce_dim(scaling_stats_op, scaling_per_output_channel, group_dim=None): - + def stats_reduce_dim(scaling_stats_op, scaling_per_output_channel=None, group_dim=None): + if group_dim is None: + assert scaling_per_output_channel is not None, 'scaling_per_output_channel parameter required' + if scaling_per_output_channel is None: + assert group_dim is not None, 'group_dim required' if group_dim is not None: return SCALING_STATS_REDUCE_DIM + 1 elif scaling_stats_op == StatsOp.MAX_AVE or scaling_per_output_channel: @@ -192,7 +195,13 @@ class SolveScalingStatsInputViewShapeImplFromEnum(ExtendedInjector): @value def scaling_stats_input_view_shape_impl( - scaling_per_output_channel, scaling_stats_op, group_dim=None): + scaling_stats_op, scaling_per_output_channel=None, group_dim=None): + + if group_dim is None: + assert scaling_per_output_channel is not None, 'scaling_per_output_channel parameter required' + if scaling_per_output_channel is None: + assert group_dim is not None, 'group_dim required' + if group_dim is not None: return StatsInputViewShapeImpl.OVER_SUBCHANNEL_BLOCK elif scaling_per_output_channel or scaling_stats_op == StatsOp.MAX_AVE: diff --git a/src/brevitas/quant/solver/parameter.py b/src/brevitas/quant/solver/parameter.py index 29f321809..ad748232d 100644 --- a/src/brevitas/quant/solver/parameter.py +++ b/src/brevitas/quant/solver/parameter.py @@ -108,8 +108,9 @@ def scaling_impl(scaling_impl_type): class SolveParameterScalingShape(ExtendedInjector): @value - def scaling_shape(module, scaling_per_output_channel, group_size=None): + def scaling_shape(module, scaling_per_output_channel=None, group_size=None): if group_size is None: + assert scaling_per_output_channel is not None # this pattern of returning this.something allows to resolve scaling_output_channel_shape # only when scaling_per_output_channel is True if scaling_per_output_channel: diff --git a/src/brevitas/quant/solver/weight.py b/src/brevitas/quant/solver/weight.py index 57f7dd8b4..516c7526a 100644 --- a/src/brevitas/quant/solver/weight.py +++ b/src/brevitas/quant/solver/weight.py @@ -62,11 +62,12 @@ class SolveWeightScalingStatsInputDimsFromModule(ExtendedInjector): # such that output channels are dim 0 and the remaining features are dim 1, # along which we concatenate @value - def scaling_stats_input_concat_dim(scaling_per_output_channel): - if scaling_per_output_channel: - return 1 - else: - return 0 + def scaling_stats_input_concat_dim(scaling_per_output_channel=None): + if scaling_per_output_channel is not None: + if scaling_per_output_channel: + return 1 + else: + return 0 @value def permute_dims(module, output_channel_dim): diff --git a/src/brevitas_examples/common/generative/quant_blocks.py b/src/brevitas_examples/common/generative/quant_blocks.py index 881697901..18149578d 100644 --- a/src/brevitas_examples/common/generative/quant_blocks.py +++ b/src/brevitas_examples/common/generative/quant_blocks.py @@ -62,30 +62,3 @@ def forward(self, x, scale, bit_width) -> Tensor: x = abs_binary_sign_grad(x) x = self.scale_shift_zero_point(x, scale, bit_width) return x - - -class RuntimeDynamicGroupStatsScaling(brevitas.jit.ScriptModule): - - def __init__(self, group_size: int, group_dim: int, scaling_stats_impl: nn.Module) -> None: - super(RuntimeDynamicGroupStatsScaling, self).__init__() - self.group_size = group_size - self.group_dim = group_dim - self.scaling_stats_impl = scaling_stats_impl - - @brevitas.jit.script_method - def group_scaling_reshape(self, stats_input): - tensor_shape = stats_input.shape - tensor_shape_list = list(tensor_shape) - tensor_shape_list[self.group_dim] = int(tensor_shape_list[self.group_dim] / self.group_size) - block_dim = self.group_dim + 1 if self.group_dim != -1 else -1 - tensor_shape_list.insert(block_dim, self.group_size) - stats_input = stats_input.view(tensor_shape_list) - return stats_input - - @brevitas.jit.script_method - def forward(self, stats_input) -> Tensor: - stats_input_reshaped = self.group_scaling_reshape(stats_input) - out = self.scaling_stats_impl(stats_input_reshaped) - # Scaling min val - out = torch.clamp_min(out, min=torch.tensor(1e-6, device=out.device, dtype=out.dtype)) - return out diff --git a/src/brevitas_examples/common/generative/quantizers.py b/src/brevitas_examples/common/generative/quantizers.py index 8a8d0e02f..60efa1136 100644 --- a/src/brevitas_examples/common/generative/quantizers.py +++ b/src/brevitas_examples/common/generative/quantizers.py @@ -3,20 +3,20 @@ # SPDX-License-Identifier: BSD-3-Clause """ -from brevitas.proxy.groupwise_float_runtime_quant import GroupwiseActFloatQuantProxyFromInjector from torch import nn from brevitas.core.function_wrapper.shape import OverOutputFeaturesView from brevitas.core.function_wrapper.shape import OverTensorView +from brevitas.core.scaling.runtime import RuntimeDynamicGroupStatsScaling from brevitas.core.stats import AbsMinMax from brevitas.core.stats import NegativeMinOrZero from brevitas.core.stats.stats_wrapper import SCALAR_SHAPE -from brevitas.core.zero_point import ParameterFromStatsFromParameterZeroPoint from brevitas.inject import ExtendedInjector from brevitas.inject import this from brevitas.inject import value from brevitas.proxy.groupwise_float_parameter_quant import \ GroupwiseWeightFloatQuantProxyFromInjector +from brevitas.proxy.groupwise_float_runtime_quant import GroupwiseActFloatQuantProxyFromInjector from brevitas.proxy.groupwise_int_parameter_quant import GroupwiseWeightQuantProxyFromInjector from brevitas.proxy.runtime_quant import DynamicActQuantProxyFromInjector from brevitas.quant.experimental.float import Fp8e4m3WeightPerChannelFloat @@ -89,9 +89,7 @@ class Int8DynamicActPerGroupFloat(DynamicActProxyMixin, Int8ActPerTensorFloat): """ proxy_class = GroupwiseActFloatQuantProxyFromInjector scaling_impl = RuntimeDynamicGroupStatsScaling - keepdim = True scaling_stats_op = 'min_max' - scaling_per_output_channel = True @value def stats_reduce_dim(group_dim):