Skip to content

Commit

Permalink
Extended tests
Browse files Browse the repository at this point in the history
  • Loading branch information
Giuseppe5 committed Nov 18, 2024
1 parent 955bb55 commit ecb1aad
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 15 deletions.
14 changes: 7 additions & 7 deletions src/brevitas/core/zero_point.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,16 +61,14 @@ def forward(self, zero_point: Tensor, scale: Tensor, bit_width: Tensor) -> Tenso


class _ScaleShiftQuantZeroPoint(brevitas.jit.ScriptModule):
__constants__ = ['quantize_zero_point']

def __init__(self, zp_int_quant: Module, quantize_zero_point: bool) -> None:
def __init__(self, zp_int_quant: Module) -> None:
super(_ScaleShiftQuantZeroPoint, self).__init__()
self.zp_int_quant = zp_int_quant
self.quantize_zero_point = quantize_zero_point

@brevitas.jit.script_method
def forward(self, zero_point: Tensor, scale: Tensor, bit_width: Tensor) -> Tensor:
quant_zp, scale, *_ = self.zp_int_quant(zero_point)
quant_zp, *_ = self.zp_int_quant(zero_point)
return quant_zp


Expand All @@ -85,18 +83,20 @@ def __init__(
zero_point_stats_impl: Module,
zero_point_shape: Tuple[int, ...],
tracked_parameter_list: List[torch.nn.Parameter],
scale_shit_zero_point_impl: Optional[Module] = None) -> None:
scale_shift_zero_point_impl: Optional[Module] = None) -> None:
super(StatsFromParameterZeroPoint, self).__init__()
self.parameter_list_stats = _ParameterListStats(
zero_point_stats_impl,
zero_point_shape,
zero_point_stats_input_view_shape_impl,
zero_point_stats_input_concat_dim,
tracked_parameter_list)
if scale_shit_zero_point_impl is None:
# This is for backward compatibility. Having int_quant/quantize_zero_point required for this
# interface but not for the else seems a bit off and might require some clean-up.
if scale_shift_zero_point_impl is None:
self.scale_shift_zero_point = _ScaleShiftZeroPoint(int_quant, quantize_zero_point)
else:
self.scale_shift_zero_point = scale_shit_zero_point_impl
self.scale_shift_zero_point = scale_shift_zero_point_impl

@brevitas.jit.script_method
def forward(self, x: Tensor, scale: Tensor, bit_width: Tensor) -> torch.Tensor:
Expand Down
20 changes: 12 additions & 8 deletions tests/brevitas/core/test_scaling_quant.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,12 @@
from brevitas.quant.scaled_int import Int8WeightPerTensorFloat
from brevitas.quant.shifted_scaled_int import ShiftedUint8WeightPerTensorFloat

ZP_BIT_WIDTH = 6
SCALE_BIT_WIDTH = 5


class QuantScalingInt(Int8WeightPerTensorFloat):
bit_width = 8
bit_width = SCALE_BIT_WIDTH
module = (this << 1).module
tracked_parameter_list = (this << 1).tracked_parameter_list
upstream_scaling = (this << 1).scaling_per_output_type
Expand Down Expand Up @@ -50,12 +53,11 @@ def scaling_shape(


class QuantZPInt(Int8WeightPerTensorFloat):
bit_width = 8
module = (this << 1).module
tracked_parameter_list = (this << 1).tracked_parameter_list
upstream_scaling = (this << 1).scaling_per_output_type
rescaling_int_quant = RescalingIntQuant
bit_width = 6
bit_width = ZP_BIT_WIDTH
quantize_zero_point = True
scaling_per_output_type = ScalingPerOutputType.CHANNEL

Expand Down Expand Up @@ -86,13 +88,13 @@ def scaling_shape(
return scaling


class QuantScaleInt8WeightPerTensorFloat(ShiftedUint8WeightPerTensorFloat):
class QuantScaleQuantZPInt8WeightPerTensorFloat(ShiftedUint8WeightPerTensorFloat):
proxy_class = GroupwiseWeightQuantProxyFromInjector
scaling_int_quant = QuantScalingInt
zp_int = QuantZPInt
restrict_scaling_impl = QuantRestrictValue
scaling_per_output_type = ScalingPerOutputType.GROUP
scale_shit_zero_point_impl = _ScaleShiftQuantZeroPoint
scale_shift_zero_point_impl = _ScaleShiftQuantZeroPoint
group_size = 32

@value
Expand All @@ -108,15 +110,17 @@ def test_quant_scale():

def hook_scale(module, inp):
inp = inp[0]
quant_scale, scale, *_ = module.float_to_int_impl(inp)
quant_scale, scale, zp, bit_width = module.float_to_int_impl(inp)
assert bit_width == SCALE_BIT_WIDTH
assert torch.allclose(quant_scale / scale, torch.round(quant_scale / scale))

def hook_zp(module, inp):
inp = inp[0]
quant_scale, scale, *_ = module.zp_int_quant(inp)
quant_scale, scale, zp, bit_width = module.zp_int_quant(inp)
assert bit_width == ZP_BIT_WIDTH
assert torch.allclose(quant_scale / scale, torch.round(quant_scale / scale))

linear = qnn.QuantLinear(64, 768, weight_quant=QuantScaleInt8WeightPerTensorFloat)
linear = qnn.QuantLinear(64, 768, weight_quant=QuantScaleQuantZPInt8WeightPerTensorFloat)
for module in linear.modules():
if isinstance(module, QuantRestrictValue):
module.register_forward_pre_hook(hook_scale)
Expand Down

0 comments on commit ecb1aad

Please sign in to comment.