Skip to content

Commit

Permalink
last fix
Browse files Browse the repository at this point in the history
  • Loading branch information
Giuseppe5 committed Aug 1, 2024
1 parent e978fa7 commit 09373b4
Show file tree
Hide file tree
Showing 2 changed files with 123 additions and 5 deletions.
3 changes: 2 additions & 1 deletion src/brevitas/quant/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -302,7 +302,8 @@ class WeightPerTensorFloatDecoupledL2Param(SolveWeightScalingStatsInputDimsFromM
signed = True


class WeightPerChannelFloatDecoupled(SolveWeightScalingStatsInputDimsFromModule,
class WeightPerChannelFloatDecoupled(SolveStatsReduceDimFromEnum,
SolveWeightScalingStatsInputDimsFromModule,
SolveWeightScalingPerOutputChannelShapeFromModule,
SolveParameterScalingShape):
"""
Expand Down
125 changes: 121 additions & 4 deletions src/brevitas/quant/experimental/float_quant_fnuz.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@
from brevitas.quant.base import MSESymmetricScale
from brevitas.quant.experimental.float_base import FloatActBase
from brevitas.quant.experimental.float_base import FloatWeightBase
from brevitas.quant.experimental.float_base import Fp8e4m3Mixin
from brevitas.quant.experimental.float_base import Fp8e5m2Mixin
from brevitas.quant.experimental.float_base import ScaledFloatActBase
from brevitas.quant.experimental.float_base import ScaledFloatWeightBase

Expand Down Expand Up @@ -36,7 +38,7 @@ class FpFNUZAct(FpFNUZMixin, FloatActBase):

class FpFNUZWeightPerTensorFloat(FpFNUZMixin, ScaledFloatWeightBase):
"""
FP8 signed E3M4 weight quantizer with per-tensor absmax-based scaling.
FP8 signed weight quantizer with per-tensor absmax-based scaling.
"""
scaling_output_type = ScalingPerOutputType.TENSOR

Expand All @@ -50,7 +52,7 @@ class FpFNUZActPerTensorFloat(FpFNUZMixin, ScaledFloatActBase):

class FpFNUZWeightPerChannelFloat(FpFNUZMixin, ScaledFloatWeightBase):
"""
FP8 signed E3M4 weight quantizer with per-channel absmax-based scaling.
FP8 signed weight quantizer with per-channel absmax-based scaling.
"""
scaling_output_type = ScalingPerOutputType.CHANNEL

Expand Down Expand Up @@ -80,13 +82,128 @@ class FpFNUZActPerChannelFloat2dMSE(FpFNUZMixin, MSESymmetricScale, ScaledFloatA

class FpFNUZWeightPerChannelFloatMSE(FpFNUZMixin, MSESymmetricScale, ScaledFloatWeightBase):
"""
FP8 signed E3M4 weight quantizer with per-channel MSE-based scaling.
FP8 signed weight quantizer with per-channel MSE-based scaling.
"""
scaling_output_type = ScalingPerOutputType.CHANNEL


class FpFNUZWeightPerTensorFloatMSE(FpFNUZMixin, MSESymmetricScale, ScaledFloatWeightBase):
"""
FP8 signed E3M4 weight quantizer with per-tensor MSE-based scaling.
FP8 signed weight quantizer with per-tensor MSE-based scaling.
"""
scaling_output_type = ScalingPerOutputType.TENSOR


## Predefined FP8 Quantizers


class Fp8e4m3FNUZWeight(FpFNUZWeight, Fp8e4m3Mixin):
"""
FNUZ FP8 E4M3 signed weight quantizer.
"""
pass


class Fp8e4m3FNUZAct(FpFNUZAct, Fp8e4m3Mixin):
"""
FNUZ FP8 E4M3 signed act quantizer.
"""
pass


class Fp8e4m3FNUZWeightPerTensorFloat(FpFNUZWeightPerTensorFloat, Fp8e4m3Mixin):
"""
FNUZ FP8 E4M3 per-tensor scaled signed weight quantizer.
"""
pass


class Fp8e4m3FNUZWeightPerChannelFloat(FpFNUZWeightPerChannelFloat, Fp8e4m3Mixin):
"""
FNUZ FP8 E4M3 per-channel scaled signed weight quantizer.
"""
pass


class Fp8e4m3FNUZActPerTensorFloat(FpFNUZActPerTensorFloat, Fp8e4m3Mixin):
"""
FNUZ FP8 E4M3 scaled signed act quantizer.
"""
pass


class Fp8e4m3FNUZActPerTensorFloatMSE(FpFNUZActPerTensorFloatMSE, Fp8e4m3Mixin):
"""
FNUZ FP8 E4M3 MSE-based scaled signed act quantizer.
"""
pass


class Fp8e4m3FNUZWeightPerTensorFloatMSE(FpFNUZWeightPerTensorFloatMSE, Fp8e4m3Mixin):
"""
FNUZ FP8 E4M3 MSE-based per-tensor scaled signed weight quantizer.
"""
pass


class Fp8e4m3FNUZWeightPerChannelFloatMSE(FpFNUZWeightPerChannelFloatMSE, Fp8e4m3Mixin):
"""
FNUZ FP8 E4M3 MSE-based per-channel scaled signed weight quantizer.
"""
pass


class Fp8e5m2FNUZWeight(FpFNUZWeight, Fp8e5m2Mixin):
"""
FNUZ FP8 e5m2 signed weight quantizer.
"""
pass


class Fp8e5m2FNUZAct(FpFNUZAct, Fp8e5m2Mixin):
"""
FNUZ FP8 e5m2 signed act quantizer.
"""
pass


class Fp8e5m2FNUZWeightPerTensorFloat(FpFNUZWeightPerTensorFloat, Fp8e5m2Mixin):
"""
FNUZ FP8 e5m2 per-tensor scaled signed weight quantizer.
"""
pass


class Fp8e5m2FNUZWeightPerChannelFloat(FpFNUZWeightPerChannelFloat, Fp8e5m2Mixin):
"""
FNUZ FP8 e5m2 per-channel scaled signed weight quantizer.
"""
pass


class Fp8e5m2FNUZActPerTensorFloat(FpFNUZActPerTensorFloat, Fp8e5m2Mixin):
"""
FNUZ FP8 e5m2 scaled signed act quantizer.
"""
pass


class Fp8e5m2FNUZActPerTensorFloatMSE(FpFNUZActPerTensorFloatMSE, Fp8e5m2Mixin):
"""
FNUZ FP8 e5m2 MSE-based scaled signed act quantizer.
"""
pass


class Fp8e5m2FNUZWeightPerTensorFloatMSE(FpFNUZWeightPerTensorFloatMSE, Fp8e5m2Mixin):
"""
FNUZ FP8 e5m2 MSE-based per-tensor scaled signed weight quantizer.
"""
pass


class Fp8e5m2FNUZWeightPerChannelFloatMSE(FpFNUZWeightPerChannelFloatMSE, Fp8e5m2Mixin):
"""
FNUZ FP8 e5m2 MSE-based per-channel scaled signed weight quantizer.
"""
pass

0 comments on commit 09373b4

Please sign in to comment.