Skip to content

Commit

Permalink
Fix (gpxq): handling empty tensors with GPxQ and adding unit tests (#892
Browse files Browse the repository at this point in the history
)
  • Loading branch information
i-colbert authored Mar 7, 2024
1 parent 4e82c7b commit 5c1932b
Show file tree
Hide file tree
Showing 5 changed files with 293 additions and 55 deletions.
46 changes: 38 additions & 8 deletions src/brevitas/graph/gpfq.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,14 @@
import unfoldNd

from brevitas.function import get_upper_bound_on_l1_norm
from brevitas.graph.calibrate import disable_return_quant_tensor
from brevitas.graph.calibrate import restore_return_quant_tensor
from brevitas.graph.gpxq import GPxQ
from brevitas.graph.gpxq import gpxq_mode
from brevitas.graph.gpxq import StopFwdException
from brevitas.graph.gpxq import SUPPORTED_CONV_OP
import brevitas.nn as qnn
from brevitas.quant_tensor import QuantTensor


class gpfq_mode(gpxq_mode):
Expand Down Expand Up @@ -89,6 +92,7 @@ def catch_stopfwd(self, *args, **kwargs):
pass

# Disable quantization
self.return_quant_tensor_state = disable_return_quant_tensor(self.model)
self.disable_quant_inference.disable_param_quantization(self.model, is_training=False)
self.disable_quant_inference.disable_act_quantization(self.model, is_training=False)
# Collect float input
Expand All @@ -104,6 +108,7 @@ def catch_stopfwd(self, *args, **kwargs):
self.disable_quant_inference.enable_act_quantization(self.model, is_training=False)
else:
self.disable_quant_inference.disable_bias_quantization(self.model, is_training=False)
restore_return_quant_tensor(self.model, self.return_quant_tensor_state)

if self.return_forward_output:
# If we want to return the output of the network, we need to disable all hooks
Expand Down Expand Up @@ -155,7 +160,7 @@ def update_batch(self, module, input, current_layer):

# Update reference to current layer
current_layer.layer_names.add(self.name)
is_quant_disabled = module.weight_quant.disable_quant
is_quant_enabled = module.weight_quant.is_quant_enabled

inp = self.process_input(input)
batch_size = inp.shape[0]
Expand Down Expand Up @@ -210,7 +215,7 @@ def update_batch(self, module, input, current_layer):
inp_processed.append(inp)
inp_processed = torch.stack(inp_processed)

if is_quant_disabled:
if not is_quant_enabled:
if self.float_input is None:
self.float_input = inp_processed
else:
Expand All @@ -229,6 +234,7 @@ def update_batch(self, module, input, current_layer):
raise StopFwdException

def single_layer_update(self):
assert not self.layer.weight_quant_requires_quant_input, "Error: GPFQ does not support weight quantizers that require quantized inputs."
weight = self.layer.weight.data
dev = weight.device
dtype = weight.dtype
Expand Down Expand Up @@ -302,13 +308,36 @@ def __init__(
p=p)
self.accumulator_bit_width = accumulator_bit_width
assert self.accumulator_bit_width is not None
self.requires_quant_input = True # force true

def process_input(self, inp):
inp = super().process_input(inp)
inp = self.layer.input_quant(inp)

is_quant_enabled = self.layer.weight_quant.is_quant_enabled

# If using quantized activations, inp could be QuantTensor. In
# this case, we overwrite the metadata.
if isinstance(inp, QuantTensor):
if is_quant_enabled and self.quant_input is None:
self.quant_input = QuantTensor(
value=torch.empty(
1, dtype=self.layer.weight.dtype, device=self.layer.weight.device),
scale=inp.scale,
zero_point=inp.zero_point,
bit_width=inp.bit_width,
signed=inp.signed,
training=inp.training)
inp = inp.value

return inp

def single_layer_update(self):
# raise error in case no quant-input is here
if self.quant_input is None:
raise ValueError(
'Expected quant input to calculate L1-norm upper bound, but received None')
raise ValueError('Expected self.quant_input to calculate L1-norm upper bound, but recevied None. ' + \
'Make sure that either the input to the model is a QuantTensor or the layer has an input quant enabled. ' \
'Also, check if `use_quant_activations=True` in `gpfq_mode` when `accumulator_bit_width` is specified. ' + \
'Alternatively, provide a custom `a2q_layer_filter_fnc` to `gpfq_mode` to filter layers without a quant_tensor input.')
weight = self.layer.weight.data
dev = weight.device
dtype = weight.dtype
Expand All @@ -328,7 +357,8 @@ def single_layer_update(self):
T = get_upper_bound_on_l1_norm(
torch.tensor(self.accumulator_bit_width), input_bit_width, input_is_signed)
s = self.layer.quant_weight_scale()
s = s.view(self.groups, -1) # [Groups, OC/Groups]
if s.ndim > 1:
s = s.view(self.groups, -1) # [Groups, OC/Groups]

# initialize cumulative l1-norm
z = torch.zeros(weight.shape[:-1], device=dev)
Expand Down Expand Up @@ -362,8 +392,8 @@ def single_layer_update(self):
else:
q_arg = torch.zeros_like(U[group_index, :, 0])

max_q_arg = s[group_index, :] * torch.clamp_min(T - z[group_index, :], 0.)
q_arg = q_arg.sign() * torch.clamp_max(q_arg.abs(), max_q_arg)
max_q_arg = s * torch.clamp_min(T - z, 0.)
q_arg = q_arg.sign() * torch.clamp_max(q_arg.abs(), max_q_arg[group_index, :])
weight[group_index, :, permutation_list[group_index][t]] = q_arg
q = self.get_quant_weights(t, 0, permutation_list)
z += q.abs() / s # increment cumulative l1-norm
Expand Down
7 changes: 6 additions & 1 deletion src/brevitas/graph/gptq.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,10 @@

from copy import deepcopy
import math
from typing import List, Optional, Set
from typing import List, Optional
import warnings

from packaging import version
import torch

try:
Expand All @@ -14,6 +15,7 @@
LinAlgError = RuntimeError
import unfoldNd

from brevitas import torch_version
from brevitas.graph.gpxq import GPxQ
from brevitas.graph.gpxq import gpxq_mode
from brevitas.graph.gpxq import StopFwdException
Expand Down Expand Up @@ -133,6 +135,8 @@ def __init__(
dtype=torch.float32)
self.nsamples = 0

assert torch_version >= version.parse('1.10'), "GPTQ requires torch 1.10 or higher"

def update_batch(self, module, input, current_layer):
if self.disable_pre_forward_hook:
return input
Expand Down Expand Up @@ -188,6 +192,7 @@ def update_batch(self, module, input, current_layer):
raise StopFwdException

def single_layer_update(self, percdamp=.01):
assert not self.layer.weight_quant_requires_quant_input, "Error: GPTQ does not support weight quantizers that require quantized inputs."
if hasattr(self.layer, 'allocate_params'):
self.layer.allocate_params(self.layer)
weight = self.layer.weight.data
Expand Down
51 changes: 5 additions & 46 deletions src/brevitas/graph/gpxq.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,13 @@
from typing import List, Optional, Set
import warnings

import torch
from torch.fx import GraphModule as TorchGraphModule

from brevitas.fx import GraphModule
from brevitas.graph.calibrate import disable_return_quant_tensor
from brevitas.graph.calibrate import DisableEnableQuantization
from brevitas.graph.calibrate import restore_return_quant_tensor
import brevitas.nn as qnn
from brevitas.quant_tensor import QuantTensor

SUPPORTED_CONV_OP = (
qnn.QuantConv2d, qnn.QuantConv1d, qnn.QuantConvTranspose1d, qnn.QuantConvTranspose2d)
Expand Down Expand Up @@ -87,6 +87,7 @@ def __init__(
# How many subblock to use during GPTQ for each layer

self.disable_quant_inference = DisableEnableQuantization()
self.return_quant_tensor_state = dict()

self.group_of_parallel_layers = group_of_parallel_layers
self.return_forward_output = return_forward_output
Expand Down Expand Up @@ -146,6 +147,7 @@ def __enter__(self):
self.gpxq_layers[name] = gpxq_module_optimizer

if not self.use_quant_activations:
self.return_quant_tensor_state = disable_return_quant_tensor(self.model)
self.disable_quant_inference.disable_act_quantization(
self.model, is_training=self.model.training)
self.disable_quant_inference.disable_bias_quantization(
Expand All @@ -165,6 +167,7 @@ def __exit__(self, type, value, traceback):
self.model, is_training=self.model.training)
self.disable_quant_inference.enable_bias_quantization(
self.model, is_training=self.model.training)
restore_return_quant_tensor(self.model, self.return_quant_tensor_state)

def update(self):
for name in self.current_layer.layer_names:
Expand Down Expand Up @@ -207,55 +210,11 @@ def __init__(
self.disable_pre_forward_hook = False
# Some layers require knowledge from quant inputs to compute quant weights
self.quant_input = None
self.requires_quant_input = False # For GPFA2Q

@property
def layer_requires_input_quant(self):
# some weight quantizers require a quant input (e.g., A2Q)
check_1 = self.layer.weight_quant_requires_quant_input
# if input_quant is enabled, then we will store its information
check_2 = self.layer.is_input_quant_enabled
# GPFA2Q requires the quantized input to be stored
check_3 = self.requires_quant_input
requires_input_quant = check_1 or check_2 or check_3
return requires_input_quant

def process_input(self, inp):
# Input is a tuple, so we take first element
inp = inp[0]

# if the quant_input is not already cached, then get
# metadata from QuantWBIOL module
if self.quant_input is None:
inp_scale = self.layer.quant_input_scale()
inp_zero_point = self.layer.quant_input_zero_point()
inp_bit_width = self.layer.quant_input_bit_width()
inp_signed = self.layer.is_quant_input_signed
inp_training = self.layer.training

# If using quantized activations, inp could be QuantTensor. In
# this case, we overwrite the metadata.
if isinstance(inp, QuantTensor):
if self.layer_requires_input_quant and (self.quant_input is None):
inp_scale = inp.scale
inp_zero_point = inp.zero_point
inp_bit_width = inp.bit_width
inp_signed = inp.signed
inp_training = inp.training
inp = inp.value

# if the layer requires an input quant and the quant input cache has
# yet to be populated, then populate with the collected metadata
if self.layer_requires_input_quant and (self.quant_input is None):
self.quant_input = QuantTensor(
value=torch.empty(
1, dtype=self.layer.weight.dtype, device=self.layer.weight.device),
scale=inp_scale,
zero_point=inp_zero_point,
bit_width=inp_bit_width,
signed=inp_signed,
training=inp_training)

# If input is unbatched, add batch_size = 1
if len(inp.shape) == 1:
warnings.warn("Found unbatched input, adding batch dimension equal to 1")
Expand Down
95 changes: 95 additions & 0 deletions tests/brevitas/graph/equalization_fixtures.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@

from brevitas import torch_version
from brevitas.graph.equalize import _cross_layer_equalization
import brevitas.nn as qnn
from brevitas.quant import Int8ActPerTensorFloat

SEED = 123456
ATOL = 1e-3
Expand All @@ -26,6 +28,7 @@

IN_SIZE_CONV = (1, 3, 224, 224)
IN_SIZE_LINEAR = (1, 224, 3)
IN_SIZE_CONV_SMALL = (1, 3, 32, 32)


def equalize_test(regions, merge_bias, bias_shrinkage, scale_computation_type):
Expand Down Expand Up @@ -374,3 +377,95 @@ def forward(self, x):
('layer1.0.conv1', 'layer1.1.conv1', 'layer2.0.conv1', 'layer2.0.downsample.0')],
[('layer2.0.bn1',), ('layer2.0.conv2',)],
[('layer4.0.bn2', 'layer4.0.downsample.1', 'layer4.1.bn2'), ('fc', 'layer4.1.conv1')],]


@pytest_cases.fixture
def quant_conv_with_input_quant_model():

class QuantConvModel(nn.Module):

def __init__(self) -> None:
super().__init__()
self.conv_0 = qnn.QuantConv2d(
3, 16, kernel_size=3) # gpxq tests assume no quant on first layer
self.conv_1 = qnn.QuantConv2d(16, 32, kernel_size=3, input_quant=Int8ActPerTensorFloat)

def forward(self, x):
x = self.conv_0(x)
x = torch.relu(x)
x = self.conv_1(x)
return x

return QuantConvModel


@pytest_cases.fixture
def quant_convdepthconv_model():

class QuantConvDepthConvModel(nn.Module):

def __init__(self) -> None:
super().__init__()
self.conv = qnn.QuantConv2d(3, 16, kernel_size=3)
self.conv_0 = qnn.QuantConv2d(16, 16, kernel_size=1, groups=16)
self.relu = qnn.QuantReLU(return_quant_tensor=True)

def forward(self, x):
x = self.conv(x)
x = self.relu(x)
x = self.conv_0(x)
return x

return QuantConvDepthConvModel


@pytest_cases.fixture
def quant_residual_model():

class QuantResidualModel(nn.Module):

def __init__(self) -> None:
super().__init__()
self.conv = qnn.QuantConv2d(3, 16, kernel_size=1)
self.conv_0 = qnn.QuantConv2d(16, 3, kernel_size=1)
self.relu = qnn.QuantReLU(return_quant_tensor=True)

def forward(self, x):
start = x
x = self.conv(x)
x = self.relu(x)
x = self.conv_0(x)
x = start + x
return x

return QuantResidualModel


@pytest_cases.fixture
def quant_convtranspose_model():

class QuantConvTransposeModel(nn.Module):

def __init__(self) -> None:
super().__init__()
self.relu = qnn.QuantReLU(return_quant_tensor=True)
self.conv_0 = qnn.QuantConvTranspose2d(in_channels=3, out_channels=8, kernel_size=3)
self.conv_1 = qnn.QuantConvTranspose2d(in_channels=8, out_channels=32, kernel_size=3)

def forward(self, x):
x = self.conv_0(x)
x = self.relu(x)
x = self.conv_1(x)
return x

return QuantConvTransposeModel


list_of_quant_fixtures = [
'quant_conv_with_input_quant_model',
'quant_convdepthconv_model',
'quant_residual_model',
'quant_convtranspose_model']

toy_quant_model = fixture_union(
'toy_quant_model', list_of_quant_fixtures, ids=list_of_quant_fixtures)
Loading

0 comments on commit 5c1932b

Please sign in to comment.