Skip to content

Commit

Permalink
Feat: functionalize QuantTensor
Browse files Browse the repository at this point in the history
  • Loading branch information
Giuseppe5 committed Mar 25, 2024
1 parent 8d11f52 commit e52db0a
Show file tree
Hide file tree
Showing 14 changed files with 773 additions and 546 deletions.
1 change: 0 additions & 1 deletion src/brevitas/graph/calibrate.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
from brevitas.nn import QuantHardTanh
from brevitas.nn import QuantLinear
from brevitas.nn.quant_layer import QuantWeightBiasInputOutputLayer as QuantWBIOL
from brevitas.nn.utils import compute_channel_view_shape
from brevitas.proxy.parameter_quant import BiasQuantProxyFromInjector
from brevitas.proxy.parameter_quant import WeightQuantProxyFromInjector
from brevitas.proxy.runtime_quant import ActQuantProxyFromInjector
Expand Down
2 changes: 1 addition & 1 deletion src/brevitas/nn/mixin/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,9 @@
from brevitas.common import ExportMixin
from brevitas.inject import ExtendedInjector
from brevitas.inject import Injector
from brevitas.nn.utils import compute_channel_view_shape
from brevitas.quant_tensor import _unpack_quant_tensor
from brevitas.quant_tensor import QuantTensor
from brevitas.utils.torch_utils import compute_channel_view_shape

from .utils import filter_kwargs

Expand Down
44 changes: 9 additions & 35 deletions src/brevitas/nn/quant_avg_pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,38 +47,19 @@ def channelwise_separable(self) -> bool:
def requires_export_handler(self):
return True

@property
def _avg_scaling(self):
if isinstance(self.kernel_size, tuple):
return self.kernel_size[0] * self.kernel_size[1]
else:
return self.kernel_size * self.kernel_size

def forward(self, input: Union[Tensor, QuantTensor]):
x = self.unpack_input(input)

if self.export_mode:
return self.export_handler(_unpack_quant_tensor(x))

if isinstance(x, QuantTensor):
x = x.set(value=super(TruncAvgPool2d, self).forward(x.value))
if self.is_trunc_quant_enabled:
# remove avg scaling
rescaled_value = x.value * self._avg_scaling
x = x.set(value=rescaled_value)
x = x.set(bit_width=self.max_acc_bit_width(x.bit_width))
x = self.trunc_quant(x)
if isinstance(x, QuantTensor) and self.is_trunc_quant_enabled:
y = AvgPool2d.forward(self, x)
y = self.trunc_quant(y)
else:
assert not self.is_trunc_quant_enabled
x = super(TruncAvgPool2d, self).forward(x)

return self.pack_output(x)
y = AvgPool2d.forward(self, _unpack_quant_tensor(x))

def max_acc_bit_width(self, input_bit_width):
max_uint_input = max_int(bit_width=input_bit_width, signed=False, narrow_range=False)
max_uint_output = max_uint_input * self._avg_scaling
max_output_bit_width = ceil_ste(torch.log2(max_uint_output))
return max_output_bit_width
return self.pack_output(y)


class TruncAdaptiveAvgPool2d(TruncMixin, QuantLayerMixin, AdaptiveAvgPool2d):
Expand Down Expand Up @@ -130,18 +111,11 @@ def forward(self, input: Union[Tensor, QuantTensor]):
self._set_global_is_quant_layer(False)
return out

if isinstance(x, QuantTensor):
y = x.set(value=super(TruncAdaptiveAvgPool2d, self).forward(x.value))
k_size, stride = self.compute_kernel_size_stride(x.value.shape[2:], y.value.shape[2:])
if self.is_trunc_quant_enabled:
reduce_size = reduce(mul, k_size, 1)
rescaled_value = y.value * reduce_size # remove avg scaling
y = y.set(value=rescaled_value)
y = y.set(bit_width=self.max_acc_bit_width(y.bit_width, reduce_size))
y = self.trunc_quant(y)
if isinstance(x, QuantTensor) and self.is_trunc_quant_enabled:
y = AdaptiveAvgPool2d.forward(self, x)
y = self.trunc_quant(y)
else:
assert not self.is_trunc_quant_enabled
y = super(TruncAdaptiveAvgPool2d, self).forward(x)
y = AdaptiveAvgPool2d.forward(self, _unpack_quant_tensor(x))

return self.pack_output(y)

Expand Down
8 changes: 0 additions & 8 deletions src/brevitas/nn/quant_conv.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,14 +109,6 @@ def inner_forward_impl(self, x: Tensor, quant_weight: Tensor, quant_bias: Option
else:
return self._conv_forward(x, quant_weight, quant_bias)

def max_acc_bit_width(self, input_bit_width, weight_bit_width):
max_uint_input = max_int(bit_width=input_bit_width, signed=False, narrow_range=False)
max_kernel_val = self.weight_quant.max_uint_value(weight_bit_width)
group_size = self.in_channels // self.groups
max_uint_output = max_uint_input * max_kernel_val * self.kernel_size[0] * group_size
max_output_bit_width = ceil_ste(torch.log2(max_uint_output))
return max_output_bit_width


class QuantConv2d(QuantWBIOL, Conv2d):

Expand Down
27 changes: 24 additions & 3 deletions src/brevitas/nn/quant_convtranspose.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,14 @@ def compute_output_padding(self, inp, output_size):
def conv_transpose1d_zeros_pad(
self, x: Tensor, weight: Tensor, bias: Optional[Tensor], output_padding):
out = conv_transpose1d(
x, weight, bias, self.stride, self.padding, output_padding, self.groups, self.dilation)
x,
weight,
bias,
stride=self.stride,
padding=self.padding,
output_padding=output_padding,
groups=self.groups,
dilation=self.dilation)
return out

def inner_forward_impl(self, x: Tensor, quant_weight: Tensor, quant_bias: Optional[Tensor]):
Expand Down Expand Up @@ -200,7 +207,14 @@ def compute_output_padding(self, inp, output_size):
def conv_transpose2d_zeros_pad(
self, x: Tensor, weight: Tensor, bias: Optional[Tensor], output_padding):
out = conv_transpose2d(
x, weight, bias, self.stride, self.padding, output_padding, self.groups, self.dilation)
x,
weight,
bias,
stride=self.stride,
padding=self.padding,
output_padding=output_padding,
groups=self.groups,
dilation=self.dilation)
return out

def inner_forward_impl(self, x: Tensor, quant_weight: Tensor, quant_bias: Optional[Tensor]):
Expand Down Expand Up @@ -298,7 +312,14 @@ def compute_output_padding(self, inp, output_size):
def conv_transpose3d_zeros_pad(
self, x: Tensor, weight: Tensor, bias: Optional[Tensor], output_padding):
out = conv_transpose3d(
x, weight, bias, self.stride, self.padding, output_padding, self.groups, self.dilation)
x,
weight,
bias,
stride=self.stride,
padding=self.padding,
output_padding=output_padding,
groups=self.groups,
dilation=self.dilation)
return out

def inner_forward_impl(self, x: Tensor, quant_weight: Tensor, quant_bias: Optional[Tensor]):
Expand Down
59 changes: 9 additions & 50 deletions src/brevitas/nn/quant_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,10 @@
from torch import Tensor
from torch.nn import Module

from brevitas.quant_tensor import _unpack_quant_tensor
from brevitas.quant_tensor import QuantTensor
from brevitas.utils.torch_utils import compute_channel_view_shape

from .mixin import *
from .utils import compute_channel_view_shape
from .utils import merge_bn
from .utils import rename_state_dict_by_prefix

Expand Down Expand Up @@ -47,7 +46,7 @@ def forward(self, input: Union[Tensor, QuantTensor]):
quant_input = self.input_quant(input)
# shortcut execution through the export impl during export
if self.export_mode:
out = self.export_handler(_unpack_quant_tensor(quant_input))
out = self.export_handler(quant_input)
self._set_global_is_quant_layer(False)
return out
out = self.act_quant(quant_input)
Expand Down Expand Up @@ -121,7 +120,8 @@ def max_acc_bit_width(self, input_bit_width: Tensor, quant_weight_bit_width: Ten

def quant_output_scale_impl(
self, inp: Tensor, quant_input_scale: Tensor, quant_weight_scale: Tensor):
output_scale_shape = compute_channel_view_shape(inp, channel_dim=1)
channel_dim = -1 if isinstance(self, torch.nn.Linear) else 1
output_scale_shape = compute_channel_view_shape(inp, channel_dim=channel_dim)
output_scale = quant_weight_scale.view(output_scale_shape)
output_scale = output_scale * quant_input_scale.view(output_scale_shape)
return output_scale
Expand All @@ -140,16 +140,12 @@ def merge_bn_in(self, bn):
merge_bn(self, bn, output_channel_dim=self.output_channel_dim)

def forward_impl(self, inp: Union[Tensor, QuantTensor]) -> Union[Tensor, QuantTensor]:
output_scale = None
output_bit_width = None
output_zero_point = None
output_signed = None

inp = self.unpack_input(inp)

# shortcut execution through the export impl during export
if self.export_mode:
out = self.export_handler(_unpack_quant_tensor(inp))
out = self.export_handler(inp)
self._set_global_is_quant_layer(False)
return out

Expand All @@ -163,58 +159,21 @@ def forward_impl(self, inp: Union[Tensor, QuantTensor]) -> Union[Tensor, QuantTe
self.output_quant.is_quant_enabled) and self.return_quant_tensor:
raise RuntimeError("QuantLayer is not correctly configured")

output_scale = None
if isinstance(quant_input, QuantTensor) and isinstance(quant_weight, QuantTensor):
output_bit_width = self.max_acc_bit_width(quant_input.bit_width, quant_weight.bit_width)
output_scale = self.quant_output_scale_impl(inp, quant_input.scale, quant_weight.scale)
output_signed = quant_input.signed or quant_weight.signed

if self.bias is not None:
quant_bias = self.bias_quant(self.bias, output_scale)

output_tensor = self.inner_forward_impl(
_unpack_quant_tensor(quant_input),
_unpack_quant_tensor(quant_weight),
_unpack_quant_tensor(quant_bias))

if output_scale is not None:
if (isinstance(quant_bias, QuantTensor) and
quant_bias.scale.data_ptr() != output_scale.data_ptr()) or not isinstance(
quant_bias, QuantTensor):
channel_dim = -1 if isinstance(self, torch.nn.Linear) else 1
output_scale_broadcast_shape = compute_channel_view_shape(
inp, channel_dim=channel_dim)
output_zero_point = -_unpack_quant_tensor(quant_bias).view(
output_scale_broadcast_shape) / output_scale

if output_bit_width is not None and isinstance(quant_bias, QuantTensor):
output_bit_width = torch.where(
quant_bias.bit_width > output_bit_width, quant_bias.bit_width, output_bit_width)
output_bit_width = output_bit_width + 1
output_tensor = self.inner_forward_impl(quant_input, quant_weight, quant_bias)
else:
output_tensor = self.inner_forward_impl(
_unpack_quant_tensor(quant_input), _unpack_quant_tensor(quant_weight), None)
output_tensor = self.inner_forward_impl(quant_input, quant_weight, None)

if not self.output_quant.is_quant_enabled and self.return_quant_tensor:
if compute_output_quant_tensor:
if (quant_input.zero_point != 0.0).any() or (quant_weight.zero_point != 0.0).any():
raise RuntimeError(
"Computing zero point of output accumulator not supported yet.")
elif output_zero_point is None:
output_zero_point = quant_input.zero_point

elif output_zero_point is None:
output_zero_point = torch.zeros(1).type_as(output_tensor)

if compute_output_quant_tensor:
quant_output = QuantTensor(
output_tensor,
scale=output_scale,
zero_point=output_zero_point,
bit_width=output_bit_width,
signed=output_signed,
training=self.training)
else:
quant_output = output_tensor

quant_output = self.output_quant(quant_output)
quant_output = self.output_quant(output_tensor)
return self.pack_output(quant_output)
7 changes: 0 additions & 7 deletions src/brevitas/nn/quant_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,10 +79,3 @@ def quant_output_scale_impl(
quant_weight_scale = quant_weight_scale.view(weight_broadcast_shape)
quant_output_scale = linear(quant_input_scale, quant_weight_scale)
return quant_output_scale

def max_acc_bit_width(self, input_bit_width, weight_bit_width):
max_input_val = max_int(bit_width=input_bit_width, signed=False, narrow_range=False)
max_fc_val = self.weight_quant.max_uint_value(weight_bit_width)
max_output_val = max_input_val * max_fc_val * self.in_features
output_bit_width = ceil_ste(torch.log2(max_output_val))
return output_bit_width
24 changes: 21 additions & 3 deletions src/brevitas/nn/quant_scale_bias.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,13 +78,31 @@ def channelwise_separable(self) -> bool:
def forward(self, inp: Union[Tensor, QuantTensor]) -> Union[Tensor, QuantTensor]:
return self.forward_impl(inp)

def inner_forward_impl(self, input: Tensor, quant_weight: Tensor, quant_bias: Optional[Tensor]):
def inner_forward_impl(
self,
input: Union[Tensor, QuantTensor],
quant_weight: Union[Tensor, QuantTensor],
quant_bias: Optional[Union[Tensor, QuantTensor]]):
quant_weight = quant_weight.view(self.runtime_shape)
quant_bias = quant_bias.view(self.runtime_shape)
output_tensor = input * quant_weight + quant_bias

# TODO: when implementing new types of QuantTensor, this should be revised
if isinstance(input, QuantTensor):
from brevitas.quant_tensor.torch_handler import quant_layer
output_tensor = quant_layer(
torch.mul,
input,
quant_weight,
bias=None,
external_acc_bit_width_fn=self.max_acc_bit_width)
else:
output_tensor = torch.mul(input, quant_weight)

if quant_bias is not None:
output_tensor += quant_bias
return output_tensor

def max_acc_bit_width(self, input_bit_width, weight_bit_width):
def max_acc_bit_width(self, input_bit_width, weight_bit_width, weight_shape):
max_input_val = max_int(bit_width=input_bit_width, signed=False, narrow_range=False)
max_weight_val = self.weight_quant.max_uint_value(weight_bit_width)
max_output_val = max_input_val * max_weight_val
Expand Down
12 changes: 3 additions & 9 deletions src/brevitas/nn/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,17 +2,9 @@
# SPDX-License-Identifier: BSD-3-Clause

import torch
from torch import Tensor
from torch.nn import Parameter

from brevitas.proxy.parameter_quant import BiasQuantProxyFromInjector
from brevitas.proxy.parameter_quant import WeightQuantProxyFromInjector


def compute_channel_view_shape(tensor: Tensor, channel_dim: int):
broadcast_shape = [1] * len(tensor.size())
broadcast_shape[channel_dim] = -1
return tuple(broadcast_shape)
from brevitas.utils.torch_utils import compute_channel_view_shape


def mul_add_from_bn(bn_mean, bn_var, bn_eps, bn_weight, bn_bias):
Expand All @@ -23,6 +15,8 @@ def mul_add_from_bn(bn_mean, bn_var, bn_eps, bn_weight, bn_bias):


def merge_bn(layer, bn, output_channel_dim=0):
from brevitas.proxy.parameter_quant import BiasQuantProxyFromInjector
from brevitas.proxy.parameter_quant import WeightQuantProxyFromInjector
out = mul_add_from_bn(
bn_mean=bn.running_mean,
bn_var=bn.running_var,
Expand Down
Loading

0 comments on commit e52db0a

Please sign in to comment.