From 76119af4a0785d9ef3393cbcaceb3050801ec1c1 Mon Sep 17 00:00:00 2001 From: Giuseppe Franco Date: Wed, 21 Aug 2024 18:58:03 +0100 Subject: [PATCH] update interface for groupwise export --- src/brevitas_examples/llm/llm_quant/export.py | 66 +++++-------------- 1 file changed, 17 insertions(+), 49 deletions(-) diff --git a/src/brevitas_examples/llm/llm_quant/export.py b/src/brevitas_examples/llm/llm_quant/export.py index bd8bb3df2..0998fd417 100644 --- a/src/brevitas_examples/llm/llm_quant/export.py +++ b/src/brevitas_examples/llm/llm_quant/export.py @@ -25,6 +25,7 @@ from brevitas.function.ops import min_int from brevitas.nn import QuantLinear from brevitas.proxy.parameter_quant import WeightQuantProxyFromInjector +from brevitas.proxy.groupwise_int_parameter_quant import GroupwiseWeightQuantProxyFromInjector # TODO: Improve Groupwise export @@ -56,30 +57,6 @@ def __init__(self): self.bit_width = None self.dtype = None - def scaling_impl(self, proxy_module): - return proxy_module.tensor_quant.scaling_impl - - def zero_point_impl(self, proxy_module): - return proxy_module.tensor_quant.zero_point_impl - - def bit_width_impl(self, proxy_module): - return proxy_module.tensor_quant.msb_clamp_bit_width_impl - - def export_scale(self, proxy_module, bit_width): - scaling_impl = self.scaling_impl(proxy_module) - int_scaling_impl = proxy_module.tensor_quant.int_scaling_impl - int_threshold = int_scaling_impl(bit_width) - if hasattr(scaling_impl, 'wrapped_scaling_impl'): - threshold = scaling_impl.wrapped_scaling_impl.stats_scaling_impl( - scaling_impl.wrapped_scaling_impl.parameter_list_stats()) - else: - threshold = scaling_impl.stats_scaling_impl(scaling_impl.parameter_list_stats()) - return threshold / int_threshold - - def export_zero_point(self, proxy_module, scale, bit_width): - zero_point_impl = self.zero_point_impl(proxy_module) - return zero_point_impl.unexpanded_zero_point(scale, bit_width) - @abstractmethod def prepare_for_export(self, module): pass @@ -90,6 +67,7 @@ def forward(self, x): class WeightBlockQuantProxyHandler(WeightBlockQuantHandlerBase): + handled_layer = GroupwiseWeightQuantProxyFromInjector def __init__(self): super().__init__() @@ -100,20 +78,18 @@ def __init__(self): def prepare_for_export(self, module): assert len(module.tracked_module_list) == 1, "Shared quantizers not supported." - self.bit_width = self.bit_width_impl(module)() - assert self.bit_width <= 8., "Only 8b or lower is supported." quant_layer = module.tracked_module_list[0] quant_weight = quant_layer.quant_weight() + self.bit_width = quant_weight.bit_width + assert self.bit_width <= 8., "Only 8b or lower is supported." signed = module.is_signed self.int_dtype = torch.int8 if signed else torch.uint8 self.dtype = quant_weight.value.dtype - self.scale = self.export_scale(module, self.bit_width).detach() - self.expanded_scaling_shape = self.scaling_impl(module).expanded_scaling_shape - self.reshaped_scaling_shape = self.scaling_impl(module).reshaped_scaling_shape + self.scale = quant_weight.scale_ + self.expanded_scaling_shape = quant_weight.value_.shape + self.reshaped_scaling_shape = quant_weight.value.shape if (quant_weight.zero_point != 0.).any(): - self.zero_point = self.export_zero_point(module, self.scale, self.bit_width).detach() - self.expanded_zero_point_shape = self.zero_point_impl(module).expanded_zero_point_shape - self.reshaped_zero_point_shape = self.zero_point_impl(module).reshaped_zero_point_shape + self.zero_point = quant_weight.zero_point_ else: self.zero_point = None @@ -138,15 +114,9 @@ def forward(self, x): x = (x.type(self.dtype) - zero_point) * scale # Fix shape post quantization - scale = scale.expand(self.expanded_scaling_shape).contiguous().view( - self.reshaped_scaling_shape) # If zero_point is not defined, propagate same shape as scale if self.zero_point is None: zero_point = torch.zeros_like(scale).type(self.int_dtype) - else: - zero_point = zero_point.expand(self.expanded_zero_point_shape).contiguous().view( - self.reshaped_zero_point_shape).type(self.int_dtype) - x = x.view(self.reshaped_scaling_shape) return x, scale, zero_point, bit_width @@ -215,18 +185,17 @@ def lcm(x, y): raise ValueError(f"Bit width {bit_width} not supported.") def prepare_for_export(self, module): - self.bit_width = self.bit_width_impl(module.weight_quant)() - assert self.bit_width <= 8., "Only 8b or lower is supported." quant_weight = module.quant_weight() + self.bit_width = quant_weight.bit_width + assert self.bit_width <= 8., "Only 8b or lower is supported." self.bias = module.bias - self.scale = self.export_scale(module.weight_quant, self.bit_width) + self.scale = quant_weight.scale_ if (quant_weight.zero_point != 0.).any(): - self.zero_point = self.export_zero_point( - module.weight_quant, self.scale, self.bit_width) + self.zero_point = quant_weight.zero_point_ else: # if there is no zero-point, export zeroes in the shape of scale self.zero_point = torch.zeros_like(self.scale) - self.group_size = module.weight_quant.quant_injector.group_size + self.group_size = quant_weight.group_size self.bit_width = int(self.bit_width.cpu().item()) self.int_weight = self.pack_int_weights(self.bit_width, quant_weight.int().detach()) @@ -346,14 +315,13 @@ def pack_int_weights(self, bit_width, int_weights, zero_point): return torch.tensor(packed), packed_zp def prepare_for_export(self, module): - self.bit_width = self.bit_width_impl(module.weight_quant)() - assert self.bit_width <= 8., "Only 8b or lower is supported." quant_weight = module.quant_weight() + self.bit_width = quant_weight.bit_width + assert self.bit_width <= 8., "Only 8b or lower is supported." self.bias = module.bias - self.scale = self.export_scale(module.weight_quant, self.bit_width) + self.scale = quant_weight.scale_ if (quant_weight.zero_point != 0.).any(): - self.zero_point = self.export_zero_point( - module.weight_quant, self.scale, self.bit_width) + self.zero_point = quant_weight.zero_point_ else: # if there is no zero-point, export zeroes in the shape of scale self.zero_point = torch.zeros_like(self.scale)