Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feat (quant_tensor): support for FloatQuantTensor #919

Merged
merged 12 commits into from
May 21, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
84 changes: 46 additions & 38 deletions notebooks/01_quant_tensor_quant_conv2d_overview.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -157,11 +157,21 @@
},
{
"cell_type": "code",
"execution_count": 12,
"execution_count": 4,
"metadata": {
"scrolled": true
},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/home/giuseppe/miniconda3/envs/torch_2.1/lib/python3.11/site-packages/torch/_tensor.py:1394: UserWarning: Named tensors and all their associated APIs are an experimental feature and subject to change. Please do not use them for anything important until they are released as stable. (Triggered internally at /opt/conda/conda-bld/pytorch_1708025842427/work/c10/core/TensorImpl.h:1908.)\n",
" return super().rename(names)\n",
"/home/giuseppe/miniconda3/envs/torch_2.1/lib/python3.11/site-packages/torch/nn/modules/conv.py:456: UserWarning: Defining your `__torch_function__` as a plain method is deprecated and will be an error in future, please define it as a classmethod. (Triggered internally at /opt/conda/conda-bld/pytorch_1708025842427/work/torch/csrc/utils/python_arg_parser.cpp:294.)\n",
" return F.conv2d(input, weight, bias, self.stride,\n"
]
},
{
"data": {
"text/plain": [
Expand Down Expand Up @@ -255,7 +265,7 @@
{
"data": {
"text/plain": [
"QuantTensor(value=tensor([[[[-0.0018, 0.1273, -0.1937],\n",
"IntQuantTensor(value=tensor([[[[-0.0018, 0.1273, -0.1937],\n",
" [-0.1734, -0.0904, 0.0627],\n",
" [-0.0055, 0.1863, -0.0203]],\n",
"\n",
Expand Down Expand Up @@ -377,8 +387,6 @@
}
],
"source": [
"from brevitas.quant_tensor import QuantTensor\n",
"\n",
"quant_act = QuantIdentity(return_quant_tensor=True)\n",
"\n",
"out_tensor_0 = quant_act(torch.randn(1,2,5,5))\n",
Expand Down Expand Up @@ -407,7 +415,7 @@
"name": "stdout",
"output_type": "stream",
"text": [
"QuantTensor(value=tensor([[[[-0.1106, 1.1945, -0.4972, -2.0968, 0.7175],\n",
"IntQuantTensor(value=tensor([[[[-0.1106, 1.1945, -0.4972, -2.0968, 0.7175],\n",
" [-2.5901, 0.0588, -0.2014, 2.1486, 1.6435],\n",
" [ 0.9067, -2.5212, 2.2193, 0.2352, -0.8395],\n",
" [-0.8351, 0.6341, -0.5551, 0.1040, -3.3151],\n",
Expand Down Expand Up @@ -467,7 +475,7 @@
{
"data": {
"text/plain": [
"QuantTensor(value=tensor([[[[0.5191, 0.6402],\n",
"IntQuantTensor(value=tensor([[[[0.5191, 0.6402],\n",
" [2.1455, 0.5883]],\n",
"\n",
" [[2.0417, 0.5883],\n",
Expand Down Expand Up @@ -506,7 +514,7 @@
"name": "stderr",
"output_type": "stream",
"text": [
"/tmp/ipykernel_4048/1377665000.py:1: UserWarning: Defining your `__torch_function__` as a plain method is deprecated and will be an error in future, please define it as a classmethod. (Triggered internally at /opt/conda/conda-bld/pytorch_1699449183005/work/torch/csrc/utils/python_arg_parser.cpp:368.)\n",
"/tmp/ipykernel_528161/1377665000.py:1: UserWarning: Defining your `__torch_function__` as a plain method is deprecated and will be an error in future, please define it as a classmethod. (Triggered internally at /opt/conda/conda-bld/pytorch_1708025842427/work/torch/csrc/utils/python_arg_parser.cpp:294.)\n",
" torch.tanh(quant_tensor)\n"
]
},
Expand Down Expand Up @@ -555,7 +563,7 @@
{
"data": {
"text/plain": [
"QuantTensor(value=tensor([[[[-0.3568, -0.1883, 0.3589],\n",
"IntQuantTensor(value=tensor([[[[-0.3568, -0.1883, 0.3589],\n",
" [-0.4470, 0.1039, -0.3945],\n",
" [-0.4190, 0.3723, 0.8384]],\n",
"\n",
Expand All @@ -565,7 +573,7 @@
"\n",
" [[ 0.2734, 0.7268, -0.0249],\n",
" [-0.1732, 0.5197, 1.1158],\n",
" [ 0.3771, -0.3810, 0.2008]]]], grad_fn=<ConvolutionBackward0>), scale=tensor([[[[3.1958e-05]]]], grad_fn=<MulBackward0>), zero_point=tensor(0.), bit_width=tensor(21.), signed_t=tensor(True), training_t=tensor(True))"
" [ 0.3771, -0.3810, 0.2008]]]], grad_fn=<ConvolutionBackward0>), scale=tensor([[[[3.1958e-05]]]], grad_fn=<MulBackward0>), zero_point=tensor([0.]), bit_width=tensor(21.), signed_t=tensor(True), training_t=tensor(True))"
]
},
"execution_count": 14,
Expand Down Expand Up @@ -618,39 +626,39 @@
},
{
"cell_type": "code",
"execution_count": 16,
"execution_count": 17,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"QuantTensor(value=tensor([[[[ 7.2000e-03, -3.7000e-03, 7.7000e-03, -2.4000e-03, -8.9000e-03],\n",
" [-1.2000e-02, -8.1000e-03, 7.2000e-03, -1.1300e-02, -9.7000e-03],\n",
" [-1.0000e-03, 1.0100e-02, 3.8000e-03, -1.1900e-02, 6.9000e-03],\n",
" [ 8.3000e-03, 1.0000e-04, -6.9000e-03, 3.9000e-03, -5.4000e-03],\n",
" [ 1.1300e-02, -6.0000e-03, 9.7000e-03, 0.0000e+00, 1.0900e-02]],\n",
"IntQuantTensor(value=tensor([[[[-9.9000e-03, -7.1000e-03, -4.7000e-03, 5.0000e-03, -1.2300e-02],\n",
" [-8.2000e-03, 8.5000e-03, -1.2000e-03, -1.2500e-02, 4.4000e-03],\n",
" [ 4.3000e-03, -6.3000e-03, -9.4000e-03, 1.0400e-02, -1.2100e-02],\n",
" [ 1.1700e-02, -3.6000e-03, 5.3000e-03, -1.1700e-02, -4.3000e-03],\n",
" [-8.8000e-03, 1.0900e-02, -8.3000e-03, -2.9000e-03, 1.2400e-02]],\n",
"\n",
" [[-1.0900e-02, 1.1400e-02, -6.4000e-03, 9.2000e-03, 7.1000e-03],\n",
" [-6.0000e-04, 9.2000e-03, -8.5000e-03, 5.0000e-03, 6.5000e-03],\n",
" [-8.3000e-03, -1.2000e-03, 7.4000e-03, 9.2000e-03, -6.0000e-04],\n",
" [-2.1000e-03, 9.5000e-03, 3.0000e-04, -2.9000e-03, -6.5000e-03],\n",
" [-1.1800e-02, -4.8000e-03, 5.4000e-03, -2.5000e-03, 9.0000e-04]]]]), scale=tensor(1.0000e-04), zero_point=tensor(0.), bit_width=tensor(8.), signed_t=tensor(True), training_t=tensor(True))"
" [[ 9.3000e-03, -8.5000e-03, 6.5000e-03, -2.7000e-03, -3.4000e-03],\n",
" [-1.0000e-04, -1.1000e-02, 8.3000e-03, 1.9000e-03, -9.8000e-03],\n",
" [ 4.3000e-03, -8.5000e-03, 1.1000e-02, 5.3000e-03, 3.4000e-03],\n",
" [ 8.1000e-03, 9.8000e-03, 6.8000e-03, 1.5000e-03, 6.3000e-03],\n",
" [ 5.7000e-03, -8.5000e-03, 5.2000e-03, -3.0000e-04, 4.9000e-03]]]]), scale=tensor(1.0000e-04), zero_point=tensor(0.), bit_width=tensor(8.), signed_t=tensor(True), training_t=tensor(True))"
]
},
"execution_count": 16,
"execution_count": 17,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"from brevitas.quant_tensor import QuantTensor\n",
"from brevitas.quant_tensor import IntQuantTensor\n",
"\n",
"scale = 0.0001\n",
"bit_width = 8\n",
"zero_point = 0.\n",
"int_value = torch.randint(low=- 2 ** (bit_width - 1), high=2 ** (bit_width - 1) - 1, size=(1, 2, 5, 5))\n",
"quant_value = (int_value - zero_point) * scale\n",
"quant_tensor_input = QuantTensor(\n",
"quant_tensor_input = IntQuantTensor(\n",
" quant_value, \n",
" scale=torch.tensor(scale), \n",
" zero_point=torch.tensor(zero_point), \n",
Expand All @@ -662,7 +670,7 @@
},
{
"cell_type": "code",
"execution_count": 17,
"execution_count": null,
"metadata": {},
"outputs": [
{
Expand All @@ -688,7 +696,7 @@
},
{
"cell_type": "code",
"execution_count": 18,
"execution_count": null,
"metadata": {},
"outputs": [
{
Expand Down Expand Up @@ -721,7 +729,7 @@
},
{
"cell_type": "code",
"execution_count": 19,
"execution_count": null,
"metadata": {},
"outputs": [
{
Expand All @@ -745,7 +753,7 @@
},
{
"cell_type": "code",
"execution_count": 20,
"execution_count": null,
"metadata": {},
"outputs": [
{
Expand Down Expand Up @@ -784,7 +792,7 @@
},
{
"cell_type": "code",
"execution_count": 21,
"execution_count": null,
"metadata": {},
"outputs": [
{
Expand Down Expand Up @@ -820,7 +828,7 @@
},
{
"cell_type": "code",
"execution_count": 22,
"execution_count": null,
"metadata": {},
"outputs": [
{
Expand Down Expand Up @@ -856,7 +864,7 @@
},
{
"cell_type": "code",
"execution_count": 23,
"execution_count": null,
"metadata": {
"tags": [
"raises-exception"
Expand Down Expand Up @@ -897,7 +905,7 @@
},
{
"cell_type": "code",
"execution_count": 24,
"execution_count": null,
"metadata": {},
"outputs": [
{
Expand Down Expand Up @@ -935,7 +943,7 @@
},
{
"cell_type": "code",
"execution_count": 25,
"execution_count": null,
"metadata": {},
"outputs": [
{
Expand Down Expand Up @@ -968,7 +976,7 @@
},
{
"cell_type": "code",
"execution_count": 26,
"execution_count": null,
"metadata": {},
"outputs": [
{
Expand Down Expand Up @@ -1007,7 +1015,7 @@
},
{
"cell_type": "code",
"execution_count": 27,
"execution_count": null,
"metadata": {
"tags": [
"raises-exception"
Expand Down Expand Up @@ -1049,7 +1057,7 @@
},
{
"cell_type": "code",
"execution_count": 28,
"execution_count": null,
"metadata": {},
"outputs": [
{
Expand Down Expand Up @@ -1093,7 +1101,7 @@
},
{
"cell_type": "code",
"execution_count": 29,
"execution_count": null,
"metadata": {},
"outputs": [
{
Expand Down Expand Up @@ -1131,7 +1139,7 @@
},
{
"cell_type": "code",
"execution_count": 30,
"execution_count": null,
"metadata": {},
"outputs": [
{
Expand All @@ -1155,7 +1163,7 @@
},
{
"cell_type": "code",
"execution_count": 31,
"execution_count": null,
"metadata": {},
"outputs": [
{
Expand Down
1 change: 1 addition & 0 deletions requirements/requirements-test.txt
Original file line number Diff line number Diff line change
Expand Up @@ -7,3 +7,4 @@ pytest-xdist
pytest_cases
scipy
torchvision
tqdm
2 changes: 1 addition & 1 deletion src/brevitas/core/function_wrapper/clamp.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,4 +149,4 @@ def forward(
"Clamping is not saturating, but neither `inf_values` nor `nan_values` is specified"
)

return x
return x, self.saturating, self.inf_values, self.nan_values
4 changes: 2 additions & 2 deletions src/brevitas/core/quant/float.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,8 +85,8 @@ def dequantize(self, y, scale):
def forward(self, x):
y, scale = self.quantize(x)
# after quantizing, clamp to special cases like NaN/inf if they are set
y = self.float_clamp_impl(
y, saturating, inf_values, nan_values = self.float_clamp_impl(
y, self.exponent_bit_width(), self.mantissa_bit_width(), self.exponent_bias())
y = self.dequantize(y, scale)
# This is to respect the current interface of proxies
return y, scale, self.zero_point_impl(), self.bit_width()
return y, scale, self.zero_point_impl(), self.exponent_bit_width(), self.mantissa_bit_width(), self.exponent_bias(), saturating, inf_values, nan_values
4 changes: 2 additions & 2 deletions src/brevitas/fx/value_tracer.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@
import torch.utils._pytree as pytree

from brevitas import torch_version
from brevitas.quant_tensor import QuantTensorBase
from brevitas.quant_tensor import QuantTensor

from . import *
from . import _assert_is_none
Expand All @@ -82,7 +82,7 @@
from . import ScopeContextManager

_UNSET = object()
extended_base_types = base_types + (QuantTensorBase,)
extended_base_types = base_types + (QuantTensor,)

FRAME_FILES = [
'fx/brevitas_tracer.py',
Expand Down
8 changes: 4 additions & 4 deletions src/brevitas/graph/calibrate.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,8 @@
from brevitas.nn import QuantHardTanh
from brevitas.nn import QuantLinear
from brevitas.nn.quant_layer import QuantWeightBiasInputOutputLayer as QuantWBIOL
from brevitas.proxy.parameter_quant import BiasQuantProxyFromInjector
from brevitas.proxy.parameter_quant import WeightQuantProxyFromInjector
from brevitas.proxy.parameter_quant import BiasQuantProxyFromInjectorBase
from brevitas.proxy.parameter_quant import WeightQuantProxyFromInjectorBase
from brevitas.proxy.runtime_quant import ActQuantProxyFromInjector
from brevitas.proxy.runtime_quant import ClampQuantProxyFromInjector
from brevitas.proxy.runtime_quant import TruncQuantProxyFromInjector
Expand All @@ -29,9 +29,9 @@
'calibration_mode',
'load_quant_model_mode']

_PARAM_PROXIES = (WeightQuantProxyFromInjector, BiasQuantProxyFromInjector)
_PARAM_PROXIES = (WeightQuantProxyFromInjectorBase, BiasQuantProxyFromInjectorBase)

_BIAS_PROXIES = (BiasQuantProxyFromInjector)
_BIAS_PROXIES = (BiasQuantProxyFromInjectorBase)

_ACC_PROXIES = (TruncQuantProxyFromInjector, ClampQuantProxyFromInjector)

Expand Down
2 changes: 1 addition & 1 deletion src/brevitas/graph/gpfq.py
Original file line number Diff line number Diff line change
Expand Up @@ -316,7 +316,7 @@ def single_layer_update(self):
# raise error in case no quant-input is here
if self.quant_metadata is None:
raise ValueError('Expected self.quant_metadata to calculate L1-norm upper bound, but recevied None. ' + \
'Make sure that either the input to the model is a QuantTensor or the layer has an input quant enabled. ' \
'Make sure that either the input to the model is a IntQuantTensor or the layer has an input quant enabled. ' \
'Also, check if `use_quant_activations=True` in `gpfq_mode` when `accumulator_bit_width` is specified. ' + \
'Alternatively, provide a custom `a2q_layer_filter_fnc` to `gpfq_mode` to filter layers without a quant_tensor input.')
weight = self.layer.weight.data
Expand Down
6 changes: 3 additions & 3 deletions src/brevitas/graph/gpxq.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from brevitas.graph.calibrate import DisableEnableQuantization
from brevitas.graph.calibrate import restore_return_quant_tensor
import brevitas.nn as qnn
from brevitas.quant_tensor import QuantTensor
from brevitas.quant_tensor import IntQuantTensor
from brevitas.utils.quant_utils import _CachedIO

SUPPORTED_CONV_OP = (
Expand Down Expand Up @@ -227,9 +227,9 @@ def process_input(self, inp):

is_quant_enabled = self.layer.weight_quant.is_quant_enabled

# If using quantized activations, inp could be QuantTensor. In
# If using quantized activations, inp could be IntQuantTensor. In
# this case, we overwrite the metadata.
if isinstance(inp, QuantTensor):
if isinstance(inp, IntQuantTensor):
if is_quant_enabled and self.quant_metadata is None:
self.quant_metadata = _CachedIO(inp, metadata_only=True)
inp = inp.value
Expand Down
Loading
Loading