Skip to content

Commit

Permalink
update interface for groupwise export
Browse files Browse the repository at this point in the history
  • Loading branch information
Giuseppe5 committed Aug 21, 2024
1 parent caa3b72 commit 76119af
Showing 1 changed file with 17 additions and 49 deletions.
66 changes: 17 additions & 49 deletions src/brevitas_examples/llm/llm_quant/export.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from brevitas.function.ops import min_int
from brevitas.nn import QuantLinear
from brevitas.proxy.parameter_quant import WeightQuantProxyFromInjector
from brevitas.proxy.groupwise_int_parameter_quant import GroupwiseWeightQuantProxyFromInjector


# TODO: Improve Groupwise export
Expand Down Expand Up @@ -56,30 +57,6 @@ def __init__(self):
self.bit_width = None
self.dtype = None

def scaling_impl(self, proxy_module):
return proxy_module.tensor_quant.scaling_impl

def zero_point_impl(self, proxy_module):
return proxy_module.tensor_quant.zero_point_impl

def bit_width_impl(self, proxy_module):
return proxy_module.tensor_quant.msb_clamp_bit_width_impl

def export_scale(self, proxy_module, bit_width):
scaling_impl = self.scaling_impl(proxy_module)
int_scaling_impl = proxy_module.tensor_quant.int_scaling_impl
int_threshold = int_scaling_impl(bit_width)
if hasattr(scaling_impl, 'wrapped_scaling_impl'):
threshold = scaling_impl.wrapped_scaling_impl.stats_scaling_impl(
scaling_impl.wrapped_scaling_impl.parameter_list_stats())
else:
threshold = scaling_impl.stats_scaling_impl(scaling_impl.parameter_list_stats())
return threshold / int_threshold

def export_zero_point(self, proxy_module, scale, bit_width):
zero_point_impl = self.zero_point_impl(proxy_module)
return zero_point_impl.unexpanded_zero_point(scale, bit_width)

@abstractmethod
def prepare_for_export(self, module):
pass
Expand All @@ -90,6 +67,7 @@ def forward(self, x):


class WeightBlockQuantProxyHandler(WeightBlockQuantHandlerBase):
handled_layer = GroupwiseWeightQuantProxyFromInjector

def __init__(self):
super().__init__()
Expand All @@ -100,20 +78,18 @@ def __init__(self):

def prepare_for_export(self, module):
assert len(module.tracked_module_list) == 1, "Shared quantizers not supported."
self.bit_width = self.bit_width_impl(module)()
assert self.bit_width <= 8., "Only 8b or lower is supported."
quant_layer = module.tracked_module_list[0]
quant_weight = quant_layer.quant_weight()
self.bit_width = quant_weight.bit_width
assert self.bit_width <= 8., "Only 8b or lower is supported."
signed = module.is_signed
self.int_dtype = torch.int8 if signed else torch.uint8
self.dtype = quant_weight.value.dtype
self.scale = self.export_scale(module, self.bit_width).detach()
self.expanded_scaling_shape = self.scaling_impl(module).expanded_scaling_shape
self.reshaped_scaling_shape = self.scaling_impl(module).reshaped_scaling_shape
self.scale = quant_weight.scale_
self.expanded_scaling_shape = quant_weight.value_.shape
self.reshaped_scaling_shape = quant_weight.value.shape
if (quant_weight.zero_point != 0.).any():
self.zero_point = self.export_zero_point(module, self.scale, self.bit_width).detach()
self.expanded_zero_point_shape = self.zero_point_impl(module).expanded_zero_point_shape
self.reshaped_zero_point_shape = self.zero_point_impl(module).reshaped_zero_point_shape
self.zero_point = quant_weight.zero_point_
else:
self.zero_point = None

Expand All @@ -138,15 +114,9 @@ def forward(self, x):
x = (x.type(self.dtype) - zero_point) * scale

# Fix shape post quantization
scale = scale.expand(self.expanded_scaling_shape).contiguous().view(
self.reshaped_scaling_shape)
# If zero_point is not defined, propagate same shape as scale
if self.zero_point is None:
zero_point = torch.zeros_like(scale).type(self.int_dtype)
else:
zero_point = zero_point.expand(self.expanded_zero_point_shape).contiguous().view(
self.reshaped_zero_point_shape).type(self.int_dtype)
x = x.view(self.reshaped_scaling_shape)

return x, scale, zero_point, bit_width

Expand Down Expand Up @@ -215,18 +185,17 @@ def lcm(x, y):
raise ValueError(f"Bit width {bit_width} not supported.")

def prepare_for_export(self, module):
self.bit_width = self.bit_width_impl(module.weight_quant)()
assert self.bit_width <= 8., "Only 8b or lower is supported."
quant_weight = module.quant_weight()
self.bit_width = quant_weight.bit_width
assert self.bit_width <= 8., "Only 8b or lower is supported."
self.bias = module.bias
self.scale = self.export_scale(module.weight_quant, self.bit_width)
self.scale = quant_weight.scale_
if (quant_weight.zero_point != 0.).any():
self.zero_point = self.export_zero_point(
module.weight_quant, self.scale, self.bit_width)
self.zero_point = quant_weight.zero_point_
else:
# if there is no zero-point, export zeroes in the shape of scale
self.zero_point = torch.zeros_like(self.scale)
self.group_size = module.weight_quant.quant_injector.group_size
self.group_size = quant_weight.group_size
self.bit_width = int(self.bit_width.cpu().item())
self.int_weight = self.pack_int_weights(self.bit_width, quant_weight.int().detach())

Expand Down Expand Up @@ -346,14 +315,13 @@ def pack_int_weights(self, bit_width, int_weights, zero_point):
return torch.tensor(packed), packed_zp

def prepare_for_export(self, module):
self.bit_width = self.bit_width_impl(module.weight_quant)()
assert self.bit_width <= 8., "Only 8b or lower is supported."
quant_weight = module.quant_weight()
self.bit_width = quant_weight.bit_width
assert self.bit_width <= 8., "Only 8b or lower is supported."
self.bias = module.bias
self.scale = self.export_scale(module.weight_quant, self.bit_width)
self.scale = quant_weight.scale_
if (quant_weight.zero_point != 0.).any():
self.zero_point = self.export_zero_point(
module.weight_quant, self.scale, self.bit_width)
self.zero_point = quant_weight.zero_point_
else:
# if there is no zero-point, export zeroes in the shape of scale
self.zero_point = torch.zeros_like(self.scale)
Expand Down

0 comments on commit 76119af

Please sign in to comment.