Skip to content

Commit

Permalink
Update
Browse files Browse the repository at this point in the history
  • Loading branch information
Giuseppe5 committed Jun 19, 2024
1 parent 526617b commit 2b82908
Show file tree
Hide file tree
Showing 11 changed files with 197 additions and 255 deletions.
35 changes: 35 additions & 0 deletions src/brevitas/core/scaling/runtime.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,3 +158,38 @@ def _load_from_state_dict(
missing_keys.remove(affine_weight_key)
if config.IGNORE_MISSING_KEYS and affine_bias_key in missing_keys:
missing_keys.remove(affine_bias_key)


class RuntimeDynamicGroupStatsScaling(brevitas.jit.ScriptModule):

def __init__(
self,
group_size: int,
group_dim: int,
scaling_stats_impl: torch.nn.Module,
scaling_min_val: Optional[float],
restrict_scaling_impl: Optional[torch.nn.Module]) -> None:
super(RuntimeDynamicGroupStatsScaling, self).__init__()
self.group_size = group_size
self.group_dim = group_dim
self.scaling_stats_impl = scaling_stats_impl
self.scaling_min_val = scaling_min_val
self.restrict_clamp_scaling = _RestrictClampValue(scaling_min_val, restrict_scaling_impl)

@brevitas.jit.script_method
def group_scaling_reshape(self, stats_input):
tensor_shape = stats_input.shape
tensor_shape_list = list(tensor_shape)
tensor_shape_list[self.group_dim] = int(tensor_shape_list[self.group_dim] / self.group_size)
block_dim = self.group_dim + 1 if self.group_dim != -1 else -1
tensor_shape_list.insert(block_dim, self.group_size)
stats_input = stats_input.view(tensor_shape_list)
return stats_input

@brevitas.jit.script_method
def forward(self, stats_input) -> torch.Tensor:
stats_input_reshaped = self.group_scaling_reshape(stats_input)
out = self.scaling_stats_impl(stats_input_reshaped)
# Scaling min val
out = self.restrict_clamp_scaling(out)
return out
16 changes: 15 additions & 1 deletion src/brevitas/proxy/groupwise_float_runtime_quant.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,21 @@ def forward(
# otherwise return a simple Tensor
# We exclude the last two values (inf_values and nan_values)
if isinstance(y, tuple) and not any(map(lambda f: f is None, y[:-2])):
out = GroupwiseFloatQuantTensor(*y, signed=self.is_signed, training=self.training)
value, scale, zero_point, exponent_bit_width, mantissa_bit_width, exponent_bias, saturating, inf_values, nan_values = y
out = GroupwiseFloatQuantTensor(
value,
scale,
zero_point,
self.group_size,
self.group_dim,
exponent_bit_width,
mantissa_bit_width,
exponent_bias,
saturating,
inf_values,
nan_values,
signed=self.is_signed,
training=self.training)
elif self.is_passthrough_act: # preserve scale/zp/bit/sign even without output quant
if isinstance(y, tuple):
y = y[0]
Expand Down
21 changes: 21 additions & 0 deletions src/brevitas/quant/experimental/float_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,3 +62,24 @@ class Fp8e5m2Mixin(ExtendedInjector):
exponent_bit_width = 5
mantissa_bit_width = 2
saturating = True


class Fp6e3m2Mixin(ExtendedInjector):
bit_width = 6
exponent_bit_width = 3
mantissa_bit_width = 2
saturating = True


class Fp6e2m3Mixin(ExtendedInjector):
bit_width = 6
exponent_bit_width = 2
mantissa_bit_width = 3
saturating = True


class Fp4e2m1Mixin(ExtendedInjector):
bit_width = 4
exponent_bit_width = 2
mantissa_bit_width = 1
saturating = True
110 changes: 19 additions & 91 deletions src/brevitas/quant/experimental/float_quant_fnuz.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,160 +3,88 @@

from dependencies import value

from brevitas.inject import ExtendedInjector
from brevitas.quant.base import MSESymmetricScale
from brevitas.quant.experimental.float_base import FloatActBase
from brevitas.quant.experimental.float_base import FloatWeightBase
from brevitas.quant.experimental.float_base import Fp8e4m3Mixin
from brevitas.quant.experimental.float_base import Fp8e5m2Mixin
from brevitas.quant.experimental.float_base import ScaledFloatActBase
from brevitas.quant.experimental.float_base import ScaledFloatWeightBase


class Fp8e4m3FNUZMixin(Fp8e4m3Mixin):
nan_values = None
inf_values = None
class FpFNUZMixin(ExtendedInjector):
saturating = True

@value
def exponent_bias(exponent_bit_width):
return 2 ** (exponent_bit_width - 1)


class Fp8e5m2FNUZMixin(Fp8e5m2Mixin):
nan_values = None
inf_values = None

@value
def exponent_bias(exponent_bit_width):
return 2 ** (exponent_bit_width - 1)


class Fp8e4m3FNUZWeight(Fp8e4m3FNUZMixin, FloatWeightBase):
"""
FP8 signed E3M4 weight quantizer.
"""
pass


class Fp8e5m2FNUZWeight(Fp8e5m2FNUZMixin, FloatWeightBase):
"""
FP8 signed E5M2 weight quantizer.
class FpFNUZWeight(FpFNUZMixin, FloatWeightBase):
"""
pass


class Fp8e4m3FNUZAct(Fp8e4m3FNUZMixin, FloatActBase):
"""
FP8 signed E4M3 activation quantizer.
FNUZ FP8 signed weight quantizer.
"""
pass


class Fp8e5m2FNUZAct(Fp8e5m2FNUZMixin, FloatActBase):
class FpFNUZAct(FpFNUZMixin, FloatActBase):
"""
FP8 signed E5M2 activation quantizer.
FP8 signed activation quantizer.
"""
pass


class Fp8e4m3FNUZWeightPerTensorFloat(Fp8e4m3FNUZMixin, ScaledFloatWeightBase):
class FpFNUZWeightPerTensorFloat(FpFNUZMixin, ScaledFloatWeightBase):
"""
FP8 signed E3M4 weight quantizer with per-tensor absmax-based scaling.
"""
scaling_per_output_channel = False


class Fp8e5m2FNUZWeightPerTensorFloat(Fp8e5m2FNUZMixin, ScaledFloatWeightBase):
"""
FP8 signed E5M2 weight quantizer with per-tensor absmax-based scaling.
"""
scaling_per_output_channel = False


class Fp8e4m3FNUZActPerTensorFloat(Fp8e4m3FNUZMixin, ScaledFloatActBase):
"""
FP8 signed E4M3 activation quantizer with per-tensor static percentile-based scaling.
"""
scaling_per_output_channel = False


class Fp8e5m2FNUZActPerTensorFloat(Fp8e5m2FNUZMixin, ScaledFloatActBase):
class FpFNUZActPerTensorFloat(FpFNUZMixin, ScaledFloatActBase):
"""
FP8 signed E5M2 activation quantizer with per-tensor static percentile-based scaling.
FP8 signed activation quantizer with per-tensor static percentile-based scaling.
"""
scaling_per_output_channel = False


class Fp8e4m3FNUZWeightPerChannelFloat(Fp8e4m3FNUZMixin, ScaledFloatWeightBase):
class FpFNUZWeightPerChannelFloat(FpFNUZMixin, ScaledFloatWeightBase):
"""
FP8 signed E3M4 weight quantizer with per-channel absmax-based scaling.
"""
scaling_per_output_channel = True


class Fp8e5m2FNUZWeightPerChannelFloat(Fp8e5m2FNUZMixin, ScaledFloatWeightBase):
"""
FP8 signed E5M2 weight quantizer with per-channel absmax-based scaling.
"""
scaling_per_output_channel = True


class Fp8e4m3FNUZActPerChannelFloat2d(Fp8e4m3FNUZMixin, ScaledFloatActBase):
"""
FP8 signed E4M3 activation quantizer with per-channel static percentile-based scaling.
"""
scaling_per_output_channel = True
scaling_stats_permute_dims = (1, 0, 2, 3)


class Fp8e5m2FNUZActPerChannelFloat2d(Fp8e5m2FNUZMixin, ScaledFloatActBase):
class FpFNUZActPerChannelFloat2d(FpFNUZMixin, ScaledFloatActBase):
"""
FP8 signed E5M2 activation quantizer with per-channel static percentile-based scaling.
FP8 signed activation quantizer with per-channel static percentile-based scaling.
"""
scaling_per_output_channel = True
scaling_stats_permute_dims = (1, 0, 2, 3)


class Fp8e4m3FNUZActPerTensorFloatMSE(Fp8e4m3FNUZMixin, MSESymmetricScale, ScaledFloatActBase):
"""
FP8 signed E4M3 activation quantizer with per-tensor static MSE-based scaling.
"""
scaling_per_output_channel = False


class Fp8e5m2FNUZActPerTensorFloatMSE(Fp8e5m2FNUZMixin, MSESymmetricScale, ScaledFloatActBase):
class FpFNUZActPerTensorFloatMSE(FpFNUZMixin, MSESymmetricScale, ScaledFloatActBase):
"""
FP8 signed E5M2 activation quantizer with per-tensor static MSE-based scaling.
FP8 signed activation quantizer with per-tensor static MSE-based scaling.
"""
scaling_per_output_channel = False


class Fp8e4m3FNUZActPerChannelFloat2dMSE(Fp8e4m3FNUZMixin, MSESymmetricScale, ScaledFloatActBase):
"""
FP8 signed E4M3 activation quantizer with per-channel static MSE-based scaling.
"""
scaling_per_output_channel = True
scaling_stats_permute_dims = (1, 0, 2, 3)


class Fp8e5m2FNUZActPerChannelFloat2dMSE(Fp8e5m2FNUZMixin, MSESymmetricScale, ScaledFloatActBase):
class FpFNUZActPerChannelFloat2dMSE(FpFNUZMixin, MSESymmetricScale, ScaledFloatActBase):
"""
FP8 signed E5M2 activation quantizer with per-channel static MSE-based scaling.
FP8 signed activation quantizer with per-channel static MSE-based scaling.
"""
scaling_per_output_channel = True
scaling_stats_permute_dims = (1, 0, 2, 3)


class Fp8e4m3FNUZWeightPerChannelFloatMSE(Fp8e4m3FNUZMixin,
MSESymmetricScale,
ScaledFloatWeightBase):
class FpFNUZWeightPerChannelFloatMSE(FpFNUZMixin, MSESymmetricScale, ScaledFloatWeightBase):
"""
FP8 signed E3M4 weight quantizer with per-channel MSE-based scaling.
"""
scaling_per_output_channel = True


class Fp8e4m3FNUZWeightPerTensorFloatMSE(Fp8e4m3FNUZMixin, MSESymmetricScale,
ScaledFloatWeightBase):
class FpFNUZWeightPerTensorFloatMSE(FpFNUZMixin, MSESymmetricScale, ScaledFloatWeightBase):
"""
FP8 signed E3M4 weight quantizer with per-tensor MSE-based scaling.
"""
Expand Down
Loading

0 comments on commit 2b82908

Please sign in to comment.