Skip to content

Commit

Permalink
Rename
Browse files Browse the repository at this point in the history
  • Loading branch information
Giuseppe5 committed Aug 1, 2024
1 parent 09373b4 commit b876ff1
Show file tree
Hide file tree
Showing 10 changed files with 72 additions and 72 deletions.
2 changes: 1 addition & 1 deletion src/brevitas/export/onnx/standard/qcdq/handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,7 @@ def validate(self, module):
# Below 8b quantization is not supported.
self.validate_8b_bit_width(module.bit_width(), le_then=False)
# Only per tensor quantization is supported
assert module.quant_injector.scaling_per_output_type == ScalingPerOutputType.TENSOR, "Only per tensor scaling supported"
assert module.quant_injector.scaling_per_output == ScalingPerOutputType.TENSOR, "Only per tensor scaling supported"

def quantize_fn(self, x, dtype):
return DynamicQuantizeLinearFn.apply(x, dtype)
Expand Down
30 changes: 15 additions & 15 deletions src/brevitas/quant/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,31 +229,31 @@ class ShiftedParamFromPercentileUintQuant(ExtendedInjector):
class PerChannelFloatScaling8bit(ExtendedInjector):
"""
"""
scaling_output_type = ScalingPerOutputType.CHANNEL
scaling_per_output_type = ScalingPerOutputType.CHANNEL
restrict_scaling_type = RestrictValueType.FP
bit_width = 8


class PerTensorFloatScaling8bit(ExtendedInjector):
"""
"""
scaling_output_type = ScalingPerOutputType.TENSOR
scaling_per_output_type = ScalingPerOutputType.TENSOR
restrict_scaling_type = RestrictValueType.FP
bit_width = 8


class PerChannelPoTScaling8bit(ExtendedInjector):
"""
"""
scaling_output_type = ScalingPerOutputType.CHANNEL
scaling_per_output_type = ScalingPerOutputType.CHANNEL
restrict_scaling_type = RestrictValueType.FP
bit_width = 8


class PerTensorPoTScaling8bit(ExtendedInjector):
"""
"""
scaling_output_type = ScalingPerOutputType.TENSOR
scaling_per_output_type = ScalingPerOutputType.TENSOR
restrict_scaling_type = RestrictValueType.POWER_OF_TWO
bit_width = 8
restrict_value_float_to_int_impl = CeilSte
Expand All @@ -264,15 +264,15 @@ class SignedBinaryClampedConst(ExtendedInjector):
scaling_impl_type = ScalingImplType.CONST
restrict_scaling_type = RestrictValueType.FP
float_to_int_impl_type = FloatToIntImplType.ROUND
scaling_output_type = ScalingPerOutputType.TENSOR
scaling_per_output_type = ScalingPerOutputType.TENSOR
narrow_range = True
signed = True


class PerTensorConstScaling2bit(ExtendedInjector):
scaling_impl_type = ScalingImplType.CONST
restrict_scaling_type = RestrictValueType.FP
scaling_output_type = ScalingPerOutputType.TENSOR
scaling_per_output_type = ScalingPerOutputType.TENSOR
bit_width = 2


Expand Down Expand Up @@ -327,7 +327,7 @@ class WeightPerChannelFloatDecoupled(SolveStatsReduceDimFromEnum,
signed = True
scaling_stats_input_view_shape_impl = OverOutputChannelView
stats_reduce_dim = SCALING_STATS_REDUCE_DIM
scaling_output_type = ScalingPerOutputType.CHANNEL
scaling_per_output_type = ScalingPerOutputType.CHANNEL


class WeightNormPerChannelFloatDecoupled(SolveStatsReduceDimFromEnum,
Expand Down Expand Up @@ -364,7 +364,7 @@ def scaling_init(scaling_init_impl, bit_width):
pre_scaling_impl = ParameterPreScalingWeightNorm
restrict_pre_scaling_impl = LogFloatRestrictValue
normalize_stats_impl = L2Norm
scaling_output_type = ScalingPerOutputType.CHANNEL
scaling_per_output_type = ScalingPerOutputType.CHANNEL
pre_scaling_shape = this.scaling_shape # TODO: decouple pre_scaling_shape from scaling_shape
int_scaling_impl = SingleArgStatelessBuffer(1.)
zero_point_impl = ZeroZeroPoint
Expand Down Expand Up @@ -424,19 +424,19 @@ class AccumulatorAwareZeroCenterWeightQuant(AccumulatorAwareWeightQuant):
class MSESubInjectorBase(ExtendedInjector):

@value
def inner_stats_input_view_shape_impl(scaling_per_output_type):
if scaling_per_output_type == ScalingPerOutputType.CHANNEL:
def inner_stats_input_view_shape_impl(scaling_per_output):
if scaling_per_output == ScalingPerOutputType.CHANNEL:
return StatsInputViewShapeImpl.OVER_OUTPUT_CHANNELS
elif scaling_per_output_type == ScalingPerOutputType.TENSOR:
elif scaling_per_output == ScalingPerOutputType.TENSOR:
return StatsInputViewShapeImpl.OVER_TENSOR
elif scaling_per_output_type == ScalingPerOutputType.GROUP:
elif scaling_per_output == ScalingPerOutputType.GROUP:
raise RuntimeError("Not implemented yet")

permute_dims = (this << 1).permute_dims


class MSESymmetricScaleSubInjector(MSESubInjectorBase):
scaling_per_output_type = (this << 1).scaling_per_output_type
scaling_per_output = (this << 1).scaling_per_output
proxy_module = (this << 1).proxy_module
mse_init_op = AbsMax
stats_impl = MSE
Expand All @@ -446,7 +446,7 @@ class MSESymmetricScaleSubInjector(MSESubInjectorBase):


class MSEAsymmetricScaleSubInjector(MSESubInjectorBase):
scaling_per_output_type = (this << 1).scaling_per_output_type
scaling_per_output = (this << 1).scaling_per_output
proxy_module = (this << 1).proxy_module
mse_init_op = AbsMinMax
stats_impl = MSE
Expand All @@ -457,7 +457,7 @@ class MSEAsymmetricScaleSubInjector(MSESubInjectorBase):

class MSEZeroPointSubInjector(MSESubInjectorBase):
# zp is per channel when scaling is per channel
scaling_per_output_type = (this << 1).scaling_per_output_type
scaling_per_output = (this << 1).scaling_per_output
proxy_module = (this << 1).proxy_module
mse_init_op = NegativeMinOrZero
mse_search_method = 'grid'
Expand Down
28 changes: 14 additions & 14 deletions src/brevitas/quant/experimental/float.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,99 +43,99 @@ class Fp8e4m3WeightPerTensorFloat(Fp8e4m3Mixin, ScaledFloatWeightBase):
"""
FP8 signed E3M4 weight quantizer with per-tensor absmax-based scaling.
"""
scaling_output_type = ScalingPerOutputType.TENSOR
scaling_per_output_type = ScalingPerOutputType.TENSOR


class Fp8e5m2WeightPerTensorFloat(Fp8e5m2Mixin, ScaledFloatWeightBase):
"""
FP8 signed E5M2 weight quantizer with per-tensor absmax-based scaling.
"""
scaling_output_type = ScalingPerOutputType.TENSOR
scaling_per_output_type = ScalingPerOutputType.TENSOR


class Fp8e4m3ActPerTensorFloat(Fp8e4m3Mixin, ScaledFloatActBase):
"""
FP8 signed E4M3 activation quantizer with per-tensor static percentile-based scaling.
"""
scaling_output_type = ScalingPerOutputType.TENSOR
scaling_per_output_type = ScalingPerOutputType.TENSOR


class Fp8e5m2ActPerTensorFloat(Fp8e5m2Mixin, ScaledFloatActBase):
"""
FP8 signed E5M2 activation quantizer with per-tensor static percentile-based scaling.
"""
scaling_output_type = ScalingPerOutputType.TENSOR
scaling_per_output_type = ScalingPerOutputType.TENSOR


class Fp8e4m3WeightPerChannelFloat(Fp8e4m3Mixin, ScaledFloatWeightBase):
"""
FP8 signed E3M4 weight quantizer with per-channel absmax-based scaling.
"""
scaling_output_type = ScalingPerOutputType.CHANNEL
scaling_per_output_type = ScalingPerOutputType.CHANNEL


class Fp8e5m2WeightPerChannelFloat(Fp8e5m2Mixin, ScaledFloatWeightBase):
"""
FP8 signed E5M2 weight quantizer with per-channel absmax-based scaling.
"""
scaling_output_type = ScalingPerOutputType.CHANNEL
scaling_per_output_type = ScalingPerOutputType.CHANNEL


class Fp8e4m3ActPerChannelFloat2d(Fp8e4m3Mixin, ScaledFloatActBase):
"""
FP8 signed E4M3 activation quantizer with per-channel static percentile-based scaling.
"""
scaling_output_type = ScalingPerOutputType.CHANNEL
scaling_per_output_type = ScalingPerOutputType.CHANNEL
scaling_stats_permute_dims = (1, 0, 2, 3)


class Fp8e5m2ActPerChannelFloat2d(Fp8e5m2Mixin, ScaledFloatActBase):
"""
FP8 signed E5M2 activation quantizer with per-channel static percentile-based scaling.
"""
scaling_output_type = ScalingPerOutputType.CHANNEL
scaling_per_output_type = ScalingPerOutputType.CHANNEL
scaling_stats_permute_dims = (1, 0, 2, 3)


class Fp8e4m3ActPerTensorFloatMSE(Fp8e4m3Mixin, MSESymmetricScale, ScaledFloatActBase):
"""
FP8 signed E4M3 activation quantizer with per-tensor static MSE-based scaling.
"""
scaling_output_type = ScalingPerOutputType.TENSOR
scaling_per_output_type = ScalingPerOutputType.TENSOR


class Fp8e5m2ActPerTensorFloatMSE(Fp8e5m2Mixin, MSESymmetricScale, ScaledFloatActBase):
"""
FP8 signed E5M2 activation quantizer with per-tensor static MSE-based scaling.
"""
scaling_output_type = ScalingPerOutputType.TENSOR
scaling_per_output_type = ScalingPerOutputType.TENSOR


class Fp8e4m3ActPerChannelFloat2dMSE(Fp8e4m3Mixin, MSESymmetricScale, ScaledFloatActBase):
"""
FP8 signed E4M3 activation quantizer with per-channel static MSE-based scaling.
"""
scaling_output_type = ScalingPerOutputType.CHANNEL
scaling_per_output_type = ScalingPerOutputType.CHANNEL
scaling_stats_permute_dims = (1, 0, 2, 3)


class Fp8e5m2ActPerChannelFloat2dMSE(Fp8e5m2Mixin, MSESymmetricScale, ScaledFloatActBase):
"""
FP8 signed E5M2 activation quantizer with per-channel static MSE-based scaling.
"""
scaling_output_type = ScalingPerOutputType.CHANNEL
scaling_per_output_type = ScalingPerOutputType.CHANNEL
scaling_stats_permute_dims = (1, 0, 2, 3)


class Fp8e4m3WeightPerChannelFloatMSE(Fp8e4m3Mixin, MSESymmetricScale, ScaledFloatWeightBase):
"""
FP8 signed E3M4 weight quantizer with per-channel MSE-based scaling.
"""
scaling_output_type = ScalingPerOutputType.CHANNEL
scaling_per_output_type = ScalingPerOutputType.CHANNEL


class Fp8e4m3WeightPerTensorFloatMSE(Fp8e4m3Mixin, MSESymmetricScale, ScaledFloatWeightBase):
"""
FP8 signed E3M4 weight quantizer with per-tensor MSE-based scaling.
"""
scaling_output_type = ScalingPerOutputType.TENSOR
scaling_per_output_type = ScalingPerOutputType.TENSOR
16 changes: 8 additions & 8 deletions src/brevitas/quant/experimental/float_quant_fnuz.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,58 +40,58 @@ class FpFNUZWeightPerTensorFloat(FpFNUZMixin, ScaledFloatWeightBase):
"""
FP8 signed weight quantizer with per-tensor absmax-based scaling.
"""
scaling_output_type = ScalingPerOutputType.TENSOR
scaling_per_output_type = ScalingPerOutputType.TENSOR


class FpFNUZActPerTensorFloat(FpFNUZMixin, ScaledFloatActBase):
"""
FP8 signed activation quantizer with per-tensor static percentile-based scaling.
"""
scaling_output_type = ScalingPerOutputType.TENSOR
scaling_per_output_type = ScalingPerOutputType.TENSOR


class FpFNUZWeightPerChannelFloat(FpFNUZMixin, ScaledFloatWeightBase):
"""
FP8 signed weight quantizer with per-channel absmax-based scaling.
"""
scaling_output_type = ScalingPerOutputType.CHANNEL
scaling_per_output_type = ScalingPerOutputType.CHANNEL


class FpFNUZActPerChannelFloat2d(FpFNUZMixin, ScaledFloatActBase):
"""
FP8 signed activation quantizer with per-channel static percentile-based scaling.
"""
scaling_output_type = ScalingPerOutputType.CHANNEL
scaling_per_output_type = ScalingPerOutputType.CHANNEL
scaling_stats_permute_dims = (1, 0, 2, 3)


class FpFNUZActPerTensorFloatMSE(FpFNUZMixin, MSESymmetricScale, ScaledFloatActBase):
"""
FP8 signed activation quantizer with per-tensor static MSE-based scaling.
"""
scaling_output_type = ScalingPerOutputType.TENSOR
scaling_per_output_type = ScalingPerOutputType.TENSOR


class FpFNUZActPerChannelFloat2dMSE(FpFNUZMixin, MSESymmetricScale, ScaledFloatActBase):
"""
FP8 signed activation quantizer with per-channel static MSE-based scaling.
"""
scaling_output_type = ScalingPerOutputType.CHANNEL
scaling_per_output_type = ScalingPerOutputType.CHANNEL
scaling_stats_permute_dims = (1, 0, 2, 3)


class FpFNUZWeightPerChannelFloatMSE(FpFNUZMixin, MSESymmetricScale, ScaledFloatWeightBase):
"""
FP8 signed weight quantizer with per-channel MSE-based scaling.
"""
scaling_output_type = ScalingPerOutputType.CHANNEL
scaling_per_output_type = ScalingPerOutputType.CHANNEL


class FpFNUZWeightPerTensorFloatMSE(FpFNUZMixin, MSESymmetricScale, ScaledFloatWeightBase):
"""
FP8 signed weight quantizer with per-tensor MSE-based scaling.
"""
scaling_output_type = ScalingPerOutputType.TENSOR
scaling_per_output_type = ScalingPerOutputType.TENSOR


## Predefined FP8 Quantizers
Expand Down
16 changes: 8 additions & 8 deletions src/brevitas/quant/experimental/float_quant_ocp.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,58 +69,58 @@ class FpOCPWeightPerTensorFloat(FpOCPMixin, ScaledFloatWeightBase):
"""
OCP FP signed E3M4 weight quantizer with per-tensor absmax-based scaling.
"""
scaling_output_type = ScalingPerOutputType.TENSOR
scaling_per_output_type = ScalingPerOutputType.TENSOR


class FpOCPActPerTensorFloat(FpOCPMixin, ScaledFloatActBase):
"""
OCP FP signed activation quantizer with per-tensor static percentile-based scaling.
"""
scaling_output_type = ScalingPerOutputType.TENSOR
scaling_per_output_type = ScalingPerOutputType.TENSOR


class FpOCPWeightPerChannelFloat(FpOCPMixin, ScaledFloatWeightBase):
"""
OCP FP signed E3M4 weight quantizer with per-channel absmax-based scaling.
"""
scaling_output_type = ScalingPerOutputType.CHANNEL
scaling_per_output_type = ScalingPerOutputType.CHANNEL


class FpOCPActPerChannelFloat2d(FpOCPMixin, ScaledFloatActBase):
"""
OCP FP signed activation quantizer with per-channel static percentile-based scaling.
"""
scaling_output_type = ScalingPerOutputType.CHANNEL
scaling_per_output_type = ScalingPerOutputType.CHANNEL
scaling_stats_permute_dims = (1, 0, 2, 3)


class FpOCPActPerTensorFloatMSE(FpOCPMixin, MSESymmetricScale, ScaledFloatActBase):
"""
OCP FP signed activation quantizer with per-tensor static MSE-based scaling.
"""
scaling_output_type = ScalingPerOutputType.TENSOR
scaling_per_output_type = ScalingPerOutputType.TENSOR


class FpOCPActPerChannelFloat2dMSE(FpOCPMixin, MSESymmetricScale, ScaledFloatActBase):
"""
OCP FP signed activation quantizer with per-channel static MSE-based scaling.
"""
scaling_output_type = ScalingPerOutputType.CHANNEL
scaling_per_output_type = ScalingPerOutputType.CHANNEL
scaling_stats_permute_dims = (1, 0, 2, 3)


class FpOCPWeightPerChannelFloatMSE(FpOCPMixin, MSESymmetricScale, ScaledFloatWeightBase):
"""
OCP FP signed E3M4 weight quantizer with per-channel MSE-based scaling.
"""
scaling_output_type = ScalingPerOutputType.CHANNEL
scaling_per_output_type = ScalingPerOutputType.CHANNEL


class FpOCPWeightPerTensorFloatMSE(FpOCPMixin, MSESymmetricScale, ScaledFloatWeightBase):
"""
OCP FP signed E3M4 weight quantizer with per-tensor MSE-based scaling.
"""
scaling_output_type = ScalingPerOutputType.TENSOR
scaling_per_output_type = ScalingPerOutputType.TENSOR


## Predefined FP8 Quantizers
Expand Down
4 changes: 2 additions & 2 deletions src/brevitas/quant/experimental/mx_quant_ocp.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ class MXFloatWeightMixin(ExtendedInjector):
group_size = 32
restrict_scaling_type = RestrictValueType.POWER_OF_TWO
restrict_value_float_to_int_impl = CeilSte
scaling_output_type = ScalingPerOutputType.GROUP
scaling_per_output_type = ScalingPerOutputType.GROUP


class MXFloatActMixin(ExtendedInjector):
Expand All @@ -37,7 +37,7 @@ class MXFloatActMixin(ExtendedInjector):
restrict_scaling_type = RestrictValueType.POWER_OF_TWO
restrict_value_float_to_int_impl = CeilSte
scaling_impl = RuntimeDynamicGroupStatsScaling
scaling_output_type = ScalingPerOutputType.GROUP
scaling_per_output_type = ScalingPerOutputType.GROUP

@value
def stats_reduce_dim(group_dim):
Expand Down
Loading

0 comments on commit b876ff1

Please sign in to comment.