diff --git a/src/brevitas/export/onnx/standard/function.py b/src/brevitas/export/onnx/standard/function.py index aed2782bc..03916c299 100644 --- a/src/brevitas/export/onnx/standard/function.py +++ b/src/brevitas/export/onnx/standard/function.py @@ -4,9 +4,43 @@ import onnx import torch from torch.autograd import Function +from torch.onnx.symbolic_helper import _get_tensor_sizes from brevitas.export.onnx import onnx_export_opset + +class MatMulNBitsFn(Function): + + @staticmethod + def symbolic(g, x, int_weights, scales, zero_points, K, N, bits, block_size): + ret = g.op( + 'com.microsoft::MatMulNBits', + x, + int_weights, + scales, + zero_points, + K_i=K, + N_i=N, + bits_i=bits, + block_size_i=block_size) + output_size = _get_tensor_sizes(x) + output_size[-1] = N + ret.setType(x.type().with_sizes(output_size)) + return ret + + @staticmethod + def forward(g, x, int_weights, scales, zero_points, K, N, bits, block_size): + dtype = x.dtype + device = x.device + shape = x.shape + out_shape = list(shape) + out_shape[-1] = N + # Only tensor metadata (shape, dtype, device) are preserved in the forward pass during + # tracing, not the correct value + out = torch.empty(out_shape, dtype=dtype, device=device) + return out + + AXIS_OPSET = 13 DATATYPE_DICT = { diff --git a/src/brevitas_examples/llm/llm_quant/export.py b/src/brevitas_examples/llm/llm_quant/export.py index 5716c6f50..fe060b8c5 100644 --- a/src/brevitas_examples/llm/llm_quant/export.py +++ b/src/brevitas_examples/llm/llm_quant/export.py @@ -10,6 +10,8 @@ import numpy as np import torch +from torch.nn import Module +from torch.onnx import register_custom_op_symbolic from brevitas.export.common.handler.base import BaseHandler from brevitas.export.manager import _set_layer_export_handler @@ -19,6 +21,8 @@ from brevitas.export.manager import BaseManager from brevitas.function.ops import max_int from brevitas.function.ops import min_int +from brevitas.export.onnx.handler import ONNXBaseHandler +from brevitas.export.onnx.standard.function import MatMulNBitsFn from brevitas.nn import QuantLinear from brevitas.proxy.parameter_quant import WeightQuantProxyFromInjector @@ -65,8 +69,11 @@ def export_scale(self, proxy_module, bit_width): scaling_impl = self.scaling_impl(proxy_module) int_scaling_impl = proxy_module.tensor_quant.int_scaling_impl int_threshold = int_scaling_impl(bit_width) - threshold = scaling_impl.wrapped_scaling_impl.stats_scaling_impl( - scaling_impl.wrapped_scaling_impl.parameter_list_stats()) + if hasattr(scaling_impl, 'wrapped_scaling_impl'): + threshold = scaling_impl.wrapped_scaling_impl.stats_scaling_impl( + scaling_impl.wrapped_scaling_impl.parameter_list_stats()) + else: + threshold = scaling_impl.stats_scaling_impl(scaling_impl.parameter_list_stats()) return threshold / int_threshold def export_zero_point(self, proxy_module, scale, bit_width): @@ -219,7 +226,7 @@ def prepare_for_export(self, module): else: # if there is no zero-point, export zeroes in the shape of scale self.zero_point = torch.zeros_like(self.scale) - self.group_size = module.weight_quant.quant_injector.block_size + self.group_size = module.weight_quant.quant_injector.group_size self.bit_width = int(self.bit_width.cpu().item()) self.int_weight = self.pack_int_weights(self.bit_width, quant_weight.int().detach()) @@ -237,10 +244,12 @@ def set_export_handler(cls, module): _set_proxy_export_handler(cls, module) -def block_quant_layer_level_manager(export_handlers): +def block_quant_layer_level_manager(export_handlers, target=None, custom_fns_to_register=None): class BlockQuantLayerLevelManager(BaseManager): handlers = export_handlers + target_name = '' if target is None else target + custom_fns = [] if custom_fns_to_register is None else custom_fns_to_register @classmethod def set_export_handler(cls, module): @@ -281,3 +290,93 @@ def replace_call_fn_target(graph_model, src, target): node.target = target graph_model.graph.lint() graph_model.recompile() + + +class ONNXLinearWeightBlockQuantHandlerFwd(ONNXBaseHandler, WeightBlockQuantHandlerBase): + handled_layer = QuantLinear + + def __init__(self): + super(ONNXLinearWeightBlockQuantHandlerFwd, self).__init__() + self.group_size = None + + def pack_int_weights(self, bit_width, int_weights, zero_point): + assert int_weights.dtype in [torch.uint8, torch.int8], "Packing requires (u)int8 input." + assert bit_width == 4, "Only 4 bit quantization export is supported at the moment" + + is_symmetric = torch.sum(zero_point) == 0 + zero_point = zero_point.to(torch.uint8) + rows, cols = int_weights.shape + group_size = self.group_size + blob_size = group_size // 2 + k_blocks = (rows + group_size - 1) // group_size + padded_rows = k_blocks * group_size + pad_len = padded_rows - rows + + # ONNX operator assumes implicit zp of 8 (largest negative number in Po2) + # If we are in a "symmetric" quantized scenario, we need to add this implicit zero point + # Otherwise it has already been added during the convesion to integer. + # This allows to pack weights always in unsigned integer. + zp = 0 if not int_weights.dtype == torch.int8 else 8 + int_weights += zp + if pad_len > 0: + int_weights = torch.nn.functional(int_weights, (0, 0, 0, pad_len)) + packed = np.zeros((cols, k_blocks, blob_size), dtype="uint8") + rows, cols = int_weights.shape + int_weights = int_weights.t() + for n in range(cols): + for k_id in range(0, rows, group_size): + blk_int0 = (int_weights[n, k_id:k_id + group_size:2].numpy()).astype("uint8") + blk_int1 = (int_weights[n, k_id + 1:k_id + group_size:2].numpy()).astype("uint8") + packed[n, k_id // group_size] = np.bitwise_or(blk_int0, np.left_shift(blk_int1, 4)) + + zero_point = zero_point.to(torch.uint8).flatten() + + # The constant value 136 is derived from the source code in ORT test suite. + # https://github.com/microsoft/onnxruntime/blob/main/onnxruntime/test/python/quantization/test_quantizeblockwise_4bits.py + base_zp = 136 if is_symmetric else 0 + packed_zp = base_zp * torch.ones( + (zero_point.shape[0] + 1) // 2, device=int_weights.device, dtype=torch.uint8) + + i = 0 + for column in range(packed_zp.shape[0]): + for j in range(i, i + (8 // bit_width)): + shift_factor = (bit_width * (j - i)) + packed_zp[column] |= zero_point[j] << shift_factor + i += 8 // bit_width + return torch.tensor(packed), packed_zp + + def prepare_for_export(self, module): + self.bit_width = self.bit_width_impl(module.weight_quant)() + assert self.bit_width <= 8., "Only 8b or lower is supported." + quant_weight = module.quant_weight() + self.bias = module.bias + self.scale = self.export_scale(module.weight_quant, self.bit_width) + if (quant_weight.zero_point != 0.).any(): + self.zero_point = self.export_zero_point( + module.weight_quant, self.scale, self.bit_width) + else: + # if there is no zero-point, export zeroes in the shape of scale + self.zero_point = torch.zeros_like(self.scale) + self.group_size = module.weight_quant.quant_injector.group_size + self.bit_width = int(self.bit_width.cpu().item()) + self.int_weight, self.zero_point = self.pack_int_weights(self.bit_width, quant_weight.int().t().detach(), self.zero_point) + self.weight_shape = module.weight.shape + + def symbolic_execution(self, x): + int_weights = self.int_weight + scale = self.scale + bit_width = self.bit_width + N, K = self.weight_shape + out = MatMulNBitsFn.apply( + x, int_weights, scale.flatten(), self.zero_point, K, N, bit_width, self.group_size) + return out + + +def export_packed_onnx(model, input, export_path): + export_class = block_quant_layer_level_manager( + export_handlers=[ONNXLinearWeightBlockQuantHandlerFwd], + target='', + custom_fns_to_register=MatMulNBitsFn) + + with torch.inference_mode(), brevitas_layer_export_mode(model, export_class): + torch.onnx.export(model, input, export_path) diff --git a/src/brevitas_examples/llm/main.py b/src/brevitas_examples/llm/main.py index 5237c31c7..7a9e6ba4e 100644 --- a/src/brevitas_examples/llm/main.py +++ b/src/brevitas_examples/llm/main.py @@ -170,6 +170,7 @@ choices=[ None, 'onnx_qcdq', + 'packed_onnx', 'torch_qcdq', 'sharded_torchmlir_group_weight', 'sharded_packed_torchmlir_group_weight'], @@ -190,6 +191,8 @@ def model_export(model, ref_input, args): from brevitas_examples.llm.llm_quant.sharded_mlir_group_export import \ sharded_weight_group_export sharded_weight_group_export(model, no_custom_packed_export=False) + elif args.export_target == 'packed_onnx': + export_packed_onnx(model, ref_input, export_path=f"{args.model.replace('/', '-')}.onnx") elif args.export_target == 'onnx_qcdq': if args.weight_quant_granularity == 'per_group': export_manager = BlockQuantProxyLevelManager