From a7fdf2736d7138c0b8aeb6c0dfd3ae3282385d1b Mon Sep 17 00:00:00 2001 From: Giuseppe Franco Date: Mon, 17 Jul 2023 14:06:12 +0100 Subject: [PATCH 01/12] GPFQ support --- src/brevitas/graph/calibrate.py | 8 +- src/brevitas/graph/{gptq.py => gpxq.py} | 442 +++++++++++++----- .../imagenet_classification/ptq/ptq_common.py | 17 +- .../ptq/ptq_evaluate.py | 8 +- 4 files changed, 363 insertions(+), 112 deletions(-) rename src/brevitas/graph/{gptq.py => gpxq.py} (65%) diff --git a/src/brevitas/graph/calibrate.py b/src/brevitas/graph/calibrate.py index d206e016a..a2d235e8a 100644 --- a/src/brevitas/graph/calibrate.py +++ b/src/brevitas/graph/calibrate.py @@ -2,6 +2,7 @@ # SPDX-License-Identifier: BSD-3-Clause from abc import ABC +from copy import deepcopy from functools import partial import sys @@ -279,11 +280,16 @@ def forward_hook_wbiol(self, module, inp, output, name): # Compute float reference self.disable_act_quantization(module, is_training=False) self.disable_param_quantization(module, is_training=False) + quant_weight = dict() + if hasattr(module, 'weight_orig_data'): + quant_weight[module] = deepcopy(module.weight.data) + module.weight.data = module.weight_orig_data out_float = module.forward(*inp) # Required to avoid infinite recursion self.collect_float_mean(module, out_float, name) self.enable_act_quantization(module, is_training=False) self.enable_param_quantization(module, is_training=False) - + for module, value in quant_weight.items(): + module.weight.data = value # Compute quant output # We need to disable output_quant while out_quant is being computed # or we are going to apply bias correction on post quant values instead of pre quant diff --git a/src/brevitas/graph/gptq.py b/src/brevitas/graph/gpxq.py similarity index 65% rename from src/brevitas/graph/gptq.py rename to src/brevitas/graph/gpxq.py index 8f8ffb6ae..c731773fd 100644 --- a/src/brevitas/graph/gptq.py +++ b/src/brevitas/graph/gpxq.py @@ -1,6 +1,7 @@ # Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. # SPDX-License-Identifier: BSD-3-Clause +from abc import ABC, abstractmethod from copy import deepcopy from dataclasses import dataclass from dataclasses import field @@ -18,6 +19,7 @@ LinAlgError = RuntimeError import unfoldNd +import numpy as np from brevitas.graph.calibrate import DisableEnableQuantization import brevitas.nn as qnn @@ -37,42 +39,22 @@ class LayerHandler: forward_count: int = 0 -class gptq_mode: - """ - Apply GPTQ algorithm https://arxiv.org/abs/2210.17323. - - Args: - model (Module): The model to quantize with GPTQ - inplace (bool): Wheter to apply GPTQ inplace or perform a deepcopy. Default: True - use_quant_activations (bool): Wheter to leave quantize activations enabled while performing - GPTQ. Default: False - - Example: - >>> with torch.no_grad(): - >>> with gptq_mode(model) as gptq: - >>> gptq_model = gptq.model - >>> for i in tqdm(range(gptq.num_layers)): - >>> for img, t in calib_loader: - >>> img = img.cuda() - >>> gptq_model(img) - >>> gptq.update() - """ - +class gpxq_mode(ABC): def __init__( self, model, group_of_parallel_layers: Optional[List[str]] = None, inplace: bool = True, use_quant_activations: bool = True, - num_blocks: int = 100, act_order: bool = False, return_forward_output: bool = False) -> None: + if not inplace: model = deepcopy(model) self.model = model self.use_quant_activations = use_quant_activations self.hook_dict = dict() - self.gptq_layers = dict() + self.gpxq_layers = dict() # reference for each layer to update self.current_layer = LayerHandler() # How many layer to optimize @@ -80,11 +62,9 @@ def __init__( # Quantize following magnitude of activation self.act_order = act_order # How many subblock to use during GPTQ for each layer - self.num_blocks = num_blocks self.disable_quant_inference = DisableEnableQuantization() - self.orig_forward = self.model.forward - self.model.forward = self.catch_stopfwd + self.group_of_parallel_layers = group_of_parallel_layers self.return_forward_output = return_forward_output @@ -115,25 +95,25 @@ def __enter__(self): (name, attrgetter(name)(self.model)) for name in parallel_layers] # Print warning if hooks are attached to any module, since the normal forward flow of the - # network is highly disrupted during GPTQ + # network is highly disrupted during GPxQ for _, parallel_layers in dict_of_layers.items(): for name, module in parallel_layers: if len(module._forward_hooks) > 0 or len(module._forward_pre_hooks): warnings.warn( - f'Hooks detected during setup for GPTQ. ' + f'Hooks detected during setup for GPxQ. ' f'Behaviour might deviate from what expected.') # Attach hooks for GPTQ if self._is_module_supported(module): - gptq = GPTQ( + gpxq = self.class_implementation( module, name, - num_blocks=self.num_blocks, + # num_blocks=self.num_blocks, act_order=self.act_order, parallel_layers=parallel_layers) - hook_fn = partial(gptq.update_batch, current_layer=self.current_layer) + hook_fn = partial(gpxq.update_batch, current_layer=self.current_layer) self.hook_dict[name] = module.register_forward_pre_hook(hook_fn) - self.gptq_layers[name] = gptq + self.gpxq_layers[name] = gpxq if not self.use_quant_activations: self.disable_quant_inference.disable_act_quantization( self.model, is_training=self.model.training) @@ -153,10 +133,54 @@ def __exit__(self, type, value, traceback): def update(self): for name in self.current_layer.layer_names: - self.gptq_layers[name].single_layer_update() + self.gpxq_layers[name].single_layer_update() self.hook_dict[name].remove() self.current_layer.layer_names.clear() + @abstractmethod + def catch_stopfwd(self, *args, **kwargs): + pass + +class gptq_mode(gpxq_mode): + """ + Apply GPTQ algorithm https://arxiv.org/abs/2210.17323. + + Args: + model (Module): The model to quantize with GPTQ + inplace (bool): Wheter to apply GPTQ inplace or perform a deepcopy. Default: True + use_quant_activations (bool): Wheter to leave quantize activations enabled while performing + GPTQ. Default: False + + Example: + >>> with torch.no_grad(): + >>> with gptq_mode(model) as gptq: + >>> gptq_model = gptq.model + >>> for i in tqdm(range(gptq.num_layers)): + >>> for img, t in calib_loader: + >>> img = img.cuda() + >>> gptq_model(img) + >>> gptq.update() + """ + + def __init__( + self, + model, + group_of_parallel_layers: Optional[List[str]] = None, + inplace: bool = True, + use_quant_activations: bool = True, + num_blocks: int = 100, + act_order: bool = False) -> None: + if not inplace: + model = deepcopy(model) + super().__init__(model, group_of_parallel_layers, inplace, use_quant_activations, act_order) + + self.orig_forward = self.model.forward + self.model.forward = self.catch_stopfwd + # How many subblock to use during GPTQ for each layer + self.num_blocks = num_blocks + self.class_implementation = GPTQ + GPTQ.num_blocks = num_blocks + def catch_stopfwd(self, *args, **kwargs): try: self.orig_forward(*args, **kwargs) @@ -172,35 +196,84 @@ def catch_stopfwd(self, *args, **kwargs): gptq_class.disable_pre_forward_hook = False return out +class gpfq_mode(gpxq_mode): + """ + Apply GPTQ algorithm https://arxiv.org/abs/2210.17323. -class GPTQ(): + Args: + model (Module): The model to quantize with GPTQ + inplace (bool): Wheter to apply GPTQ inplace or perform a deepcopy. Default: True + use_quant_activations (bool): Wheter to leave quantize activations enabled while performing + GPTQ. Default: False + + Example: + >>> with torch.no_grad(): + >>> with gptq_mode(model) as gptq: + >>> gptq_model = gptq.model + >>> for i in tqdm(range(gptq.num_layers)): + >>> for img, t in calib_loader: + >>> img = img.cuda() + >>> gptq_model(img) + >>> gptq.update() """ - Adapted from https://github.com/IST-DASLab/gptq, released under the following LICENSE: - Copyright 2023 IST-DASLab + def __init__( + self, + model, + group_of_parallel_layers: Optional[List[str]] = None, + inplace: bool = True, + use_quant_activations: bool = True, + act_order: bool = False) -> None: + if not inplace: + model = deepcopy(model) + super().__init__(model, group_of_parallel_layers, inplace, use_quant_activations, act_order) - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at + self.orig_forward = self.model.forward + self.model.forward = self.catch_stopfwd + self.class_implementation = GPFQ - http://www.apache.org/licenses/LICENSE-2.0 + def catch_stopfwd(self, *args, **kwargs): + # Collect quant input + try: + self.orig_forward(*args, **kwargs) + except StopFwdException: + pass + # Before collecting float input, restore original float weights if they have been modified + quant_weight = dict() + for module in self.model.modules(): + if hasattr(module, 'weight_orig_data'): + quant_weight[module] = deepcopy(module.weight.data) + module.weight.data = module.weight_orig_data + # Disable quantization + self.disable_quant_inference.disable_param_quantization(self.model, is_training=False) + self.disable_quant_inference.disable_act_quantization(self.model, is_training=False) + # Collect float input + try: + self.orig_forward(*args, **kwargs) + except StopFwdException: + pass + # Restore correct weights + for module in self.model.modules(): + if hasattr(module, 'weight_orig_data'): + module.weight.data = quant_weight[module] + # Re-enable quantization. If activation quantization is disabled, + # we also disable bias quantization + self.disable_quant_inference.enable_param_quantization(self.model, is_training=False) + if self.use_quant_activations: + self.disable_quant_inference.enable_act_quantization(self.model, is_training=False) + else: + self.disable_quant_inference.disable_bias_quantization(self.model, is_training=False) - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. - """ - def __init__(self, layer, name, num_blocks, act_order, parallel_layers=1) -> None: +class GPxQ(ABC): + + def __init__(self, layer, name, act_order, parallel_layers=1) -> None: self.layer = layer self.name = name - self.num_blocks = num_blocks self.act_order = act_order weight = layer.weight.data - dev = weight.device - + self.layer.weight_orig_data = deepcopy(weight) # By default, use groups = 1 self.groups = 1 if isinstance(self.layer, SUPPORTED_CONV_OP): @@ -213,29 +286,16 @@ def __init__(self, layer, name, num_blocks, act_order, parallel_layers=1) -> Non self.rows = weight.shape[0] # Number of columns is equal to the input channels (IC) self.columns = weight.shape[1] - - # Define how many columns to update in each mini-block - self.blocksize = math.ceil(self.columns / self.num_blocks) - - # Initialize Hessian matrix and counter. We need it in float32 to compute the inverse - self.H = torch.zeros((self.groups, self.columns, self.columns), - device=dev, - dtype=torch.float32) - self.nsamples = 0 self.parallel_layers = parallel_layers self.disable_pre_forward_hook = False # Some layers require knowledge from quant inputs to compute quant weights self.quant_input = None - def update_batch(self, module, input, current_layer): - if self.disable_pre_forward_hook: - return input - # Update reference to current layer - current_layer.layer_names.add(self.name) - + + def process_input(self, inp): # Input is a tuple, so we take first element - inp = input[0] + inp = inp[0] # If using Quant Activations, inp could be QuantTensor if isinstance(inp, QuantTensor): if self.layer.weight_quant_requires_quant_input: @@ -259,6 +319,105 @@ def update_batch(self, module, input, current_layer): batch_dim = inp.names.index('N') inp.rename_(None) inp = inp.transpose(0, batch_dim) + return inp + + @abstractmethod + def update_batch(self, module, input, current_layer): + pass + + @abstractmethod + def single_layer_update(self, percdamp=.01): + pass + + def get_quant_weights(self, i, i1, permutation_list): + # We need to recompute quant weights at runtime since our float weights are being updated + # Add offset in case of blockwise computation (e.g., GPTQ) + i = i1 + i + # For QuantLinear and for some QuantConvolutional layers, we exploit the possibility + # of quantizing only a subset of the entire matrix speeding up the computation of GPTQ + if isinstance(self.layer, qnn.QuantLinear): + index = permutation_list[0][i] + subtensor_slice_list = [None, (index, index + 1)] + q = self.layer.quant_weight( + subtensor_slice_list=subtensor_slice_list, + quant_input=self.quant_input).value.unsqueeze(0) # [1, OC, 1] + elif isinstance(self.layer, SUPPORTED_CONV_OP): + # For depthwise and ConvTranspose we fall back to quantizing the entire martix. + # For all other cases, we create a mask that represent the slicing we will perform on the weight matrix + # and we quantize only the selected dimensions. + if self.groups > 1 or (self.groups == 1 and isinstance( + self.layer, (qnn.QuantConvTranspose1d, qnn.QuantConvTranspose2d))): + + quant_weight = self.layer.quant_weight(quant_input=self.quant_input) + quant_weight = quant_weight.value + + if isinstance(self.layer, (qnn.QuantConvTranspose1d, qnn.QuantConvTranspose2d)): + quant_weight = quant_weight.transpose(1, 0) # This performs a view + quant_weight = quant_weight.flatten(1) + quant_weight = quant_weight.view(self.groups, -1, quant_weight.shape[-1]) + + if self.act_order: + for ii, perm in enumerate(permutation_list): + quant_weight[ii, :, :] = quant_weight[ii, :, perm] + + q = quant_weight[:, :, i:i + 1] # [groups, OC/groups, 1] + else: + index = permutation_list[0][i] + shapes = self.layer.weight.shape[1:] + index_2d_to_nd = [] + residual_index = index.item() + for shape in shapes[::-1]: + index_2d_to_nd.append((residual_index % shape, residual_index % shape + 1)) + residual_index = residual_index // shape + index_2d_to_nd = index_2d_to_nd[::-1] + index_2d_to_nd.insert(0, None) + q = self.layer.quant_weight( + subtensor_slice_list=index_2d_to_nd, + quant_input=self.quant_input).value.flatten(1) # [OC, 1] + q = q.unsqueeze(0) # [1, OC, 1] + # We need to remove the last dim + q = q.squeeze(2) # [groups, OC/groups] or [1, OC] + return q + + +class GPTQ(GPxQ): + """ + Adapted from https://github.com/IST-DASLab/gptq, released under the following LICENSE: + + Copyright 2023 IST-DASLab + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + """ + num_blocks = 100 + + def __init__(self, layer, name, act_order, parallel_layers=1) -> None: + super().__init__(layer, name, act_order, parallel_layers) + + dev = self.layer.weight.device + + # Define how many columns to update in each mini-block + self.blocksize = math.ceil(self.columns / GPTQ.num_blocks) + + # Initialize Hessian matrix and counter. We need it in float32 to compute the inverse + self.H = torch.zeros((self.groups, self.columns, self.columns), + device=dev, + dtype=torch.float32) + self.nsamples = 0 + + def update_batch(self, module, input, current_layer): + # Update reference to current layer + current_layer.layer_names.add(self.name) + inp = self.process_input(input) batch_size = inp.shape[0] # Preprocess the input to compute the Hessian @@ -391,52 +550,117 @@ def single_layer_update(self, percdamp=.01): weight[group_index, :, perm[i2:]] -= ( error_block[group_index].matmul(h_inv[group_index, i1:i2, i2:])).to(dtype) - def get_quant_weights(self, i, i1, permutation_list): - # We need to recompute quant weights at runtime since our float weights are being updated - # Add offset in case of blockwise computation (e.g., GPTQ) - i = i1 + i - # For QuantLinear and for some QuantConvolutional layers, we exploit the possibility - # of quantizing only a subset of the entire matrix speeding up the computation of GPTQ + + + +class GPFQ(GPxQ): + """ + Based on https://github.com/YixuanSeanZhou/Quantized_Neural_Nets/tree/main + """ + def __init__(self, layer, name, act_order, parallel_layers=1) -> None: + super().__init__(layer, name, act_order, parallel_layers) + self.float_input = None + self.quantized_input = None + self.index_computed = False + self.p = 0.25 + + + def update_batch(self, module, input, current_layer): + + # Update reference to current layer + current_layer.layer_names.add(self.name) + is_quant_disabled = module.weight_quant.disable_quant + + inp = self.process_input(input) + batch_size = inp.shape[0] + + # Preprocess the input to compute the Hessian if isinstance(self.layer, qnn.QuantLinear): - index = permutation_list[0][i] - subtensor_slice_list = [None, (index, index + 1)] - q = self.layer.quant_weight( - subtensor_slice_list=subtensor_slice_list, - quant_input=self.quant_input).value.unsqueeze(0) # [1, OC, 1] - elif isinstance(self.layer, SUPPORTED_CONV_OP): - # For depthwise and ConvTranspose we fall back to quantizing the entire martix. - # For all other cases, we create a mask that represent the slicing we will perform on the weight matrix - # and we quantize only the selected dimensions. - if self.groups > 1 or (self.groups == 1 and isinstance( - self.layer, (qnn.QuantConvTranspose1d, qnn.QuantConvTranspose2d))): + if len(inp.shape) > 2: + inp = inp.reshape((-1, sum(inp.shape[2:]))) + # For QuantLinear layer, groups will be 1 + inp_processed = inp.unsqueeze(0) - quant_weight = self.layer.quant_weight(quant_input=self.quant_input) - quant_weight = quant_weight.value + if isinstance(self.layer, SUPPORTED_CONV_OP): + # Pick the correct unfoldNd class + if isinstance(self.layer, (qnn.QuantConvTranspose1d, qnn.QuantConvTranspose2d)): + unfold_impl = unfoldNd.UnfoldTransposeNd + else: + unfold_impl = unfoldNd.UnfoldNd - if isinstance(self.layer, (qnn.QuantConvTranspose1d, qnn.QuantConvTranspose2d)): - quant_weight = quant_weight.transpose(1, 0) # This performs a view - quant_weight = quant_weight.flatten(1) - quant_weight = quant_weight.view(self.groups, -1, quant_weight.shape[-1]) + unfold = unfold_impl( + self.layer.kernel_size, + dilation=self.layer.dilation, + padding=self.layer.padding, + stride=self.layer.kernel_size) - if self.act_order: - for ii, perm in enumerate(permutation_list): - quant_weight[ii, :, :] = quant_weight[ii, :, perm] + # Split input based on how many groups in convolution + inp_by_group = torch.chunk(inp, self.groups, 1) + inp_processed = [] + # Preprocess input by group + for i, inp in enumerate(inp_by_group): - q = quant_weight[:, :, i:i + 1] # [groups, OC/groups, 1] + inp = unfold(inp) + + batch_size, num_blocks = inp.shape[0], inp.shape[-1] + inp = torch.transpose(inp, 1, 2) # shape (B, L, C*kernel_size[0]*kernel_size[1]) + inp = inp.reshape(-1, inp.size(-1)) # shape (B*L, C*kernel_size[0]*kernel_size[1]) + + if not self.index_computed: + self.index_computed = True + self.rand_indices = np.concatenate([ + np.random.choice( + np.arange(num_blocks * i, num_blocks * (i + 1)), + size=int( + self.p * num_blocks + 1 if self.p != 1 else self.p * num_blocks)) + for i in range(batch_size)]) # need to define self.p (probability) + + indexes = self.rand_indices + if np.max(self.rand_indices) > inp.shape[0]: + indexes = self.rand_indices < inp.shape[0] + indexes = self.rand_indices[indexes] + + + inp = inp[indexes] + inp_processed.append(inp) + inp_processed = torch.stack(inp_processed) + + if is_quant_disabled: + if self.float_input is None: + self.float_input = inp_processed else: - index = permutation_list[0][i] - shapes = self.layer.weight.shape[1:] - index_2d_to_nd = [] - residual_index = index.item() - for shape in shapes[::-1]: - index_2d_to_nd.append((residual_index % shape, residual_index % shape + 1)) - residual_index = residual_index // shape - index_2d_to_nd = index_2d_to_nd[::-1] - index_2d_to_nd.insert(0, None) - q = self.layer.quant_weight( - subtensor_slice_list=index_2d_to_nd, - quant_input=self.quant_input).value.flatten(1) # [OC, 1] - q = q.unsqueeze(0) # [1, OC, 1] - # We need to remove the last dim - q = q.squeeze(2) # [groups, OC/groups] or [1, OC] - return q + self.float_input = torch.cat([self.float_input, inp_processed], dim=1) + else: + if self.quantized_input is None: + self.quantized_input = inp_processed + else: + self.quantized_input = torch.cat([self.quantized_input, inp_processed], dim=1) + raise StopFwdException + + def single_layer_update(self): + weight = self.layer.weight.data + dev = weight.device + dtype = weight.dtype + if isinstance(self.layer, SUPPORTED_CONV_OP): + if isinstance(self.layer, (qnn.QuantConvTranspose1d, qnn.QuantConvTranspose2d)): + weight = weight.transpose(1, 0) # This performs a view + weight = weight.flatten(1) + weight = weight.view(self.groups, -1, weight.shape[-1]) # [Groups, OC/Groups, IC] + U = torch.zeros(weight.shape[0], weight.shape[1], self.float_input.shape[1], device=dev, dtype=dtype) + + for t in range(weight.shape[-1]): + for group_index in range(self.groups): + U[group_index] += weight[group_index, :, t].unsqueeze(1) * self.float_input[group_index, :, t].unsqueeze(0) #[OC/Groups, 1] * [1, INSHAPE[1]] + norm = torch.linalg.norm(self.quantized_input[group_index, :, t], 2) ** 2 + if norm > 0: + q_arg = U[group_index].matmul(self.quantized_input[group_index, :, t]) / norm + else: + q_arg = torch.zeros_like(U[group_index, :, 0]) + + weight[group_index, :, t] = q_arg + q = self.get_quant_weights(t, 0, [torch.tensor(range(weight.shape[-1]))]) + for group_index in range(self.groups): + U[group_index] -= q[group_index].unsqueeze(1) * self.quantized_input[group_index, :, t].unsqueeze(0) + + del self.float_input + del self.quantized_input diff --git a/src/brevitas_examples/imagenet_classification/ptq/ptq_common.py b/src/brevitas_examples/imagenet_classification/ptq/ptq_common.py index 7982059f8..e95e1936f 100644 --- a/src/brevitas_examples/imagenet_classification/ptq/ptq_common.py +++ b/src/brevitas_examples/imagenet_classification/ptq/ptq_common.py @@ -2,6 +2,7 @@ # SPDX-License-Identifier: BSD-3-Clause from copy import deepcopy +from brevitas.graph.gpxq import gptq_mode, gpfq_mode import torch import torch.backends.cudnn as cudnn @@ -13,7 +14,6 @@ from brevitas.graph.calibrate import calibration_mode from brevitas.graph.calibrate import norm_correction_mode from brevitas.graph.equalize import activation_equalization_mode -from brevitas.graph.gptq import gptq_mode from brevitas.graph.quantize import layerwise_quantize from brevitas.graph.quantize import quantize from brevitas.graph.target.flexml import quantize_flexml @@ -360,6 +360,21 @@ def apply_gptq(calib_loader, model, act_order=False): gptq.update() +def apply_gpfq(calib_loader, model, act_order=False): + model.eval() + dtype = next(model.parameters()).dtype + device = next(model.parameters()).device + with torch.no_grad(): + with gpfq_mode(model, act_order=act_order, use_quant_activations=True) as gpfq: + gpfq_model = gpfq.model + for i in tqdm(range(gpfq.num_layers)): + for i, (images, target) in enumerate(calib_loader): + images = images.to(device) + images = images.to(dtype) + gpfq_model(images) + gpfq.update() + + def apply_learned_round_learning( model, dataloader, optimizer_class=torch.optim.Adam, iters=1000, optimizer_lr=1e-1): layers = [] diff --git a/src/brevitas_examples/imagenet_classification/ptq/ptq_evaluate.py b/src/brevitas_examples/imagenet_classification/ptq/ptq_evaluate.py index fdf2b966c..88819e69e 100644 --- a/src/brevitas_examples/imagenet_classification/ptq/ptq_evaluate.py +++ b/src/brevitas_examples/imagenet_classification/ptq/ptq_evaluate.py @@ -20,7 +20,7 @@ from brevitas.graph.equalize import activation_equalization_mode from brevitas.graph.quantize import preprocess_for_quantize from brevitas.graph.target.flexml import preprocess_for_flexml_quantize -from brevitas_examples.imagenet_classification.ptq.ptq_common import apply_act_equalization +from brevitas_examples.imagenet_classification.ptq.ptq_common import apply_act_equalization, apply_gpfq from brevitas_examples.imagenet_classification.ptq.ptq_common import apply_bias_correction from brevitas_examples.imagenet_classification.ptq.ptq_common import apply_gptq from brevitas_examples.imagenet_classification.ptq.ptq_common import apply_learned_round_learning @@ -166,6 +166,7 @@ default=True, help='Narrow range for weight quantization (default: enabled)') add_bool_arg(parser, 'gptq', default=True, help='GPTQ (default: enabled)') +add_bool_arg(parser, 'gpfq', default=False, help='GPTQ (default: disabled)') add_bool_arg( parser, 'gptq-act-order', default=False, help='GPTQ Act order heuristic (default: disabled)') add_bool_arg(parser, 'learned-round', default=False, help='Learned round (default: disabled)') @@ -299,9 +300,14 @@ def main(): print("Starting activation calibration:") calibrate(calib_loader, quant_model) + if args.gpfq: + print("Performing GPFQ:") + apply_gpfq(calib_loader, quant_model) + if args.gptq: print("Performing GPTQ:") apply_gptq(calib_loader, quant_model, args.gptq_act_order) + if args.learned_round: print("Applying Learned Round:") From ae870573d8a163aaaf23eaf0c93f00c6f81c45e6 Mon Sep 17 00:00:00 2001 From: Giuseppe Franco Date: Tue, 25 Jul 2023 12:16:14 +0100 Subject: [PATCH 02/12] Review --- src/brevitas/graph/gpxq.py | 23 ++++++++----------- .../imagenet_classification/ptq/ptq_common.py | 3 ++- .../ptq/ptq_evaluate.py | 5 ++-- 3 files changed, 15 insertions(+), 16 deletions(-) diff --git a/src/brevitas/graph/gpxq.py b/src/brevitas/graph/gpxq.py index c731773fd..42d78197c 100644 --- a/src/brevitas/graph/gpxq.py +++ b/src/brevitas/graph/gpxq.py @@ -1,7 +1,8 @@ # Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. # SPDX-License-Identifier: BSD-3-Clause -from abc import ABC, abstractmethod +from abc import ABC +from abc import abstractmethod from copy import deepcopy from dataclasses import dataclass from dataclasses import field @@ -17,9 +18,8 @@ from torch.linalg import LinAlgError except: LinAlgError = RuntimeError - -import unfoldNd import numpy as np +import unfoldNd from brevitas.graph.calibrate import DisableEnableQuantization import brevitas.nn as qnn @@ -40,6 +40,7 @@ class LayerHandler: class gpxq_mode(ABC): + def __init__( self, model, @@ -106,11 +107,7 @@ def __enter__(self): # Attach hooks for GPTQ if self._is_module_supported(module): gpxq = self.class_implementation( - module, - name, - # num_blocks=self.num_blocks, - act_order=self.act_order, - parallel_layers=parallel_layers) + module, name, act_order=self.act_order, parallel_layers=parallel_layers) hook_fn = partial(gpxq.update_batch, current_layer=self.current_layer) self.hook_dict[name] = module.register_forward_pre_hook(hook_fn) self.gpxq_layers[name] = gpxq @@ -141,6 +138,7 @@ def update(self): def catch_stopfwd(self, *args, **kwargs): pass + class gptq_mode(gpxq_mode): """ Apply GPTQ algorithm https://arxiv.org/abs/2210.17323. @@ -196,6 +194,7 @@ def catch_stopfwd(self, *args, **kwargs): gptq_class.disable_pre_forward_hook = False return out + class gpfq_mode(gpxq_mode): """ Apply GPTQ algorithm https://arxiv.org/abs/2210.17323. @@ -256,7 +255,7 @@ def catch_stopfwd(self, *args, **kwargs): for module in self.model.modules(): if hasattr(module, 'weight_orig_data'): module.weight.data = quant_weight[module] - # Re-enable quantization. If activation quantization is disabled, + # Re-enable quantization. If activation quantization is disabled, # we also disable bias quantization self.disable_quant_inference.enable_param_quantization(self.model, is_training=False) if self.use_quant_activations: @@ -551,12 +550,11 @@ def single_layer_update(self, percdamp=.01): error_block[group_index].matmul(h_inv[group_index, i1:i2, i2:])).to(dtype) - - class GPFQ(GPxQ): """ Based on https://github.com/YixuanSeanZhou/Quantized_Neural_Nets/tree/main """ + def __init__(self, layer, name, act_order, parallel_layers=1) -> None: super().__init__(layer, name, act_order, parallel_layers) self.float_input = None @@ -564,9 +562,8 @@ def __init__(self, layer, name, act_order, parallel_layers=1) -> None: self.index_computed = False self.p = 0.25 - def update_batch(self, module, input, current_layer): - + # Update reference to current layer current_layer.layer_names.add(self.name) is_quant_disabled = module.weight_quant.disable_quant diff --git a/src/brevitas_examples/imagenet_classification/ptq/ptq_common.py b/src/brevitas_examples/imagenet_classification/ptq/ptq_common.py index e95e1936f..343624950 100644 --- a/src/brevitas_examples/imagenet_classification/ptq/ptq_common.py +++ b/src/brevitas_examples/imagenet_classification/ptq/ptq_common.py @@ -2,7 +2,6 @@ # SPDX-License-Identifier: BSD-3-Clause from copy import deepcopy -from brevitas.graph.gpxq import gptq_mode, gpfq_mode import torch import torch.backends.cudnn as cudnn @@ -14,6 +13,8 @@ from brevitas.graph.calibrate import calibration_mode from brevitas.graph.calibrate import norm_correction_mode from brevitas.graph.equalize import activation_equalization_mode +from brevitas.graph.gpxq import gpfq_mode +from brevitas.graph.gpxq import gptq_mode from brevitas.graph.quantize import layerwise_quantize from brevitas.graph.quantize import quantize from brevitas.graph.target.flexml import quantize_flexml diff --git a/src/brevitas_examples/imagenet_classification/ptq/ptq_evaluate.py b/src/brevitas_examples/imagenet_classification/ptq/ptq_evaluate.py index 88819e69e..08dde3b21 100644 --- a/src/brevitas_examples/imagenet_classification/ptq/ptq_evaluate.py +++ b/src/brevitas_examples/imagenet_classification/ptq/ptq_evaluate.py @@ -20,8 +20,9 @@ from brevitas.graph.equalize import activation_equalization_mode from brevitas.graph.quantize import preprocess_for_quantize from brevitas.graph.target.flexml import preprocess_for_flexml_quantize -from brevitas_examples.imagenet_classification.ptq.ptq_common import apply_act_equalization, apply_gpfq +from brevitas_examples.imagenet_classification.ptq.ptq_common import apply_act_equalization from brevitas_examples.imagenet_classification.ptq.ptq_common import apply_bias_correction +from brevitas_examples.imagenet_classification.ptq.ptq_common import apply_gpfq from brevitas_examples.imagenet_classification.ptq.ptq_common import apply_gptq from brevitas_examples.imagenet_classification.ptq.ptq_common import apply_learned_round_learning from brevitas_examples.imagenet_classification.ptq.ptq_common import calibrate @@ -166,7 +167,7 @@ default=True, help='Narrow range for weight quantization (default: enabled)') add_bool_arg(parser, 'gptq', default=True, help='GPTQ (default: enabled)') -add_bool_arg(parser, 'gpfq', default=False, help='GPTQ (default: disabled)') +add_bool_arg(parser, 'gpfq', default=False, help='GPFQ (default: disabled)') add_bool_arg( parser, 'gptq-act-order', default=False, help='GPTQ Act order heuristic (default: disabled)') add_bool_arg(parser, 'learned-round', default=False, help='Learned round (default: disabled)') From 7dbee11dae6fa7b6b6d9d0cc20cd2750e46c3c1b Mon Sep 17 00:00:00 2001 From: Giuseppe Franco Date: Wed, 26 Jul 2023 15:10:05 +0100 Subject: [PATCH 03/12] Fix for depthwise act_order gptq --- src/brevitas/graph/gpxq.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/brevitas/graph/gpxq.py b/src/brevitas/graph/gpxq.py index 42d78197c..3e820d71b 100644 --- a/src/brevitas/graph/gpxq.py +++ b/src/brevitas/graph/gpxq.py @@ -616,7 +616,6 @@ def update_batch(self, module, input, current_layer): if np.max(self.rand_indices) > inp.shape[0]: indexes = self.rand_indices < inp.shape[0] indexes = self.rand_indices[indexes] - inp = inp[indexes] inp_processed.append(inp) From 60856fd97d56d1c2a0a4fd3ac53c40665416e3c2 Mon Sep 17 00:00:00 2001 From: Giuseppe Franco Date: Wed, 26 Jul 2023 15:11:33 +0100 Subject: [PATCH 04/12] Docstring update --- src/brevitas/graph/gpxq.py | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/src/brevitas/graph/gpxq.py b/src/brevitas/graph/gpxq.py index 3e820d71b..dd995ec6d 100644 --- a/src/brevitas/graph/gpxq.py +++ b/src/brevitas/graph/gpxq.py @@ -197,23 +197,23 @@ def catch_stopfwd(self, *args, **kwargs): class gpfq_mode(gpxq_mode): """ - Apply GPTQ algorithm https://arxiv.org/abs/2210.17323. + Apply GPFQ algorithm. Args: - model (Module): The model to quantize with GPTQ - inplace (bool): Wheter to apply GPTQ inplace or perform a deepcopy. Default: True + model (Module): The model to quantize with GPFQ + inplace (bool): Wheter to apply GPFQ inplace or perform a deepcopy. Default: True use_quant_activations (bool): Wheter to leave quantize activations enabled while performing - GPTQ. Default: False + GPFQ. Default: False Example: >>> with torch.no_grad(): - >>> with gptq_mode(model) as gptq: - >>> gptq_model = gptq.model - >>> for i in tqdm(range(gptq.num_layers)): + >>> with gpfq_mode(model) as gpfq: + >>> gpfq_model = gpfq.model + >>> for i in tqdm(range(gpfq.num_layers)): >>> for img, t in calib_loader: >>> img = img.cuda() - >>> gptq_model(img) - >>> gptq.update() + >>> gpfq_model(img) + >>> gpfq.update() """ def __init__( From 20030f409a7467609d6d96cfee57610a1506c940 Mon Sep 17 00:00:00 2001 From: Giuseppe Franco Date: Wed, 23 Aug 2023 14:15:22 +0100 Subject: [PATCH 05/12] Update --- src/brevitas/graph/gpxq.py | 77 +++++++++++++++---- .../imagenet_classification/ptq/README.md | 6 +- .../imagenet_classification/ptq/ptq_common.py | 4 +- .../ptq/ptq_evaluate.py | 10 ++- 4 files changed, 74 insertions(+), 23 deletions(-) diff --git a/src/brevitas/graph/gpxq.py b/src/brevitas/graph/gpxq.py index dd995ec6d..c076b2827 100644 --- a/src/brevitas/graph/gpxq.py +++ b/src/brevitas/graph/gpxq.py @@ -167,10 +167,17 @@ def __init__( inplace: bool = True, use_quant_activations: bool = True, num_blocks: int = 100, + return_forward_output: bool = False, act_order: bool = False) -> None: if not inplace: model = deepcopy(model) - super().__init__(model, group_of_parallel_layers, inplace, use_quant_activations, act_order) + super().__init__( + model, + group_of_parallel_layers, + inplace, + use_quant_activations, + act_order, + return_forward_output) self.orig_forward = self.model.forward self.model.forward = self.catch_stopfwd @@ -187,11 +194,11 @@ def catch_stopfwd(self, *args, **kwargs): finally: if self.return_forward_output: # If we want to return the output of the network, we need to disable all hooks - for name, gptq_class in self.gptq_layers.items(): - gptq_class.disable_pre_forward_hook = True + for name, gpxq_class in self.gpxq_layers.items(): + gpxq_class.disable_pre_forward_hook = True out = self.orig_forward(*args, **kwargs) - for name, gptq_class in self.gptq_layers.items(): - gptq_class.disable_pre_forward_hook = False + for name, gpxq_class in self.gpxq_layers.items(): + gpxq_class.disable_pre_forward_hook = False return out @@ -222,14 +229,23 @@ def __init__( group_of_parallel_layers: Optional[List[str]] = None, inplace: bool = True, use_quant_activations: bool = True, + p: int = 0.25, + return_forward_output: bool = False, act_order: bool = False) -> None: if not inplace: model = deepcopy(model) - super().__init__(model, group_of_parallel_layers, inplace, use_quant_activations, act_order) + super().__init__( + model, + group_of_parallel_layers, + inplace, + use_quant_activations, + act_order, + return_forward_output) self.orig_forward = self.model.forward self.model.forward = self.catch_stopfwd self.class_implementation = GPFQ + GPFQ.p = p def catch_stopfwd(self, *args, **kwargs): # Collect quant input @@ -263,6 +279,15 @@ def catch_stopfwd(self, *args, **kwargs): else: self.disable_quant_inference.disable_bias_quantization(self.model, is_training=False) + if self.return_forward_output: + # If we want to return the output of the network, we need to disable all hooks + for name, gpxq_class in self.gpxq_layers.items(): + gpxq_class.disable_pre_forward_hook = True + out = self.orig_forward(*args, **kwargs) + for name, gpxq_class in self.gpxq_layers.items(): + gpxq_class.disable_pre_forward_hook = False + return out + class GPxQ(ABC): @@ -291,7 +316,6 @@ def __init__(self, layer, name, act_order, parallel_layers=1) -> None: # Some layers require knowledge from quant inputs to compute quant weights self.quant_input = None - def process_input(self, inp): # Input is a tuple, so we take first element inp = inp[0] @@ -321,19 +345,19 @@ def process_input(self, inp): return inp @abstractmethod - def update_batch(self, module, input, current_layer): + def update_batch(self): pass @abstractmethod - def single_layer_update(self, percdamp=.01): + def single_layer_update(self): pass def get_quant_weights(self, i, i1, permutation_list): # We need to recompute quant weights at runtime since our float weights are being updated - # Add offset in case of blockwise computation (e.g., GPTQ) + # Add offset in case of blockwise computation i = i1 + i # For QuantLinear and for some QuantConvolutional layers, we exploit the possibility - # of quantizing only a subset of the entire matrix speeding up the computation of GPTQ + # of quantizing only a subset of the entire matrix speeding up the computation of GPxQ if isinstance(self.layer, qnn.QuantLinear): index = permutation_list[0][i] subtensor_slice_list = [None, (index, index + 1)] @@ -414,6 +438,9 @@ def __init__(self, layer, name, act_order, parallel_layers=1) -> None: self.nsamples = 0 def update_batch(self, module, input, current_layer): + if self.disable_pre_forward_hook: + return input + # Update reference to current layer current_layer.layer_names.add(self.name) inp = self.process_input(input) @@ -554,15 +581,22 @@ class GPFQ(GPxQ): """ Based on https://github.com/YixuanSeanZhou/Quantized_Neural_Nets/tree/main """ + p = 0.25 def __init__(self, layer, name, act_order, parallel_layers=1) -> None: + + if act_order: + raise ValueError("Act_order is not supported in GPFQ") + super().__init__(layer, name, act_order, parallel_layers) self.float_input = None self.quantized_input = None self.index_computed = False - self.p = 0.25 + self.p = GPFQ.p def update_batch(self, module, input, current_layer): + if self.disable_pre_forward_hook: + return input # Update reference to current layer current_layer.layer_names.add(self.name) @@ -631,7 +665,13 @@ def update_batch(self, module, input, current_layer): self.quantized_input = inp_processed else: self.quantized_input = torch.cat([self.quantized_input, inp_processed], dim=1) - raise StopFwdException + # If we are executing GPFQ with group of parallel layers, we keep track of how many forward + # we executed. Once we executed as many as the number of parallel_layers, we raise + # StopFwdException + current_layer.forward_count += 1 + if current_layer.forward_count == len(self.parallel_layers): + current_layer.forward_count = 0 + raise StopFwdException def single_layer_update(self): weight = self.layer.weight.data @@ -641,12 +681,14 @@ def single_layer_update(self): if isinstance(self.layer, (qnn.QuantConvTranspose1d, qnn.QuantConvTranspose2d)): weight = weight.transpose(1, 0) # This performs a view weight = weight.flatten(1) - weight = weight.view(self.groups, -1, weight.shape[-1]) # [Groups, OC/Groups, IC] - U = torch.zeros(weight.shape[0], weight.shape[1], self.float_input.shape[1], device=dev, dtype=dtype) + weight = weight.view(self.groups, -1, weight.shape[-1]) # [Groups, OC/Groups, IC] + U = torch.zeros( + weight.shape[0], weight.shape[1], self.float_input.shape[1], device=dev, dtype=dtype) for t in range(weight.shape[-1]): for group_index in range(self.groups): - U[group_index] += weight[group_index, :, t].unsqueeze(1) * self.float_input[group_index, :, t].unsqueeze(0) #[OC/Groups, 1] * [1, INSHAPE[1]] + U[group_index] += weight[group_index, :, t].unsqueeze(1) * self.float_input[ + group_index, :, t].unsqueeze(0) #[OC/Groups, 1] * [1, INSHAPE[1]] norm = torch.linalg.norm(self.quantized_input[group_index, :, t], 2) ** 2 if norm > 0: q_arg = U[group_index].matmul(self.quantized_input[group_index, :, t]) / norm @@ -656,7 +698,8 @@ def single_layer_update(self): weight[group_index, :, t] = q_arg q = self.get_quant_weights(t, 0, [torch.tensor(range(weight.shape[-1]))]) for group_index in range(self.groups): - U[group_index] -= q[group_index].unsqueeze(1) * self.quantized_input[group_index, :, t].unsqueeze(0) + U[group_index] -= q[group_index].unsqueeze(1) * self.quantized_input[ + group_index, :, t].unsqueeze(0) del self.float_input del self.quantized_input diff --git a/src/brevitas_examples/imagenet_classification/ptq/README.md b/src/brevitas_examples/imagenet_classification/ptq/README.md index e0e7c7455..7312b8c2b 100644 --- a/src/brevitas_examples/imagenet_classification/ptq/README.md +++ b/src/brevitas_examples/imagenet_classification/ptq/README.md @@ -85,7 +85,8 @@ usage: ptq_evaluate.py [-h] --calibration-dir CALIBRATION_DIR --validation-dir [--bias-corr | --no-bias-corr] [--graph-eq-merge-bias | --no-graph-eq-merge-bias] [--weight-narrow-range | --no-weight-narrow-range] - [--gptq | --no-gptq] + [--gpfq-p GPFQ_P] [--gptq | --no-gptq] + [--gpfq | --no-gpfq] [--gptq-act-order | --no-gptq-act-order] [--learned-round | --no-learned-round] [--calibrate-bn | --no-calibrate-bn] @@ -171,8 +172,11 @@ optional arguments: Enable Narrow range for weight quantization (default: enabled) --no-weight-narrow-range Disable Narrow range for weight quantization (default: enabled) + --gpfq-p GPFQ_P P parameter for GPFQ (default: 0.25) --gptq Enable GPTQ (default: enabled) --no-gptq Disable GPTQ (default: enabled) + --gpfq Enable GPFQ (default: disabled) + --no-gpfq Disable GPFQ (default: disabled) --gptq-act-order Enable GPTQ Act order heuristic (default: disabled) --no-gptq-act-order Disable GPTQ Act order heuristic (default: disabled) --learned-round Enable Learned round (default: disabled) diff --git a/src/brevitas_examples/imagenet_classification/ptq/ptq_common.py b/src/brevitas_examples/imagenet_classification/ptq/ptq_common.py index 343624950..7e4afc0f8 100644 --- a/src/brevitas_examples/imagenet_classification/ptq/ptq_common.py +++ b/src/brevitas_examples/imagenet_classification/ptq/ptq_common.py @@ -361,12 +361,12 @@ def apply_gptq(calib_loader, model, act_order=False): gptq.update() -def apply_gpfq(calib_loader, model, act_order=False): +def apply_gpfq(calib_loader, model, p=0.25): model.eval() dtype = next(model.parameters()).dtype device = next(model.parameters()).device with torch.no_grad(): - with gpfq_mode(model, act_order=act_order, use_quant_activations=True) as gpfq: + with gpfq_mode(model, p=p, use_quant_activations=True) as gpfq: gpfq_model = gpfq.model for i in tqdm(range(gpfq.num_layers)): for i, (images, target) in enumerate(calib_loader): diff --git a/src/brevitas_examples/imagenet_classification/ptq/ptq_evaluate.py b/src/brevitas_examples/imagenet_classification/ptq/ptq_evaluate.py index 08dde3b21..a560cd4c3 100644 --- a/src/brevitas_examples/imagenet_classification/ptq/ptq_evaluate.py +++ b/src/brevitas_examples/imagenet_classification/ptq/ptq_evaluate.py @@ -166,6 +166,8 @@ 'weight-narrow-range', default=True, help='Narrow range for weight quantization (default: enabled)') +parser.add_argument( + '--gpfq-p', default=0.25, type=float, help='P parameter for GPFQ (default: 0.25)') add_bool_arg(parser, 'gptq', default=True, help='GPTQ (default: enabled)') add_bool_arg(parser, 'gpfq', default=False, help='GPFQ (default: disabled)') add_bool_arg( @@ -193,6 +195,7 @@ def main(): f"a{args.act_bit_width}" f"w{args.weight_bit_width}_" f"{'gptq_' if args.gptq else ''}" + f"{'gpfq_' if args.gpfq else ''}" f"{'gptq_act_order_' if args.gptq_act_order else ''}" f"{'learned_round_' if args.learned_round else ''}" f"{'weight_narrow_range_' if args.weight_narrow_range else ''}" @@ -213,6 +216,8 @@ def main(): f"Activation bit width: {args.act_bit_width} - " f"Weight bit width: {args.weight_bit_width} - " f"GPTQ: {args.gptq} - " + f"GPFQ: {args.gpfq} - " + f"GPFQ P: {args.gpfq_p} - " f"GPTQ Act Order: {args.gptq_act_order} - " f"Learned Round: {args.learned_round} - " f"Weight narrow range: {args.weight_narrow_range} - " @@ -303,12 +308,11 @@ def main(): if args.gpfq: print("Performing GPFQ:") - apply_gpfq(calib_loader, quant_model) + apply_gpfq(calib_loader, quant_model, p=args.gpfq_p) if args.gptq: print("Performing GPTQ:") - apply_gptq(calib_loader, quant_model, args.gptq_act_order) - + apply_gptq(calib_loader, quant_model, act_order=args.gptq_act_order) if args.learned_round: print("Applying Learned Round:") From 1055dc2be93edd890181a2e322ad0f31e701dadc Mon Sep 17 00:00:00 2001 From: Giuseppe Franco Date: Thu, 24 Aug 2023 14:12:44 +0100 Subject: [PATCH 06/12] Fix for llm import --- src/brevitas_examples/llm/llm_quant/gptq.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/brevitas_examples/llm/llm_quant/gptq.py b/src/brevitas_examples/llm/llm_quant/gptq.py index 2e73bdf76..ab5d78195 100644 --- a/src/brevitas_examples/llm/llm_quant/gptq.py +++ b/src/brevitas_examples/llm/llm_quant/gptq.py @@ -5,7 +5,7 @@ import torch -from brevitas.graph.gptq import gptq_mode +from brevitas.graph.gpxq import gptq_mode from brevitas_examples.llm.llm_quant.run_utils import apply_layer_ptq_fn From 57a164d9318be5aff67c28eccfc0d4b5fe0e0048 Mon Sep 17 00:00:00 2001 From: Giuseppe Franco Date: Tue, 5 Sep 2023 17:49:42 +0100 Subject: [PATCH 07/12] File split --- src/brevitas/graph/gpfq.py | 235 +++++++++ src/brevitas/graph/gptq.py | 258 ++++++++++ src/brevitas/graph/gpxq.py | 460 +----------------- .../imagenet_classification/ptq/ptq_common.py | 4 +- src/brevitas_examples/llm/llm_quant/gptq.py | 2 +- 5 files changed, 497 insertions(+), 462 deletions(-) create mode 100644 src/brevitas/graph/gpfq.py create mode 100644 src/brevitas/graph/gptq.py diff --git a/src/brevitas/graph/gpfq.py b/src/brevitas/graph/gpfq.py new file mode 100644 index 000000000..b4fd66076 --- /dev/null +++ b/src/brevitas/graph/gpfq.py @@ -0,0 +1,235 @@ +# Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause + +from copy import deepcopy +from typing import List, Optional + +import numpy as np +import torch +import unfoldNd + +from brevitas.graph.gpxq import GPxQ +from brevitas.graph.gpxq import gpxq_mode +from brevitas.graph.gpxq import StopFwdException +from brevitas.graph.gpxq import SUPPORTED_CONV_OP +import brevitas.nn as qnn + + +class gpfq_mode(gpxq_mode): + """ + Apply GPFQ algorithm. + + Args: + model (Module): The model to quantize with GPFQ + inplace (bool): Wheter to apply GPFQ inplace or perform a deepcopy. Default: True + use_quant_activations (bool): Wheter to leave quantize activations enabled while performing + GPFQ. Default: False + + Example: + >>> with torch.no_grad(): + >>> with gpfq_mode(model) as gpfq: + >>> gpfq_model = gpfq.model + >>> for i in tqdm(range(gpfq.num_layers)): + >>> for img, t in calib_loader: + >>> img = img.cuda() + >>> gpfq_model(img) + >>> gpfq.update() + """ + + def __init__( + self, + model, + group_of_parallel_layers: Optional[List[str]] = None, + inplace: bool = True, + use_quant_activations: bool = True, + p: int = 0.25, + return_forward_output: bool = False, + act_order: bool = False) -> None: + if not inplace: + model = deepcopy(model) + super().__init__( + model, + group_of_parallel_layers, + inplace, + use_quant_activations, + act_order, + return_forward_output) + + self.orig_forward = self.model.forward + self.model.forward = self.catch_stopfwd + self.class_implementation = GPFQ + GPFQ.p = p + + def catch_stopfwd(self, *args, **kwargs): + # Collect quant input + try: + self.orig_forward(*args, **kwargs) + except StopFwdException: + pass + # Before collecting float input, restore original float weights if they have been modified + quant_weight = dict() + for module in self.model.modules(): + if hasattr(module, 'weight_orig_data'): + quant_weight[module] = deepcopy(module.weight.data) + module.weight.data = module.weight_orig_data + # Disable quantization + self.disable_quant_inference.disable_param_quantization(self.model, is_training=False) + self.disable_quant_inference.disable_act_quantization(self.model, is_training=False) + # Collect float input + try: + self.orig_forward(*args, **kwargs) + except StopFwdException: + pass + # Restore correct weights + for module in self.model.modules(): + if hasattr(module, 'weight_orig_data'): + module.weight.data = quant_weight[module] + # Re-enable quantization. If activation quantization is disabled, + # we also disable bias quantization + self.disable_quant_inference.enable_param_quantization(self.model, is_training=False) + if self.use_quant_activations: + self.disable_quant_inference.enable_act_quantization(self.model, is_training=False) + else: + self.disable_quant_inference.disable_bias_quantization(self.model, is_training=False) + + if self.return_forward_output: + # If we want to return the output of the network, we need to disable all hooks + for name, gpxq_class in self.gpxq_layers.items(): + gpxq_class.disable_pre_forward_hook = True + out = self.orig_forward(*args, **kwargs) + for name, gpxq_class in self.gpxq_layers.items(): + gpxq_class.disable_pre_forward_hook = False + return out + + +class GPFQ(GPxQ): + """ + Based on https://github.com/YixuanSeanZhou/Quantized_Neural_Nets/tree/main + """ + p = 0.25 + + def __init__(self, layer, name, act_order, parallel_layers=1) -> None: + + if act_order: + raise ValueError("Act_order is not supported in GPFQ") + + super().__init__(layer, name, act_order, parallel_layers) + self.float_input = None + self.quantized_input = None + self.index_computed = False + self.p = GPFQ.p + + def update_batch(self, module, input, current_layer): + if self.disable_pre_forward_hook: + return input + + # Update reference to current layer + current_layer.layer_names.add(self.name) + is_quant_disabled = module.weight_quant.disable_quant + + inp = self.process_input(input) + batch_size = inp.shape[0] + + # Preprocess the input to compute the Hessian + if isinstance(self.layer, qnn.QuantLinear): + if len(inp.shape) > 2: + inp = inp.reshape((-1, sum(inp.shape[2:]))) + # For QuantLinear layer, groups will be 1 + inp_processed = inp.unsqueeze(0) + + if isinstance(self.layer, SUPPORTED_CONV_OP): + # Pick the correct unfoldNd class + if isinstance(self.layer, (qnn.QuantConvTranspose1d, qnn.QuantConvTranspose2d)): + unfold_impl = unfoldNd.UnfoldTransposeNd + else: + unfold_impl = unfoldNd.UnfoldNd + + unfold = unfold_impl( + self.layer.kernel_size, + dilation=self.layer.dilation, + padding=self.layer.padding, + stride=self.layer.kernel_size) + + # Split input based on how many groups in convolution + inp_by_group = torch.chunk(inp, self.groups, 1) + inp_processed = [] + # Preprocess input by group + for i, inp in enumerate(inp_by_group): + + inp = unfold(inp) + + batch_size, num_blocks = inp.shape[0], inp.shape[-1] + inp = torch.transpose(inp, 1, 2) # shape (B, L, C*kernel_size[0]*kernel_size[1]) + inp = inp.reshape(-1, inp.size(-1)) # shape (B*L, C*kernel_size[0]*kernel_size[1]) + + if not self.index_computed: + self.index_computed = True + self.rand_indices = np.concatenate([ + np.random.choice( + np.arange(num_blocks * i, num_blocks * (i + 1)), + size=int( + self.p * num_blocks + 1 if self.p != 1 else self.p * num_blocks)) + for i in range(batch_size)]) # need to define self.p (probability) + + indexes = self.rand_indices + if np.max(self.rand_indices) > inp.shape[0]: + indexes = self.rand_indices < inp.shape[0] + indexes = self.rand_indices[indexes] + + inp = inp[indexes] + inp_processed.append(inp) + inp_processed = torch.stack(inp_processed) + + if is_quant_disabled: + if self.float_input is None: + self.float_input = inp_processed + else: + self.float_input = torch.cat([self.float_input, inp_processed], dim=1) + else: + if self.quantized_input is None: + self.quantized_input = inp_processed + else: + self.quantized_input = torch.cat([self.quantized_input, inp_processed], dim=1) + # If we are executing GPFQ with group of parallel layers, we keep track of how many forward + # we executed. Once we executed as many as the number of parallel_layers, we raise + # StopFwdException + current_layer.forward_count += 1 + if current_layer.forward_count == len(self.parallel_layers): + current_layer.forward_count = 0 + raise StopFwdException + + def single_layer_update(self): + weight = self.layer.weight.data + dev = weight.device + dtype = weight.dtype + if isinstance(self.layer, SUPPORTED_CONV_OP): + if isinstance(self.layer, (qnn.QuantConvTranspose1d, qnn.QuantConvTranspose2d)): + weight = weight.transpose(1, 0) # This performs a view + weight = weight.flatten(1) + weight = weight.view(self.groups, -1, weight.shape[-1]) # [Groups, OC/Groups, IC] + U = torch.zeros( + weight.shape[0], weight.shape[1], self.float_input.shape[1], device=dev, dtype=dtype) + self.float_input = self.float_input.to(dev) + self.quantized_input = self.quantized_input.to(dev) + permutation_list = [torch.tensor(range(weight.shape[-1]))] + for t in range(weight.shape[-1]): + for group_index in range(self.groups): + U[group_index] += torch.matmul( + weight[group_index, :, t].unsqueeze(1), + self.float_input[group_index, :, + t].unsqueeze(0)) #[OC/Groups, 1] * [1, INSHAPE[1]] + norm = torch.linalg.norm(self.quantized_input[group_index, :, t], 2) ** 2 + if norm > 0: + q_arg = U[group_index].matmul(self.quantized_input[group_index, :, t]) / norm + else: + q_arg = torch.zeros_like(U[group_index, :, 0]) + + weight[group_index, :, t] = q_arg + q = self.get_quant_weights(t, 0, permutation_list) + for group_index in range(self.groups): + U[group_index] -= torch.matmul( + q[group_index].unsqueeze(1), + self.quantized_input[group_index, :, t].unsqueeze(0)) + + del self.float_input + del self.quantized_input diff --git a/src/brevitas/graph/gptq.py b/src/brevitas/graph/gptq.py new file mode 100644 index 000000000..32e3a7869 --- /dev/null +++ b/src/brevitas/graph/gptq.py @@ -0,0 +1,258 @@ +# Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause + +from copy import deepcopy +import math +from typing import List, Optional, Set +import warnings + +import torch + +try: + from torch.linalg import LinAlgError +except: + LinAlgError = RuntimeError +import unfoldNd + +from brevitas.graph.gpxq import GPxQ +from brevitas.graph.gpxq import gpxq_mode +from brevitas.graph.gpxq import StopFwdException +from brevitas.graph.gpxq import SUPPORTED_CONV_OP +import brevitas.nn as qnn + + +class gptq_mode(gpxq_mode): + """ + Apply GPTQ algorithm https://arxiv.org/abs/2210.17323. + + Args: + model (Module): The model to quantize with GPTQ + inplace (bool): Wheter to apply GPTQ inplace or perform a deepcopy. Default: True + use_quant_activations (bool): Wheter to leave quantize activations enabled while performing + GPTQ. Default: False + + Example: + >>> with torch.no_grad(): + >>> with gptq_mode(model) as gptq: + >>> gptq_model = gptq.model + >>> for i in tqdm(range(gptq.num_layers)): + >>> for img, t in calib_loader: + >>> img = img.cuda() + >>> gptq_model(img) + >>> gptq.update() + """ + + def __init__( + self, + model, + group_of_parallel_layers: Optional[List[str]] = None, + inplace: bool = True, + use_quant_activations: bool = True, + num_blocks: int = 100, + return_forward_output: bool = False, + act_order: bool = False) -> None: + if not inplace: + model = deepcopy(model) + super().__init__( + model, + group_of_parallel_layers, + inplace, + use_quant_activations, + act_order, + return_forward_output) + + self.orig_forward = self.model.forward + self.model.forward = self.catch_stopfwd + # How many subblock to use during GPTQ for each layer + self.num_blocks = num_blocks + self.class_implementation = GPTQ + GPTQ.num_blocks = num_blocks + + def catch_stopfwd(self, *args, **kwargs): + try: + self.orig_forward(*args, **kwargs) + except StopFwdException: + pass + finally: + if self.return_forward_output: + # If we want to return the output of the network, we need to disable all hooks + for name, gpxq_class in self.gpxq_layers.items(): + gpxq_class.disable_pre_forward_hook = True + out = self.orig_forward(*args, **kwargs) + for name, gpxq_class in self.gpxq_layers.items(): + gpxq_class.disable_pre_forward_hook = False + return out + + +class GPTQ(GPxQ): + """ + Adapted from https://github.com/IST-DASLab/gptq, released under the following LICENSE: + + Copyright 2023 IST-DASLab + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + """ + num_blocks = 100 + + def __init__(self, layer, name, act_order, parallel_layers=1) -> None: + super().__init__(layer, name, act_order, parallel_layers) + + dev = self.layer.weight.device + + # Define how many columns to update in each mini-block + self.blocksize = math.ceil(self.columns / GPTQ.num_blocks) + + # Initialize Hessian matrix and counter. We need it in float32 to compute the inverse + self.H = torch.zeros((self.groups, self.columns, self.columns), + device=dev, + dtype=torch.float32) + self.nsamples = 0 + + def update_batch(self, module, input, current_layer): + if self.disable_pre_forward_hook: + return input + + # Update reference to current layer + current_layer.layer_names.add(self.name) + inp = self.process_input(input) + batch_size = inp.shape[0] + + # Preprocess the input to compute the Hessian + if isinstance(self.layer, qnn.QuantLinear): + if len(inp.shape) > 2: + inp = inp.reshape((-1, sum(inp.shape[2:]))) + inp = inp.t() + # For QuantLinear layer, groups will be 1 + inp_processed = inp.unsqueeze(0) + + if isinstance(self.layer, SUPPORTED_CONV_OP): + # Pick the correct unfoldNd class + if isinstance(self.layer, (qnn.QuantConvTranspose1d, qnn.QuantConvTranspose2d)): + unfold_impl = unfoldNd.UnfoldTransposeNd + else: + unfold_impl = unfoldNd.UnfoldNd + + unfold = unfold_impl( + self.layer.kernel_size, + dilation=self.layer.dilation, + padding=self.layer.padding, + stride=self.layer.stride) + + # Split input based on how many groups in convolution + inp_by_group = torch.chunk(inp, self.groups, 1) + inp_processed = [] + # Preprocess input by group + for i, inp in enumerate(inp_by_group): + inp = unfold(inp) + inp = inp.transpose(1, 0) + inp = inp.flatten(1) + inp_processed.append(inp) + inp_processed = torch.stack(inp_processed) + + # Hessian computation + self.H *= self.nsamples / (self.nsamples + batch_size) + self.nsamples += batch_size + inp_processed = math.sqrt(2 / self.nsamples) * inp_processed.to(torch.float32) + self.H += inp_processed.bmm(inp_processed.transpose(2, 1)) + # If we are executing GPTQ with group of parallel layers, we keep track of how many forward + # we executed. Once we executed as many as the number of parallel_layers, we raise + # StopFwdException + current_layer.forward_count += 1 + if current_layer.forward_count == len(self.parallel_layers): + current_layer.forward_count = 0 + raise StopFwdException + + def single_layer_update(self, percdamp=.01): + weight = self.layer.weight.data + dev = weight.device + + # Store the original dtype of the weights + # During computation, everything is converted to float32. + # When the weights are updated, we cast everything back to the original dtype + dtype = weight.dtype + + if isinstance(self.layer, SUPPORTED_CONV_OP): + if isinstance(self.layer, (qnn.QuantConvTranspose1d, qnn.QuantConvTranspose2d)): + weight = weight.transpose(1, 0) # This performs a view + weight = weight.flatten(1) + + # List with permutation tensors for the Hessian and Weight matrix. + # If act_order is False, the tensors will be ordered indexes. + # For groupwise convolution, we have one tensor per group, + # thus len(permutation_list) is always equal to self.groups. + # We do not explicity permute the weight matrix, only the Hessian. + permutation_list = [] + weight = weight.view(self.groups, -1, weight.shape[-1]) + # For groupwise convolution, these operations are groupwise so we iterate + for i in range(self.groups): + # If a diagonal element on the Hessian is zero, we can set to 0 the corresponding + # column in the weight matrix. + # The diagonal element is set to 1 to avoid division-by-zero + dead = torch.diag(self.H[i, :, :]) == 0 + self.H[i, dead, dead] = 1 + # If the diagonal of activations is zero, we set the weight to zero + weight[i, :, dead] = 0 + if self.act_order: + # Re-order Hessian so that weights associated to + # higher magnitude activations are quantized first + perm = torch.argsort(torch.diag(self.H[i, :, :]), descending=True) + self.H[i, :, :] = self.H[i, perm, :][:, perm] + else: + # No permutation, permutation tensor is a ordered index + perm = torch.tensor(range(self.H.shape[-1]), device=dev) + permutation_list.append(perm) + + # Try/Except in case the inverse Hessian cannot be computed + try: + for i in range(self.groups): + damp = percdamp * torch.mean(torch.diag(self.H[i, :, :])) + diag = torch.arange(self.columns, device=dev) + self.H[i, diag, diag] += damp + self.H[i, :, :] = torch.linalg.cholesky(self.H[i, :, :]) + self.H[i, :, :] = torch.cholesky_inverse(self.H[i, :, :]) + self.H[i, :, :] = torch.linalg.cholesky(self.H[i, :, :], upper=True) + h_inv = self.H + except LinAlgError as e: + warnings.warn( + f'Failed to compute the inverse of the Hessian for layer {self.name} ' + f'GPTQ will not be applied. ' + f'Increasing the number of samples might fix this issue') + return + finally: + del self.H + + for i1 in range(0, self.columns, self.blocksize): + i2 = min(i1 + self.blocksize, self.columns) + count = i2 - i1 + error_block = torch.zeros_like( + weight[:, :, perm[i1:i2]], dtype=torch.float32) # [groups, OC/groups, i2-i1] + + h_inv_block = h_inv[:, i1:i2, i1:i2] + for i in range(count): + q_groups = self.get_quant_weights(i, i1, permutation_list) # [groups, OC/groups] + for group_index in range(self.groups): + perm = permutation_list[group_index] + q = q_groups[group_index] # [OC/groups] + w = weight[group_index, :, perm[i1:i2][i]].to(torch.float32) # [OC/groups] + d = h_inv_block[group_index, i, i] # [1] + error = (w - q) / d # [OC/groups] + error_block[group_index, :, i] = error + # We need to update the original weights + weight[group_index, :, perm[i1:i2][i:]] -= ( + error.unsqueeze(1).matmul(h_inv_block[group_index, i, + i:].unsqueeze(0))).to(dtype) + + for group_index in range(self.groups): + perm = permutation_list[group_index] + weight[group_index, :, perm[i2:]] -= ( + error_block[group_index].matmul(h_inv[group_index, i1:i2, i2:])).to(dtype) diff --git a/src/brevitas/graph/gpxq.py b/src/brevitas/graph/gpxq.py index c076b2827..9dc11d955 100644 --- a/src/brevitas/graph/gpxq.py +++ b/src/brevitas/graph/gpxq.py @@ -7,18 +7,12 @@ from dataclasses import dataclass from dataclasses import field from functools import partial -import math from operator import attrgetter from typing import List, Optional, Set import warnings -import torch - -try: - from torch.linalg import LinAlgError -except: - LinAlgError = RuntimeError import numpy as np +import torch import unfoldNd from brevitas.graph.calibrate import DisableEnableQuantization @@ -139,156 +133,6 @@ def catch_stopfwd(self, *args, **kwargs): pass -class gptq_mode(gpxq_mode): - """ - Apply GPTQ algorithm https://arxiv.org/abs/2210.17323. - - Args: - model (Module): The model to quantize with GPTQ - inplace (bool): Wheter to apply GPTQ inplace or perform a deepcopy. Default: True - use_quant_activations (bool): Wheter to leave quantize activations enabled while performing - GPTQ. Default: False - - Example: - >>> with torch.no_grad(): - >>> with gptq_mode(model) as gptq: - >>> gptq_model = gptq.model - >>> for i in tqdm(range(gptq.num_layers)): - >>> for img, t in calib_loader: - >>> img = img.cuda() - >>> gptq_model(img) - >>> gptq.update() - """ - - def __init__( - self, - model, - group_of_parallel_layers: Optional[List[str]] = None, - inplace: bool = True, - use_quant_activations: bool = True, - num_blocks: int = 100, - return_forward_output: bool = False, - act_order: bool = False) -> None: - if not inplace: - model = deepcopy(model) - super().__init__( - model, - group_of_parallel_layers, - inplace, - use_quant_activations, - act_order, - return_forward_output) - - self.orig_forward = self.model.forward - self.model.forward = self.catch_stopfwd - # How many subblock to use during GPTQ for each layer - self.num_blocks = num_blocks - self.class_implementation = GPTQ - GPTQ.num_blocks = num_blocks - - def catch_stopfwd(self, *args, **kwargs): - try: - self.orig_forward(*args, **kwargs) - except StopFwdException: - pass - finally: - if self.return_forward_output: - # If we want to return the output of the network, we need to disable all hooks - for name, gpxq_class in self.gpxq_layers.items(): - gpxq_class.disable_pre_forward_hook = True - out = self.orig_forward(*args, **kwargs) - for name, gpxq_class in self.gpxq_layers.items(): - gpxq_class.disable_pre_forward_hook = False - return out - - -class gpfq_mode(gpxq_mode): - """ - Apply GPFQ algorithm. - - Args: - model (Module): The model to quantize with GPFQ - inplace (bool): Wheter to apply GPFQ inplace or perform a deepcopy. Default: True - use_quant_activations (bool): Wheter to leave quantize activations enabled while performing - GPFQ. Default: False - - Example: - >>> with torch.no_grad(): - >>> with gpfq_mode(model) as gpfq: - >>> gpfq_model = gpfq.model - >>> for i in tqdm(range(gpfq.num_layers)): - >>> for img, t in calib_loader: - >>> img = img.cuda() - >>> gpfq_model(img) - >>> gpfq.update() - """ - - def __init__( - self, - model, - group_of_parallel_layers: Optional[List[str]] = None, - inplace: bool = True, - use_quant_activations: bool = True, - p: int = 0.25, - return_forward_output: bool = False, - act_order: bool = False) -> None: - if not inplace: - model = deepcopy(model) - super().__init__( - model, - group_of_parallel_layers, - inplace, - use_quant_activations, - act_order, - return_forward_output) - - self.orig_forward = self.model.forward - self.model.forward = self.catch_stopfwd - self.class_implementation = GPFQ - GPFQ.p = p - - def catch_stopfwd(self, *args, **kwargs): - # Collect quant input - try: - self.orig_forward(*args, **kwargs) - except StopFwdException: - pass - # Before collecting float input, restore original float weights if they have been modified - quant_weight = dict() - for module in self.model.modules(): - if hasattr(module, 'weight_orig_data'): - quant_weight[module] = deepcopy(module.weight.data) - module.weight.data = module.weight_orig_data - # Disable quantization - self.disable_quant_inference.disable_param_quantization(self.model, is_training=False) - self.disable_quant_inference.disable_act_quantization(self.model, is_training=False) - # Collect float input - try: - self.orig_forward(*args, **kwargs) - except StopFwdException: - pass - # Restore correct weights - for module in self.model.modules(): - if hasattr(module, 'weight_orig_data'): - module.weight.data = quant_weight[module] - # Re-enable quantization. If activation quantization is disabled, - # we also disable bias quantization - self.disable_quant_inference.enable_param_quantization(self.model, is_training=False) - if self.use_quant_activations: - self.disable_quant_inference.enable_act_quantization(self.model, is_training=False) - else: - self.disable_quant_inference.disable_bias_quantization(self.model, is_training=False) - - if self.return_forward_output: - # If we want to return the output of the network, we need to disable all hooks - for name, gpxq_class in self.gpxq_layers.items(): - gpxq_class.disable_pre_forward_hook = True - out = self.orig_forward(*args, **kwargs) - for name, gpxq_class in self.gpxq_layers.items(): - gpxq_class.disable_pre_forward_hook = False - return out - - class GPxQ(ABC): def __init__(self, layer, name, act_order, parallel_layers=1) -> None: @@ -401,305 +245,3 @@ def get_quant_weights(self, i, i1, permutation_list): # We need to remove the last dim q = q.squeeze(2) # [groups, OC/groups] or [1, OC] return q - - -class GPTQ(GPxQ): - """ - Adapted from https://github.com/IST-DASLab/gptq, released under the following LICENSE: - - Copyright 2023 IST-DASLab - - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. - """ - num_blocks = 100 - - def __init__(self, layer, name, act_order, parallel_layers=1) -> None: - super().__init__(layer, name, act_order, parallel_layers) - - dev = self.layer.weight.device - - # Define how many columns to update in each mini-block - self.blocksize = math.ceil(self.columns / GPTQ.num_blocks) - - # Initialize Hessian matrix and counter. We need it in float32 to compute the inverse - self.H = torch.zeros((self.groups, self.columns, self.columns), - device=dev, - dtype=torch.float32) - self.nsamples = 0 - - def update_batch(self, module, input, current_layer): - if self.disable_pre_forward_hook: - return input - - # Update reference to current layer - current_layer.layer_names.add(self.name) - inp = self.process_input(input) - batch_size = inp.shape[0] - - # Preprocess the input to compute the Hessian - if isinstance(self.layer, qnn.QuantLinear): - if len(inp.shape) > 2: - inp = inp.reshape((-1, sum(inp.shape[2:]))) - inp = inp.t() - # For QuantLinear layer, groups will be 1 - inp_processed = inp.unsqueeze(0) - - if isinstance(self.layer, SUPPORTED_CONV_OP): - # Pick the correct unfoldNd class - if isinstance(self.layer, (qnn.QuantConvTranspose1d, qnn.QuantConvTranspose2d)): - unfold_impl = unfoldNd.UnfoldTransposeNd - else: - unfold_impl = unfoldNd.UnfoldNd - - unfold = unfold_impl( - self.layer.kernel_size, - dilation=self.layer.dilation, - padding=self.layer.padding, - stride=self.layer.stride) - - # Split input based on how many groups in convolution - inp_by_group = torch.chunk(inp, self.groups, 1) - inp_processed = [] - # Preprocess input by group - for i, inp in enumerate(inp_by_group): - inp = unfold(inp) - inp = inp.transpose(1, 0) - inp = inp.flatten(1) - inp_processed.append(inp) - inp_processed = torch.stack(inp_processed) - - # Hessian computation - self.H *= self.nsamples / (self.nsamples + batch_size) - self.nsamples += batch_size - inp_processed = math.sqrt(2 / self.nsamples) * inp_processed.to(torch.float32) - self.H += inp_processed.bmm(inp_processed.transpose(2, 1)) - # If we are executing GPTQ with group of parallel layers, we keep track of how many forward - # we executed. Once we executed as many as the number of parallel_layers, we raise - # StopFwdException - current_layer.forward_count += 1 - if current_layer.forward_count == len(self.parallel_layers): - current_layer.forward_count = 0 - raise StopFwdException - - def single_layer_update(self, percdamp=.01): - weight = self.layer.weight.data - dev = weight.device - - # Store the original dtype of the weights - # During computation, everything is converted to float32. - # When the weights are updated, we cast everything back to the original dtype - dtype = weight.dtype - - if isinstance(self.layer, SUPPORTED_CONV_OP): - if isinstance(self.layer, (qnn.QuantConvTranspose1d, qnn.QuantConvTranspose2d)): - weight = weight.transpose(1, 0) # This performs a view - weight = weight.flatten(1) - - # List with permutation tensors for the Hessian and Weight matrix. - # If act_order is False, the tensors will be ordered indexes. - # For groupwise convolution, we have one tensor per group, - # thus len(permutation_list) is always equal to self.groups. - # We do not explicity permute the weight matrix, only the Hessian. - permutation_list = [] - weight = weight.view(self.groups, -1, weight.shape[-1]) - # For groupwise convolution, these operations are groupwise so we iterate - for i in range(self.groups): - # If a diagonal element on the Hessian is zero, we can set to 0 the corresponding - # column in the weight matrix. - # The diagonal element is set to 1 to avoid division-by-zero - dead = torch.diag(self.H[i, :, :]) == 0 - self.H[i, dead, dead] = 1 - # If the diagonal of activations is zero, we set the weight to zero - weight[i, :, dead] = 0 - if self.act_order: - # Re-order Hessian so that weights associated to - # higher magnitude activations are quantized first - perm = torch.argsort(torch.diag(self.H[i, :, :]), descending=True) - self.H[i, :, :] = self.H[i, perm, :][:, perm] - else: - # No permutation, permutation tensor is a ordered index - perm = torch.tensor(range(self.H.shape[-1]), device=dev) - permutation_list.append(perm) - - # Try/Except in case the inverse Hessian cannot be computed - try: - for i in range(self.groups): - damp = percdamp * torch.mean(torch.diag(self.H[i, :, :])) - diag = torch.arange(self.columns, device=dev) - self.H[i, diag, diag] += damp - self.H[i, :, :] = torch.linalg.cholesky(self.H[i, :, :]) - self.H[i, :, :] = torch.cholesky_inverse(self.H[i, :, :]) - self.H[i, :, :] = torch.linalg.cholesky(self.H[i, :, :], upper=True) - h_inv = self.H - except LinAlgError as e: - warnings.warn( - f'Failed to compute the inverse of the Hessian for layer {self.name} ' - f'GPTQ will not be applied. ' - f'Increasing the number of samples might fix this issue') - return - finally: - del self.H - - for i1 in range(0, self.columns, self.blocksize): - i2 = min(i1 + self.blocksize, self.columns) - count = i2 - i1 - error_block = torch.zeros_like( - weight[:, :, perm[i1:i2]], dtype=torch.float32) # [groups, OC/groups, i2-i1] - - h_inv_block = h_inv[:, i1:i2, i1:i2] - for i in range(count): - q_groups = self.get_quant_weights(i, i1, permutation_list) # [groups, OC/groups] - for group_index in range(self.groups): - perm = permutation_list[group_index] - q = q_groups[group_index] # [OC/groups] - w = weight[group_index, :, perm[i1:i2][i]].to(torch.float32) # [OC/groups] - d = h_inv_block[group_index, i, i] # [1] - error = (w - q) / d # [OC/groups] - error_block[group_index, :, i] = error - # We need to update the original weights - weight[group_index, :, perm[i1:i2][i:]] -= ( - error.unsqueeze(1).matmul(h_inv_block[group_index, i, - i:].unsqueeze(0))).to(dtype) - - for group_index in range(self.groups): - perm = permutation_list[group_index] - weight[group_index, :, perm[i2:]] -= ( - error_block[group_index].matmul(h_inv[group_index, i1:i2, i2:])).to(dtype) - - -class GPFQ(GPxQ): - """ - Based on https://github.com/YixuanSeanZhou/Quantized_Neural_Nets/tree/main - """ - p = 0.25 - - def __init__(self, layer, name, act_order, parallel_layers=1) -> None: - - if act_order: - raise ValueError("Act_order is not supported in GPFQ") - - super().__init__(layer, name, act_order, parallel_layers) - self.float_input = None - self.quantized_input = None - self.index_computed = False - self.p = GPFQ.p - - def update_batch(self, module, input, current_layer): - if self.disable_pre_forward_hook: - return input - - # Update reference to current layer - current_layer.layer_names.add(self.name) - is_quant_disabled = module.weight_quant.disable_quant - - inp = self.process_input(input) - batch_size = inp.shape[0] - - # Preprocess the input to compute the Hessian - if isinstance(self.layer, qnn.QuantLinear): - if len(inp.shape) > 2: - inp = inp.reshape((-1, sum(inp.shape[2:]))) - # For QuantLinear layer, groups will be 1 - inp_processed = inp.unsqueeze(0) - - if isinstance(self.layer, SUPPORTED_CONV_OP): - # Pick the correct unfoldNd class - if isinstance(self.layer, (qnn.QuantConvTranspose1d, qnn.QuantConvTranspose2d)): - unfold_impl = unfoldNd.UnfoldTransposeNd - else: - unfold_impl = unfoldNd.UnfoldNd - - unfold = unfold_impl( - self.layer.kernel_size, - dilation=self.layer.dilation, - padding=self.layer.padding, - stride=self.layer.kernel_size) - - # Split input based on how many groups in convolution - inp_by_group = torch.chunk(inp, self.groups, 1) - inp_processed = [] - # Preprocess input by group - for i, inp in enumerate(inp_by_group): - - inp = unfold(inp) - - batch_size, num_blocks = inp.shape[0], inp.shape[-1] - inp = torch.transpose(inp, 1, 2) # shape (B, L, C*kernel_size[0]*kernel_size[1]) - inp = inp.reshape(-1, inp.size(-1)) # shape (B*L, C*kernel_size[0]*kernel_size[1]) - - if not self.index_computed: - self.index_computed = True - self.rand_indices = np.concatenate([ - np.random.choice( - np.arange(num_blocks * i, num_blocks * (i + 1)), - size=int( - self.p * num_blocks + 1 if self.p != 1 else self.p * num_blocks)) - for i in range(batch_size)]) # need to define self.p (probability) - - indexes = self.rand_indices - if np.max(self.rand_indices) > inp.shape[0]: - indexes = self.rand_indices < inp.shape[0] - indexes = self.rand_indices[indexes] - - inp = inp[indexes] - inp_processed.append(inp) - inp_processed = torch.stack(inp_processed) - - if is_quant_disabled: - if self.float_input is None: - self.float_input = inp_processed - else: - self.float_input = torch.cat([self.float_input, inp_processed], dim=1) - else: - if self.quantized_input is None: - self.quantized_input = inp_processed - else: - self.quantized_input = torch.cat([self.quantized_input, inp_processed], dim=1) - # If we are executing GPFQ with group of parallel layers, we keep track of how many forward - # we executed. Once we executed as many as the number of parallel_layers, we raise - # StopFwdException - current_layer.forward_count += 1 - if current_layer.forward_count == len(self.parallel_layers): - current_layer.forward_count = 0 - raise StopFwdException - - def single_layer_update(self): - weight = self.layer.weight.data - dev = weight.device - dtype = weight.dtype - if isinstance(self.layer, SUPPORTED_CONV_OP): - if isinstance(self.layer, (qnn.QuantConvTranspose1d, qnn.QuantConvTranspose2d)): - weight = weight.transpose(1, 0) # This performs a view - weight = weight.flatten(1) - weight = weight.view(self.groups, -1, weight.shape[-1]) # [Groups, OC/Groups, IC] - U = torch.zeros( - weight.shape[0], weight.shape[1], self.float_input.shape[1], device=dev, dtype=dtype) - - for t in range(weight.shape[-1]): - for group_index in range(self.groups): - U[group_index] += weight[group_index, :, t].unsqueeze(1) * self.float_input[ - group_index, :, t].unsqueeze(0) #[OC/Groups, 1] * [1, INSHAPE[1]] - norm = torch.linalg.norm(self.quantized_input[group_index, :, t], 2) ** 2 - if norm > 0: - q_arg = U[group_index].matmul(self.quantized_input[group_index, :, t]) / norm - else: - q_arg = torch.zeros_like(U[group_index, :, 0]) - - weight[group_index, :, t] = q_arg - q = self.get_quant_weights(t, 0, [torch.tensor(range(weight.shape[-1]))]) - for group_index in range(self.groups): - U[group_index] -= q[group_index].unsqueeze(1) * self.quantized_input[ - group_index, :, t].unsqueeze(0) - - del self.float_input - del self.quantized_input diff --git a/src/brevitas_examples/imagenet_classification/ptq/ptq_common.py b/src/brevitas_examples/imagenet_classification/ptq/ptq_common.py index 7e4afc0f8..6e60ecb09 100644 --- a/src/brevitas_examples/imagenet_classification/ptq/ptq_common.py +++ b/src/brevitas_examples/imagenet_classification/ptq/ptq_common.py @@ -13,8 +13,8 @@ from brevitas.graph.calibrate import calibration_mode from brevitas.graph.calibrate import norm_correction_mode from brevitas.graph.equalize import activation_equalization_mode -from brevitas.graph.gpxq import gpfq_mode -from brevitas.graph.gpxq import gptq_mode +from brevitas.graph.gpfq import gpfq_mode +from brevitas.graph.gptq import gptq_mode from brevitas.graph.quantize import layerwise_quantize from brevitas.graph.quantize import quantize from brevitas.graph.target.flexml import quantize_flexml diff --git a/src/brevitas_examples/llm/llm_quant/gptq.py b/src/brevitas_examples/llm/llm_quant/gptq.py index ab5d78195..2e73bdf76 100644 --- a/src/brevitas_examples/llm/llm_quant/gptq.py +++ b/src/brevitas_examples/llm/llm_quant/gptq.py @@ -5,7 +5,7 @@ import torch -from brevitas.graph.gpxq import gptq_mode +from brevitas.graph.gptq import gptq_mode from brevitas_examples.llm.llm_quant.run_utils import apply_layer_ptq_fn From 98bf622427ee43ca934113e3a9cf10ec59de87f4 Mon Sep 17 00:00:00 2001 From: Giuseppe Franco Date: Sun, 24 Sep 2023 12:52:14 +0100 Subject: [PATCH 08/12] Support for gpfq in benchmark --- src/brevitas/graph/calibrate.py | 8 ++----- src/brevitas/graph/gpfq.py | 12 ++-------- src/brevitas/graph/gpxq.py | 6 +---- src/brevitas/nn/mixin/parameter.py | 7 ++++-- .../benchmark/ptq_benchmark_torchvision.py | 22 +++++++++---------- 5 files changed, 20 insertions(+), 35 deletions(-) diff --git a/src/brevitas/graph/calibrate.py b/src/brevitas/graph/calibrate.py index a2d235e8a..fddbfd892 100644 --- a/src/brevitas/graph/calibrate.py +++ b/src/brevitas/graph/calibrate.py @@ -280,16 +280,12 @@ def forward_hook_wbiol(self, module, inp, output, name): # Compute float reference self.disable_act_quantization(module, is_training=False) self.disable_param_quantization(module, is_training=False) - quant_weight = dict() - if hasattr(module, 'weight_orig_data'): - quant_weight[module] = deepcopy(module.weight.data) - module.weight.data = module.weight_orig_data + out_float = module.forward(*inp) # Required to avoid infinite recursion self.collect_float_mean(module, out_float, name) self.enable_act_quantization(module, is_training=False) self.enable_param_quantization(module, is_training=False) - for module, value in quant_weight.items(): - module.weight.data = value + # Compute quant output # We need to disable output_quant while out_quant is being computed # or we are going to apply bias correction on post quant values instead of pre quant diff --git a/src/brevitas/graph/gpfq.py b/src/brevitas/graph/gpfq.py index b4fd66076..f6d8db53a 100644 --- a/src/brevitas/graph/gpfq.py +++ b/src/brevitas/graph/gpfq.py @@ -66,12 +66,7 @@ def catch_stopfwd(self, *args, **kwargs): self.orig_forward(*args, **kwargs) except StopFwdException: pass - # Before collecting float input, restore original float weights if they have been modified - quant_weight = dict() - for module in self.model.modules(): - if hasattr(module, 'weight_orig_data'): - quant_weight[module] = deepcopy(module.weight.data) - module.weight.data = module.weight_orig_data + # Disable quantization self.disable_quant_inference.disable_param_quantization(self.model, is_training=False) self.disable_quant_inference.disable_act_quantization(self.model, is_training=False) @@ -80,10 +75,7 @@ def catch_stopfwd(self, *args, **kwargs): self.orig_forward(*args, **kwargs) except StopFwdException: pass - # Restore correct weights - for module in self.model.modules(): - if hasattr(module, 'weight_orig_data'): - module.weight.data = quant_weight[module] + # Re-enable quantization. If activation quantization is disabled, # we also disable bias quantization self.disable_quant_inference.enable_param_quantization(self.model, is_training=False) diff --git a/src/brevitas/graph/gpxq.py b/src/brevitas/graph/gpxq.py index 9dc11d955..0ab37b525 100644 --- a/src/brevitas/graph/gpxq.py +++ b/src/brevitas/graph/gpxq.py @@ -11,10 +11,6 @@ from typing import List, Optional, Set import warnings -import numpy as np -import torch -import unfoldNd - from brevitas.graph.calibrate import DisableEnableQuantization import brevitas.nn as qnn from brevitas.quant_tensor import QuantTensor @@ -141,7 +137,7 @@ def __init__(self, layer, name, act_order, parallel_layers=1) -> None: self.act_order = act_order weight = layer.weight.data - self.layer.weight_orig_data = deepcopy(weight) + self.layer.weight_orig = deepcopy(layer.weight) # By default, use groups = 1 self.groups = 1 if isinstance(self.layer, SUPPORTED_CONV_OP): diff --git a/src/brevitas/nn/mixin/parameter.py b/src/brevitas/nn/mixin/parameter.py index f65621c3f..095c981f1 100644 --- a/src/brevitas/nn/mixin/parameter.py +++ b/src/brevitas/nn/mixin/parameter.py @@ -61,6 +61,9 @@ def quant_weight( self, quant_input: Optional[QuantTensor] = None, subtensor_slice_list: List[Optional[Tuple[int, int]]] = None): + weights_to_quantize = self.weight + if not self.weight_quant.is_quant_enabled and hasattr(self, 'weight_orig'): + weights_to_quantize = self.weight_orig if subtensor_slice_list is not None: # prepare the quantizer for a subtensor input, if any modifications are required # we set a list of tuples rather than a list of slices so that it's jit friendly @@ -95,9 +98,9 @@ def quant_weight( input_bit_width = None input_is_signed = None out = self.weight_quant( - self.weight[weight_slice_tuple], input_bit_width, input_is_signed) + weights_to_quantize[weight_slice_tuple], input_bit_width, input_is_signed) else: - out = self.weight_quant(self.weight[weight_slice_tuple]) + out = self.weight_quant(weights_to_quantize[weight_slice_tuple]) if subtensor_slice_list is not None: # Restore the quantizer behaviour to full tensor quantization # The modules to slice should have been cached already at this point diff --git a/src/brevitas_examples/imagenet_classification/ptq/benchmark/ptq_benchmark_torchvision.py b/src/brevitas_examples/imagenet_classification/ptq/benchmark/ptq_benchmark_torchvision.py index 9acafbf58..f3a446fcc 100644 --- a/src/brevitas_examples/imagenet_classification/ptq/benchmark/ptq_benchmark_torchvision.py +++ b/src/brevitas_examples/imagenet_classification/ptq/benchmark/ptq_benchmark_torchvision.py @@ -24,6 +24,7 @@ from brevitas.graph.target.flexml import preprocess_for_flexml_quantize from brevitas_examples.imagenet_classification.ptq.ptq_common import apply_act_equalization from brevitas_examples.imagenet_classification.ptq.ptq_common import apply_bias_correction +from brevitas_examples.imagenet_classification.ptq.ptq_common import apply_gpfq from brevitas_examples.imagenet_classification.ptq.ptq_common import apply_gptq from brevitas_examples.imagenet_classification.ptq.ptq_common import apply_learned_round_learning from brevitas_examples.imagenet_classification.ptq.ptq_common import calibrate @@ -59,6 +60,8 @@ 'act_equalization': ['fx', 'layerwise', None], # Perform Activation Equalization (Smoothquant) 'learned_round': [False, True], # Enable/Disable Learned Round 'gptq': [False, True], # Enable/Disable GPTQ + 'gpfq': [False, True], # Enable/Disable GPFQ + 'gpfq_p': [0.25, 0.75], # GPFQ P 'gptq_act_order': [False, True], # Use act_order euristics for GPTQ 'act_quant_percentile': [99.9, 99.99, 99.999], # Activation Quantization Percentile } @@ -78,6 +81,8 @@ 'act_equalization': [None], # Perform Activation Equalization (Smoothquant) 'learned_round': [False], # Enable/Disable Learned Round 'gptq': [True], # Enable/Disable GPTQ + 'gpfq': [False], # Enable/Disable GPFQ + 'gpfq_p': [0.25], # GPFQ P 'gptq_act_order': [False], # Use act_order euristics for GPTQ 'act_quant_percentile': [99.999], # Activation Quantization Percentile } @@ -115,19 +120,8 @@ def main(): args.gpu = get_gpu_index(args.idx) print("Iter {}, GPU {}".format(args.idx, args.gpu)) - options_names = [k.replace('_', ' ').capitalize() for k in OPTIONS.keys()] - torchvision_df = pd.DataFrame( - columns=options_names + [ - 'Top 1% floating point accuracy', - 'Top 1% quant accuracy', - 'Floating point accuracy - quant accuracy', - 'Quant accuracy / floating point accuracy', - 'Calibration size', - 'Calibration batch size', - 'Torch version', - 'Brevitas version']) try: - ptq_torchvision_models(torchvision_df, args) + ptq_torchvision_models(args) except Exception as E: print("Exception at index {}: {}".format(args.idx, E)) @@ -228,6 +222,10 @@ def ptq_torchvision_models(df, args): print("Starting calibration") calibrate(calib_loader, quant_model) + if config_namespace.gpfq: + print("Performing GPFQ:") + apply_gpfq(calib_loader, quant_model, p=config_namespace.gpfq_p) + if config_namespace.gptq: print("Performing gptq") apply_gptq(calib_loader, quant_model, config_namespace.gptq_act_order) From 7cf9f4d98ae99731e7aae7c03668285678110c6d Mon Sep 17 00:00:00 2001 From: Giuseppe Franco Date: Tue, 26 Sep 2023 17:24:06 +0100 Subject: [PATCH 09/12] benchmark scripts updated --- .../benchmark/ptq_benchmark_torchvision.py | 46 ++++++++++++++----- .../imagenet_classification/ptq/ptq_common.py | 4 +- 2 files changed, 37 insertions(+), 13 deletions(-) diff --git a/src/brevitas_examples/imagenet_classification/ptq/benchmark/ptq_benchmark_torchvision.py b/src/brevitas_examples/imagenet_classification/ptq/benchmark/ptq_benchmark_torchvision.py index f3a446fcc..606041a86 100644 --- a/src/brevitas_examples/imagenet_classification/ptq/benchmark/ptq_benchmark_torchvision.py +++ b/src/brevitas_examples/imagenet_classification/ptq/benchmark/ptq_benchmark_torchvision.py @@ -53,6 +53,7 @@ 'bias_bit_width': [32, 16], # Bias Bit-Width for Po2 scale 'weight_quant_granularity': ['per_tensor', 'per_channel'], # Scaling Per Output Channel 'act_quant_type': ['asym', 'sym'], # Act Quant Type + 'weight_param_method': ['stats', 'mse'], # Weight Quant Type 'act_param_method': ['stats', 'mse'], # Act Param Method 'bias_corr': [True], # Bias Correction 'graph_eq_iterations': [0, 20], # Graph Equalization @@ -60,9 +61,9 @@ 'act_equalization': ['fx', 'layerwise', None], # Perform Activation Equalization (Smoothquant) 'learned_round': [False, True], # Enable/Disable Learned Round 'gptq': [False, True], # Enable/Disable GPTQ + 'gptq_act_order': [False, True], # Use act_order euristics for GPTQ 'gpfq': [False, True], # Enable/Disable GPFQ 'gpfq_p': [0.25, 0.75], # GPFQ P - 'gptq_act_order': [False, True], # Use act_order euristics for GPTQ 'act_quant_percentile': [99.9, 99.99, 99.999], # Activation Quantization Percentile } @@ -74,7 +75,8 @@ 'bias_bit_width': [32], # Bias Bit-Width for Po2 scale 'weight_quant_granularity': ['per_channel'], # Scaling Per Output Channel 'act_quant_type': ['sym'], # Act Quant Type - 'act_param_method': ['stats'], # Act Param Method + 'act_param_method': ['mse'], # Act Param Method + 'weight_param_method': ['stats'], # Weight Quant Type 'bias_corr': [True], # Bias Correction 'graph_eq_iterations': [20], # Graph Equalization 'graph_eq_merge_bias': [True], # Merge bias for Graph Equalization @@ -119,24 +121,38 @@ def main(): args.gpu = get_gpu_index(args.idx) print("Iter {}, GPU {}".format(args.idx, args.gpu)) - try: ptq_torchvision_models(args) except Exception as E: print("Exception at index {}: {}".format(args.idx, E)) -def ptq_torchvision_models(df, args): +def ptq_torchvision_models(args): # Generate all possible combinations, including invalid ones # Split stats and mse due to the act_quant_percentile value - percentile_options = OPTIONS.copy() - percentile_options['act_param_method'] = ['stats'] - mse_options = OPTIONS.copy() - mse_options['act_param_method'] = ['mse'] - mse_options['act_quant_percentile'] = [None] + + print(OPTIONS['act_param_method']) + if 'stats' in OPTIONS['act_param_method']: + percentile_options = OPTIONS.copy() + percentile_options['act_param_method'] = ['stats'] + else: + percentile_options = None + + if 'mse' in OPTIONS['act_param_method']: + mse_options = OPTIONS.copy() + mse_options['act_param_method'] = ['mse'] + else: + mse_options = None + + if mse_options is not None and percentile_options is not None: + combinations = list(product(*percentile_options.values())) + list( + product(*mse_options.values())) + elif mse_options is not None: + combinations = list(product(*mse_options.values())) + elif percentile_options is not None: + combinations = list(product(*percentile_options.values())) + # Combine the two sets of combinations - combinations = list(product(*percentile_options.values())) + list( - product(*mse_options.values())) # Generate Namespace for each configuration configs = [ SimpleNamespace(**{k: v @@ -146,10 +162,12 @@ def ptq_torchvision_models(df, args): configs = list(map(validate_config, configs)) # Drop invalid configurations configs = list(config for config in configs if config.is_valid) + if args.idx > len(configs): return config_namespace = configs[args.idx] + print(config_namespace) fp_accuracy = TORCHVISION_TOP1_MAP[config_namespace.model_name] # Get model-specific configurations about input shapes and normalization @@ -206,6 +224,8 @@ def ptq_torchvision_models(df, args): backend=config_namespace.target_backend, act_bit_width=config_namespace.act_bit_width, weight_bit_width=config_namespace.weight_bit_width, + weight_param_method=config_namespace.weight_param_method, + act_param_method=config_namespace.act_param_method, bias_bit_width=config_namespace.bias_bit_width, weight_quant_granularity=config_namespace.weight_quant_granularity, act_quant_percentile=config_namespace.act_quant_percentile, @@ -290,6 +310,10 @@ def validate_config(config_namespace): if not config_namespace.gptq and config_namespace.gptq_act_order: is_valid = False + # If GPFQ is disabled, we execute only one configuration for p==0.25 + if not config_namespace.gpfq and config_namespace.gpfq_p == 0.75: + is_valid = False + if config_namespace.act_equalization == 'layerwise' and config_namespace.target_backend == 'fx': is_valid = False if config_namespace.act_bit_width < config_namespace.weight_bit_width: diff --git a/src/brevitas_examples/imagenet_classification/ptq/ptq_common.py b/src/brevitas_examples/imagenet_classification/ptq/ptq_common.py index 6e60ecb09..f2ae5092c 100644 --- a/src/brevitas_examples/imagenet_classification/ptq/ptq_common.py +++ b/src/brevitas_examples/imagenet_classification/ptq/ptq_common.py @@ -204,7 +204,7 @@ def kwargs_prefix(prefix, weight_kwargs): weight_quant = weight_quant.let(zero_point_impl=ParameterFromStatsFromParameterZeroPoint) if act_quant is not None: act_quant = act_quant.let(**{'high_percentile_q': act_quant_percentile, 'dtype': dtype}) - if act_quant_type == 'asym': + if act_quant_type == 'asym' and act_quant_percentile is not None: act_quant = act_quant.let(**{'low_percentile_q': 100 - act_quant_percentile}) if sym_act_quant is not None: sym_act_quant = sym_act_quant.let( @@ -214,7 +214,7 @@ def kwargs_prefix(prefix, weight_kwargs): per_tensor_act_quant = per_tensor_act_quant.let( **{ 'high_percentile_q': act_quant_percentile, 'dtype': dtype}) - if act_quant_type == 'asym': + if act_quant_type == 'asym' and act_quant_percentile is not None: per_tensor_act_quant = per_tensor_act_quant.let( **{'low_percentile_q': 100 - act_quant_percentile}) From 953ed0eda19cbd11d9c5d7fe7c53dd648e906674 Mon Sep 17 00:00:00 2001 From: Giuseppe Franco Date: Tue, 26 Sep 2023 17:26:57 +0100 Subject: [PATCH 10/12] Updated readme --- src/brevitas_examples/imagenet_classification/ptq/README.md | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/brevitas_examples/imagenet_classification/ptq/README.md b/src/brevitas_examples/imagenet_classification/ptq/README.md index 7312b8c2b..29386659b 100644 --- a/src/brevitas_examples/imagenet_classification/ptq/README.md +++ b/src/brevitas_examples/imagenet_classification/ptq/README.md @@ -36,6 +36,7 @@ Furthermore, Brevitas additional PTQ techniques can be enabled: - If Graph equalization is enabled, the _merge\_bias_ technique can be enabled.[2 ] [3 ]. - GPTQ [4 ]. - Learned Round [5 ]. +- GPFQ [6 ]. Internally, when defining a quantized model programmatically, Brevitas leverages `torch.fx` and its `symbolic_trace` functionality, meaning that an input model is required to pass symbolic tracing for it to work. @@ -212,3 +213,4 @@ and a `RESULTS_IMGCLSMOB.csv` with the results on manually quantized models star [3 ]: https://github.com/openppl-public/ppq/blob/master/ppq/quantization/algorithm/equalization.py [4 ]: https://arxiv.org/abs/2210.17323 [5 ]: https://arxiv.org/abs/2004.10568 +[6 ]: https://arxiv.org/abs/2201.11113 From f043fcd78347a18d9a39abb58839861646d08750 Mon Sep 17 00:00:00 2001 From: Giuseppe Franco Date: Wed, 27 Sep 2023 13:19:12 +0100 Subject: [PATCH 11/12] Clean-up --- .../ptq/benchmark/ptq_benchmark_torchvision.py | 16 +++++++--------- 1 file changed, 7 insertions(+), 9 deletions(-) diff --git a/src/brevitas_examples/imagenet_classification/ptq/benchmark/ptq_benchmark_torchvision.py b/src/brevitas_examples/imagenet_classification/ptq/benchmark/ptq_benchmark_torchvision.py index 606041a86..9a97b4794 100644 --- a/src/brevitas_examples/imagenet_classification/ptq/benchmark/ptq_benchmark_torchvision.py +++ b/src/brevitas_examples/imagenet_classification/ptq/benchmark/ptq_benchmark_torchvision.py @@ -131,7 +131,6 @@ def ptq_torchvision_models(args): # Generate all possible combinations, including invalid ones # Split stats and mse due to the act_quant_percentile value - print(OPTIONS['act_param_method']) if 'stats' in OPTIONS['act_param_method']: percentile_options = OPTIONS.copy() percentile_options['act_param_method'] = ['stats'] @@ -141,17 +140,16 @@ def ptq_torchvision_models(args): if 'mse' in OPTIONS['act_param_method']: mse_options = OPTIONS.copy() mse_options['act_param_method'] = ['mse'] + mse_options['act_quant_percentile'] = [None] else: mse_options = None - if mse_options is not None and percentile_options is not None: - combinations = list(product(*percentile_options.values())) + list( - product(*mse_options.values())) - elif mse_options is not None: - combinations = list(product(*mse_options.values())) - elif percentile_options is not None: - combinations = list(product(*percentile_options.values())) - + # Combine MSE and Percentile combinations, if they are defined + combinations = [] + if mse_options is not None: + combinations += list(product(*mse_options.values())) + if percentile_options is not None: + combinations += list(product(*percentile_options.values())) # Combine the two sets of combinations # Generate Namespace for each configuration configs = [ From dec588f0cd71c170ab16d14abceb99628185cdcd Mon Sep 17 00:00:00 2001 From: Giuseppe Franco Date: Wed, 27 Sep 2023 14:25:49 +0100 Subject: [PATCH 12/12] Update for weight orig --- src/brevitas/graph/gpfq.py | 6 ++++-- src/brevitas/graph/gptq.py | 6 ++++-- src/brevitas/graph/gpxq.py | 15 ++++++++++++--- 3 files changed, 20 insertions(+), 7 deletions(-) diff --git a/src/brevitas/graph/gpfq.py b/src/brevitas/graph/gpfq.py index f6d8db53a..01dc11a82 100644 --- a/src/brevitas/graph/gpfq.py +++ b/src/brevitas/graph/gpfq.py @@ -41,6 +41,7 @@ def __init__( model, group_of_parallel_layers: Optional[List[str]] = None, inplace: bool = True, + create_weight_orig: bool = True, use_quant_activations: bool = True, p: int = 0.25, return_forward_output: bool = False, @@ -51,6 +52,7 @@ def __init__( model, group_of_parallel_layers, inplace, + create_weight_orig, use_quant_activations, act_order, return_forward_output) @@ -100,12 +102,12 @@ class GPFQ(GPxQ): """ p = 0.25 - def __init__(self, layer, name, act_order, parallel_layers=1) -> None: + def __init__(self, layer, name, act_order, parallel_layers=1, create_weight_orig=True) -> None: if act_order: raise ValueError("Act_order is not supported in GPFQ") - super().__init__(layer, name, act_order, parallel_layers) + super().__init__(layer, name, act_order, parallel_layers, create_weight_orig) self.float_input = None self.quantized_input = None self.index_computed = False diff --git a/src/brevitas/graph/gptq.py b/src/brevitas/graph/gptq.py index 32e3a7869..b224e0a37 100644 --- a/src/brevitas/graph/gptq.py +++ b/src/brevitas/graph/gptq.py @@ -47,6 +47,7 @@ def __init__( model, group_of_parallel_layers: Optional[List[str]] = None, inplace: bool = True, + create_weight_orig: bool = True, use_quant_activations: bool = True, num_blocks: int = 100, return_forward_output: bool = False, @@ -57,6 +58,7 @@ def __init__( model, group_of_parallel_layers, inplace, + create_weight_orig, use_quant_activations, act_order, return_forward_output) @@ -104,8 +106,8 @@ class GPTQ(GPxQ): """ num_blocks = 100 - def __init__(self, layer, name, act_order, parallel_layers=1) -> None: - super().__init__(layer, name, act_order, parallel_layers) + def __init__(self, layer, name, act_order, parallel_layers=1, create_weight_orig=True) -> None: + super().__init__(layer, name, act_order, parallel_layers, create_weight_orig) dev = self.layer.weight.device diff --git a/src/brevitas/graph/gpxq.py b/src/brevitas/graph/gpxq.py index 0ab37b525..b13c46683 100644 --- a/src/brevitas/graph/gpxq.py +++ b/src/brevitas/graph/gpxq.py @@ -36,6 +36,7 @@ def __init__( model, group_of_parallel_layers: Optional[List[str]] = None, inplace: bool = True, + create_weight_orig: bool = True, use_quant_activations: bool = True, act_order: bool = False, return_forward_output: bool = False) -> None: @@ -43,6 +44,7 @@ def __init__( if not inplace: model = deepcopy(model) self.model = model + self.create_weight_orig = create_weight_orig self.use_quant_activations = use_quant_activations self.hook_dict = dict() self.gpxq_layers = dict() @@ -97,7 +99,11 @@ def __enter__(self): # Attach hooks for GPTQ if self._is_module_supported(module): gpxq = self.class_implementation( - module, name, act_order=self.act_order, parallel_layers=parallel_layers) + module, + name, + act_order=self.act_order, + parallel_layers=parallel_layers, + create_weight_orig=self.create_weight_orig) hook_fn = partial(gpxq.update_batch, current_layer=self.current_layer) self.hook_dict[name] = module.register_forward_pre_hook(hook_fn) self.gpxq_layers[name] = gpxq @@ -131,13 +137,16 @@ def catch_stopfwd(self, *args, **kwargs): class GPxQ(ABC): - def __init__(self, layer, name, act_order, parallel_layers=1) -> None: + def __init__(self, layer, name, act_order, parallel_layers=1, create_weight_orig=True) -> None: self.layer = layer self.name = name self.act_order = act_order weight = layer.weight.data - self.layer.weight_orig = deepcopy(layer.weight) + + if create_weight_orig and not hasattr(self.layer, 'weight_orig'): + self.layer.register_buffer('weight_orig', layer.weight.detach().clone()) + # By default, use groups = 1 self.groups = 1 if isinstance(self.layer, SUPPORTED_CONV_OP):