Skip to content

Commit

Permalink
Feat (QuantTensor)!: QuantTensor cannot be empty (#819)
Browse files Browse the repository at this point in the history
BREAKING CHANGE: Creating a QuantTensor without metadata is no longer allowed.
  • Loading branch information
Giuseppe5 authored Feb 23, 2024
1 parent 506954c commit 6079b12
Show file tree
Hide file tree
Showing 31 changed files with 2,040 additions and 1,496 deletions.
551 changes: 258 additions & 293 deletions notebooks/01_quant_tensor_quant_conv2d_overview.ipynb

Large diffs are not rendered by default.

169 changes: 48 additions & 121 deletions notebooks/02_quant_activation_overview.ipynb

Large diffs are not rendered by default.

489 changes: 214 additions & 275 deletions notebooks/03_anatomy_of_a_quantizer.ipynb

Large diffs are not rendered by default.

946 changes: 830 additions & 116 deletions notebooks/Brevitas_TVMCon2021.ipynb

Large diffs are not rendered by default.

89 changes: 55 additions & 34 deletions notebooks/ONNX_export_tutorial.ipynb

Large diffs are not rendered by default.

639 changes: 302 additions & 337 deletions notebooks/quantized_recurrent.ipynb

Large diffs are not rendered by default.

3 changes: 2 additions & 1 deletion src/brevitas/core/quant/binary.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,8 +48,9 @@ class BinaryQuant(brevitas.jit.ScriptModule):
Set env variable BREVITAS_JIT=1 to enable TorchScript compilation of this module.
"""

def __init__(self, scaling_impl: Module, quant_delay_steps: int = 0):
def __init__(self, scaling_impl: Module, signed: bool = True, quant_delay_steps: int = 0):
super(BinaryQuant, self).__init__()
assert signed, "Unsigned binary quant not supported"
self.scaling_impl = scaling_impl
self.bit_width = BitWidthConst(1)
self.zero_point = StatelessBuffer(torch.tensor(0.0))
Expand Down
4 changes: 2 additions & 2 deletions src/brevitas/core/stats/stats_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from brevitas import config
from brevitas.core.utils import StatelessBuffer
from brevitas.function.ops import max_int
from brevitas.quant_tensor import _unpack_quant_tensor
# Use custom implementation of kthvalue as work around to (b)float16 kernel limitations
from brevitas.utils.torch_utils import kthvalue

Expand Down Expand Up @@ -478,8 +479,7 @@ def evaluate_loss(self, x, candidate):
# Set to local_loss_mode before calling the proxy
self.set_local_loss_mode(True)
quant_value = self.proxy_forward(x)
if isinstance(quant_value, tuple):
quant_value = quant_value[0]
quant_value = _unpack_quant_tensor(quant_value)
loss = self.mse_loss_fn(x, quant_value)
self.set_local_loss_mode(False)
return loss
Expand Down
5 changes: 2 additions & 3 deletions src/brevitas/export/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,10 +207,9 @@ def _cache_fn_dispatcher(cls, fn, input, *args, **kwargs):
if isinstance(input, QuantTensor):
inp_cache = None
out_cache = None
if input.is_not_none:
inp_cache = _CachedIO(input, metadata_only=True)
inp_cache = _CachedIO(input, metadata_only=True)
output = fn(input, *args, **kwargs)
if isinstance(output, QuantTensor) and output.is_not_none:
if isinstance(output, QuantTensor):
out_cache = _CachedIO(output, metadata_only=True)
cached_io = (inp_cache, out_cache)
if fn in cls._cached_io_handler_map:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ def input_quant_symbolic_kwargs(cls, module):

@classmethod
def input_dequant_symbolic_kwargs(cls, module):
if module._cached_inp.scale is not None:
if module._cached_inp is not None:
return cls.dequant_symbolic_kwargs_from_cached_io(module._cached_inp)
else:
return None
Expand Down
7 changes: 0 additions & 7 deletions src/brevitas/export/onnx/standard/qoperator/manager.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,11 @@
# Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause

from typing import Optional, Tuple, Union

from packaging import version
from torch import Tensor
from torch.nn import functional as F
from torch.nn import Module

from brevitas import torch_version
from brevitas.export.manager import _set_layer_export_handler
from brevitas.export.manager import _set_layer_export_mode
from brevitas.export.onnx.manager import ONNXBaseManager
from brevitas.quant_tensor import QuantTensor

from ..function import DequantizeLinearFn
from ..function import IntClipFn
Expand Down
20 changes: 19 additions & 1 deletion src/brevitas/graph/calibrate.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,21 @@
BN_LAYERS = (nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d)


def disable_return_quant_tensor(model):
previous_state = {}
for module in model.modules():
if hasattr(module, 'return_quant_tensor'):
previous_state[module] = module.return_quant_tensor
module.return_quant_tensor = False
return previous_state


def restore_return_quant_tensor(model, previous_state):
for module in model.modules():
if hasattr(module, 'return_quant_tensor'):
module.return_quant_tensor = previous_state[module]


def extend_collect_stats_steps(module):
if hasattr(module, 'collect_stats_steps'):
# We extend the collect steps in PTQ to match potentially long calibrations
Expand Down Expand Up @@ -75,18 +90,21 @@ def __init__(self, model, enabled=True):
self.previous_training_state = model.training
self.disable_quant_inference = DisableEnableQuantization(call_act_quantizer_impl=True)
self.enabled = enabled
self.return_quant_tensor_state = dict()

def __enter__(self):
if self.enabled:
self.model.apply(extend_collect_stats_steps)
self.model.apply(set_collect_stats_to_average)
self.return_quant_tensor_state = disable_return_quant_tensor(self.model)
self.disable_quant_inference.apply(
self.model, is_training=True, quantization_enabled=False)

def __exit__(self, type, value, traceback):
self.model.apply(finalize_collect_stats)
self.disable_quant_inference.apply(
self.model, is_training=self.previous_training_state, quantization_enabled=True)
restore_return_quant_tensor(self.model, self.return_quant_tensor_state)


class load_quant_model:
Expand Down Expand Up @@ -168,7 +186,7 @@ def disable_act_quant_hook(self, module, inp, output):
if isinstance(module.tracked_module_list[0], QuantHardTanh):
inp = F.hardtanh(
inp, min_val=module.quant_injector.min_val, max_val=module.quant_injector.max_val)
return QuantTensor(value=inp, training=module.training)
return inp

def disable_act_quantization(self, model, is_training):
# If self.call_act_quantizer_impl is set to True, the quantization will be performed but the output
Expand Down
17 changes: 6 additions & 11 deletions src/brevitas/graph/gpxq.py
Original file line number Diff line number Diff line change
Expand Up @@ -234,19 +234,14 @@ def process_input(self, inp):
inp_training = self.layer.training

# If using quantized activations, inp could be QuantTensor. In
# this case, we overwrite the metadata if it is specified.
# this case, we overwrite the metadata.
if isinstance(inp, QuantTensor):
if self.layer_requires_input_quant and (self.quant_input is None):
if inp.scale is not None:
inp_scale = inp.scale
if inp.zero_point is not None:
inp_zero_point = inp.zero_point
if inp.bit_width is not None:
inp_bit_width = inp.bit_width
if inp.signed is not None:
inp_signed = inp.signed
if inp.training is not None:
inp_training = inp.training
inp_scale = inp.scale
inp_zero_point = inp.zero_point
inp_bit_width = inp.bit_width
inp_signed = inp.signed
inp_training = inp.training
inp = inp.value

# if the layer requires an input quant and the quant input cache has
Expand Down
12 changes: 5 additions & 7 deletions src/brevitas/nn/hadamard_classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,15 +49,13 @@ def forward(self, inp):
out = inp.value / norm
out = nn.functional.linear(out, self.proj[:self.out_channels, :self.in_channels])
out = -self.scale * out
if inp.scale is not None:
if isinstance(inp, QuantTensor):
output_scale = inp.scale * self.scale / norm
if inp.bit_width is not None:
output_bit_width = self.max_output_bit_width(inp.bit_width)
if (self.return_quant_tensor and inp.zero_point is not None and
(inp.zero_point != 0.0).any()):
raise RuntimeError("Computing zero point of output accumulator not supported yet.")
else:
output_zp = inp.zero_point
if (self.return_quant_tensor and inp.zero_point != 0.0).any():
raise RuntimeError("Computing zero point of output accumulator not supported yet.")
else:
output_zp = inp.zero_point
out = QuantTensor(
value=out,
scale=output_scale,
Expand Down
42 changes: 21 additions & 21 deletions src/brevitas/nn/mixin/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from brevitas.inject import ExtendedInjector
from brevitas.inject import Injector
from brevitas.nn.utils import compute_channel_view_shape
from brevitas.quant_tensor import _unpack_quant_tensor
from brevitas.quant_tensor import QuantTensor

from .utils import filter_kwargs
Expand Down Expand Up @@ -154,7 +155,7 @@ def quant_output_bit_width(self):
else:
return None

def unpack_input(self, inp: Union[Tensor, QuantTensor]):
def unpack_input(self, inp: Union[Tensor, QuantTensor]) -> Union[Tensor, QuantTensor]:
self._set_global_is_quant_layer(True)
# Hack to recognize a QuantTensor that has decayed to a tuple
# when used as input to tracing (e.g. during ONNX export)
Expand All @@ -166,25 +167,23 @@ def unpack_input(self, inp: Union[Tensor, QuantTensor]):
if not self.training and not self._export_mode and self.cache_inference_quant_inp:
cached_inp = _CachedIO(inp.detach(), self.cache_quant_io_metadata_only)
self._cached_inp = cached_inp
else:
inp = QuantTensor(inp, training=self.training)
if not self.training and self.cache_inference_quant_inp:
cached_inp = _CachedIO(inp.detach(), self.cache_quant_io_metadata_only)
self._cached_inp = cached_inp
# Remove any naming metadata to avoid dowmstream errors
# Avoid inplace operations on the input in case of forward hooks
if not torch._C._get_tracing_state():
inp = inp.set(value=inp.value.rename(None))
if isinstance(inp, QuantTensor):
inp = inp.set(value=inp.value.rename(None))
else:
inp = inp.rename(None)
return inp

def pack_output(self, quant_output: QuantTensor):
if not self.training and self.cache_inference_quant_out:
def pack_output(self, quant_output: Union[Tensor, QuantTensor]) -> Union[Tensor, QuantTensor]:
if not self.training and self.cache_inference_quant_out and isinstance(quant_output,
QuantTensor):
self._cached_out = _CachedIO(quant_output.detach(), self.cache_quant_io_metadata_only)
self._set_global_is_quant_layer(False)
if self.return_quant_tensor:
assert isinstance(quant_output, QuantTensor)
return quant_output
else:
return quant_output.value
return _unpack_quant_tensor(quant_output)


class QuantRecurrentLayerMixin(ExportMixin):
Expand Down Expand Up @@ -246,9 +245,9 @@ def gate_params_fwd(gate, quant_input):
acc_bit_width = None
quant_weight_ih = gate.input_weight()
quant_weight_hh = gate.hidden_weight()
if quant_input.bit_width is not None:
if isinstance(quant_input, QuantTensor):
acc_bit_width = None # TODO
if quant_input.scale is not None and quant_weight_ih.scale is not None:
if isinstance(quant_input, QuantTensor) and isinstance(quant_weight_ih, QuantTensor):
acc_scale_shape = compute_channel_view_shape(quant_input.value, channel_dim=1)
acc_scale = quant_weight_ih.scale.view(acc_scale_shape)
acc_scale = acc_scale * quant_input.scale.view(acc_scale_shape)
Expand All @@ -267,24 +266,23 @@ def maybe_quantize_input(self, inp):
quant_input = inp
if not self.quantize_output_only:
quant_input = self.io_quant(quant_input)
elif not isinstance(inp, QuantTensor):
quant_input = QuantTensor(quant_input)
return quant_input

def maybe_quantize_state(self, inp, state, quant):
if state is None:
batch_size = inp.size(0) if self.cell.batch_first else inp.size(1)
quant_state = torch.zeros(
int(batch_size), self.hidden_size, dtype=inp.dtype, device=inp.device)
quant_state = QuantTensor(quant_state)
else:
quant_state = quant(state)
return quant_state

def pack_quant_outputs(self, quant_outputs):
# In export mode, quant_outputs has the shape of the output concatenated value
# Even though we check that return_quant_tensor can be enabled only with io_quant != None,
# inner layers in a deep network overrides it, so we check again.
if self.export_mode:
if self.return_quant_tensor:
if self.return_quant_tensor and self.io_quant.is_quant_enabled:
return QuantTensor(
quant_outputs,
self.io_quant.scale(),
Expand All @@ -295,7 +293,7 @@ def pack_quant_outputs(self, quant_outputs):
else:
return quant_outputs
seq_dim = 1 if self.cell.batch_first else 0
if self.return_quant_tensor:
if self.return_quant_tensor and self.io_quant.is_quant_enabled:
outputs = [
QuantTensor(
torch.unsqueeze(quant_output[0], dim=seq_dim),
Expand All @@ -312,8 +310,10 @@ def pack_quant_outputs(self, quant_outputs):
return torch.cat(outputs, dim=seq_dim)

def pack_quant_state(self, quant_state, quant):
# Even though we check that return_quant_tensor can be enabled only with quant != None,
# inner layers in a deep network overrides it, so we check again.
if self.export_mode:
if self.return_quant_tensor:
if self.return_quant_tensor and quant.is_quant_enabled:
quant_state = QuantTensor(
torch.unsqueeze(quant_state, dim=0),
quant.scale(),
Expand All @@ -324,7 +324,7 @@ def pack_quant_state(self, quant_state, quant):
else:
quant_state = torch.unsqueeze(quant_state, dim=0)
else:
if self.return_quant_tensor:
if self.return_quant_tensor and quant.is_quant_enabled:
quant_state = QuantTensor(
torch.unsqueeze(quant_state[0], dim=0),
quant_state[1],
Expand Down
6 changes: 5 additions & 1 deletion src/brevitas/nn/mixin/parameter.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,7 +198,11 @@ def quant_bias_zero_point(self):
if self.bias is None:
return None
if not self.bias_quant.requires_input_scale and not self.bias_quant.requires_input_bit_width:
return self.bias_quant(self.bias).zero_point
bias_quant = self.bias_quant(self.bias)
if isinstance(bias_quant, QuantTensor):
return bias_quant.zero_point
else:
return None
else:
if self._cached_bias is None:
raise RuntimeError(
Expand Down
52 changes: 33 additions & 19 deletions src/brevitas/nn/quant_avg_pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from brevitas.function.ops import max_int
from brevitas.function.ops_ste import ceil_ste
from brevitas.inject.defaults import RoundTo8bit
from brevitas.quant_tensor import _unpack_quant_tensor
from brevitas.quant_tensor import QuantTensor

from .mixin.acc import AccQuantType
Expand Down Expand Up @@ -55,16 +56,22 @@ def _avg_scaling(self):

def forward(self, input: Union[Tensor, QuantTensor]):
x = self.unpack_input(input)

if self.export_mode:
return self.export_handler(x.value)
x = x.set(value=super(TruncAvgPool2d, self).forward(x.value))
if self.is_trunc_quant_enabled:
assert x.is_not_none # check input quant tensor is filled with values
# remove avg scaling
rescaled_value = x.value * self._avg_scaling
x = x.set(value=rescaled_value)
x = x.set(bit_width=self.max_acc_bit_width(x.bit_width))
x = self.trunc_quant(x)
return self.export_handler(_unpack_quant_tensor(x))

if isinstance(x, QuantTensor):
x = x.set(value=super(TruncAvgPool2d, self).forward(x.value))
if self.is_trunc_quant_enabled:
# remove avg scaling
rescaled_value = x.value * self._avg_scaling
x = x.set(value=rescaled_value)
x = x.set(bit_width=self.max_acc_bit_width(x.bit_width))
x = self.trunc_quant(x)
else:
assert not self.is_trunc_quant_enabled
x = super(TruncAvgPool2d, self).forward(x)

return self.pack_output(x)

def max_acc_bit_width(self, input_bit_width):
Expand Down Expand Up @@ -127,23 +134,30 @@ def compute_kernel_size_stride(self, input_shape, output_shape):

def forward(self, input: Union[Tensor, QuantTensor]):
x = self.unpack_input(input)

# shortcut execution through the export impl during export
if self.export_mode:
out = self.export_handler(x.value)
out = self.export_handler(_unpack_quant_tensor(x))
self._set_global_is_quant_layer(False)
return out
y = x.set(value=super(TruncAdaptiveAvgPool2d, self).forward(x.value))
k_size, stride = self.compute_kernel_size_stride(x.value.shape[2:], y.value.shape[2:])

if self.cache_kernel_size_stride:
self._cached_kernel_size = k_size
self._cached_kernel_stride = stride
if self.is_trunc_quant_enabled:
assert y.is_not_none # check input quant tensor is filled with values
reduce_size = reduce(mul, k_size, 1)
rescaled_value = y.value * reduce_size # remove avg scaling
y = y.set(value=rescaled_value)
y = y.set(bit_width=self.max_acc_bit_width(y.bit_width, reduce_size))
y = self.trunc_quant(y)

if isinstance(x, QuantTensor):
y = x.set(value=super(TruncAdaptiveAvgPool2d, self).forward(x.value))
k_size, stride = self.compute_kernel_size_stride(x.value.shape[2:], y.value.shape[2:])
if self.is_trunc_quant_enabled:
reduce_size = reduce(mul, k_size, 1)
rescaled_value = y.value * reduce_size # remove avg scaling
y = y.set(value=rescaled_value)
y = y.set(bit_width=self.max_acc_bit_width(y.bit_width, reduce_size))
y = self.trunc_quant(y)
else:
assert not self.is_trunc_quant_enabled
y = super(TruncAdaptiveAvgPool2d, self).forward(x)

return self.pack_output(y)

def max_acc_bit_width(self, input_bit_width, reduce_size):
Expand Down
Loading

0 comments on commit 6079b12

Please sign in to comment.