From 5c1932b88d123062c131c27c81b748d4fc8fd6f1 Mon Sep 17 00:00:00 2001 From: Ian Colbert <88047104+i-colbert@users.noreply.github.com> Date: Thu, 7 Mar 2024 01:09:02 -0800 Subject: [PATCH] Fix (gpxq): handling empty tensors with GPxQ and adding unit tests (#892) --- src/brevitas/graph/gpfq.py | 46 +++++- src/brevitas/graph/gptq.py | 7 +- src/brevitas/graph/gpxq.py | 51 +----- tests/brevitas/graph/equalization_fixtures.py | 95 +++++++++++ tests/brevitas/graph/test_gpxq.py | 149 ++++++++++++++++++ 5 files changed, 293 insertions(+), 55 deletions(-) create mode 100644 tests/brevitas/graph/test_gpxq.py diff --git a/src/brevitas/graph/gpfq.py b/src/brevitas/graph/gpfq.py index e255660a0..7e80f61cb 100644 --- a/src/brevitas/graph/gpfq.py +++ b/src/brevitas/graph/gpfq.py @@ -10,11 +10,14 @@ import unfoldNd from brevitas.function import get_upper_bound_on_l1_norm +from brevitas.graph.calibrate import disable_return_quant_tensor +from brevitas.graph.calibrate import restore_return_quant_tensor from brevitas.graph.gpxq import GPxQ from brevitas.graph.gpxq import gpxq_mode from brevitas.graph.gpxq import StopFwdException from brevitas.graph.gpxq import SUPPORTED_CONV_OP import brevitas.nn as qnn +from brevitas.quant_tensor import QuantTensor class gpfq_mode(gpxq_mode): @@ -89,6 +92,7 @@ def catch_stopfwd(self, *args, **kwargs): pass # Disable quantization + self.return_quant_tensor_state = disable_return_quant_tensor(self.model) self.disable_quant_inference.disable_param_quantization(self.model, is_training=False) self.disable_quant_inference.disable_act_quantization(self.model, is_training=False) # Collect float input @@ -104,6 +108,7 @@ def catch_stopfwd(self, *args, **kwargs): self.disable_quant_inference.enable_act_quantization(self.model, is_training=False) else: self.disable_quant_inference.disable_bias_quantization(self.model, is_training=False) + restore_return_quant_tensor(self.model, self.return_quant_tensor_state) if self.return_forward_output: # If we want to return the output of the network, we need to disable all hooks @@ -155,7 +160,7 @@ def update_batch(self, module, input, current_layer): # Update reference to current layer current_layer.layer_names.add(self.name) - is_quant_disabled = module.weight_quant.disable_quant + is_quant_enabled = module.weight_quant.is_quant_enabled inp = self.process_input(input) batch_size = inp.shape[0] @@ -210,7 +215,7 @@ def update_batch(self, module, input, current_layer): inp_processed.append(inp) inp_processed = torch.stack(inp_processed) - if is_quant_disabled: + if not is_quant_enabled: if self.float_input is None: self.float_input = inp_processed else: @@ -229,6 +234,7 @@ def update_batch(self, module, input, current_layer): raise StopFwdException def single_layer_update(self): + assert not self.layer.weight_quant_requires_quant_input, "Error: GPFQ does not support weight quantizers that require quantized inputs." weight = self.layer.weight.data dev = weight.device dtype = weight.dtype @@ -302,13 +308,36 @@ def __init__( p=p) self.accumulator_bit_width = accumulator_bit_width assert self.accumulator_bit_width is not None - self.requires_quant_input = True # force true + + def process_input(self, inp): + inp = super().process_input(inp) + inp = self.layer.input_quant(inp) + + is_quant_enabled = self.layer.weight_quant.is_quant_enabled + + # If using quantized activations, inp could be QuantTensor. In + # this case, we overwrite the metadata. + if isinstance(inp, QuantTensor): + if is_quant_enabled and self.quant_input is None: + self.quant_input = QuantTensor( + value=torch.empty( + 1, dtype=self.layer.weight.dtype, device=self.layer.weight.device), + scale=inp.scale, + zero_point=inp.zero_point, + bit_width=inp.bit_width, + signed=inp.signed, + training=inp.training) + inp = inp.value + + return inp def single_layer_update(self): # raise error in case no quant-input is here if self.quant_input is None: - raise ValueError( - 'Expected quant input to calculate L1-norm upper bound, but received None') + raise ValueError('Expected self.quant_input to calculate L1-norm upper bound, but recevied None. ' + \ + 'Make sure that either the input to the model is a QuantTensor or the layer has an input quant enabled. ' \ + 'Also, check if `use_quant_activations=True` in `gpfq_mode` when `accumulator_bit_width` is specified. ' + \ + 'Alternatively, provide a custom `a2q_layer_filter_fnc` to `gpfq_mode` to filter layers without a quant_tensor input.') weight = self.layer.weight.data dev = weight.device dtype = weight.dtype @@ -328,7 +357,8 @@ def single_layer_update(self): T = get_upper_bound_on_l1_norm( torch.tensor(self.accumulator_bit_width), input_bit_width, input_is_signed) s = self.layer.quant_weight_scale() - s = s.view(self.groups, -1) # [Groups, OC/Groups] + if s.ndim > 1: + s = s.view(self.groups, -1) # [Groups, OC/Groups] # initialize cumulative l1-norm z = torch.zeros(weight.shape[:-1], device=dev) @@ -362,8 +392,8 @@ def single_layer_update(self): else: q_arg = torch.zeros_like(U[group_index, :, 0]) - max_q_arg = s[group_index, :] * torch.clamp_min(T - z[group_index, :], 0.) - q_arg = q_arg.sign() * torch.clamp_max(q_arg.abs(), max_q_arg) + max_q_arg = s * torch.clamp_min(T - z, 0.) + q_arg = q_arg.sign() * torch.clamp_max(q_arg.abs(), max_q_arg[group_index, :]) weight[group_index, :, permutation_list[group_index][t]] = q_arg q = self.get_quant_weights(t, 0, permutation_list) z += q.abs() / s # increment cumulative l1-norm diff --git a/src/brevitas/graph/gptq.py b/src/brevitas/graph/gptq.py index 28cb12cd6..9c466f0eb 100644 --- a/src/brevitas/graph/gptq.py +++ b/src/brevitas/graph/gptq.py @@ -3,9 +3,10 @@ from copy import deepcopy import math -from typing import List, Optional, Set +from typing import List, Optional import warnings +from packaging import version import torch try: @@ -14,6 +15,7 @@ LinAlgError = RuntimeError import unfoldNd +from brevitas import torch_version from brevitas.graph.gpxq import GPxQ from brevitas.graph.gpxq import gpxq_mode from brevitas.graph.gpxq import StopFwdException @@ -133,6 +135,8 @@ def __init__( dtype=torch.float32) self.nsamples = 0 + assert torch_version >= version.parse('1.10'), "GPTQ requires torch 1.10 or higher" + def update_batch(self, module, input, current_layer): if self.disable_pre_forward_hook: return input @@ -188,6 +192,7 @@ def update_batch(self, module, input, current_layer): raise StopFwdException def single_layer_update(self, percdamp=.01): + assert not self.layer.weight_quant_requires_quant_input, "Error: GPTQ does not support weight quantizers that require quantized inputs." if hasattr(self.layer, 'allocate_params'): self.layer.allocate_params(self.layer) weight = self.layer.weight.data diff --git a/src/brevitas/graph/gpxq.py b/src/brevitas/graph/gpxq.py index e9641a5a8..dc9d6e19b 100644 --- a/src/brevitas/graph/gpxq.py +++ b/src/brevitas/graph/gpxq.py @@ -11,13 +11,13 @@ from typing import List, Optional, Set import warnings -import torch from torch.fx import GraphModule as TorchGraphModule from brevitas.fx import GraphModule +from brevitas.graph.calibrate import disable_return_quant_tensor from brevitas.graph.calibrate import DisableEnableQuantization +from brevitas.graph.calibrate import restore_return_quant_tensor import brevitas.nn as qnn -from brevitas.quant_tensor import QuantTensor SUPPORTED_CONV_OP = ( qnn.QuantConv2d, qnn.QuantConv1d, qnn.QuantConvTranspose1d, qnn.QuantConvTranspose2d) @@ -87,6 +87,7 @@ def __init__( # How many subblock to use during GPTQ for each layer self.disable_quant_inference = DisableEnableQuantization() + self.return_quant_tensor_state = dict() self.group_of_parallel_layers = group_of_parallel_layers self.return_forward_output = return_forward_output @@ -146,6 +147,7 @@ def __enter__(self): self.gpxq_layers[name] = gpxq_module_optimizer if not self.use_quant_activations: + self.return_quant_tensor_state = disable_return_quant_tensor(self.model) self.disable_quant_inference.disable_act_quantization( self.model, is_training=self.model.training) self.disable_quant_inference.disable_bias_quantization( @@ -165,6 +167,7 @@ def __exit__(self, type, value, traceback): self.model, is_training=self.model.training) self.disable_quant_inference.enable_bias_quantization( self.model, is_training=self.model.training) + restore_return_quant_tensor(self.model, self.return_quant_tensor_state) def update(self): for name in self.current_layer.layer_names: @@ -207,55 +210,11 @@ def __init__( self.disable_pre_forward_hook = False # Some layers require knowledge from quant inputs to compute quant weights self.quant_input = None - self.requires_quant_input = False # For GPFA2Q - - @property - def layer_requires_input_quant(self): - # some weight quantizers require a quant input (e.g., A2Q) - check_1 = self.layer.weight_quant_requires_quant_input - # if input_quant is enabled, then we will store its information - check_2 = self.layer.is_input_quant_enabled - # GPFA2Q requires the quantized input to be stored - check_3 = self.requires_quant_input - requires_input_quant = check_1 or check_2 or check_3 - return requires_input_quant def process_input(self, inp): # Input is a tuple, so we take first element inp = inp[0] - # if the quant_input is not already cached, then get - # metadata from QuantWBIOL module - if self.quant_input is None: - inp_scale = self.layer.quant_input_scale() - inp_zero_point = self.layer.quant_input_zero_point() - inp_bit_width = self.layer.quant_input_bit_width() - inp_signed = self.layer.is_quant_input_signed - inp_training = self.layer.training - - # If using quantized activations, inp could be QuantTensor. In - # this case, we overwrite the metadata. - if isinstance(inp, QuantTensor): - if self.layer_requires_input_quant and (self.quant_input is None): - inp_scale = inp.scale - inp_zero_point = inp.zero_point - inp_bit_width = inp.bit_width - inp_signed = inp.signed - inp_training = inp.training - inp = inp.value - - # if the layer requires an input quant and the quant input cache has - # yet to be populated, then populate with the collected metadata - if self.layer_requires_input_quant and (self.quant_input is None): - self.quant_input = QuantTensor( - value=torch.empty( - 1, dtype=self.layer.weight.dtype, device=self.layer.weight.device), - scale=inp_scale, - zero_point=inp_zero_point, - bit_width=inp_bit_width, - signed=inp_signed, - training=inp_training) - # If input is unbatched, add batch_size = 1 if len(inp.shape) == 1: warnings.warn("Found unbatched input, adding batch dimension equal to 1") diff --git a/tests/brevitas/graph/equalization_fixtures.py b/tests/brevitas/graph/equalization_fixtures.py index 4750fc96d..985986789 100644 --- a/tests/brevitas/graph/equalization_fixtures.py +++ b/tests/brevitas/graph/equalization_fixtures.py @@ -11,6 +11,8 @@ from brevitas import torch_version from brevitas.graph.equalize import _cross_layer_equalization +import brevitas.nn as qnn +from brevitas.quant import Int8ActPerTensorFloat SEED = 123456 ATOL = 1e-3 @@ -26,6 +28,7 @@ IN_SIZE_CONV = (1, 3, 224, 224) IN_SIZE_LINEAR = (1, 224, 3) +IN_SIZE_CONV_SMALL = (1, 3, 32, 32) def equalize_test(regions, merge_bias, bias_shrinkage, scale_computation_type): @@ -374,3 +377,95 @@ def forward(self, x): ('layer1.0.conv1', 'layer1.1.conv1', 'layer2.0.conv1', 'layer2.0.downsample.0')], [('layer2.0.bn1',), ('layer2.0.conv2',)], [('layer4.0.bn2', 'layer4.0.downsample.1', 'layer4.1.bn2'), ('fc', 'layer4.1.conv1')],] + + +@pytest_cases.fixture +def quant_conv_with_input_quant_model(): + + class QuantConvModel(nn.Module): + + def __init__(self) -> None: + super().__init__() + self.conv_0 = qnn.QuantConv2d( + 3, 16, kernel_size=3) # gpxq tests assume no quant on first layer + self.conv_1 = qnn.QuantConv2d(16, 32, kernel_size=3, input_quant=Int8ActPerTensorFloat) + + def forward(self, x): + x = self.conv_0(x) + x = torch.relu(x) + x = self.conv_1(x) + return x + + return QuantConvModel + + +@pytest_cases.fixture +def quant_convdepthconv_model(): + + class QuantConvDepthConvModel(nn.Module): + + def __init__(self) -> None: + super().__init__() + self.conv = qnn.QuantConv2d(3, 16, kernel_size=3) + self.conv_0 = qnn.QuantConv2d(16, 16, kernel_size=1, groups=16) + self.relu = qnn.QuantReLU(return_quant_tensor=True) + + def forward(self, x): + x = self.conv(x) + x = self.relu(x) + x = self.conv_0(x) + return x + + return QuantConvDepthConvModel + + +@pytest_cases.fixture +def quant_residual_model(): + + class QuantResidualModel(nn.Module): + + def __init__(self) -> None: + super().__init__() + self.conv = qnn.QuantConv2d(3, 16, kernel_size=1) + self.conv_0 = qnn.QuantConv2d(16, 3, kernel_size=1) + self.relu = qnn.QuantReLU(return_quant_tensor=True) + + def forward(self, x): + start = x + x = self.conv(x) + x = self.relu(x) + x = self.conv_0(x) + x = start + x + return x + + return QuantResidualModel + + +@pytest_cases.fixture +def quant_convtranspose_model(): + + class QuantConvTransposeModel(nn.Module): + + def __init__(self) -> None: + super().__init__() + self.relu = qnn.QuantReLU(return_quant_tensor=True) + self.conv_0 = qnn.QuantConvTranspose2d(in_channels=3, out_channels=8, kernel_size=3) + self.conv_1 = qnn.QuantConvTranspose2d(in_channels=8, out_channels=32, kernel_size=3) + + def forward(self, x): + x = self.conv_0(x) + x = self.relu(x) + x = self.conv_1(x) + return x + + return QuantConvTransposeModel + + +list_of_quant_fixtures = [ + 'quant_conv_with_input_quant_model', + 'quant_convdepthconv_model', + 'quant_residual_model', + 'quant_convtranspose_model'] + +toy_quant_model = fixture_union( + 'toy_quant_model', list_of_quant_fixtures, ids=list_of_quant_fixtures) diff --git a/tests/brevitas/graph/test_gpxq.py b/tests/brevitas/graph/test_gpxq.py new file mode 100644 index 000000000..49d470402 --- /dev/null +++ b/tests/brevitas/graph/test_gpxq.py @@ -0,0 +1,149 @@ +# Copyright (C) 2024, Advanced Micro Devices, Inc. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause + +from functools import partial + +import pytest +import torch +import torch.nn as nn +from torch.utils.data import DataLoader +from torch.utils.data import TensorDataset + +from brevitas.graph.gpfq import gpfq_mode +from brevitas.graph.gptq import gptq_mode + +from .equalization_fixtures import * + + +def apply_gpfq( + calib_loader: DataLoader, + model: nn.Module, + act_order: bool, + use_quant_activations: bool = True, + accumulator_bit_width: int = 32, + a2q_layer_filter_fnc=lambda x: True): + model.eval() + dtype = next(model.parameters()).dtype + device = next(model.parameters()).device + with torch.no_grad(): + # use A2GPFQ if accumulator is less than 32 is specified + with gpfq_mode( + model, + use_quant_activations=use_quant_activations, + act_order=act_order, + use_gpfa2q=accumulator_bit_width < 32, + accumulator_bit_width=accumulator_bit_width, + a2q_layer_filter_fnc=a2q_layer_filter_fnc, + ) as gpfq: + gpfq_model = gpfq.model + for _ in range(gpfq.num_layers): + for _, (images, _) in enumerate(calib_loader): + images = images.to(device) + images = images.to(dtype) + gpfq_model(images) + gpfq.update() + + +def apply_gptq( + calib_loader: DataLoader, model: nn.Module, act_order: bool, use_quant_activations: bool): + model.eval() + dtype = next(model.parameters()).dtype + device = next(model.parameters()).device + with torch.no_grad(): + with gptq_mode( + model, + use_quant_activations=use_quant_activations, + act_order=act_order, + ) as gptq: + gptq_model = gptq.model + for _ in range(gptq.num_layers): + for _, (images, _) in enumerate(calib_loader): + images = images.to(device) + images = images.to(dtype) + gptq_model(images) + gptq.update() + + +def custom_layer_filter_fnc(layer: nn.Module) -> bool: + if isinstance(layer, nn.Conv2d) and layer.in_channels == 3: + return False + elif isinstance(layer, nn.ConvTranspose2d) and layer.in_channels == 3: + return False + return True + + +def identity_layer_filter_func(layer: nn.Module) -> bool: + return True + + +filter_func_dict = {"identity": identity_layer_filter_func, "ignore_input": custom_layer_filter_fnc} + +apply_gpxq_func_map = {"gpfq": apply_gpfq, "gptq": apply_gptq} + + +@pytest.mark.parametrize("act_order", [True, False]) +@pytest.mark.parametrize("use_quant_activations", [True, False]) +@pytest.mark.parametrize("acc_bit_width", [32, 24, 16, 12]) +@pytest.mark.parametrize("filter_func_str", filter_func_dict.keys()) +@pytest.mark.parametrize("apply_gpxq_tuple", apply_gpxq_func_map.items()) +def test_toymodels( + toy_quant_model, + act_order, + use_quant_activations, + acc_bit_width, + filter_func_str, + apply_gpxq_tuple, + request): + + test_id = request.node.callspec.id + + torch.manual_seed(SEED) + + name, apply_gpxq = apply_gpxq_tuple + + if (name == 'gptq' and acc_bit_width < 32): + pytest.skip("GPTQ does not support accumulator-aware quantization.") + + if name == 'gpfq': + filter_func = filter_func_dict[filter_func_str] + apply_gpxq = partial( + apply_gpxq, accumulator_bit_width=acc_bit_width, a2q_layer_filter_fnc=filter_func) + + model_class = toy_quant_model + model = model_class() + if 'mha' in test_id: + inp = torch.randn(32, *IN_SIZE_LINEAR[1:]) + else: + inp = torch.randn(32, *IN_SIZE_CONV_SMALL[1:]) + model.eval() + model(inp) # test forward pass and collect scaling factors + dataset = TensorDataset(inp, inp) + calib_loader = DataLoader(dataset, batch_size=16, num_workers=0, pin_memory=True, shuffle=True) + + if (name == 'gptq' and torch_version < version.parse('1.10')): + # GPTQ usage of linalg_cholesky() is not compatible with torch 1.9.1 and below + with pytest.raises(AssertionError): + apply_gpxq( + calib_loader=calib_loader, + model=model, + act_order=act_order, + use_quant_activations=use_quant_activations) + + elif (name == 'gpfq') and (acc_bit_width < 32) and (not use_quant_activations or + filter_func_str == 'identity'): + # GPFA2Q requires that the quant activations are used. GPFA2Q.single_layer_update will + # raise a ValueError if GPFA2Q.quant_input is None (also see GPxQ.process_input). This will + # happen when `use_quant_activations=False` or when the input to a model is not quantized + # and `a2q_layer_filter_fnc` does not properly handle it. + with pytest.raises(ValueError): + apply_gpxq( + calib_loader=calib_loader, + model=model, + act_order=act_order, + use_quant_activations=use_quant_activations) + else: + apply_gpxq( + calib_loader=calib_loader, + model=model, + act_order=act_order, + use_quant_activations=use_quant_activations)