Skip to content

Commit

Permalink
Feat (mx): gptq compatibility and quant tests (#1013)
Browse files Browse the repository at this point in the history
---------

Co-authored-by: Nick Fraser <icanlosh@gmail.com>
  • Loading branch information
Giuseppe5 and nickfraser authored Sep 5, 2024
1 parent 9693ff5 commit d4834bd
Show file tree
Hide file tree
Showing 19 changed files with 196 additions and 105 deletions.
24 changes: 12 additions & 12 deletions src/brevitas/core/function_wrapper/shape.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from brevitas.function.shape import over_output_channels
from brevitas.function.shape import over_output_features
from brevitas.function.shape import over_tensor
from brevitas.utils.torch_utils import padding


class PermuteDims(brevitas.jit.ScriptModule):
Expand Down Expand Up @@ -154,17 +155,19 @@ def forward(self, x: torch.Tensor):


class OverSubChannelBlockView(brevitas.jit.ScriptModule):
__constants__ = ['expanded_scaling_shape']
__constants__ = ['expanded_groupwise_shape', 'group_size', 'group_dim']

def __init__(self, expanded_scaling_shape, padding) -> None:
def __init__(self, expanded_groupwise_shape, group_size, group_dim) -> None:
super(OverSubChannelBlockView, self).__init__()
self.expanded_scaling_shape = expanded_scaling_shape
self.padding = padding
self.expanded_groupwise_shape = expanded_groupwise_shape
self.group_dim = group_dim
self.group_size = group_size

@brevitas.jit.script_method
def forward(self, x: torch.Tensor):
y = torch.nn.functional.pad(x, self.padding, mode='constant', value=0)
y = y.view(self.expanded_scaling_shape)
y = torch.nn.functional.pad(
x, padding(x, self.group_size, self.group_dim), mode='constant', value=0.)
y = y.view(self.expanded_groupwise_shape)
return y


Expand All @@ -181,12 +184,9 @@ def forward(self, x):

tensor_shape = x.shape
tensor_shape_list = list(tensor_shape)
padding = [0, 0] * len(tensor_shape_list)
if tensor_shape_list[self.group_dim] % self.group_size != 0:
padding[2 * self.group_dim] = self.group_size - tensor_shape_list[
self.group_dim] % self.group_size
padding = list(reversed(padding))
x = torch.nn.functional.pad(x, padding, mode='constant', value=0)
pad = padding(x, self.group_size, self.group_dim)

x = torch.nn.functional.pad(x, pad, mode='constant', value=0.)

tensor_shape = x.shape
tensor_shape_list = list(tensor_shape)
Expand Down
2 changes: 1 addition & 1 deletion src/brevitas/function/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,7 +189,7 @@ def min_int(signed: bool, narrow_range: bool, bit_width: Tensor) -> Tensor:
return value


@brevitas.jit.script
@brevitas.jit.ignore
def max_float(exponent_bit_width: Tensor, mantissa_bit_width: Tensor, exponent_bias: Tensor):
max_exponent = (2. ** exponent_bit_width) - 1. - exponent_bias
max_mantissa = torch.sum((
Expand Down
44 changes: 29 additions & 15 deletions src/brevitas/graph/gpxq.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,14 +11,17 @@
from typing import List, Optional, Set
import warnings

import torch
from torch.fx import GraphModule as TorchGraphModule

from brevitas.fx import GraphModule
from brevitas.graph.calibrate import disable_return_quant_tensor
from brevitas.graph.calibrate import DisableEnableQuantization
from brevitas.graph.calibrate import restore_return_quant_tensor
from brevitas.graph.utils import is_conv_transposed
import brevitas.nn as qnn
from brevitas.quant_tensor import IntQuantTensor
from brevitas.quant_tensor.base_quant_tensor import QuantTensor
from brevitas.utils.quant_utils import _CachedIO

SUPPORTED_CONV_OP = (
Expand Down Expand Up @@ -194,26 +197,29 @@ def __init__(
self.layer = layer
self.name = name
self.act_order = act_order
if self.layer.weight_quant.is_groupwise:
weight = self.layer.weight_quant.apply_input_view(self.layer.weight)
weight = weight.view(self.layer.weight_quant.quant_injector.reshaped_groupwise_shape)
self.layer.weight.data = weight.data
self.layer.in_channels = weight.shape[1] if is_conv_transposed(
self.layer) else weight.shape[0]

weight = layer.weight.data
weight_shape = torch.tensor(layer.weight.shape)

if create_weight_orig and not hasattr(self.layer, 'weight_orig'):
self.layer.register_buffer('weight_orig', layer.weight.detach().clone())

# By default, use groups = 1
self.groups = 1
if isinstance(self.layer, SUPPORTED_CONV_OP):
if isinstance(
self.layer,
(qnn.QuantConvTranspose1d, qnn.QuantConvTranspose2d, qnn.QuantConvTranspose3d)):
weight = weight.transpose(1, 0) # This performs a view
weight = weight.flatten(1)
if is_conv_transposed(self.layer):
weight_shape[1], weight_shape[0] = weight_shape[0], weight_shape[1]
self.groups = self.layer.groups

# Number of rows is equal to the output channels (OC)
self.rows = weight.shape[0]
self.rows = weight_shape[0]
# Number of columns is equal to the input channels (IC)
self.columns = weight.shape[1]
self.columns = torch.prod(weight_shape[1:])
self.len_parallel_layers = len_parallel_layers

self.disable_pre_forward_hook = False
Expand Down Expand Up @@ -262,17 +268,25 @@ def get_quant_weights(self, i, i1, permutation_list):
# For QuantLinear and for some QuantConvolutional layers, we exploit the possibility
# of quantizing only a subset of the entire matrix speeding up the computation of GPxQ
if isinstance(self.layer, qnn.QuantLinear):
index = permutation_list[0][i]
subtensor_slice_list = [None, (index, index + 1)]
q = self.layer.quant_weight(
subtensor_slice_list=subtensor_slice_list,
quant_input=self.quant_metadata).value.unsqueeze(0) # [1, OC, 1]
if self.layer.weight_quant.is_groupwise:
# No slicing, not optimized
index = permutation_list[0][i]
q = self.layer.quant_weight(quant_input=self.quant_metadata).value.unsqueeze(
0) # [1, OC, 1]
q = q[:, :, i:i + 1] # [groups, OC/groups, 1]
else:
index = permutation_list[0][i]
subtensor_slice_list = [None, (index, index + 1)]
q = self.layer.quant_weight(
subtensor_slice_list=subtensor_slice_list,
quant_input=self.quant_metadata).value.unsqueeze(0) # [1, OC, 1]
elif isinstance(self.layer, SUPPORTED_CONV_OP):
# For depthwise and ConvTranspose we fall back to quantizing the entire martix.
# For all other cases, we create a mask that represent the slicing we will perform on the weight matrix
# and we quantize only the selected dimensions.
if self.groups > 1 or (self.groups == 1 and isinstance(
self.layer, (qnn.QuantConvTranspose1d, qnn.QuantConvTranspose2d))):
if self.layer.weight_quant.is_groupwise or self.groups > 1 or (
self.groups == 1 and
isinstance(self.layer, (qnn.QuantConvTranspose1d, qnn.QuantConvTranspose2d))):

quant_weight = self.layer.quant_weight(quant_input=self.quant_metadata)
quant_weight = quant_weight.value
Expand Down
6 changes: 3 additions & 3 deletions src/brevitas/graph/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,13 +26,13 @@
'get_output_channels',
'get_output_channel_dim']

CONV_TRANSPOSED = [
CONV_TRANSPOSED = (
nn.ConvTranspose1d,
nn.ConvTranspose2d,
nn.ConvTranspose3d,
qnn.QuantConvTranspose1d,
qnn.QuantConvTranspose2d,
qnn.QuantConvTranspose3d]
qnn.QuantConvTranspose3d)


def module_class_name(m: torch.nn.Module):
Expand Down Expand Up @@ -146,7 +146,7 @@ def matches_module_pattern(pattern: Iterable, node: Node, modules: Dict[str, Any


def is_conv_transposed(module):
return isinstance(module, tuple(CONV_TRANSPOSED))
return isinstance(module, CONV_TRANSPOSED)


def get_output_channel_dim(module):
Expand Down
2 changes: 2 additions & 0 deletions src/brevitas/jit.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,14 @@ def _disabled(fn):

script_method = torch.jit.script_method
script = torch.jit.script
ignore = torch.jit.ignore
ScriptModule = torch.jit.ScriptModule
Attribute = torch.jit.Attribute

else:

script_method = _disabled
script = _disabled
ignore = _disabled
ScriptModule = torch.nn.Module
Attribute = lambda val, type: val
5 changes: 5 additions & 0 deletions src/brevitas/proxy/groupwise_float_parameter_quant.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,11 @@ def group_dim(self):
def group_size(self):
return self.quant_injector.group_size

def apply_input_view(self, x):
x = super().apply_input_view(x)
start_dim = self.group_dim if self.group_dim != -1 else -2
return x.flatten(start_dim, start_dim + 1)

def create_quant_tensor(self, qt_args: Tuple[Any]) -> GroupwiseFloatQuantTensor:
out, scale, zero_point, exponent_bit_width, mantissa_bit_width, exponent_bias, saturating, inf_values, nan_values = qt_args
return GroupwiseFloatQuantTensor(
Expand Down
5 changes: 5 additions & 0 deletions src/brevitas/proxy/groupwise_float_runtime_quant.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,11 @@ def group_dim(self):
def group_size(self):
return self.quant_injector.group_size

def apply_input_view(self, x):
x = super().apply_input_view(x)
start_dim = self.group_dim if self.group_dim != -1 else -2
return x.flatten(start_dim, start_dim + 1)

def create_quant_tensor(
self,
qt_args: Union[torch.Tensor, Tuple[Any]],
Expand Down
5 changes: 5 additions & 0 deletions src/brevitas/proxy/groupwise_int_parameter_quant.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,11 @@ def group_dim(self):
def group_size(self):
return self.quant_injector.group_size

def apply_input_view(self, x):
x = super().apply_input_view(x)
start_dim = self.group_dim if self.group_dim != -1 else -2
return x.flatten(start_dim, start_dim + 1)

def create_quant_tensor(self, qt_args: Tuple[Any]) -> GroupwiseIntQuantTensor:
out, scale, zero_point, bit_width = qt_args
return GroupwiseIntQuantTensor(
Expand Down
5 changes: 5 additions & 0 deletions src/brevitas/proxy/groupwise_int_runtime_quant.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,11 @@ def group_dim(self):
def group_size(self):
return self.quant_injector.group_size

def apply_input_view(self, x):
x = super().apply_input_view(x)
start_dim = self.group_dim if self.group_dim != -1 else -2
return x.flatten(start_dim, start_dim + 1)

def create_quant_tensor(
self,
qt_args: Union[torch.Tensor, Tuple[Any]],
Expand Down
2 changes: 1 addition & 1 deletion src/brevitas/proxy/parameter_quant.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,7 @@ def forward(self, x: torch.Tensor) -> Union[Tensor, QuantTensor]:
self._cached_weight = self.cache_class(
out.detach(), metadata_only=self.cache_inference_quant_weight_metadata_only)
else: # quantization disabled
out = x
out = self.apply_input_view(x)
return out


Expand Down
9 changes: 5 additions & 4 deletions src/brevitas/proxy/quant_proxy.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@

from brevitas import config
from brevitas.common import ExportMixin
from brevitas.core.scaling import ScalingPerOutputType
from brevitas.core.utils import StatelessBuffer
from brevitas.inject import BaseInjector as Injector
from brevitas.utils.quant_utils import float_to_int_impl_to_enum
Expand All @@ -21,10 +22,7 @@


def _is_groupwise(quant_injector):
if 'group_size' in quant_injector:
return True
else:
return False
return 'scaling_per_output' in quant_injector and quant_injector.scaling_per_output == ScalingPerOutputType.GROUP


def _is_narrow_range(quant_injector):
Expand Down Expand Up @@ -123,6 +121,9 @@ def add_tracked_module(self, module: nn.Module) -> None:
else:
raise RuntimeError("Trying to add None as a parent module.")

def apply_input_view(self, x):
return self.quant_injector.input_view_impl(x)

def _load_from_state_dict(
self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys,
error_msgs):
Expand Down
4 changes: 3 additions & 1 deletion src/brevitas/proxy/runtime_quant.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,7 +168,9 @@ def forward(self, x: Union[Tensor, QuantTensor]) -> Union[Tensor, QuantTensor]:
elif not self.is_quant_enabled:
# A tuple helps later with control flows
# The second None value is used later
y = (self.fused_activation_quant_proxy.activation_impl(y), None)
# If quant is not enabled, we still apply input_view in the case of groupwise + padding
y = self.apply_input_view(self.fused_activation_quant_proxy.activation_impl(y))
y = (y, None)
else:
y = self.fused_activation_quant_proxy(y)
# If y is an empty QuantTensor, we need to check if this is a passthrough proxy,
Expand Down
54 changes: 29 additions & 25 deletions src/brevitas/quant/solver/parameter.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,45 +111,49 @@ def scaling_impl(scaling_impl_type):
class SolveParameterScalingShape(ExtendedInjector):

@value
def scaling_shape(module, group_dim, group_size=None, scaling_per_output=None):
def scaling_shape(scaling_per_output, expanded_groupwise_shape=None, group_dim=None):
if scaling_per_output == ScalingPerOutputType.TENSOR:
return SCALAR_SHAPE
elif scaling_per_output == ScalingPerOutputType.CHANNEL:
return this.scaling_per_output_channel_shape
elif scaling_per_output == ScalingPerOutputType.GROUP:
assert group_size is not None, "Per Group scaling requires group size"
assert group_dim is not None, "Per Group scaling requires group dim"
size = list(module.weight.shape)
size[group_dim] = (size[group_dim] + group_size - 1) // group_size
size.insert(group_dim + 1, 1)
return size
# Scaling shape is like expanded_groupwise_shape but has 1 in position group_dim + 1
assert expanded_groupwise_shape is not None, "Per Group scaling not correctly configured"
assert group_dim is not None, "Per Group scaling not correctly configured"
size = list(expanded_groupwise_shape)
size[group_dim + 1] = 1
return tuple(size)

@value
def reshaped_scaling_shape(module):
return module.weight.shape
def reshaped_groupwise_shape(expanded_groupwise_shape, group_dim, group_size):
new_shape = list(expanded_groupwise_shape)
del new_shape[group_dim + 1] # delete the group_size shape
# Expand the group_dim shape, accounting for padding
new_shape[group_dim] = new_shape[group_dim] * group_size
return new_shape

@value
def expanded_scaling_shape(module, group_dim, group_size=None):
assert group_size is not None, "Per Group scaling requires group size"
size = list(module.weight.shape)
def expanded_groupwise_shape(tracked_parameter_list, group_dim, group_size=None):
# expanded_groupwise_shape will be called always to create scaling_shape, but it is only needed
# for groupwise quantization. All other groupwise shape infos are derived from this.

# If conditions do not allow for groupwise quantization, early exit and return None
if group_size is None:
return

# If group_size is specified and shared quantization is used, raise an error.
assert len(tracked_parameter_list) == 1, "Shared groupwise quantization is not currently supported"

weight_shape = tracked_parameter_list[0].shape
size = list(weight_shape)
size[group_dim] = (size[group_dim] + group_size - 1) // group_size
size.insert(group_dim + 1, group_size)
return size

@value
def padding(module, group_dim, group_size):
padding = [0, 0] * len(module.weight.shape)
size = list(module.weight.shape)
if size[group_dim] % group_size != 0:
# Padding is done on the left side
padding[2 * group_dim] = group_size - size[group_dim] % group_size
# Padding takes a list of 2 values per dim in reverse order (N_DIM, N_DIM-1,...,0)
# so we need to reverse the order
padding = list(reversed(padding))
return padding
return tuple(size)

@value
def group_dim(module, group_size=None):
# group_dim will be called always to create scaling_shape, but it is only needed
# for groupwise quantization.
if group_size is not None:
return 1 if not hasattr(module, 'transposed') or not module.transposed else 0

Expand Down
13 changes: 12 additions & 1 deletion src/brevitas/utils/torch_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
# SPDX-License-Identifier: BSD-3-Clause

import copy
from typing import Optional, Tuple
from typing import List, Optional, Tuple

import torch
from torch.nn import Sequential
Expand Down Expand Up @@ -102,3 +102,14 @@ def float_internal_scale(
internal_scale = torch.clamp_min(internal_scale, fp_internal_scale_min)
internal_scale = torch.exp2(internal_scale)
return internal_scale


@brevitas.jit.ignore
def padding(x: torch.Tensor, group_size: int, group_dim: int) -> List[int]:
# Given a tensor X, compute the padding aloing group_dim so that groupwise shaping is possible
padding = [0, 0] * len(x.shape)
size = x.shape
if size[group_dim] % group_size != 0:
padding[2 * group_dim] = group_size - size[group_dim] % group_size
padding = list(reversed(padding))
return padding
Loading

0 comments on commit d4834bd

Please sign in to comment.