Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix (mx): input view during quantization #1005

Merged
merged 16 commits into from
Aug 24, 2024
6 changes: 3 additions & 3 deletions notebooks/minifloat_mx_tutorial.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -233,20 +233,20 @@
"import brevitas.nn as qnn\n",
"import torch\n",
"\n",
"class MXFloat8Weight(MXInt8Weight):\n",
"class MXInt8Weight(MXInt8Weight):\n",
" # The group dimension for the weights it is automatically identified based on the layer type\n",
" # If a new layer type is used, it can be manually specified\n",
" bit_width = 8\n",
"\n",
"class MXFloat8Act(MXInt8Act):\n",
"class MXInt8Act(MXInt8Act):\n",
" # It is necessary to specify the group dimension for the activation quantization\n",
" group_dim = 1\n",
" bit_width = 8\n",
"\n",
"class MXModel(nn.Module):\n",
" def __init__(self):\n",
" super().__init__()\n",
" self.conv = qnn.QuantConv2d(32, 64, 3, weight_quant=MXFloat8Weight, input_quant=MXFloat8Act)\n",
" self.conv = qnn.QuantConv2d(32, 64, 3, weight_quant=MXInt8Weight, input_quant=MXInt8Act)\n",
" \n",
" def forward(self, x):\n",
" return self.conv(x)\n",
Expand Down
20 changes: 20 additions & 0 deletions src/brevitas/core/function_wrapper/shape.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,6 +171,25 @@ def forward(self, x: torch.Tensor):
return y


class DynamicOverSubChannelBlockView(brevitas.jit.ScriptModule):
__constants__ = ['group_size', 'group_dim']

def __init__(self, group_size, group_dim) -> None:
super(DynamicOverSubChannelBlockView, self).__init__()
self.group_size = group_size
self.group_dim = group_dim

@brevitas.jit.script_method
def forward(self, x):
tensor_shape = x.shape
tensor_shape_list = list(tensor_shape)
tensor_shape_list[self.group_dim] = int(tensor_shape_list[self.group_dim] / self.group_size)
block_dim = self.group_dim + 1 if self.group_dim != -1 else -1
tensor_shape_list.insert(block_dim, self.group_size)
x = x.view(tensor_shape_list)
return x


class StatsInputViewShapeImpl(object):
"""
Enum-like object to collect pointers to variants of ScriptModules that perform a view on a tensor.
Expand All @@ -182,3 +201,4 @@ class StatsInputViewShapeImpl(object):
OVER_BATCH_OVER_OUTPUT_CHANNELS = OverBatchOverOutputChannelView
OVER_OUTPUT_FEATURES = OverOutputFeaturesView
OVER_SUBCHANNEL_BLOCK = OverSubChannelBlockView
DYNAMIC_OVER_SUBCHANNEL_BLOCK = DynamicOverSubChannelBlockView
4 changes: 3 additions & 1 deletion src/brevitas/core/quant/float.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ def __init__(
mantissa_bit_width: int,
exponent_bias: int,
float_clamp_impl: nn.Module,
input_view_impl: nn.Module,
scaling_impl: Optional[nn.Module] = None,
float_scaling_impl: Optional[nn.Module] = None,
float_to_int_impl: nn.Module = RoundSte(),
Expand Down Expand Up @@ -52,6 +53,7 @@ def __init__(
if scaling_impl is None:
scaling_impl = ConstScaling(1., device=device, dtype=dtype)

self.input_view_impl = input_view_impl
# Zero-point is currently hardcoded to 0
self.zero_point_impl = StatelessBuffer(torch.tensor(0., device=device, dtype=dtype))
self.float_scaling_impl = float_scaling_impl
Expand All @@ -71,7 +73,7 @@ def quantize(self, x: torch.Tensor):
float_scaling_impl_value = self.float_scaling_impl(
self.exponent_bit_width(), self.mantissa_bit_width(), self.exponent_bias())
scale = scale / float_scaling_impl_value

x = self.input_view_impl(x)
scaled_x = x / scale
internal_scale = float_internal_scale(
scaled_x, self.mantissa_bit_width(), self.fp_internal_scale_min(), self.eps)
Expand Down
8 changes: 7 additions & 1 deletion src/brevitas/core/quant/int.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,7 @@ class RescalingIntQuant(brevitas.jit.ScriptModule):
def __init__(
self,
int_quant: Module,
input_view_impl: Module,
scaling_impl: Module,
int_scaling_impl: Module,
zero_point_impl: Module,
Expand All @@ -145,6 +146,7 @@ def __init__(
self.int_scaling_impl = int_scaling_impl
self.zero_point_impl = zero_point_impl
self.msb_clamp_bit_width_impl = bit_width_impl
self.input_view_impl = input_view_impl

@brevitas.jit.script_method
def forward(self, x: Tensor) -> Tuple[Tensor, Tensor, Tensor, Tensor]:
Expand All @@ -153,6 +155,7 @@ def forward(self, x: Tensor) -> Tuple[Tensor, Tensor, Tensor, Tensor]:
int_threshold = self.int_scaling_impl(bit_width)
scale = threshold / int_threshold
zero_point = self.zero_point_impl(x, scale, bit_width)
x = self.input_view_impl(x)
y = self.int_quant(scale, zero_point, bit_width, x)
return y, scale, zero_point, bit_width
Giuseppe5 marked this conversation as resolved.
Show resolved Hide resolved

Expand All @@ -167,7 +170,8 @@ def __init__(
int_scaling_impl: Module,
pre_zero_point_impl: Module,
zero_point_impl: Module,
bit_width_impl: Module):
bit_width_impl: Module,
input_view_impl: Module):
super(DecoupledRescalingIntQuant, self).__init__()
self.decoupled_int_quant = decoupled_int_quant
self.pre_scaling_impl = pre_scaling_impl
Expand All @@ -176,6 +180,7 @@ def __init__(
self.pre_zero_point_impl = pre_zero_point_impl
self.zero_point_impl = zero_point_impl
self.msb_clamp_bit_width_impl = bit_width_impl
self.input_view_impl = input_view_impl

@brevitas.jit.script_method
def forward(self, x: Tensor) -> Tuple[Tensor, Tensor, Tensor, Tensor, Tensor, Tensor]:
Expand All @@ -184,6 +189,7 @@ def forward(self, x: Tensor) -> Tuple[Tensor, Tensor, Tensor, Tensor, Tensor, Te
pre_threshold = self.pre_scaling_impl(x)
pre_scale = pre_threshold / int_threshold
pre_zero_point = self.pre_zero_point_impl(x, pre_scale, bit_width)
x = self.input_view_impl(x)
threshold = self.scaling_impl(x)
scale = threshold / int_threshold
zero_point = self.zero_point_impl(x, scale, bit_width)
Expand Down
13 changes: 3 additions & 10 deletions src/brevitas/core/scaling/runtime.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,6 +166,7 @@ def __init__(
self,
group_size: int,
group_dim: int,
input_view_impl: torch.nn.Module,
scaling_stats_impl: torch.nn.Module,
scaling_min_val: Optional[float],
restrict_scaling_impl: Optional[torch.nn.Module]) -> None:
Expand All @@ -174,21 +175,13 @@ def __init__(
self.group_dim = group_dim
self.scaling_stats_impl = scaling_stats_impl
self.scaling_min_val = scaling_min_val
self.input_view_impl = input_view_impl
self.restrict_clamp_scaling = _RestrictClampValue(scaling_min_val, restrict_scaling_impl)

@brevitas.jit.script_method
def group_scaling_reshape(self, stats_input):
tensor_shape = stats_input.shape
tensor_shape_list = list(tensor_shape)
tensor_shape_list[self.group_dim] = int(tensor_shape_list[self.group_dim] / self.group_size)
block_dim = self.group_dim + 1 if self.group_dim != -1 else -1
tensor_shape_list.insert(block_dim, self.group_size)
stats_input = stats_input.view(tensor_shape_list)
return stats_input

@brevitas.jit.script_method
def forward(self, stats_input) -> torch.Tensor:
stats_input_reshaped = self.group_scaling_reshape(stats_input)
stats_input_reshaped = self.input_view_impl(stats_input)
out = self.scaling_stats_impl(stats_input_reshaped)
# Scaling min val
out = self.restrict_clamp_scaling(out)
Expand Down
6 changes: 0 additions & 6 deletions src/brevitas/proxy/groupwise_float_parameter_quant.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,6 @@

class GroupwiseWeightFloatQuantProxyFromInjector(WeightFloatQuantProxyFromInjectorBase):

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
# TODO: Is this always generated?
self.view_impl = self.quant_injector.scaling_stats_input_view_shape_impl

@property
def group_dim(self):
return self.quant_injector.group_dim
Expand All @@ -25,7 +20,6 @@ def group_size(self):
def forward(self, x: torch.Tensor) -> Union[Tensor, GroupwiseFloatQuantTensor]:
if self.is_quant_enabled:
impl = self.export_handler if self.export_mode else self.tensor_quant
x = self.view_impl(x)
out, scale, zero_point, exponent_bit_width, mantissa_bit_width, exponent_bias, saturating, inf_values, nan_values = impl(x)
return GroupwiseFloatQuantTensor(
out,
Expand Down
1 change: 0 additions & 1 deletion src/brevitas/proxy/groupwise_float_runtime_quant.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@ def forward(self, x: Union[Tensor, QuantTensor]) -> Union[Tensor, GroupwiseFloat
y = x
if isinstance(y, QuantTensor):
y = y.value

if self.export_mode:
y = self.fused_activation_quant_proxy.activation_impl(y)
y = self.export_handler(y)
Expand Down
1 change: 0 additions & 1 deletion src/brevitas/proxy/groupwise_int_parameter_quant.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@ def group_size(self):
def forward(self, x: torch.Tensor) -> Union[Tensor, GroupwiseIntQuantTensor]:
if self.is_quant_enabled:
impl = self.export_handler if self.export_mode else self.tensor_quant
x = self.view_impl(x)
out, scale, zero_point, bit_width = impl(x)
return GroupwiseIntQuantTensor(
out,
Expand Down
13 changes: 12 additions & 1 deletion src/brevitas/quant/solver/act.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
# Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause

from brevitas.core.function_wrapper.misc import Identity
from brevitas.core.function_wrapper.shape import StatsInputViewShapeImpl
import torch
from torch import nn
from torch import Tensor
Expand Down Expand Up @@ -128,6 +130,14 @@ def update_state_dict_impl(scaling_impl_type):
return None


class SolveInputViewImpl(ExtendedInjector):
@value
def input_view_impl(scaling_per_output):
if scaling_per_output == ScalingPerOutputType.GROUP:
return StatsInputViewShapeImpl.DYNAMIC_OVER_SUBCHANNEL_BLOCK
else:
return Identity

class ActQuantSolver(SolveActTensorQuantFromEnum,
SolveActScalingImplFromEnum,
SolveIntScalingImplFromEnum,
Expand All @@ -140,7 +150,8 @@ class ActQuantSolver(SolveActTensorQuantFromEnum,
SolveActScalingShape,
SolveScalingStatsInputViewShapeImplFromEnum,
SolveActScalingPerOutputChannelShape,
SolveUpdateStateDictImplFromEnum):
SolveUpdateStateDictImplFromEnum,
SolveInputViewImpl):
"""
Translate enum directives to activation-specific quantization core modules.
It should be placed last in the list of classes a quantizer inherits from,
Expand Down
10 changes: 10 additions & 0 deletions src/brevitas/quant/solver/parameter.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
import math
from typing import List

from brevitas.core.function_wrapper.misc import Identity
from brevitas.core.function_wrapper.shape import StatsInputViewShapeImpl
from dependencies import this
from dependencies import value
import torch
Expand Down Expand Up @@ -139,3 +141,11 @@ def expanded_scaling_shape(module, group_size=None):
def group_dim(module, group_size=None):
if group_size is not None:
return 1

class SolveInputViewImpl(ExtendedInjector):
@value
def input_view_impl(scaling_per_output):
if scaling_per_output == ScalingPerOutputType.GROUP:
return StatsInputViewShapeImpl.OVER_SUBCHANNEL_BLOCK
else:
return Identity
6 changes: 5 additions & 1 deletion src/brevitas/quant/solver/weight.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from brevitas.inject import this
from brevitas.inject import value
from brevitas.proxy import WeightQuantProxyFromInjector
from brevitas.quant.solver.parameter import SolveInputViewImpl
from brevitas.quant.solver.common import *
from brevitas.quant.solver.parameter import *

Expand Down Expand Up @@ -68,6 +69,8 @@ def scaling_stats_input_concat_dim(scaling_per_output):
return 0
elif scaling_per_output == ScalingPerOutputType.CHANNEL:
return 1
else:
raise RuntimeError("Shared groupwise quantization is not supported")

@value
def permute_dims(module, output_channel_dim):
Expand Down Expand Up @@ -103,7 +106,8 @@ class WeightQuantSolver(SolveStatsReduceDimFromEnum,
SolveParameterScalingShape,
SolveWeightScalingPerOutputChannelShapeFromModule,
SolveWeightTensorQuantFromEnum,
SolveDtypeDeviceFromTrackedParameterList):
SolveDtypeDeviceFromTrackedParameterList,
SolveInputViewImpl):
"""
Translate enum and shape directives to weight-specific quantization core modules.
It should be placed last in the list of classes a quantizer inherits from,
Expand Down
8 changes: 5 additions & 3 deletions src/brevitas/quant_tensor/groupwise_float_quant_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,13 +90,15 @@ def __torch_function__(self, func, types, args=(), kwargs=None):

def expand(self):
curr_shape = self.value_.shape
new_value = self.value_.flatten(self.group_dim, self.group_dim + 1)
start_dim = self.group_dim if self.group_dim != -1 else -2
new_value = self.value_.flatten(start_dim, start_dim + 1)
new_value = self.value_.flatten(start_dim, start_dim + 1)
if self.scale_.shape != ():
new_scale = self.scale_.expand(curr_shape).flatten(self.group_dim, self.group_dim + 1)
new_scale = self.scale_.expand(curr_shape).flatten(start_dim, start_dim + 1)
else:
new_scale = self.scale_
if self.zero_point_.shape != ():
new_zp = self.zero_point_.expand(curr_shape).flatten(self.group_dim, self.group_dim + 1)
new_zp = self.zero_point_.expand(curr_shape).flatten(start_dim, start_dim + 1)
else:
new_zp = self.zero_point_

Expand Down
7 changes: 4 additions & 3 deletions src/brevitas/quant_tensor/groupwise_int_quant_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,13 +59,14 @@ def __torch_function__(self, func, types, args=(), kwargs=None):

def expand(self):
curr_shape = self.value_.shape
new_value = self.value_.flatten(self.group_dim, self.group_dim + 1)
start_dim = self.group_dim if self.group_dim != -1 else -2
new_value = self.value_.flatten(start_dim, start_dim + 1)
if self.scale_.shape != ():
new_scale = self.scale_.expand(curr_shape).flatten(self.group_dim, self.group_dim + 1)
new_scale = self.scale_.expand(curr_shape).flatten(start_dim, start_dim + 1)
else:
new_scale = self.scale_
if self.zero_point_.shape != ():
new_zp = self.zero_point_.expand(curr_shape).flatten(self.group_dim, self.group_dim + 1)
new_zp = self.zero_point_.expand(curr_shape).flatten(start_dim, start_dim + 1)
else:
new_zp = self.zero_point_

Expand Down
Loading