From 6079b12224621942984f1fb85ea8eb68ff505468 Mon Sep 17 00:00:00 2001 From: Giuseppe Franco Date: Fri, 23 Feb 2024 17:06:43 +0100 Subject: [PATCH] Feat (QuantTensor)!: QuantTensor cannot be empty (#819) BREAKING CHANGE: Creating a QuantTensor without metadata is no longer allowed. --- ...1_quant_tensor_quant_conv2d_overview.ipynb | 551 +++++----- notebooks/02_quant_activation_overview.ipynb | 169 +--- notebooks/03_anatomy_of_a_quantizer.ipynb | 489 ++++----- notebooks/Brevitas_TVMCon2021.ipynb | 946 +++++++++++++++--- notebooks/ONNX_export_tutorial.ipynb | 89 +- notebooks/quantized_recurrent.ipynb | 639 ++++++------ src/brevitas/core/quant/binary.py | 3 +- src/brevitas/core/stats/stats_op.py | 4 +- src/brevitas/export/manager.py | 5 +- .../onnx/standard/qoperator/handler/base.py | 2 +- .../export/onnx/standard/qoperator/manager.py | 7 - src/brevitas/graph/calibrate.py | 20 +- src/brevitas/graph/gpxq.py | 17 +- src/brevitas/nn/hadamard_classifier.py | 12 +- src/brevitas/nn/mixin/base.py | 42 +- src/brevitas/nn/mixin/parameter.py | 6 +- src/brevitas/nn/quant_avg_pool.py | 52 +- src/brevitas/nn/quant_layer.py | 97 +- src/brevitas/nn/quant_rnn.py | 71 +- src/brevitas/nn/quant_upsample.py | 4 +- src/brevitas/nn/target/flexml.py | 8 +- src/brevitas/proxy/parameter_quant.py | 19 +- src/brevitas/proxy/runtime_quant.py | 25 +- src/brevitas/quant_tensor/__init__.py | 131 ++- src/brevitas/quant_tensor/torch_handler.py | 5 +- .../melgan/res_stack_brevitas.py | 7 +- tests/brevitas/fx/test_tracer.py | 4 +- tests/brevitas/nn/nn_quantizers_fixture.py | 2 + tests/brevitas/nn/test_linear.py | 4 +- tests/brevitas/nn/test_nn_quantizers.py | 87 +- tests/brevitas_ort/common.py | 19 +- 31 files changed, 2040 insertions(+), 1496 deletions(-) diff --git a/notebooks/01_quant_tensor_quant_conv2d_overview.ipynb b/notebooks/01_quant_tensor_quant_conv2d_overview.ipynb index 2e9ef9179..c9a6d052d 100644 --- a/notebooks/01_quant_tensor_quant_conv2d_overview.ipynb +++ b/notebooks/01_quant_tensor_quant_conv2d_overview.ipynb @@ -18,14 +18,6 @@ "execution_count": 1, "metadata": {}, "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "/home/user/.local/lib/python3.7/site-packages/tqdm/auto.py:22: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", - " from .autonotebook import tqdm as notebook_tqdm\n" - ] - }, { "data": { "text/markdown": [ @@ -39,14 +31,22 @@ " padding: Union[int, Tuple[int, int]] = 0,\n", " dilation: Union[int, Tuple[int, int]] = 1,\n", " groups: int = 1,\n", + " padding_mode: str = 'zeros',\n", " bias: bool = True,\n", - " padding_type: str = 'standard',\n", " weight_quant: Optional[WeightQuantType] = Int8WeightPerTensorFloat,\n", " bias_quant: Optional[BiasQuantType] = None,\n", " input_quant: Optional[ActQuantType] = None,\n", " output_quant: Optional[ActQuantType] = None,\n", " return_quant_tensor: bool = False,\n", + " device: Optional[torch.device] = None,\n", + " dtype: Optional[torch.dtype] = None,\n", " **kwargs) -> None:\n", + " # avoid an init error in the super class by setting padding to 0\n", + " if padding_mode == 'zeros' and padding == 'same' and stride > 1:\n", + " padding = 0\n", + " is_same_padded_strided = True\n", + " else:\n", + " is_same_padded_strided = False\n", " Conv2d.__init__(\n", " self,\n", " in_channels=in_channels,\n", @@ -54,9 +54,12 @@ " kernel_size=kernel_size,\n", " stride=stride,\n", " padding=padding,\n", + " padding_mode=padding_mode,\n", " dilation=dilation,\n", " groups=groups,\n", - " bias=bias)\n", + " bias=bias,\n", + " device=device,\n", + " dtype=dtype)\n", " QuantWBIOL.__init__(\n", " self,\n", " weight_quant=weight_quant,\n", @@ -65,9 +68,7 @@ " output_quant=output_quant,\n", " return_quant_tensor=return_quant_tensor,\n", " **kwargs)\n", - " assert self.padding_mode == 'zeros'\n", - " assert not (padding_type == 'same' and padding != 0)\n", - " self.padding_type = padding_type\n", + " self.is_same_padded_strided = is_same_padded_strided\n", "\n", "```" ], @@ -149,20 +150,28 @@ "scrolled": true }, "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/home/giuseppe/miniconda3/envs/torch_2.1/lib/python3.11/site-packages/torch/_tensor.py:1362: UserWarning: Named tensors and all their associated APIs are an experimental feature and subject to change. Please do not use them for anything important until they are released as stable. (Triggered internally at /opt/conda/conda-bld/pytorch_1695392026823/work/c10/core/TensorImpl.h:1900.)\n", + " return super().rename(names)\n" + ] + }, { "data": { "text/plain": [ - "tensor([[[[-0.2594, 0.5392, 0.5916],\n", - " [ 0.3493, 0.6813, 0.2499],\n", - " [ 1.3732, 0.1229, -0.0084]],\n", + "tensor([[[[ 1.0093, 0.4820, 0.0156],\n", + " [-0.1535, -0.2748, -0.9393],\n", + " [-1.0662, 0.2397, 0.0932]],\n", "\n", - " [[ 0.0031, -0.1702, 0.1069],\n", - " [-0.8181, -0.8056, 0.0385],\n", - " [-0.4738, 0.0589, 0.1278]],\n", + " [[ 0.6932, -0.2772, 0.0703],\n", + " [ 0.2536, 0.1734, -0.3745],\n", + " [-0.5633, 0.2231, -0.6844]],\n", "\n", - " [[-0.1718, -0.1162, -0.1526],\n", - " [-0.9903, -0.3541, 0.1645],\n", - " [ 0.0557, -0.4458, -0.2080]]]], grad_fn=)" + " [[-0.2607, 0.2174, -0.0522],\n", + " [ 0.1215, -0.3744, -0.5880],\n", + " [-0.3104, -0.6930, 0.5322]]]], grad_fn=)" ] }, "execution_count": 4, @@ -234,31 +243,31 @@ { "data": { "text/plain": [ - "QuantTensor(value=tensor([[[[-0.0790, 0.0503, -0.0934],\n", - " [-0.1149, -0.1903, -0.1329],\n", - " [-0.1813, 0.0108, 0.0593]],\n", + "QuantTensor(value=tensor([[[[ 0.0236, 0.1599, 0.1799],\n", + " [-0.0545, 0.2144, 0.2126],\n", + " [-0.1363, -0.2271, -0.1526]],\n", "\n", - " [[ 0.0970, -0.0215, -0.0144],\n", - " [ 0.2280, 0.1239, -0.0090],\n", - " [ 0.1957, -0.2011, -0.0108]]],\n", + " [[-0.0872, -0.0091, -0.1090],\n", + " [ 0.0690, -0.0327, 0.2289],\n", + " [ 0.2307, 0.0073, -0.1326]]],\n", "\n", "\n", - " [[[-0.0018, -0.1957, 0.1993],\n", - " [-0.0359, 0.1778, -0.1400],\n", - " [ 0.0916, 0.1059, 0.2173]],\n", + " [[[-0.0254, 0.0418, -0.0363],\n", + " [-0.2053, 0.2071, -0.1163],\n", + " [-0.1163, -0.1653, 0.0109]],\n", "\n", - " [[-0.1670, 0.1939, -0.2191],\n", - " [-0.0215, 0.1688, -0.1383],\n", - " [-0.0449, -0.1185, 0.1742]]],\n", + " [[-0.2107, -0.1199, 0.0799],\n", + " [ 0.0200, 0.0218, 0.1817],\n", + " [-0.1199, -0.0963, -0.0600]]],\n", "\n", "\n", - " [[[-0.0808, -0.1652, -0.0233],\n", - " [-0.0700, 0.0467, -0.0485],\n", - " [ 0.1059, 0.1418, 0.1077]],\n", + " [[[-0.0709, -0.0908, 0.1544],\n", + " [-0.0236, -0.2235, 0.2180],\n", + " [-0.0799, -0.0200, 0.0273]],\n", "\n", - " [[-0.0593, 0.0108, 0.0036],\n", - " [-0.1508, 0.0808, 0.1616],\n", - " [ 0.0144, -0.0287, -0.1365]]]], grad_fn=), scale=tensor(0.0018, grad_fn=), zero_point=tensor(0.), bit_width=tensor(8.), signed_t=tensor(True), training_t=tensor(True))" + " [[-0.1998, 0.1126, 0.1435],\n", + " [ 0.0818, 0.1399, 0.1181],\n", + " [ 0.1762, -0.1726, -0.2216]]]], grad_fn=), scale=tensor(0.0018, grad_fn=), zero_point=tensor(0.), bit_width=tensor(8.), signed_t=tensor(True), training_t=tensor(True))" ] }, "execution_count": 6, @@ -325,15 +334,15 @@ }, { "cell_type": "code", - "execution_count": 10, + "execution_count": 9, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "tensor(0.0173, grad_fn=)\n", - "tensor(0.0307, grad_fn=)\n" + "tensor(0.0211, grad_fn=)\n", + "tensor(0.0162, grad_fn=)\n" ] } ], @@ -361,34 +370,31 @@ }, { "cell_type": "code", - "execution_count": 11, + "execution_count": 10, "metadata": {}, "outputs": [ { - "data": { - "text/plain": [ - "QuantTensor(value=tensor([[[[ 0.9489, -0.9111, -0.0536, 0.5788, 0.3645],\n", - " [ 0.3401, 1.4325, 0.6498, 0.6411, -1.4390],\n", - " [-1.9029, 0.7012, 0.1591, 1.9235, 0.5883],\n", - " [-2.7258, 2.5330, 0.9165, -0.0820, 3.4148],\n", - " [-0.3651, 1.0164, 0.9567, -0.2758, -1.1376]],\n", - "\n", - " [[-0.2414, 2.2111, -1.9124, -2.3814, -0.8805],\n", - " [ 1.3191, -0.8965, -0.2048, -3.8113, 1.1142],\n", - " [-0.3381, -0.2238, 1.2661, 0.0068, 0.2567],\n", - " [ 0.0731, -0.4280, 0.0909, 0.0875, -1.6851],\n", - " [-0.7744, -1.4127, -0.8143, 1.3557, -0.2802]]]],\n", - " grad_fn=), scale=tensor(0.0240, grad_fn=), zero_point=tensor(0.), bit_width=tensor(9.), signed_t=tensor(True), training_t=tensor(True))" - ] - }, - "execution_count": 11, - "metadata": {}, - "output_type": "execute_result" + "name": "stdout", + "output_type": "stream", + "text": [ + "QuantTensor(value=tensor([[[[-0.1106, 1.1945, -0.4972, -2.0968, 0.7175],\n", + " [-2.5901, 0.0588, -0.2014, 2.1486, 1.6435],\n", + " [ 0.9067, -2.5212, 2.2193, 0.2352, -0.8395],\n", + " [-0.8351, 0.6341, -0.5551, 0.1040, -3.3151],\n", + " [-0.8979, -0.7092, 3.8232, 1.0875, 0.3954]],\n", + "\n", + " [[ 1.4363, -1.3973, 1.3249, 2.6914, 0.3660],\n", + " [ 1.5057, 1.8094, 0.5100, -1.6874, 1.9981],\n", + " [ 1.2472, -1.7813, 0.0334, -1.2880, -2.9333],\n", + " [ 0.0180, -1.4298, -2.9978, 0.5494, -1.4548],\n", + " [ 1.6738, -0.3177, -0.3721, -0.1650, -1.1871]]]],\n", + " grad_fn=), scale=0.018651068210601807, zero_point=0.0, bit_width=9.0, signed_t=True, training_t=True)\n" + ] } ], "source": [ "out_tensor = out_tensor_0 + out_tensor_1\n", - "out_tensor" + "print(out_tensor)" ] }, { @@ -401,7 +407,7 @@ }, { "cell_type": "code", - "execution_count": 12, + "execution_count": 11, "metadata": {}, "outputs": [], "source": [ @@ -417,23 +423,23 @@ }, { "cell_type": "code", - "execution_count": 108, + "execution_count": 12, "metadata": {}, "outputs": [ { "data": { "text/plain": [ - "QuantTensor(value=tensor([[[[1.5800, 1.0157],\n", - " [1.4445, 0.8577]],\n", + "QuantTensor(value=tensor([[[[0.5191, 0.6402],\n", + " [2.1455, 0.5883]],\n", "\n", - " [[0.5643, 1.2414],\n", - " [1.0383, 0.9028]],\n", + " [[2.0417, 0.5883],\n", + " [1.2631, 0.3980]],\n", "\n", - " [[0.5191, 0.6546],\n", - " [2.1442, 0.5868]]]], grad_fn=), scale=tensor(0.0226, grad_fn=), zero_point=tensor(0.), bit_width=tensor(8.), signed_t=tensor(True), training_t=tensor(True))" + " [[0.7959, 0.5191],\n", + " [0.8132, 1.3496]]]], grad_fn=), scale=tensor(0.0173, grad_fn=), zero_point=tensor(0.), bit_width=tensor(8.), signed_t=tensor(True), training_t=tensor(True))" ] }, - "execution_count": 108, + "execution_count": 12, "metadata": {}, "output_type": "execute_result" } @@ -455,29 +461,37 @@ }, { "cell_type": "code", - "execution_count": 109, + "execution_count": 13, "metadata": {}, "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/tmp/ipykernel_2482988/1377665000.py:1: UserWarning: Defining your `__torch_function__` as a plain method is deprecated and will be an error in future, please define it as a classmethod. (Triggered internally at /opt/conda/conda-bld/pytorch_1695392026823/work/torch/csrc/utils/python_arg_parser.cpp:368.)\n", + " torch.tanh(quant_tensor)\n" + ] + }, { "data": { "text/plain": [ - "tensor([[[[-0.4943, -0.9938, -0.9073, 0.7681],\n", - " [-0.3262, 0.9186, 0.1786, 0.3659],\n", - " [ 0.7489, 0.8946, -0.0451, -0.5594],\n", - " [-0.1346, -0.4943, -0.4770, 0.6951]],\n", + "tensor([[[[ 0.4770, 0.2212, 0.0691, 0.5650],\n", + " [-0.0346, -0.6618, -0.4635, -0.3482],\n", + " [ 0.9730, -0.7245, -0.5881, -0.5287],\n", + " [-0.0863, 0.8857, 0.5287, -0.4498]],\n", "\n", - " [[ 0.0676, 0.5111, 0.4943, 0.8459],\n", - " [-0.8990, -0.9426, 0.0676, -0.7945],\n", - " [-0.9220, 0.0676, -0.5594, 0.6321],\n", - " [-0.0676, 0.7772, 0.7177, -0.4414]],\n", + " [[ 0.9669, 0.5650, -0.6211, -0.4498],\n", + " [-0.2376, 0.6103, 0.5287, 0.2700],\n", + " [-0.6808, 0.8519, 0.2700, -0.5531],\n", + " [-0.0173, 0.8264, 0.3782, -0.1881]],\n", "\n", - " [[ 0.4770, 0.2220, 0.0676, 0.5747],\n", - " [-0.0451, -0.6710, -0.4594, -0.3462],\n", - " [ 0.9729, -0.7177, -0.5896, -0.5276],\n", - " [-0.0900, 0.8852, 0.5276, -0.4414]]]], grad_fn=)" + " [[-0.6211, -0.9764, -0.5993, 0.4770],\n", + " [ 0.5033, 0.6618, -0.1881, -0.6211],\n", + " [-0.8031, 0.1375, 0.5287, 0.8740],\n", + " [-0.6714, 0.6714, -0.5650, 0.8611]]]], grad_fn=)" ] }, - "execution_count": 109, + "execution_count": 13, "metadata": {}, "output_type": "execute_result" } @@ -497,26 +511,26 @@ }, { "cell_type": "code", - "execution_count": 110, + "execution_count": 14, "metadata": {}, "outputs": [ { "data": { "text/plain": [ - "QuantTensor(value=tensor([[[[ 0.9693, -0.9431, 0.2459],\n", - " [ 0.5416, 0.9037, -0.5278],\n", - " [-0.6207, -1.3578, -0.4815]],\n", + "QuantTensor(value=tensor([[[[-0.3568, -0.1883, 0.3589],\n", + " [-0.4470, 0.1039, -0.3945],\n", + " [-0.4190, 0.3723, 0.8384]],\n", "\n", - " [[ 0.4551, -1.4065, 0.8889],\n", - " [-0.3393, 0.0803, -0.1748],\n", - " [-0.0977, 0.6284, -0.7193]],\n", + " [[-0.0510, 0.5514, -0.2751],\n", + " [-0.5668, 0.5824, 0.2328],\n", + " [ 0.1316, -0.2518, 1.0418]],\n", "\n", - " [[ 0.3655, 0.7626, -0.2634],\n", - " [-0.3453, 0.3349, 0.1923],\n", - " [ 0.5993, -0.9579, 0.3557]]]], grad_fn=), scale=tensor([[[[3.2208e-05]]]], grad_fn=), zero_point=tensor(0.), bit_width=tensor(21.), signed_t=tensor(True), training_t=tensor(True))" + " [[ 0.2734, 0.7268, -0.0249],\n", + " [-0.1732, 0.5197, 1.1158],\n", + " [ 0.3771, -0.3810, 0.2008]]]], grad_fn=), scale=tensor([[[[3.1958e-05]]]], grad_fn=), zero_point=tensor(0.), bit_width=tensor(21.), signed_t=tensor(True), training_t=tensor(True))" ] }, - "execution_count": 110, + "execution_count": 14, "metadata": {}, "output_type": "execute_result" } @@ -533,20 +547,9 @@ }, { "cell_type": "code", - "execution_count": 111, + "execution_count": 15, "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "True" - ] - }, - "execution_count": 111, - "metadata": {}, - "output_type": "execute_result" - } - ], + "outputs": [], "source": [ "assert out_tensor.is_valid" ] @@ -569,26 +572,26 @@ }, { "cell_type": "code", - "execution_count": 112, + "execution_count": 16, "metadata": {}, "outputs": [ { "data": { "text/plain": [ - "QuantTensor(value=tensor([[[[ 5.7000e-03, 2.5000e-03, -1.2400e-02, -7.2000e-03, 3.7000e-03],\n", - " [-2.3000e-03, 7.0000e-04, -1.2700e-02, 5.2000e-03, 4.0000e-04],\n", - " [-7.9000e-03, 9.5000e-03, 6.6000e-03, 5.4000e-03, 2.5000e-03],\n", - " [ 1.1100e-02, 2.4000e-03, 1.0000e-02, -3.7000e-03, 7.2000e-03],\n", - " [-1.1500e-02, -5.8000e-03, -9.3000e-03, 1.0000e-02, 3.5000e-03]],\n", + "QuantTensor(value=tensor([[[[ 7.2000e-03, -3.7000e-03, 7.7000e-03, -2.4000e-03, -8.9000e-03],\n", + " [-1.2000e-02, -8.1000e-03, 7.2000e-03, -1.1300e-02, -9.7000e-03],\n", + " [-1.0000e-03, 1.0100e-02, 3.8000e-03, -1.1900e-02, 6.9000e-03],\n", + " [ 8.3000e-03, 1.0000e-04, -6.9000e-03, 3.9000e-03, -5.4000e-03],\n", + " [ 1.1300e-02, -6.0000e-03, 9.7000e-03, 0.0000e+00, 1.0900e-02]],\n", "\n", - " [[-6.8000e-03, 1.1500e-02, -1.0600e-02, -1.5000e-03, -1.9000e-03],\n", - " [ 2.9000e-03, 9.5000e-03, 7.2000e-03, -3.7000e-03, 7.7000e-03],\n", - " [-2.4000e-03, -8.9000e-03, -1.2000e-02, -8.1000e-03, 7.2000e-03],\n", - " [-1.1300e-02, -9.7000e-03, -1.0000e-03, 1.0100e-02, 3.8000e-03],\n", - " [-1.1900e-02, 6.9000e-03, 8.3000e-03, 1.0000e-04, -6.9000e-03]]]]), scale=tensor(1.0000e-04), zero_point=tensor(0.), bit_width=tensor(8.), signed_t=tensor(True), training_t=tensor(True))" + " [[-1.0900e-02, 1.1400e-02, -6.4000e-03, 9.2000e-03, 7.1000e-03],\n", + " [-6.0000e-04, 9.2000e-03, -8.5000e-03, 5.0000e-03, 6.5000e-03],\n", + " [-8.3000e-03, -1.2000e-03, 7.4000e-03, 9.2000e-03, -6.0000e-04],\n", + " [-2.1000e-03, 9.5000e-03, 3.0000e-04, -2.9000e-03, -6.5000e-03],\n", + " [-1.1800e-02, -4.8000e-03, 5.4000e-03, -2.5000e-03, 9.0000e-04]]]]), scale=tensor(1.0000e-04), zero_point=tensor(0.), bit_width=tensor(8.), signed_t=tensor(True), training_t=tensor(True))" ] }, - "execution_count": 112, + "execution_count": 16, "metadata": {}, "output_type": "execute_result" } @@ -613,20 +616,9 @@ }, { "cell_type": "code", - "execution_count": 113, + "execution_count": 17, "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "True" - ] - }, - "execution_count": 113, - "metadata": {}, - "output_type": "execute_result" - } - ], + "outputs": [], "source": [ "assert quant_tensor_input.is_valid" ] @@ -642,26 +634,26 @@ }, { "cell_type": "code", - "execution_count": 114, + "execution_count": 18, "metadata": {}, "outputs": [ { "data": { "text/plain": [ - "QuantTensor(value=tensor([[[[ 0.0085, 0.0066, 0.0050],\n", - " [-0.0038, -0.0009, -0.0115],\n", - " [-0.0055, -0.0037, 0.0009]],\n", + "QuantTensor(value=tensor([[[[-0.0019, 0.0049, -0.0012],\n", + " [-0.0012, 0.0050, -0.0074],\n", + " [-0.0023, -0.0035, -0.0033]],\n", "\n", - " [[ 0.0015, -0.0027, -0.0079],\n", - " [-0.0034, -0.0060, 0.0043],\n", - " [-0.0008, 0.0052, -0.0033]],\n", + " [[-0.0031, 0.0028, 0.0116],\n", + " [ 0.0079, 0.0046, 0.0022],\n", + " [ 0.0021, -0.0004, 0.0011]],\n", "\n", - " [[-0.0015, 0.0082, -0.0038],\n", - " [-0.0021, 0.0004, -0.0054],\n", - " [-0.0021, -0.0079, 0.0013]]]], grad_fn=), scale=tensor([[[[1.8448e-07]]]], grad_fn=), zero_point=tensor(0.), bit_width=tensor(21.), signed_t=tensor(True), training_t=tensor(True))" + " [[-0.0045, -0.0010, 0.0002],\n", + " [-0.0044, 0.0027, 0.0025],\n", + " [-0.0009, 0.0040, -0.0044]]]], grad_fn=), scale=tensor([[[[1.8307e-07]]]], grad_fn=), zero_point=tensor(0.), bit_width=tensor(21.), signed_t=tensor(True), training_t=tensor(True))" ] }, - "execution_count": 114, + "execution_count": 18, "metadata": {}, "output_type": "execute_result" } @@ -675,20 +667,9 @@ }, { "cell_type": "code", - "execution_count": 115, + "execution_count": 19, "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "True" - ] - }, - "execution_count": 115, - "metadata": {}, - "output_type": "execute_result" - } - ], + "outputs": [], "source": [ "assert out_tensor.is_valid" ] @@ -702,26 +683,26 @@ }, { "cell_type": "code", - "execution_count": 116, + "execution_count": 20, "metadata": {}, "outputs": [ { "data": { "text/plain": [ - "QuantTensor(value=tensor([[[[-0.0035, -0.0037, -0.0050],\n", - " [ 0.0010, -0.0051, -0.0027],\n", - " [-0.0010, 0.0047, 0.0017]],\n", + "QuantTensor(value=tensor([[[[-0.0073, 0.0040, -0.0011],\n", + " [-0.0033, 0.0078, -0.0028],\n", + " [ 0.0005, -0.0025, -0.0008]],\n", "\n", - " [[ 0.0021, 0.0002, 0.0027],\n", - " [ 0.0028, 0.0002, -0.0044],\n", - " [ 0.0008, -0.0052, -0.0024]],\n", + " [[ 0.0021, -0.0021, 0.0035],\n", + " [ 0.0012, -0.0016, -0.0023],\n", + " [-0.0010, -0.0015, 0.0040]],\n", "\n", - " [[ 0.0010, -0.0052, -0.0011],\n", - " [-0.0018, 0.0024, 0.0011],\n", - " [-0.0001, 0.0039, 0.0035]]]], grad_fn=), scale=tensor([[[[1.7410e-07]]]], grad_fn=), zero_point=tensor(0.), bit_width=tensor(21.), signed_t=tensor(True), training_t=tensor(True))" + " [[-0.0010, 0.0047, 0.0025],\n", + " [-0.0014, 0.0021, -0.0039],\n", + " [ 0.0036, -0.0003, 0.0026]]]], grad_fn=), scale=tensor([[[[1.7393e-07]]]], grad_fn=), zero_point=tensor(0.), bit_width=tensor(21.), signed_t=tensor(True), training_t=tensor(True))" ] }, - "execution_count": 116, + "execution_count": 20, "metadata": {}, "output_type": "execute_result" } @@ -741,26 +722,26 @@ }, { "cell_type": "code", - "execution_count": 117, + "execution_count": 21, "metadata": {}, "outputs": [ { "data": { "text/plain": [ - "QuantTensor(value=tensor([[[[ 0.2111, 0.4060, 0.3654],\n", - " [-0.7876, 0.8119, -0.9825],\n", - " [-0.5115, 0.3979, -0.3248]],\n", + "QuantTensor(value=tensor([[[[-0.2117, -0.4811, 0.0385],\n", + " [-0.5100, -0.2502, -0.2213],\n", + " [-0.5773, 0.0192, -0.5485]],\n", "\n", - " [[ 0.3816, 0.0568, -0.0812],\n", - " [ 1.0312, -0.7876, 0.8038],\n", - " [-0.3491, -0.4141, 0.0650]],\n", + " [[ 0.1347, 0.8179, -1.2316],\n", + " [-0.6062, 0.4426, -0.3849],\n", + " [ 0.1732, -0.5100, -0.1251]],\n", "\n", - " [[-0.5846, -0.4222, -0.0731],\n", - " [-0.7389, 0.5034, -0.2517],\n", - " [-0.1624, -0.4385, 0.7308]]]], grad_fn=), scale=tensor(0.0081, grad_fn=), zero_point=tensor(0.), bit_width=tensor(8.), signed_t=tensor(True), training_t=tensor(True))" + " [[ 1.0873, 0.2406, -0.2887],\n", + " [-0.4330, -0.4907, -0.2021],\n", + " [ 0.6447, 0.4811, 0.1347]]]], grad_fn=), scale=tensor(0.0096, grad_fn=), zero_point=tensor(0.), bit_width=tensor(8.), signed_t=tensor(True), training_t=tensor(True))" ] }, - "execution_count": 117, + "execution_count": 21, "metadata": {}, "output_type": "execute_result" } @@ -777,20 +758,9 @@ }, { "cell_type": "code", - "execution_count": 118, + "execution_count": 22, "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "True" - ] - }, - "execution_count": 118, - "metadata": {}, - "output_type": "execute_result" - } - ], + "outputs": [], "source": [ "assert out_tensor.is_valid" ] @@ -816,7 +786,7 @@ }, { "cell_type": "code", - "execution_count": 119, + "execution_count": 24, "metadata": { "tags": [ "raises-exception" @@ -825,18 +795,17 @@ "outputs": [ { "ename": "RuntimeError", - "evalue": "Input scale required", + "evalue": "QuantLayer is not correctly configured", "output_type": "error", "traceback": [ "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", "\u001b[0;31mRuntimeError\u001b[0m Traceback (most recent call last)", - "\u001b[0;32m/tmp/ipykernel_48365/2280634207.py\u001b[0m in \u001b[0;36m\u001b[0;34m\u001b[0m\n\u001b[1;32m 4\u001b[0m \u001b[0min_channels\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;36m2\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mout_channels\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;36m3\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mkernel_size\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;36m3\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;36m3\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mbias\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mTrue\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 5\u001b[0m bias_quant=Int8Bias, return_quant_tensor=True)\n\u001b[0;32m----> 6\u001b[0;31m \u001b[0mbias_quant_conv\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mrandn\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;36m2\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;36m5\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;36m5\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m", - "\u001b[0;32m/opt/conda/envs/torch_1.10/lib/python3.7/site-packages/torch/nn/modules/module.py\u001b[0m in \u001b[0;36m_call_impl\u001b[0;34m(self, *input, **kwargs)\u001b[0m\n\u001b[1;32m 1100\u001b[0m if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks\n\u001b[1;32m 1101\u001b[0m or _global_forward_hooks or _global_forward_pre_hooks):\n\u001b[0;32m-> 1102\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mforward_call\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 1103\u001b[0m \u001b[0;31m# Do not call functions when jit is used\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1104\u001b[0m \u001b[0mfull_backward_hooks\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mnon_full_backward_hooks\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0;32m/workspace/scratch/git/fork_brevitas/src/brevitas/nn/quant_conv.py\u001b[0m in \u001b[0;36mforward\u001b[0;34m(self, input)\u001b[0m\n\u001b[1;32m 190\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 191\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mforward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0minput\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mUnion\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mTensor\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mQuantTensor\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m->\u001b[0m \u001b[0mUnion\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mTensor\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mQuantTensor\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 192\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mforward_impl\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0minput\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 193\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 194\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0minner_forward_impl\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mx\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mTensor\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mquant_weight\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mTensor\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mquant_bias\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mOptional\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mTensor\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0;32m/workspace/scratch/git/fork_brevitas/src/brevitas/nn/quant_layer.py\u001b[0m in \u001b[0;36mforward_impl\u001b[0;34m(self, inp)\u001b[0m\n\u001b[1;32m 330\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 331\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mbias\u001b[0m \u001b[0;32mis\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 332\u001b[0;31m \u001b[0mquant_bias\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mbias_quant\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mbias\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0moutput_scale\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0moutput_bit_width\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 333\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtraining\u001b[0m \u001b[0;32mand\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcache_inference_quant_bias\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 334\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_cached_bias\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0m_CachedIO\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mquant_bias\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdetach\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mmetadata_only\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mFalse\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0;32m/opt/conda/envs/torch_1.10/lib/python3.7/site-packages/torch/nn/modules/module.py\u001b[0m in \u001b[0;36m_call_impl\u001b[0;34m(self, *input, **kwargs)\u001b[0m\n\u001b[1;32m 1100\u001b[0m if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks\n\u001b[1;32m 1101\u001b[0m or _global_forward_hooks or _global_forward_pre_hooks):\n\u001b[0;32m-> 1102\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mforward_call\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 1103\u001b[0m \u001b[0;31m# Do not call functions when jit is used\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1104\u001b[0m \u001b[0mfull_backward_hooks\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mnon_full_backward_hooks\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0;32m/workspace/scratch/git/fork_brevitas/src/brevitas/proxy/parameter_quant.py\u001b[0m in \u001b[0;36mforward\u001b[0;34m(self, x, input_scale, input_bit_width)\u001b[0m\n\u001b[1;32m 160\u001b[0m \u001b[0mimpl\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mexport_handler\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mexport_mode\u001b[0m \u001b[0;32melse\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtensor_quant\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 161\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mrequires_input_scale\u001b[0m \u001b[0;32mand\u001b[0m \u001b[0minput_scale\u001b[0m \u001b[0;32mis\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 162\u001b[0;31m \u001b[0;32mraise\u001b[0m \u001b[0mRuntimeError\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m\"Input scale required\"\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 163\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mrequires_input_bit_width\u001b[0m \u001b[0;32mand\u001b[0m \u001b[0minput_bit_width\u001b[0m \u001b[0;32mis\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 164\u001b[0m \u001b[0;32mraise\u001b[0m \u001b[0mRuntimeError\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m\"Input bit-width required\"\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0;31mRuntimeError\u001b[0m: Input scale required" + "\u001b[1;32m/home/giuseppe/Documents/git/brevitas/notebooks/01_quant_tensor_quant_conv2d_overview.ipynb Cell 46\u001b[0m line \u001b[0;36m6\n\u001b[1;32m 1\u001b[0m \u001b[39mfrom\u001b[39;00m \u001b[39mbrevitas\u001b[39;00m\u001b[39m.\u001b[39;00m\u001b[39mquant\u001b[39;00m\u001b[39m.\u001b[39;00m\u001b[39mscaled_int\u001b[39;00m \u001b[39mimport\u001b[39;00m Int8Bias\n\u001b[1;32m 3\u001b[0m bias_quant_conv \u001b[39m=\u001b[39m QuantConv2d(\n\u001b[1;32m 4\u001b[0m in_channels\u001b[39m=\u001b[39m\u001b[39m2\u001b[39m, out_channels\u001b[39m=\u001b[39m\u001b[39m3\u001b[39m, kernel_size\u001b[39m=\u001b[39m(\u001b[39m3\u001b[39m,\u001b[39m3\u001b[39m), bias\u001b[39m=\u001b[39m\u001b[39mTrue\u001b[39;00m,\n\u001b[1;32m 5\u001b[0m bias_quant\u001b[39m=\u001b[39mInt8Bias, return_quant_tensor\u001b[39m=\u001b[39m\u001b[39mTrue\u001b[39;00m)\n\u001b[0;32m----> 6\u001b[0m bias_quant_conv(torch\u001b[39m.\u001b[39;49mrandn(\u001b[39m1\u001b[39;49m, \u001b[39m2\u001b[39;49m, \u001b[39m5\u001b[39;49m, \u001b[39m5\u001b[39;49m))\n", + "File \u001b[0;32m~/miniconda3/envs/torch_2.1/lib/python3.11/site-packages/torch/nn/modules/module.py:1518\u001b[0m, in \u001b[0;36mModule._wrapped_call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1516\u001b[0m \u001b[39mreturn\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_compiled_call_impl(\u001b[39m*\u001b[39margs, \u001b[39m*\u001b[39m\u001b[39m*\u001b[39mkwargs) \u001b[39m# type: ignore[misc]\u001b[39;00m\n\u001b[1;32m 1517\u001b[0m \u001b[39melse\u001b[39;00m:\n\u001b[0;32m-> 1518\u001b[0m \u001b[39mreturn\u001b[39;00m \u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49m_call_impl(\u001b[39m*\u001b[39;49margs, \u001b[39m*\u001b[39;49m\u001b[39m*\u001b[39;49mkwargs)\n", + "File \u001b[0;32m~/miniconda3/envs/torch_2.1/lib/python3.11/site-packages/torch/nn/modules/module.py:1527\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1522\u001b[0m \u001b[39m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[1;32m 1523\u001b[0m \u001b[39m# this function, and just call forward.\u001b[39;00m\n\u001b[1;32m 1524\u001b[0m \u001b[39mif\u001b[39;00m \u001b[39mnot\u001b[39;00m (\u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_backward_hooks \u001b[39mor\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_backward_pre_hooks \u001b[39mor\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_forward_hooks \u001b[39mor\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_forward_pre_hooks\n\u001b[1;32m 1525\u001b[0m \u001b[39mor\u001b[39;00m _global_backward_pre_hooks \u001b[39mor\u001b[39;00m _global_backward_hooks\n\u001b[1;32m 1526\u001b[0m \u001b[39mor\u001b[39;00m _global_forward_hooks \u001b[39mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[0;32m-> 1527\u001b[0m \u001b[39mreturn\u001b[39;00m forward_call(\u001b[39m*\u001b[39;49margs, \u001b[39m*\u001b[39;49m\u001b[39m*\u001b[39;49mkwargs)\n\u001b[1;32m 1529\u001b[0m \u001b[39mtry\u001b[39;00m:\n\u001b[1;32m 1530\u001b[0m result \u001b[39m=\u001b[39m \u001b[39mNone\u001b[39;00m\n", + "File \u001b[0;32m~/Documents/git/brevitas/src/brevitas/nn/quant_conv.py:198\u001b[0m, in \u001b[0;36mQuantConv2d.forward\u001b[0;34m(self, input)\u001b[0m\n\u001b[1;32m 197\u001b[0m \u001b[39mdef\u001b[39;00m \u001b[39mforward\u001b[39m(\u001b[39mself\u001b[39m, \u001b[39minput\u001b[39m: Union[Tensor, QuantTensor]) \u001b[39m-\u001b[39m\u001b[39m>\u001b[39m Union[Tensor, QuantTensor]:\n\u001b[0;32m--> 198\u001b[0m \u001b[39mreturn\u001b[39;00m \u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49mforward_impl(\u001b[39minput\u001b[39;49m)\n", + "File \u001b[0;32m~/Documents/git/brevitas/src/brevitas/nn/quant_layer.py:320\u001b[0m, in \u001b[0;36mQuantWeightBiasInputOutputLayer.forward_impl\u001b[0;34m(self, inp)\u001b[0m\n\u001b[1;32m 316\u001b[0m compute_output_quant_tensor \u001b[39m=\u001b[39m \u001b[39misinstance\u001b[39m(quant_input, QuantTensor) \u001b[39mand\u001b[39;00m \u001b[39misinstance\u001b[39m(\n\u001b[1;32m 317\u001b[0m quant_weight, QuantTensor)\n\u001b[1;32m 318\u001b[0m \u001b[39mif\u001b[39;00m \u001b[39mnot\u001b[39;00m (compute_output_quant_tensor \u001b[39mor\u001b[39;00m\n\u001b[1;32m 319\u001b[0m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mis_output_quant_enabled) \u001b[39mand\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mreturn_quant_tensor:\n\u001b[0;32m--> 320\u001b[0m \u001b[39mraise\u001b[39;00m \u001b[39mRuntimeError\u001b[39;00m(\u001b[39m\"\u001b[39m\u001b[39mQuantLayer is not correctly configured\u001b[39m\u001b[39m\"\u001b[39m)\n\u001b[1;32m 322\u001b[0m \u001b[39mif\u001b[39;00m (\u001b[39mself\u001b[39m\u001b[39m.\u001b[39mreturn_quant_tensor \u001b[39mor\u001b[39;00m\n\u001b[1;32m 323\u001b[0m (\u001b[39mself\u001b[39m\u001b[39m.\u001b[39mis_bias_quant_enabled \u001b[39mand\u001b[39;00m\n\u001b[1;32m 324\u001b[0m (\u001b[39mself\u001b[39m\u001b[39m.\u001b[39mbias_quant\u001b[39m.\u001b[39mrequires_input_scale \u001b[39mor\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mbias_quant\u001b[39m.\u001b[39mrequires_input_bit_width))):\n\u001b[1;32m 325\u001b[0m \u001b[39mif\u001b[39;00m \u001b[39misinstance\u001b[39m(quant_input, QuantTensor) \u001b[39mand\u001b[39;00m \u001b[39misinstance\u001b[39m(quant_weight, QuantTensor):\n", + "\u001b[0;31mRuntimeError\u001b[0m: QuantLayer is not correctly configured" ] } ], @@ -858,26 +827,26 @@ }, { "cell_type": "code", - "execution_count": 120, + "execution_count": 25, "metadata": {}, "outputs": [ { "data": { "text/plain": [ - "QuantTensor(value=tensor([[[[ 0.0005, 0.0043, -0.0004],\n", - " [ 0.0005, 0.0106, 0.0012],\n", - " [ 0.0021, 0.0007, -0.0050]],\n", + "QuantTensor(value=tensor([[[[-0.0058, 0.0030, 0.0030],\n", + " [-0.0013, -0.0001, 0.0043],\n", + " [-0.0061, 0.0033, -0.0001]],\n", "\n", - " [[-0.0067, -0.0035, -0.0059],\n", - " [-0.0050, -0.0015, -0.0039],\n", - " [ 0.0015, 0.0028, -0.0008]],\n", + " [[ 0.0013, -0.0008, -0.0015],\n", + " [ 0.0011, 0.0012, -0.0012],\n", + " [-0.0013, -0.0020, 0.0002]],\n", "\n", - " [[-0.0051, -0.0050, 0.0060],\n", - " [-0.0015, 0.0037, 0.0071],\n", - " [ 0.0067, 0.0035, -0.0071]]]], grad_fn=), scale=tensor([[[[1.8108e-07]]]], grad_fn=), zero_point=tensor(0.), bit_width=tensor(22.), signed_t=tensor(True), training_t=tensor(True))" + " [[-0.0061, 0.0053, -0.0004],\n", + " [ 0.0028, 0.0031, -0.0038],\n", + " [ 0.0026, -0.0048, -0.0044]]]], grad_fn=), scale=tensor([[[[1.8528e-07]]]], grad_fn=), zero_point=tensor(0.), bit_width=tensor(22.), signed_t=tensor(True), training_t=tensor(True))" ] }, - "execution_count": 120, + "execution_count": 25, "metadata": {}, "output_type": "execute_result" } @@ -895,26 +864,26 @@ }, { "cell_type": "code", - "execution_count": 121, + "execution_count": 26, "metadata": {}, "outputs": [ { "data": { "text/plain": [ - "QuantTensor(value=tensor([[[[-0.3825, 0.1371, 0.9135],\n", - " [-0.2016, 0.7495, -0.4071],\n", - " [-0.0755, 0.5283, 0.2388]],\n", + "QuantTensor(value=tensor([[[[-0.4300, 0.1726, -0.3396],\n", + " [ 0.0307, -0.0052, -1.1685],\n", + " [-0.3160, 0.1334, -0.4459]],\n", "\n", - " [[ 0.0788, -0.3802, -0.2234],\n", - " [ 0.8678, -0.5546, 0.4408],\n", - " [-0.6788, 0.4422, 0.3007]],\n", + " [[ 1.0135, 0.7129, -0.3874],\n", + " [ 0.4858, -0.6205, 0.1563],\n", + " [-0.1631, -0.2198, 0.1444]],\n", "\n", - " [[ 0.4412, -0.3205, 1.0033],\n", - " [-0.0083, -0.3295, -0.2076],\n", - " [ 0.4417, -0.1046, -0.3493]]]], grad_fn=), scale=tensor([[[[3.8610e-05]]]], grad_fn=), zero_point=tensor(0.), bit_width=tensor(22.), signed_t=tensor(True), training_t=tensor(True))" + " [[-1.4600, 0.9106, 0.6328],\n", + " [ 0.6669, -0.1814, -0.0169],\n", + " [ 0.6581, 0.7420, -0.4884]]]], grad_fn=), scale=tensor([[[[2.9050e-05]]]], grad_fn=), zero_point=tensor(0.), bit_width=tensor(22.), signed_t=tensor(True), training_t=tensor(True))" ] }, - "execution_count": 121, + "execution_count": 26, "metadata": {}, "output_type": "execute_result" } @@ -928,26 +897,26 @@ }, { "cell_type": "code", - "execution_count": 122, + "execution_count": 27, "metadata": {}, "outputs": [ { "data": { "text/plain": [ - "QuantTensor(value=tensor([[[[ 0.0036, 0.0024, -0.0033],\n", - " [ 0.0050, 0.0080, -0.0014],\n", - " [-0.0036, -0.0080, -0.0029]],\n", + "QuantTensor(value=tensor([[[[-0.0015, -0.0035, 0.0003],\n", + " [-0.0054, 0.0047, 0.0055],\n", + " [ 0.0043, 0.0054, -0.0050]],\n", "\n", - " [[ 0.0083, -0.0093, 0.0048],\n", - " [ 0.0035, 0.0015, -0.0011],\n", - " [-0.0003, 0.0067, 0.0013]],\n", + " [[-0.0004, 0.0013, -0.0018],\n", + " [ 0.0055, -0.0073, 0.0023],\n", + " [-0.0053, 0.0009, 0.0032]],\n", "\n", - " [[-0.0009, -0.0019, 0.0039],\n", - " [ 0.0010, 0.0056, -0.0037],\n", - " [ 0.0091, -0.0095, 0.0054]]]], grad_fn=), scale=tensor([[[[1.8384e-07]]]], grad_fn=), zero_point=tensor(0.), bit_width=tensor(22.), signed_t=tensor(True), training_t=tensor(True))" + " [[ 0.0015, -0.0002, -0.0068],\n", + " [ 0.0015, -0.0040, -0.0046],\n", + " [-0.0033, -0.0009, 0.0079]]]], grad_fn=), scale=tensor([[[[1.7377e-07]]]], grad_fn=), zero_point=tensor(0.), bit_width=tensor(22.), signed_t=tensor(True), training_t=tensor(True))" ] }, - "execution_count": 122, + "execution_count": 27, "metadata": {}, "output_type": "execute_result" } @@ -967,7 +936,7 @@ }, { "cell_type": "code", - "execution_count": 123, + "execution_count": 28, "metadata": { "tags": [ "raises-exception" @@ -981,12 +950,14 @@ "traceback": [ "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", "\u001b[0;31mRuntimeError\u001b[0m Traceback (most recent call last)", - "\u001b[0;32m/tmp/ipykernel_48365/2990591641.py\u001b[0m in \u001b[0;36m\u001b[0;34m\u001b[0m\n\u001b[1;32m 2\u001b[0m \u001b[0min_channels\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;36m2\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mout_channels\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;36m3\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mkernel_size\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;36m3\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;36m3\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mbias\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mTrue\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 3\u001b[0m output_quant=Int8ActPerTensorFloat, bias_quant=Int8Bias, return_quant_tensor=True)\n\u001b[0;32m----> 4\u001b[0;31m \u001b[0moutput_bias_quant_conv\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mrandn\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;36m2\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;36m5\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;36m5\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m", - "\u001b[0;32m/opt/conda/envs/torch_1.10/lib/python3.7/site-packages/torch/nn/modules/module.py\u001b[0m in \u001b[0;36m_call_impl\u001b[0;34m(self, *input, **kwargs)\u001b[0m\n\u001b[1;32m 1100\u001b[0m if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks\n\u001b[1;32m 1101\u001b[0m or _global_forward_hooks or _global_forward_pre_hooks):\n\u001b[0;32m-> 1102\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mforward_call\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 1103\u001b[0m \u001b[0;31m# Do not call functions when jit is used\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1104\u001b[0m \u001b[0mfull_backward_hooks\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mnon_full_backward_hooks\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0;32m/workspace/scratch/git/fork_brevitas/src/brevitas/nn/quant_conv.py\u001b[0m in \u001b[0;36mforward\u001b[0;34m(self, input)\u001b[0m\n\u001b[1;32m 190\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 191\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mforward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0minput\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mUnion\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mTensor\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mQuantTensor\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m->\u001b[0m \u001b[0mUnion\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mTensor\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mQuantTensor\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 192\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mforward_impl\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0minput\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 193\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 194\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0minner_forward_impl\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mx\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mTensor\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mquant_weight\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mTensor\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mquant_bias\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mOptional\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mTensor\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0;32m/workspace/scratch/git/fork_brevitas/src/brevitas/nn/quant_layer.py\u001b[0m in \u001b[0;36mforward_impl\u001b[0;34m(self, inp)\u001b[0m\n\u001b[1;32m 330\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 331\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mbias\u001b[0m \u001b[0;32mis\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 332\u001b[0;31m \u001b[0mquant_bias\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mbias_quant\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mbias\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0moutput_scale\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0moutput_bit_width\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 333\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtraining\u001b[0m \u001b[0;32mand\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcache_inference_quant_bias\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 334\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_cached_bias\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0m_CachedIO\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mquant_bias\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdetach\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mmetadata_only\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mFalse\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0;32m/opt/conda/envs/torch_1.10/lib/python3.7/site-packages/torch/nn/modules/module.py\u001b[0m in \u001b[0;36m_call_impl\u001b[0;34m(self, *input, **kwargs)\u001b[0m\n\u001b[1;32m 1100\u001b[0m if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks\n\u001b[1;32m 1101\u001b[0m or _global_forward_hooks or _global_forward_pre_hooks):\n\u001b[0;32m-> 1102\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mforward_call\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 1103\u001b[0m \u001b[0;31m# Do not call functions when jit is used\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1104\u001b[0m \u001b[0mfull_backward_hooks\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mnon_full_backward_hooks\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0;32m/workspace/scratch/git/fork_brevitas/src/brevitas/proxy/parameter_quant.py\u001b[0m in \u001b[0;36mforward\u001b[0;34m(self, x, input_scale, input_bit_width)\u001b[0m\n\u001b[1;32m 160\u001b[0m \u001b[0mimpl\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mexport_handler\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mexport_mode\u001b[0m \u001b[0;32melse\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtensor_quant\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 161\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mrequires_input_scale\u001b[0m \u001b[0;32mand\u001b[0m \u001b[0minput_scale\u001b[0m \u001b[0;32mis\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 162\u001b[0;31m \u001b[0;32mraise\u001b[0m \u001b[0mRuntimeError\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m\"Input scale required\"\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 163\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mrequires_input_bit_width\u001b[0m \u001b[0;32mand\u001b[0m \u001b[0minput_bit_width\u001b[0m \u001b[0;32mis\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 164\u001b[0m \u001b[0;32mraise\u001b[0m \u001b[0mRuntimeError\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m\"Input bit-width required\"\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[1;32m/home/giuseppe/Documents/git/brevitas/notebooks/01_quant_tensor_quant_conv2d_overview.ipynb Cell 53\u001b[0m line \u001b[0;36m4\n\u001b[1;32m 1\u001b[0m output_bias_quant_conv \u001b[39m=\u001b[39m QuantConv2d(\n\u001b[1;32m 2\u001b[0m in_channels\u001b[39m=\u001b[39m\u001b[39m2\u001b[39m, out_channels\u001b[39m=\u001b[39m\u001b[39m3\u001b[39m, kernel_size\u001b[39m=\u001b[39m(\u001b[39m3\u001b[39m,\u001b[39m3\u001b[39m), bias\u001b[39m=\u001b[39m\u001b[39mTrue\u001b[39;00m,\n\u001b[1;32m 3\u001b[0m output_quant\u001b[39m=\u001b[39mInt8ActPerTensorFloat, bias_quant\u001b[39m=\u001b[39mInt8Bias, return_quant_tensor\u001b[39m=\u001b[39m\u001b[39mTrue\u001b[39;00m)\n\u001b[0;32m----> 4\u001b[0m output_bias_quant_conv(torch\u001b[39m.\u001b[39;49mrandn(\u001b[39m1\u001b[39;49m, \u001b[39m2\u001b[39;49m, \u001b[39m5\u001b[39;49m, \u001b[39m5\u001b[39;49m))\n", + "File \u001b[0;32m~/miniconda3/envs/torch_2.1/lib/python3.11/site-packages/torch/nn/modules/module.py:1518\u001b[0m, in \u001b[0;36mModule._wrapped_call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1516\u001b[0m \u001b[39mreturn\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_compiled_call_impl(\u001b[39m*\u001b[39margs, \u001b[39m*\u001b[39m\u001b[39m*\u001b[39mkwargs) \u001b[39m# type: ignore[misc]\u001b[39;00m\n\u001b[1;32m 1517\u001b[0m \u001b[39melse\u001b[39;00m:\n\u001b[0;32m-> 1518\u001b[0m \u001b[39mreturn\u001b[39;00m \u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49m_call_impl(\u001b[39m*\u001b[39;49margs, \u001b[39m*\u001b[39;49m\u001b[39m*\u001b[39;49mkwargs)\n", + "File \u001b[0;32m~/miniconda3/envs/torch_2.1/lib/python3.11/site-packages/torch/nn/modules/module.py:1527\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1522\u001b[0m \u001b[39m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[1;32m 1523\u001b[0m \u001b[39m# this function, and just call forward.\u001b[39;00m\n\u001b[1;32m 1524\u001b[0m \u001b[39mif\u001b[39;00m \u001b[39mnot\u001b[39;00m (\u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_backward_hooks \u001b[39mor\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_backward_pre_hooks \u001b[39mor\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_forward_hooks \u001b[39mor\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_forward_pre_hooks\n\u001b[1;32m 1525\u001b[0m \u001b[39mor\u001b[39;00m _global_backward_pre_hooks \u001b[39mor\u001b[39;00m _global_backward_hooks\n\u001b[1;32m 1526\u001b[0m \u001b[39mor\u001b[39;00m _global_forward_hooks \u001b[39mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[0;32m-> 1527\u001b[0m \u001b[39mreturn\u001b[39;00m forward_call(\u001b[39m*\u001b[39;49margs, \u001b[39m*\u001b[39;49m\u001b[39m*\u001b[39;49mkwargs)\n\u001b[1;32m 1529\u001b[0m \u001b[39mtry\u001b[39;00m:\n\u001b[1;32m 1530\u001b[0m result \u001b[39m=\u001b[39m \u001b[39mNone\u001b[39;00m\n", + "File \u001b[0;32m~/Documents/git/brevitas/src/brevitas/nn/quant_conv.py:198\u001b[0m, in \u001b[0;36mQuantConv2d.forward\u001b[0;34m(self, input)\u001b[0m\n\u001b[1;32m 197\u001b[0m \u001b[39mdef\u001b[39;00m \u001b[39mforward\u001b[39m(\u001b[39mself\u001b[39m, \u001b[39minput\u001b[39m: Union[Tensor, QuantTensor]) \u001b[39m-\u001b[39m\u001b[39m>\u001b[39m Union[Tensor, QuantTensor]:\n\u001b[0;32m--> 198\u001b[0m \u001b[39mreturn\u001b[39;00m \u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49mforward_impl(\u001b[39minput\u001b[39;49m)\n", + "File \u001b[0;32m~/Documents/git/brevitas/src/brevitas/nn/quant_layer.py:334\u001b[0m, in \u001b[0;36mQuantWeightBiasInputOutputLayer.forward_impl\u001b[0;34m(self, inp)\u001b[0m\n\u001b[1;32m 331\u001b[0m output_signed \u001b[39m=\u001b[39m quant_input\u001b[39m.\u001b[39msigned \u001b[39mor\u001b[39;00m quant_weight\u001b[39m.\u001b[39msigned\n\u001b[1;32m 333\u001b[0m \u001b[39mif\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mbias \u001b[39mis\u001b[39;00m \u001b[39mnot\u001b[39;00m \u001b[39mNone\u001b[39;00m:\n\u001b[0;32m--> 334\u001b[0m quant_bias \u001b[39m=\u001b[39m \u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49mbias_quant(\u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49mbias, output_scale, output_bit_width)\n\u001b[1;32m 335\u001b[0m \u001b[39mif\u001b[39;00m \u001b[39mnot\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mtraining \u001b[39mand\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mcache_inference_quant_bias \u001b[39mand\u001b[39;00m \u001b[39misinstance\u001b[39m(quant_bias,\n\u001b[1;32m 336\u001b[0m QuantTensor):\n\u001b[1;32m 337\u001b[0m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_cached_bias \u001b[39m=\u001b[39m _CachedIO(quant_bias\u001b[39m.\u001b[39mdetach(), metadata_only\u001b[39m=\u001b[39m\u001b[39mFalse\u001b[39;00m)\n", + "File \u001b[0;32m~/miniconda3/envs/torch_2.1/lib/python3.11/site-packages/torch/nn/modules/module.py:1518\u001b[0m, in \u001b[0;36mModule._wrapped_call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1516\u001b[0m \u001b[39mreturn\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_compiled_call_impl(\u001b[39m*\u001b[39margs, \u001b[39m*\u001b[39m\u001b[39m*\u001b[39mkwargs) \u001b[39m# type: ignore[misc]\u001b[39;00m\n\u001b[1;32m 1517\u001b[0m \u001b[39melse\u001b[39;00m:\n\u001b[0;32m-> 1518\u001b[0m \u001b[39mreturn\u001b[39;00m \u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49m_call_impl(\u001b[39m*\u001b[39;49margs, \u001b[39m*\u001b[39;49m\u001b[39m*\u001b[39;49mkwargs)\n", + "File \u001b[0;32m~/miniconda3/envs/torch_2.1/lib/python3.11/site-packages/torch/nn/modules/module.py:1527\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1522\u001b[0m \u001b[39m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[1;32m 1523\u001b[0m \u001b[39m# this function, and just call forward.\u001b[39;00m\n\u001b[1;32m 1524\u001b[0m \u001b[39mif\u001b[39;00m \u001b[39mnot\u001b[39;00m (\u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_backward_hooks \u001b[39mor\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_backward_pre_hooks \u001b[39mor\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_forward_hooks \u001b[39mor\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_forward_pre_hooks\n\u001b[1;32m 1525\u001b[0m \u001b[39mor\u001b[39;00m _global_backward_pre_hooks \u001b[39mor\u001b[39;00m _global_backward_hooks\n\u001b[1;32m 1526\u001b[0m \u001b[39mor\u001b[39;00m _global_forward_hooks \u001b[39mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[0;32m-> 1527\u001b[0m \u001b[39mreturn\u001b[39;00m forward_call(\u001b[39m*\u001b[39;49margs, \u001b[39m*\u001b[39;49m\u001b[39m*\u001b[39;49mkwargs)\n\u001b[1;32m 1529\u001b[0m \u001b[39mtry\u001b[39;00m:\n\u001b[1;32m 1530\u001b[0m result \u001b[39m=\u001b[39m \u001b[39mNone\u001b[39;00m\n", + "File \u001b[0;32m~/Documents/git/brevitas/src/brevitas/proxy/parameter_quant.py:206\u001b[0m, in \u001b[0;36mBiasQuantProxyFromInjector.forward\u001b[0;34m(self, x, input_scale, input_bit_width)\u001b[0m\n\u001b[1;32m 204\u001b[0m impl \u001b[39m=\u001b[39m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mexport_handler \u001b[39mif\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mexport_mode \u001b[39melse\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mtensor_quant\n\u001b[1;32m 205\u001b[0m \u001b[39mif\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mrequires_input_scale \u001b[39mand\u001b[39;00m input_scale \u001b[39mis\u001b[39;00m \u001b[39mNone\u001b[39;00m:\n\u001b[0;32m--> 206\u001b[0m \u001b[39mraise\u001b[39;00m \u001b[39mRuntimeError\u001b[39;00m(\u001b[39m\"\u001b[39m\u001b[39mInput scale required\u001b[39m\u001b[39m\"\u001b[39m)\n\u001b[1;32m 207\u001b[0m \u001b[39mif\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mrequires_input_bit_width \u001b[39mand\u001b[39;00m input_bit_width \u001b[39mis\u001b[39;00m \u001b[39mNone\u001b[39;00m:\n\u001b[1;32m 208\u001b[0m \u001b[39mraise\u001b[39;00m \u001b[39mRuntimeError\u001b[39;00m(\u001b[39m\"\u001b[39m\u001b[39mInput bit-width required\u001b[39m\u001b[39m\"\u001b[39m)\n", "\u001b[0;31mRuntimeError\u001b[0m: Input scale required" ] } @@ -1007,26 +978,26 @@ }, { "cell_type": "code", - "execution_count": 124, + "execution_count": 29, "metadata": {}, "outputs": [ { "data": { "text/plain": [ - "tensor([[[[ 0.2152, 0.8346, 0.0746],\n", - " [-0.0738, -0.5212, 0.1019],\n", - " [-0.6004, 0.1500, -0.1453]],\n", + "tensor([[[[-0.6938, 0.0069, 0.1652],\n", + " [-0.4801, -0.8120, 0.5233],\n", + " [ 0.4159, 0.4662, 0.2565]],\n", "\n", - " [[-1.1551, -1.3458, -0.1312],\n", - " [ 0.2502, -0.5267, 0.2412],\n", - " [-0.3556, -0.3289, -0.2276]],\n", + " [[ 0.3206, -0.5500, -0.5254],\n", + " [ 0.1864, 1.0210, -0.3706],\n", + " [-0.1159, 0.6967, -0.0437]],\n", "\n", - " [[-0.4599, -0.6094, 0.4682],\n", - " [-0.5064, -0.6768, -0.6638],\n", - " [ 0.0066, -0.3581, 0.2359]]]], grad_fn=)" + " [[-0.6209, -0.5257, -0.6592],\n", + " [ 0.6389, 0.2658, 0.4542],\n", + " [-0.3761, -0.7776, -0.2897]]]], grad_fn=)" ] }, - "execution_count": 124, + "execution_count": 29, "metadata": {}, "output_type": "execute_result" } @@ -1051,30 +1022,30 @@ }, { "cell_type": "code", - "execution_count": 125, + "execution_count": 30, "metadata": {}, "outputs": [ { "data": { "text/plain": [ - "QuantTensor(value=tensor([[[[-0.6879, -0.6632, -0.2411],\n", - " [ 0.2064, -0.7371, 0.3910],\n", - " [ 0.9533, 0.2994, 0.6546]],\n", + "QuantTensor(value=tensor([[[[-0.4005, 0.7588, 0.4616],\n", + " [-0.0777, -0.0651, -0.2405],\n", + " [-0.7292, 0.4504, 0.3716]],\n", "\n", - " [[-0.4684, -0.4495, -0.5021],\n", - " [ 0.5738, 0.4199, -0.3380],\n", - " [ 0.6218, -0.0408, -0.8483]],\n", + " [[ 0.4868, -0.4495, -0.1327],\n", + " [ 0.2079, -0.3236, -0.5482],\n", + " [ 0.5471, 0.1503, 0.6813]],\n", "\n", - " [[-0.5625, 0.1837, -1.0575],\n", - " [-1.2816, -0.4993, -0.3409],\n", - " [ 0.4556, -1.4269, 0.5369]]]], grad_fn=), scale=tensor([[[[3.0975e-05]]]], grad_fn=), zero_point=tensor([[[[ 1276.0774]],\n", + " [[ 0.4356, -0.2319, 1.0867],\n", + " [ 0.0126, 0.7646, 0.3627],\n", + " [-0.4466, 0.5150, 0.1176]]]], grad_fn=), scale=tensor([[[[2.7130e-05]]]], grad_fn=), zero_point=tensor([[[[ 6313.4204]],\n", "\n", - " [[-3152.4585]],\n", + " [[-2667.2593]],\n", "\n", - " [[ 7320.2324]]]], grad_fn=), bit_width=tensor(21.), signed_t=tensor(True), training_t=tensor(True))" + " [[-5507.9629]]]], grad_fn=), bit_width=tensor(21.), signed_t=tensor(True), training_t=tensor(True))" ] }, - "execution_count": 125, + "execution_count": 30, "metadata": {}, "output_type": "execute_result" } @@ -1089,20 +1060,9 @@ }, { "cell_type": "code", - "execution_count": 126, + "execution_count": 31, "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "True" - ] - }, - "execution_count": 126, - "metadata": {}, - "output_type": "execute_result" - } - ], + "outputs": [], "source": [ "assert out_tensor.is_valid" ] @@ -1116,26 +1076,26 @@ }, { "cell_type": "code", - "execution_count": 127, + "execution_count": 33, "metadata": {}, "outputs": [ { "data": { "text/plain": [ - "tensor([[[[ 0.8357, 0.0733, 0.9527],\n", - " [ 0.1803, 0.2154, 0.7598],\n", - " [ 1.1121, -0.8728, 1.0039]],\n", + "tensor([[[[ 0.0650, 0.2496, -1.2857],\n", + " [ 1.0231, 0.0516, 0.7592],\n", + " [ 0.5882, -0.7619, 0.7604]],\n", "\n", - " [[ 0.7917, 1.0063, 0.6516],\n", - " [-0.1852, -0.7263, 0.0956],\n", - " [-0.1876, 0.2747, -0.1617]],\n", + " [[-0.6307, 0.1476, 1.0949],\n", + " [-0.1488, 0.0472, 0.0097],\n", + " [-0.2861, 0.0266, -0.2970]],\n", "\n", - " [[ 0.8299, 0.9934, -0.3821],\n", - " [ 0.4865, 0.9309, -0.7924],\n", - " [-0.4201, 0.2343, 0.1532]]]], grad_fn=)" + " [[ 0.0580, 1.2994, 0.3841],\n", + " [ 0.2056, 0.0496, -0.7915],\n", + " [ 0.4698, -0.8724, -0.0405]]]], grad_fn=)" ] }, - "execution_count": 127, + "execution_count": 33, "metadata": {}, "output_type": "execute_result" } @@ -1153,6 +1113,11 @@ "source": [ "Altough not obvious, the output is actually implicitly quantized." ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [] } ], "metadata": { @@ -1171,7 +1136,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.7.15" + "version": "3.11.5" }, "vscode": { "interpreter": { diff --git a/notebooks/02_quant_activation_overview.ipynb b/notebooks/02_quant_activation_overview.ipynb index 4d2ac73d1..39a0cfc14 100644 --- a/notebooks/02_quant_activation_overview.ipynb +++ b/notebooks/02_quant_activation_overview.ipynb @@ -26,14 +26,12 @@ }, "outputs": [ { - "data": { - "text/plain": [ - "True" - ] - }, - "execution_count": 1, - "metadata": {}, - "output_type": "execute_result" + "name": "stderr", + "output_type": "stream", + "text": [ + "/home/giuseppe/miniconda3/envs/torch_2.1/lib/python3.11/site-packages/torch/_tensor.py:1362: UserWarning: Named tensors and all their associated APIs are an experimental feature and subject to change. Please do not use them for anything important until they are released as stable. (Triggered internally at /opt/conda/conda-bld/pytorch_1695392026823/work/c10/core/TensorImpl.h:1900.)\n", + " return super().rename(names)\n" + ] } ], "source": [ @@ -68,18 +66,7 @@ "cell_type": "code", "execution_count": 2, "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "True" - ] - }, - "execution_count": 2, - "metadata": {}, - "output_type": "execute_result" - } - ], + "outputs": [], "source": [ "torch.manual_seed(0)\n", "input_output_quant_conv = QuantConv2d(\n", @@ -178,18 +165,7 @@ "cell_type": "code", "execution_count": 5, "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "True" - ] - }, - "execution_count": 5, - "metadata": {}, - "output_type": "execute_result" - } - ], + "outputs": [], "source": [ "assert out_tensor.is_valid" ] @@ -220,7 +196,7 @@ " [-0.7990, -1.2936, -0.7419, -1.3127, -0.2283],\n", " [-2.4351, -0.0761, 0.2283, 0.7990, -0.1902],\n", " [-0.3615, -1.2175, -0.6278, -0.4566, 1.9214]]]],\n", - " grad_fn=)" + " grad_fn=)" ] }, "execution_count": 6, @@ -252,7 +228,7 @@ " [-0.7990, -1.2936, -0.7419, -1.3127, -0.2283],\n", " [-2.4351, -0.0761, 0.2283, 0.7990, -0.1902],\n", " [-0.3615, -1.2175, -0.6278, -0.4566, 1.9214]]]],\n", - " grad_fn=), scale=tensor(0.0190, grad_fn=), zero_point=tensor(0.), bit_width=tensor(8.), signed_t=tensor(True), training_t=tensor(True))" + " grad_fn=), scale=tensor(0.0190, grad_fn=), zero_point=tensor(0.), bit_width=tensor(8.), signed_t=tensor(True), training_t=tensor(True))" ] }, "execution_count": 7, @@ -337,12 +313,16 @@ { "cell_type": "code", "execution_count": 10, - "metadata": {}, + "metadata": { + "tags": [ + "raises-exception" + ] + }, "outputs": [ { "data": { "text/plain": [ - "QuantTensor(value=(tensor([[[[0.3878, 0.3611, 0.3655, 0.6433, 0.8236],\n", + "tensor([[[[0.3878, 0.3611, 0.3655, 0.6433, 0.8236],\n", " [0.6257, 0.3567, 0.3611, 0.5474, 0.4810],\n", " [0.3788, 0.1820, 0.4526, 0.6077, 0.7911],\n", " [0.1630, 0.8883, 0.8471, 0.9151, 0.2456],\n", @@ -353,10 +333,10 @@ " [0.3102, 0.2152, 0.3226, 0.2120, 0.4432],\n", " [0.0805, 0.4810, 0.5568, 0.6898, 0.4526],\n", " [0.4106, 0.2284, 0.3480, 0.3878, 0.8723]]]],\n", - " grad_fn=), None, None, None), scale=None, zero_point=None, bit_width=None, signed_t=None, training_t=tensor(True))" + " grad_fn=)" ] }, - "execution_count": 10, + "execution_count": 11, "metadata": {}, "output_type": "execute_result" } @@ -369,26 +349,6 @@ "sigmoid_out_tensor" ] }, - { - "cell_type": "code", - "execution_count": 11, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "False" - ] - }, - "execution_count": 11, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "assert not sigmoid_out_tensor.is_valid" - ] - }, { "cell_type": "markdown", "metadata": {}, @@ -400,7 +360,7 @@ }, { "cell_type": "code", - "execution_count": 12, + "execution_count": 14, "metadata": {}, "outputs": [ { @@ -416,10 +376,10 @@ " [0.6421, 0.0000, 0.0000, 1.1708, 0.4343],\n", " [0.0000, 0.0000, 0.0000, 0.0000, 0.0000],\n", " [0.0000, 0.0000, 0.2266, 0.7931, 0.0000],\n", - " [0.0000, 0.0000, 0.0000, 0.0000, 1.9262]]]], grad_fn=), scale=tensor(0.0189, grad_fn=), zero_point=tensor(129., grad_fn=), bit_width=tensor(8.), signed_t=tensor(False), training_t=tensor(True))" + " [0.0000, 0.0000, 0.0000, 0.0000, 1.9262]]]], grad_fn=), scale=tensor(0.0189, grad_fn=), zero_point=tensor(129., grad_fn=), bit_width=tensor(8.), signed_t=tensor(False), training_t=tensor(True))" ] }, - "execution_count": 12, + "execution_count": 14, "metadata": {}, "output_type": "execute_result" } @@ -442,7 +402,7 @@ }, { "cell_type": "code", - "execution_count": 13, + "execution_count": 15, "metadata": {}, "outputs": [ { @@ -461,7 +421,7 @@ " [0.0000, 0.0000, 0.4907]]]], grad_fn=)" ] }, - "execution_count": 13, + "execution_count": 15, "metadata": {}, "output_type": "execute_result" } @@ -482,7 +442,7 @@ }, { "cell_type": "code", - "execution_count": 14, + "execution_count": 16, "metadata": {}, "outputs": [ { @@ -501,7 +461,7 @@ " [0.0000, 0.0000, 0.4839]]]], grad_fn=)" ] }, - "execution_count": 14, + "execution_count": 16, "metadata": {}, "output_type": "execute_result" } @@ -535,7 +495,7 @@ }, { "cell_type": "code", - "execution_count": 15, + "execution_count": 17, "metadata": {}, "outputs": [], "source": [ @@ -553,20 +513,9 @@ }, { "cell_type": "code", - "execution_count": 16, + "execution_count": 18, "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "False" - ] - }, - "execution_count": 16, - "metadata": {}, - "output_type": "execute_result" - } - ], + "outputs": [], "source": [ "out1_train = quant_identity(inp1)\n", "out2_train = quant_identity(inp2)\n", @@ -575,20 +524,9 @@ }, { "cell_type": "code", - "execution_count": 17, + "execution_count": 19, "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "True" - ] - }, - "execution_count": 17, - "metadata": {}, - "output_type": "execute_result" - } - ], + "outputs": [], "source": [ "quant_identity.eval()\n", "out1_eval = quant_identity(inp1)\n", @@ -605,7 +543,7 @@ }, { "cell_type": "code", - "execution_count": 18, + "execution_count": 20, "metadata": { "tags": [ "raises-exception" @@ -617,19 +555,19 @@ "evalue": "'Int8ActPerTensorFloatMinMaxInit' can not resolve attribute 'max_val' while building 'scaling_init_impl'", "output_type": "error", "traceback": [ - "\u001b[1;31m---------------------------------------------------------------------------\u001b[0m", - "\u001b[1;31mDependencyError\u001b[0m Traceback (most recent call last)", - "\u001b[1;32m\u001b[0m in \u001b[0;36m\u001b[1;34m\u001b[0m\n\u001b[0;32m 1\u001b[0m \u001b[1;32mfrom\u001b[0m \u001b[0mbrevitas\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mnn\u001b[0m \u001b[1;32mimport\u001b[0m \u001b[0mQuantHardTanh\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 2\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m----> 3\u001b[1;33m \u001b[0mQuantHardTanh\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m", - "\u001b[1;32mc:\\brevitas_fx\\src\\brevitas\\nn\\quant_activation.py\u001b[0m in \u001b[0;36m__init__\u001b[1;34m(self, act_quant, input_quant, return_quant_tensor, **kwargs)\u001b[0m\n\u001b[0;32m 117\u001b[0m \u001b[0mact_quant\u001b[0m\u001b[1;33m=\u001b[0m\u001b[0mact_quant\u001b[0m\u001b[1;33m,\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 118\u001b[0m \u001b[0mreturn_quant_tensor\u001b[0m\u001b[1;33m=\u001b[0m\u001b[0mreturn_quant_tensor\u001b[0m\u001b[1;33m,\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m--> 119\u001b[1;33m **kwargs)\n\u001b[0m\u001b[0;32m 120\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 121\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n", - "\u001b[1;32mc:\\brevitas_fx\\src\\brevitas\\nn\\quant_layer.py\u001b[0m in \u001b[0;36m__init__\u001b[1;34m(self, act_impl, passthrough_act, input_quant, act_quant, return_quant_tensor, **kwargs)\u001b[0m\n\u001b[0;32m 77\u001b[0m \u001b[0mpassthrough_act\u001b[0m\u001b[1;33m,\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 78\u001b[0m \u001b[0mact_quant\u001b[0m\u001b[1;33m,\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m---> 79\u001b[1;33m **kwargs)\n\u001b[0m\u001b[0;32m 80\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 81\u001b[0m \u001b[1;33m@\u001b[0m\u001b[0mproperty\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n", - "\u001b[1;32mc:\\brevitas_fx\\src\\brevitas\\nn\\mixin\\act.py\u001b[0m in \u001b[0;36m__init__\u001b[1;34m(self, act_impl, passthrough_act, act_quant, **kwargs)\u001b[0m\n\u001b[0;32m 157\u001b[0m \u001b[0mproxy_prefix\u001b[0m\u001b[1;33m=\u001b[0m\u001b[1;34m'act_'\u001b[0m\u001b[1;33m,\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 158\u001b[0m \u001b[0mkwargs_prefix\u001b[0m\u001b[1;33m=\u001b[0m\u001b[1;34m''\u001b[0m\u001b[1;33m,\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m--> 159\u001b[1;33m **kwargs)\n\u001b[0m\u001b[0;32m 160\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 161\u001b[0m \u001b[1;33m@\u001b[0m\u001b[0mproperty\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n", - "\u001b[1;32mc:\\brevitas_fx\\src\\brevitas\\nn\\mixin\\base.py\u001b[0m in \u001b[0;36m__init__\u001b[1;34m(self, quant, proxy_protocol, none_quant_injector, proxy_prefix, kwargs_prefix, **kwargs)\u001b[0m\n\u001b[0;32m 98\u001b[0m \u001b[0mquant_injector\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mquant\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 99\u001b[0m \u001b[0mquant_injector\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mquant_injector\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mlet\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m**\u001b[0m\u001b[0mfilter_kwargs\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mkwargs_prefix\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mkwargs\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m--> 100\u001b[1;33m \u001b[0mquant\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mquant_injector\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mproxy_class\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mself\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mquant_injector\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 101\u001b[0m \u001b[1;32melse\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 102\u001b[0m \u001b[1;32mif\u001b[0m \u001b[1;32mnot\u001b[0m \u001b[0misinstance\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mquant\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mproxy_protocol\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n", - "\u001b[1;32mc:\\brevitas_fx\\src\\brevitas\\proxy\\runtime_quant.py\u001b[0m in \u001b[0;36m__init__\u001b[1;34m(self, quant_layer, quant_injector)\u001b[0m\n\u001b[0;32m 108\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 109\u001b[0m \u001b[1;32mdef\u001b[0m \u001b[0m__init__\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mself\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mquant_layer\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mquant_injector\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m--> 110\u001b[1;33m \u001b[0msuper\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mActQuantProxyFromInjector\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0m__init__\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mquant_layer\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mquant_injector\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 111\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mis_passthrough_act\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0m_is_passthrough_act\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mquant_injector\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 112\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n", - "\u001b[1;32mc:\\brevitas_fx\\src\\brevitas\\proxy\\quant_proxy.py\u001b[0m in \u001b[0;36m__init__\u001b[1;34m(self, quant_layer, quant_injector, export_mode, export_handler)\u001b[0m\n\u001b[0;32m 74\u001b[0m \u001b[1;31m# Use a normal list and not a ModuleList since this is a pointer to parent modules\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 75\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mtracked_module_list\u001b[0m \u001b[1;33m=\u001b[0m \u001b[1;33m[\u001b[0m\u001b[1;33m]\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m---> 76\u001b[1;33m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0madd_tracked_module\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mquant_layer\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 77\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mexport_handler\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mexport_handler\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 78\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mexport_mode\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mexport_mode\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n", - "\u001b[1;32mc:\\brevitas_fx\\src\\brevitas\\proxy\\quant_proxy.py\u001b[0m in \u001b[0;36madd_tracked_module\u001b[1;34m(self, module)\u001b[0m\n\u001b[0;32m 130\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mtracked_module_list\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mappend\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mmodule\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 131\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mupdate_tracked_modules\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m--> 132\u001b[1;33m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0minit_tensor_quant\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 133\u001b[0m \u001b[1;32melse\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 134\u001b[0m \u001b[1;32mraise\u001b[0m \u001b[0mRuntimeError\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;34m\"Trying to add None as a parent module.\"\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n", - "\u001b[1;32mc:\\brevitas_fx\\src\\brevitas\\proxy\\runtime_quant.py\u001b[0m in \u001b[0;36minit_tensor_quant\u001b[1;34m(self)\u001b[0m\n\u001b[0;32m 120\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 121\u001b[0m \u001b[1;32mdef\u001b[0m \u001b[0minit_tensor_quant\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mself\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m--> 122\u001b[1;33m \u001b[0mtensor_quant\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mquant_injector\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mtensor_quant\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 123\u001b[0m \u001b[0mact_impl\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mquant_injector\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mact_impl\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 124\u001b[0m \u001b[0mis_act_enabled\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0m_is_act_enabled\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mact_impl\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mtensor_quant\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n", - " \u001b[1;31m[... skipping hidden 1 frame]\u001b[0m\n", - "\u001b[1;31mDependencyError\u001b[0m: 'Int8ActPerTensorFloatMinMaxInit' can not resolve attribute 'max_val' while building 'scaling_init_impl'" + "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[0;31mDependencyError\u001b[0m Traceback (most recent call last)", + "\u001b[1;32m/home/giuseppe/Documents/git/brevitas/notebooks/02_quant_activation_overview.ipynb Cell 35\u001b[0m line \u001b[0;36m3\n\u001b[1;32m 1\u001b[0m \u001b[39mfrom\u001b[39;00m \u001b[39mbrevitas\u001b[39;00m\u001b[39m.\u001b[39;00m\u001b[39mnn\u001b[39;00m \u001b[39mimport\u001b[39;00m QuantHardTanh\n\u001b[0;32m----> 3\u001b[0m QuantHardTanh()\n", + "File \u001b[0;32m~/Documents/git/brevitas/src/brevitas/nn/quant_activation.py:96\u001b[0m, in \u001b[0;36mQuantHardTanh.__init__\u001b[0;34m(self, act_quant, input_quant, return_quant_tensor, **kwargs)\u001b[0m\n\u001b[1;32m 90\u001b[0m \u001b[39mdef\u001b[39;00m \u001b[39m__init__\u001b[39m(\n\u001b[1;32m 91\u001b[0m \u001b[39mself\u001b[39m,\n\u001b[1;32m 92\u001b[0m act_quant: Optional[ActQuantType] \u001b[39m=\u001b[39m Int8ActPerTensorFloatMinMaxInit,\n\u001b[1;32m 93\u001b[0m input_quant: Optional[ActQuantType] \u001b[39m=\u001b[39m \u001b[39mNone\u001b[39;00m,\n\u001b[1;32m 94\u001b[0m return_quant_tensor: \u001b[39mbool\u001b[39m \u001b[39m=\u001b[39m \u001b[39mFalse\u001b[39;00m,\n\u001b[1;32m 95\u001b[0m \u001b[39m*\u001b[39m\u001b[39m*\u001b[39mkwargs):\n\u001b[0;32m---> 96\u001b[0m QuantNLAL\u001b[39m.\u001b[39;49m\u001b[39m__init__\u001b[39;49m(\n\u001b[1;32m 97\u001b[0m \u001b[39mself\u001b[39;49m,\n\u001b[1;32m 98\u001b[0m act_impl\u001b[39m=\u001b[39;49mnn\u001b[39m.\u001b[39;49mHardtanh,\n\u001b[1;32m 99\u001b[0m passthrough_act\u001b[39m=\u001b[39;49m\u001b[39mTrue\u001b[39;49;00m,\n\u001b[1;32m 100\u001b[0m input_quant\u001b[39m=\u001b[39;49minput_quant,\n\u001b[1;32m 101\u001b[0m act_quant\u001b[39m=\u001b[39;49mact_quant,\n\u001b[1;32m 102\u001b[0m return_quant_tensor\u001b[39m=\u001b[39;49mreturn_quant_tensor,\n\u001b[1;32m 103\u001b[0m \u001b[39m*\u001b[39;49m\u001b[39m*\u001b[39;49mkwargs)\n", + "File \u001b[0;32m~/Documents/git/brevitas/src/brevitas/nn/quant_layer.py:40\u001b[0m, in \u001b[0;36mQuantNonLinearActLayer.__init__\u001b[0;34m(self, act_impl, passthrough_act, input_quant, act_quant, return_quant_tensor, **kwargs)\u001b[0m\n\u001b[1;32m 38\u001b[0m QuantLayerMixin\u001b[39m.\u001b[39m\u001b[39m__init__\u001b[39m(\u001b[39mself\u001b[39m, return_quant_tensor)\n\u001b[1;32m 39\u001b[0m QuantInputMixin\u001b[39m.\u001b[39m\u001b[39m__init__\u001b[39m(\u001b[39mself\u001b[39m, input_quant, \u001b[39m*\u001b[39m\u001b[39m*\u001b[39mkwargs)\n\u001b[0;32m---> 40\u001b[0m QuantNonLinearActMixin\u001b[39m.\u001b[39;49m\u001b[39m__init__\u001b[39;49m(\u001b[39mself\u001b[39;49m, act_impl, passthrough_act, act_quant, \u001b[39m*\u001b[39;49m\u001b[39m*\u001b[39;49mkwargs)\n", + "File \u001b[0;32m~/Documents/git/brevitas/src/brevitas/nn/mixin/act.py:118\u001b[0m, in \u001b[0;36mQuantNonLinearActMixin.__init__\u001b[0;34m(self, act_impl, passthrough_act, act_quant, act_proxy_prefix, act_kwargs_prefix, **kwargs)\u001b[0m\n\u001b[1;32m 107\u001b[0m \u001b[39mdef\u001b[39;00m \u001b[39m__init__\u001b[39m(\n\u001b[1;32m 108\u001b[0m \u001b[39mself\u001b[39m,\n\u001b[1;32m 109\u001b[0m act_impl: Optional[Type[Module]],\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 113\u001b[0m act_kwargs_prefix\u001b[39m=\u001b[39m\u001b[39m'\u001b[39m\u001b[39m'\u001b[39m,\n\u001b[1;32m 114\u001b[0m \u001b[39m*\u001b[39m\u001b[39m*\u001b[39mkwargs):\n\u001b[1;32m 115\u001b[0m prefixed_kwargs \u001b[39m=\u001b[39m {\n\u001b[1;32m 116\u001b[0m act_kwargs_prefix \u001b[39m+\u001b[39m \u001b[39m'\u001b[39m\u001b[39mact_impl\u001b[39m\u001b[39m'\u001b[39m: act_impl,\n\u001b[1;32m 117\u001b[0m act_kwargs_prefix \u001b[39m+\u001b[39m \u001b[39m'\u001b[39m\u001b[39mpassthrough_act\u001b[39m\u001b[39m'\u001b[39m: passthrough_act}\n\u001b[0;32m--> 118\u001b[0m QuantProxyMixin\u001b[39m.\u001b[39;49m\u001b[39m__init__\u001b[39;49m(\n\u001b[1;32m 119\u001b[0m \u001b[39mself\u001b[39;49m,\n\u001b[1;32m 120\u001b[0m quant\u001b[39m=\u001b[39;49mact_quant,\n\u001b[1;32m 121\u001b[0m proxy_prefix\u001b[39m=\u001b[39;49mact_proxy_prefix,\n\u001b[1;32m 122\u001b[0m kwargs_prefix\u001b[39m=\u001b[39;49mact_kwargs_prefix,\n\u001b[1;32m 123\u001b[0m proxy_protocol\u001b[39m=\u001b[39;49mActQuantProxyProtocol,\n\u001b[1;32m 124\u001b[0m none_quant_injector\u001b[39m=\u001b[39;49mNoneActQuant,\n\u001b[1;32m 125\u001b[0m \u001b[39m*\u001b[39;49m\u001b[39m*\u001b[39;49mprefixed_kwargs,\n\u001b[1;32m 126\u001b[0m \u001b[39m*\u001b[39;49m\u001b[39m*\u001b[39;49mkwargs)\n", + "File \u001b[0;32m~/Documents/git/brevitas/src/brevitas/nn/mixin/base.py:71\u001b[0m, in \u001b[0;36mQuantProxyMixin.__init__\u001b[0;34m(self, quant, proxy_protocol, none_quant_injector, proxy_prefix, kwargs_prefix, **kwargs)\u001b[0m\n\u001b[1;32m 69\u001b[0m quant_injector \u001b[39m=\u001b[39m quant\n\u001b[1;32m 70\u001b[0m quant_injector \u001b[39m=\u001b[39m quant_injector\u001b[39m.\u001b[39mlet(\u001b[39m*\u001b[39m\u001b[39m*\u001b[39mfilter_kwargs(kwargs_prefix, kwargs))\n\u001b[0;32m---> 71\u001b[0m quant \u001b[39m=\u001b[39m quant_injector\u001b[39m.\u001b[39;49mproxy_class(\u001b[39mself\u001b[39;49m, quant_injector)\n\u001b[1;32m 72\u001b[0m \u001b[39melse\u001b[39;00m:\n\u001b[1;32m 73\u001b[0m \u001b[39mif\u001b[39;00m \u001b[39mnot\u001b[39;00m \u001b[39misinstance\u001b[39m(quant, proxy_protocol):\n", + "File \u001b[0;32m~/Documents/git/brevitas/src/brevitas/proxy/runtime_quant.py:89\u001b[0m, in \u001b[0;36mActQuantProxyFromInjector.__init__\u001b[0;34m(self, quant_layer, quant_injector)\u001b[0m\n\u001b[1;32m 88\u001b[0m \u001b[39mdef\u001b[39;00m \u001b[39m__init__\u001b[39m(\u001b[39mself\u001b[39m, quant_layer, quant_injector):\n\u001b[0;32m---> 89\u001b[0m QuantProxyFromInjector\u001b[39m.\u001b[39;49m\u001b[39m__init__\u001b[39;49m(\u001b[39mself\u001b[39;49m, quant_layer, quant_injector)\n\u001b[1;32m 90\u001b[0m ActQuantProxyProtocol\u001b[39m.\u001b[39m\u001b[39m__init__\u001b[39m(\u001b[39mself\u001b[39m)\n\u001b[1;32m 91\u001b[0m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mis_passthrough_act \u001b[39m=\u001b[39m _is_passthrough_act(quant_injector)\n", + "File \u001b[0;32m~/Documents/git/brevitas/src/brevitas/proxy/quant_proxy.py:82\u001b[0m, in \u001b[0;36mQuantProxyFromInjector.__init__\u001b[0;34m(self, quant_layer, quant_injector)\u001b[0m\n\u001b[1;32m 80\u001b[0m \u001b[39m# Use a normal list and not a ModuleList since this is a pointer to parent modules\u001b[39;00m\n\u001b[1;32m 81\u001b[0m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mtracked_module_list \u001b[39m=\u001b[39m []\n\u001b[0;32m---> 82\u001b[0m \u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49madd_tracked_module(quant_layer)\n\u001b[1;32m 83\u001b[0m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mdisable_quant \u001b[39m=\u001b[39m \u001b[39mFalse\u001b[39;00m\n", + "File \u001b[0;32m~/Documents/git/brevitas/src/brevitas/proxy/quant_proxy.py:120\u001b[0m, in \u001b[0;36mQuantProxyFromInjector.add_tracked_module\u001b[0;34m(self, module)\u001b[0m\n\u001b[1;32m 118\u001b[0m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mtracked_module_list\u001b[39m.\u001b[39mappend(module)\n\u001b[1;32m 119\u001b[0m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mupdate_tracked_modules()\n\u001b[0;32m--> 120\u001b[0m \u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49minit_tensor_quant()\n\u001b[1;32m 121\u001b[0m \u001b[39melse\u001b[39;00m:\n\u001b[1;32m 122\u001b[0m \u001b[39mraise\u001b[39;00m \u001b[39mRuntimeError\u001b[39;00m(\u001b[39m\"\u001b[39m\u001b[39mTrying to add None as a parent module.\u001b[39m\u001b[39m\"\u001b[39m)\n", + "File \u001b[0;32m~/Documents/git/brevitas/src/brevitas/proxy/runtime_quant.py:102\u001b[0m, in \u001b[0;36mActQuantProxyFromInjector.init_tensor_quant\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 101\u001b[0m \u001b[39mdef\u001b[39;00m \u001b[39minit_tensor_quant\u001b[39m(\u001b[39mself\u001b[39m):\n\u001b[0;32m--> 102\u001b[0m tensor_quant \u001b[39m=\u001b[39m \u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49mquant_injector\u001b[39m.\u001b[39;49mtensor_quant\n\u001b[1;32m 103\u001b[0m \u001b[39mif\u001b[39;00m \u001b[39m'\u001b[39m\u001b[39mact_impl\u001b[39m\u001b[39m'\u001b[39m \u001b[39min\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mquant_injector:\n\u001b[1;32m 104\u001b[0m act_impl \u001b[39m=\u001b[39m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mquant_injector\u001b[39m.\u001b[39mact_impl\n", + "File \u001b[0;32m~/Documents/git/brevitas/src/brevitas/inject/__init__.py:129\u001b[0m, in \u001b[0;36m_ExtendedInjectorType.__getattr__\u001b[0;34m(cls, attrname)\u001b[0m\n\u001b[1;32m 126\u001b[0m \u001b[39melse\u001b[39;00m:\n\u001b[1;32m 127\u001b[0m message \u001b[39m=\u001b[39m \u001b[39m\"\u001b[39m\u001b[39m{!r}\u001b[39;00m\u001b[39m can not resolve attribute \u001b[39m\u001b[39m{!r}\u001b[39;00m\u001b[39m\"\u001b[39m\u001b[39m.\u001b[39mformat(\n\u001b[1;32m 128\u001b[0m \u001b[39mcls\u001b[39m\u001b[39m.\u001b[39m\u001b[39m__name__\u001b[39m, current_attr)\n\u001b[0;32m--> 129\u001b[0m \u001b[39mraise\u001b[39;00m DependencyError(message)\n\u001b[1;32m 131\u001b[0m marker, attribute, args, have_defaults \u001b[39m=\u001b[39m spec\n\u001b[1;32m 133\u001b[0m \u001b[39mif\u001b[39;00m \u001b[39mset\u001b[39m(args)\u001b[39m.\u001b[39missubset(cached):\n", + "\u001b[0;31mDependencyError\u001b[0m: 'Int8ActPerTensorFloatMinMaxInit' can not resolve attribute 'max_val' while building 'scaling_init_impl'" ] } ], @@ -648,7 +586,7 @@ }, { "cell_type": "code", - "execution_count": 19, + "execution_count": 21, "metadata": {}, "outputs": [], "source": [ @@ -664,20 +602,9 @@ }, { "cell_type": "code", - "execution_count": 20, + "execution_count": 22, "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "True" - ] - }, - "execution_count": 20, - "metadata": {}, - "output_type": "execute_result" - } - ], + "outputs": [], "source": [ "out1_train = quant_hard_tanh(inp1)\n", "quant_hard_tanh.eval()\n", @@ -711,7 +638,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.7.7" + "version": "3.11.5" } }, "nbformat": 4, diff --git a/notebooks/03_anatomy_of_a_quantizer.ipynb b/notebooks/03_anatomy_of_a_quantizer.ipynb index 2055a1714..21a0b54f4 100644 --- a/notebooks/03_anatomy_of_a_quantizer.ipynb +++ b/notebooks/03_anatomy_of_a_quantizer.ipynb @@ -181,8 +181,9 @@ " Set env variable BREVITAS_JIT=1 to enable TorchScript compilation of this module.\n", " \"\"\"\n", "\n", - " def __init__(self, scaling_impl: Module, quant_delay_steps: int = 0):\n", + " def __init__(self, scaling_impl: Module, signed: bool = True, quant_delay_steps: int = 0):\n", " super(BinaryQuant, self).__init__()\n", + " assert signed, \"Unsigned binary quant not supported\"\n", " self.scaling_impl = scaling_impl\n", " self.bit_width = BitWidthConst(1)\n", " self.zero_point = StatelessBuffer(torch.tensor(0.0))\n", @@ -247,10 +248,10 @@ { "data": { "text/plain": [ - "(tensor([[ 0.1000, 0.1000, 0.1000, 0.1000],\n", + "(tensor([[-0.1000, -0.1000, 0.1000, 0.1000],\n", + " [ 0.1000, -0.1000, 0.1000, 0.1000],\n", " [-0.1000, -0.1000, 0.1000, 0.1000],\n", - " [ 0.1000, -0.1000, 0.1000, -0.1000],\n", - " [ 0.1000, -0.1000, 0.1000, -0.1000]], grad_fn=),\n", + " [ 0.1000, -0.1000, -0.1000, -0.1000]], grad_fn=),\n", " tensor(0.1000, grad_fn=),\n", " tensor(0.),\n", " tensor(1.))" @@ -292,10 +293,10 @@ { "data": { "text/plain": [ - "(tensor([[-0.1000, 0.1000, -0.1000, 0.1000],\n", - " [ 0.1000, 0.1000, -0.1000, -0.1000],\n", - " [-0.1000, 0.1000, -0.1000, 0.1000],\n", - " [-0.1000, 0.1000, 0.1000, 0.1000]], grad_fn=),\n", + "(tensor([[ 0.1000, -0.1000, -0.1000, -0.1000],\n", + " [ 0.1000, -0.1000, -0.1000, 0.1000],\n", + " [-0.1000, -0.1000, -0.1000, 0.1000],\n", + " [ 0.1000, -0.1000, 0.1000, 0.1000]], grad_fn=),\n", " tensor(0.1000, grad_fn=),\n", " tensor(0.),\n", " tensor(1.))" @@ -342,10 +343,10 @@ { "data": { "text/plain": [ - "(tensor([[ 1., -1., 1., 1.],\n", - " [ 1., 1., -1., 1.],\n", - " [ 1., 1., 1., -1.],\n", - " [-1., 1., -1., -1.]], grad_fn=),\n", + "(tensor([[-1., -1., -1., 1.],\n", + " [-1., 1., -1., -1.],\n", + " [-1., -1., -1., -1.],\n", + " [-1., 1., -1., 1.]], grad_fn=),\n", " tensor(1., grad_fn=),\n", " tensor(0.),\n", " tensor(1.))" @@ -379,9 +380,9 @@ { "data": { "text/plain": [ - "(tensor([[ 0.1000, -0.1000, -0.1000, -0.1000],\n", - " [-0.1000, 0.1000, -0.1000, 0.1000],\n", - " [ 0.1000, -0.1000, 0.1000, 0.1000],\n", + "(tensor([[-0.1000, 0.1000, -0.1000, -0.1000],\n", + " [-0.1000, -0.1000, 0.1000, -0.1000],\n", + " [-0.1000, -0.1000, -0.1000, 0.1000],\n", " [-0.1000, 0.1000, -0.1000, 0.1000]], grad_fn=),\n", " tensor(0.1000, grad_fn=),\n", " tensor(0.),\n", @@ -444,63 +445,22 @@ "cell_type": "code", "execution_count": 11, "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "QuantTensor(value=tensor([[[[ 0.1000, 0.1000, -0.1000],\n", - " [-0.1000, 0.1000, -0.1000],\n", - " [ 0.1000, -0.1000, -0.1000]],\n", - "\n", - " [[-0.1000, 0.1000, -0.1000],\n", - " [ 0.1000, -0.1000, 0.1000],\n", - " [-0.1000, -0.1000, 0.1000]],\n", - "\n", - " [[ 0.1000, -0.1000, -0.1000],\n", - " [ 0.1000, 0.1000, -0.1000],\n", - " [-0.1000, -0.1000, 0.1000]]],\n", - "\n", - "\n", - " [[[ 0.1000, -0.1000, 0.1000],\n", - " [ 0.1000, -0.1000, -0.1000],\n", - " [ 0.1000, -0.1000, 0.1000]],\n", - "\n", - " [[-0.1000, 0.1000, -0.1000],\n", - " [ 0.1000, 0.1000, 0.1000],\n", - " [-0.1000, -0.1000, -0.1000]],\n", - "\n", - " [[ 0.1000, 0.1000, -0.1000],\n", - " [-0.1000, 0.1000, -0.1000],\n", - " [ 0.1000, 0.1000, 0.1000]]]], grad_fn=), scale=tensor(0.1000, grad_fn=), zero_point=tensor(0.), bit_width=tensor(1.), signed_t=None, training_t=tensor(True))" - ] - }, - "execution_count": 11, - "metadata": {}, - "output_type": "execute_result" - } - ], + "outputs": [], "source": [ "from brevitas.nn import QuantConv2d\n", "\n", "binary_weight_quant_conv = QuantConv2d(3, 2, (3,3), weight_quant=MyBinaryWeightQuantizer)\n", - "quant_weight = binary_weight_quant_conv.quant_weight()\n", - "quant_weight" + "try:\n", + " quant_weight = binary_weight_quant_conv.quant_weight()\n", + "except TypeError:\n", + " pass\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ - "Note however how the `QuantTensor` is not properly formed, as the `signed` attribute is `None`. This means that `quant_weight` is not considered valid, as the affine quantization invariant cannot be computed:" - ] - }, - { - "cell_type": "code", - "execution_count": 12, - "metadata": {}, - "outputs": [], - "source": [ - "assert not quant_weight.is_valid" + "Note however that we cannot compute the quantized weight, as the `signed` attribute is `None`." ] }, { @@ -512,39 +472,39 @@ }, { "cell_type": "code", - "execution_count": 13, + "execution_count": 12, "metadata": {}, "outputs": [ { "data": { "text/plain": [ - "QuantTensor(value=tensor([[[[ 0.1000, 0.1000, -0.1000],\n", - " [-0.1000, -0.1000, 0.1000],\n", - " [ 0.1000, 0.1000, -0.1000]],\n", + "QuantTensor(value=tensor([[[[-0.1000, -0.1000, 0.1000],\n", + " [ 0.1000, 0.1000, 0.1000],\n", + " [ 0.1000, -0.1000, 0.1000]],\n", + "\n", + " [[-0.1000, 0.1000, -0.1000],\n", + " [-0.1000, -0.1000, -0.1000],\n", + " [-0.1000, 0.1000, -0.1000]],\n", "\n", " [[ 0.1000, 0.1000, 0.1000],\n", - " [ 0.1000, -0.1000, 0.1000],\n", - " [ 0.1000, -0.1000, -0.1000]],\n", + " [-0.1000, 0.1000, 0.1000],\n", + " [-0.1000, -0.1000, 0.1000]]],\n", "\n", - " [[-0.1000, 0.1000, 0.1000],\n", - " [ 0.1000, -0.1000, -0.1000],\n", - " [-0.1000, -0.1000, -0.1000]]],\n", "\n", + " [[[-0.1000, 0.1000, -0.1000],\n", + " [-0.1000, -0.1000, 0.1000],\n", + " [-0.1000, 0.1000, 0.1000]],\n", "\n", - " [[[ 0.1000, 0.1000, 0.1000],\n", - " [ 0.1000, 0.1000, -0.1000],\n", + " [[-0.1000, -0.1000, -0.1000],\n", + " [ 0.1000, 0.1000, 0.1000],\n", " [-0.1000, -0.1000, 0.1000]],\n", "\n", - " [[-0.1000, -0.1000, 0.1000],\n", - " [-0.1000, 0.1000, 0.1000],\n", - " [-0.1000, -0.1000, -0.1000]],\n", - "\n", - " [[-0.1000, 0.1000, -0.1000],\n", + " [[ 0.1000, -0.1000, -0.1000],\n", " [-0.1000, 0.1000, -0.1000],\n", - " [-0.1000, 0.1000, -0.1000]]]], grad_fn=), scale=tensor(0.1000, grad_fn=), zero_point=tensor(0.), bit_width=tensor(1.), signed_t=tensor(True), training_t=tensor(True))" + " [ 0.1000, 0.1000, -0.1000]]]], grad_fn=), scale=tensor(0.1000, grad_fn=), zero_point=tensor(0.), bit_width=tensor(1.), signed_t=tensor(True), training_t=tensor(True))" ] }, - "execution_count": 13, + "execution_count": 12, "metadata": {}, "output_type": "execute_result" } @@ -560,11 +520,11 @@ }, { "cell_type": "code", - "execution_count": 15, + "execution_count": 13, "metadata": {}, "outputs": [], "source": [ - "assert signed_quant_weight.is_valid == True" + "assert signed_quant_weight.is_valid" ] }, { @@ -578,39 +538,39 @@ }, { "cell_type": "code", - "execution_count": 16, + "execution_count": 14, "metadata": {}, "outputs": [ { "data": { "text/plain": [ - "QuantTensor(value=tensor([[[[-0.1000, -0.1000, 0.1000],\n", - " [-0.1000, -0.1000, -0.1000],\n", - " [ 0.1000, -0.1000, 0.1000]],\n", + "QuantTensor(value=tensor([[[[ 0.1000, 0.1000, 0.1000],\n", + " [ 0.1000, -0.1000, 0.1000],\n", + " [-0.1000, 0.1000, -0.1000]],\n", "\n", - " [[-0.1000, 0.1000, -0.1000],\n", - " [-0.1000, 0.1000, 0.1000],\n", - " [ 0.1000, -0.1000, -0.1000]],\n", + " [[-0.1000, -0.1000, -0.1000],\n", + " [ 0.1000, -0.1000, 0.1000],\n", + " [-0.1000, 0.1000, -0.1000]],\n", "\n", - " [[-0.1000, 0.1000, -0.1000],\n", - " [-0.1000, -0.1000, 0.1000],\n", - " [-0.1000, -0.1000, -0.1000]]],\n", + " [[-0.1000, -0.1000, 0.1000],\n", + " [ 0.1000, -0.1000, 0.1000],\n", + " [ 0.1000, -0.1000, -0.1000]]],\n", "\n", "\n", - " [[[-0.1000, -0.1000, -0.1000],\n", - " [-0.1000, -0.1000, -0.1000],\n", - " [ 0.1000, 0.1000, -0.1000]],\n", + " [[[ 0.1000, -0.1000, 0.1000],\n", + " [ 0.1000, 0.1000, 0.1000],\n", + " [ 0.1000, -0.1000, 0.1000]],\n", "\n", - " [[-0.1000, -0.1000, 0.1000],\n", - " [-0.1000, 0.1000, -0.1000],\n", + " [[-0.1000, 0.1000, 0.1000],\n", + " [ 0.1000, -0.1000, 0.1000],\n", " [ 0.1000, -0.1000, 0.1000]],\n", "\n", - " [[ 0.1000, 0.1000, -0.1000],\n", - " [ 0.1000, 0.1000, 0.1000],\n", - " [ 0.1000, -0.1000, 0.1000]]]], grad_fn=), scale=tensor(0.1000, grad_fn=), zero_point=tensor(0.), bit_width=tensor(1.), signed_t=tensor(True), training_t=tensor(True))" + " [[-0.1000, -0.1000, -0.1000],\n", + " [-0.1000, 0.1000, -0.1000],\n", + " [-0.1000, 0.1000, -0.1000]]]], grad_fn=), scale=tensor(0.1000, grad_fn=), zero_point=tensor(0.), bit_width=tensor(1.), signed_t=tensor(True), training_t=tensor(True))" ] }, - "execution_count": 16, + "execution_count": 14, "metadata": {}, "output_type": "execute_result" } @@ -640,19 +600,27 @@ }, { "cell_type": "code", - "execution_count": 17, + "execution_count": 15, "metadata": {}, "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/home/giuseppe/miniconda3/envs/torch_2.1/lib/python3.11/site-packages/torch/_tensor.py:1362: UserWarning: Named tensors and all their associated APIs are an experimental feature and subject to change. Please do not use them for anything important until they are released as stable. (Triggered internally at /opt/conda/conda-bld/pytorch_1695392026823/work/c10/core/TensorImpl.h:1900.)\n", + " return super().rename(names)\n" + ] + }, { "data": { "text/plain": [ - "QuantTensor(value=tensor([[-0.1000, 0.1000, -0.1000, 0.1000],\n", - " [ 0.1000, 0.1000, 0.1000, 0.1000],\n", - " [-0.1000, 0.1000, 0.1000, 0.1000],\n", - " [-0.1000, -0.1000, 0.1000, -0.1000]], grad_fn=), scale=tensor(0.1000, grad_fn=), zero_point=tensor(0.), bit_width=tensor(1.), signed_t=tensor(True), training_t=tensor(True))" + "QuantTensor(value=tensor([[ 0.1000, -0.1000, 0.1000, -0.1000],\n", + " [-0.1000, -0.1000, 0.1000, -0.1000],\n", + " [-0.1000, -0.1000, 0.1000, -0.1000],\n", + " [ 0.1000, -0.1000, 0.1000, 0.1000]], grad_fn=), scale=tensor(0.1000, grad_fn=), zero_point=tensor(0.), bit_width=tensor(1.), signed_t=tensor(True), training_t=tensor(True))" ] }, - "execution_count": 17, + "execution_count": 15, "metadata": {}, "output_type": "execute_result" } @@ -678,19 +646,19 @@ }, { "cell_type": "code", - "execution_count": 18, + "execution_count": 16, "metadata": {}, "outputs": [ { "data": { "text/plain": [ - "QuantTensor(value=tensor([[ 0.0010, 0.0010, 0.0010, -0.0010],\n", - " [ 0.0010, -0.0010, 0.0010, -0.0010],\n", - " [-0.0010, -0.0010, -0.0010, -0.0010],\n", - " [ 0.0010, 0.0010, 0.0010, 0.0010]], grad_fn=), scale=tensor(0.0010, grad_fn=), zero_point=tensor(0.), bit_width=tensor(1.), signed_t=tensor(True), training_t=tensor(True))" + "QuantTensor(value=tensor([[-0.0010, 0.0010, 0.0010, -0.0010],\n", + " [-0.0010, 0.0010, -0.0010, -0.0010],\n", + " [ 0.0010, -0.0010, -0.0010, -0.0010],\n", + " [ 0.0010, -0.0010, 0.0010, -0.0010]], grad_fn=), scale=tensor(0.0010, grad_fn=), zero_point=tensor(0.), bit_width=tensor(1.), signed_t=tensor(True), training_t=tensor(True))" ] }, - "execution_count": 18, + "execution_count": 16, "metadata": {}, "output_type": "execute_result" } @@ -716,7 +684,7 @@ }, { "cell_type": "code", - "execution_count": 19, + "execution_count": 17, "metadata": {}, "outputs": [], "source": [ @@ -740,7 +708,7 @@ }, { "cell_type": "code", - "execution_count": 20, + "execution_count": 18, "metadata": { "scrolled": true }, @@ -748,33 +716,33 @@ { "data": { "text/plain": [ - "QuantTensor(value=tensor([[[[-0.1876, -0.1876, -0.1876],\n", - " [ 0.1876, 0.1876, 0.1876],\n", - " [-0.1876, -0.1876, 0.1876]],\n", + "QuantTensor(value=tensor([[[[ 0.1904, -0.1904, -0.1904],\n", + " [-0.1904, 0.1904, -0.1904],\n", + " [-0.1904, 0.1904, 0.1904]],\n", "\n", - " [[-0.1876, -0.1876, 0.1876],\n", - " [ 0.1876, 0.1876, -0.1876],\n", - " [-0.1876, 0.1876, 0.1876]],\n", + " [[-0.1904, 0.1904, -0.1904],\n", + " [ 0.1904, -0.1904, -0.1904],\n", + " [ 0.1904, 0.1904, -0.1904]],\n", "\n", - " [[-0.1876, -0.1876, -0.1876],\n", - " [ 0.1876, 0.1876, 0.1876],\n", - " [-0.1876, 0.1876, -0.1876]]],\n", + " [[-0.1904, -0.1904, 0.1904],\n", + " [-0.1904, -0.1904, -0.1904],\n", + " [ 0.1904, 0.1904, -0.1904]]],\n", "\n", "\n", - " [[[-0.1876, -0.1876, -0.1876],\n", - " [ 0.1876, 0.1876, -0.1876],\n", - " [ 0.1876, -0.1876, -0.1876]],\n", + " [[[-0.1904, 0.1904, 0.1904],\n", + " [ 0.1904, -0.1904, -0.1904],\n", + " [ 0.1904, 0.1904, 0.1904]],\n", "\n", - " [[-0.1876, 0.1876, -0.1876],\n", - " [ 0.1876, -0.1876, -0.1876],\n", - " [-0.1876, -0.1876, 0.1876]],\n", + " [[ 0.1904, -0.1904, 0.1904],\n", + " [ 0.1904, 0.1904, 0.1904],\n", + " [ 0.1904, -0.1904, -0.1904]],\n", "\n", - " [[-0.1876, 0.1876, 0.1876],\n", - " [ 0.1876, -0.1876, 0.1876],\n", - " [-0.1876, -0.1876, -0.1876]]]], grad_fn=), scale=tensor(0.1876, grad_fn=), zero_point=tensor(0.), bit_width=tensor(1.), signed_t=tensor(True), training_t=tensor(True))" + " [[-0.1904, 0.1904, 0.1904],\n", + " [-0.1904, -0.1904, -0.1904],\n", + " [-0.1904, -0.1904, 0.1904]]]], grad_fn=), scale=tensor(0.1904, grad_fn=), zero_point=tensor(0.), bit_width=tensor(1.), signed_t=tensor(True), training_t=tensor(True))" ] }, - "execution_count": 20, + "execution_count": 18, "metadata": {}, "output_type": "execute_result" } @@ -793,7 +761,7 @@ }, { "cell_type": "code", - "execution_count": 21, + "execution_count": 19, "metadata": {}, "outputs": [ { @@ -802,7 +770,7 @@ "True" ] }, - "execution_count": 21, + "execution_count": 19, "metadata": {}, "output_type": "execute_result" } @@ -820,16 +788,16 @@ }, { "cell_type": "code", - "execution_count": 22, + "execution_count": 20, "metadata": {}, "outputs": [ { "data": { "text/plain": [ - "tensor(0.1897, grad_fn=)" + "tensor(0.1876, grad_fn=)" ] }, - "execution_count": 22, + "execution_count": 20, "metadata": {}, "output_type": "execute_result" } @@ -850,7 +818,7 @@ }, { "cell_type": "code", - "execution_count": 23, + "execution_count": 22, "metadata": { "tags": [ "raises-exception" @@ -862,11 +830,11 @@ "evalue": "Error(s) in loading state_dict for QuantConv2d:\n\tMissing key(s) in state_dict: \"weight_quant.tensor_quant.scaling_impl.value\". ", "output_type": "error", "traceback": [ - "\u001b[1;31m---------------------------------------------------------------------------\u001b[0m", - "\u001b[1;31mRuntimeError\u001b[0m Traceback (most recent call last)", - "\u001b[1;32m\u001b[0m in \u001b[0;36m\u001b[1;34m\u001b[0m\n\u001b[1;32m----> 1\u001b[1;33m \u001b[0mparam_from_max_quant_conv\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mload_state_dict\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mfloat_conv\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mstate_dict\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m", - "\u001b[1;32mC:\\ProgramData\\Miniconda3\\envs\\pytorch\\lib\\site-packages\\torch\\nn\\modules\\module.py\u001b[0m in \u001b[0;36mload_state_dict\u001b[1;34m(self, state_dict, strict)\u001b[0m\n\u001b[0;32m 1405\u001b[0m \u001b[1;32mif\u001b[0m \u001b[0mlen\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0merror_msgs\u001b[0m\u001b[1;33m)\u001b[0m \u001b[1;33m>\u001b[0m \u001b[1;36m0\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 1406\u001b[0m raise RuntimeError('Error(s) in loading state_dict for {}:\\n\\t{}'.format(\n\u001b[1;32m-> 1407\u001b[1;33m self.__class__.__name__, \"\\n\\t\".join(error_msgs)))\n\u001b[0m\u001b[0;32m 1408\u001b[0m \u001b[1;32mreturn\u001b[0m \u001b[0m_IncompatibleKeys\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mmissing_keys\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0munexpected_keys\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 1409\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n", - "\u001b[1;31mRuntimeError\u001b[0m: Error(s) in loading state_dict for QuantConv2d:\n\tMissing key(s) in state_dict: \"weight_quant.tensor_quant.scaling_impl.value\". " + "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[0;31mRuntimeError\u001b[0m Traceback (most recent call last)", + "\u001b[1;32m/home/giuseppe/Documents/git/brevitas/notebooks/03_anatomy_of_a_quantizer.ipynb Cell 45\u001b[0m line \u001b[0;36m1\n\u001b[0;32m----> 1\u001b[0m param_from_max_quant_conv\u001b[39m.\u001b[39;49mload_state_dict(float_conv\u001b[39m.\u001b[39;49mstate_dict())\n", + "File \u001b[0;32m~/miniconda3/envs/torch_2.1/lib/python3.11/site-packages/torch/nn/modules/module.py:2152\u001b[0m, in \u001b[0;36mModule.load_state_dict\u001b[0;34m(self, state_dict, strict, assign)\u001b[0m\n\u001b[1;32m 2147\u001b[0m error_msgs\u001b[39m.\u001b[39minsert(\n\u001b[1;32m 2148\u001b[0m \u001b[39m0\u001b[39m, \u001b[39m'\u001b[39m\u001b[39mMissing key(s) in state_dict: \u001b[39m\u001b[39m{}\u001b[39;00m\u001b[39m. \u001b[39m\u001b[39m'\u001b[39m\u001b[39m.\u001b[39mformat(\n\u001b[1;32m 2149\u001b[0m \u001b[39m'\u001b[39m\u001b[39m, \u001b[39m\u001b[39m'\u001b[39m\u001b[39m.\u001b[39mjoin(\u001b[39mf\u001b[39m\u001b[39m'\u001b[39m\u001b[39m\"\u001b[39m\u001b[39m{\u001b[39;00mk\u001b[39m}\u001b[39;00m\u001b[39m\"\u001b[39m\u001b[39m'\u001b[39m \u001b[39mfor\u001b[39;00m k \u001b[39min\u001b[39;00m missing_keys)))\n\u001b[1;32m 2151\u001b[0m \u001b[39mif\u001b[39;00m \u001b[39mlen\u001b[39m(error_msgs) \u001b[39m>\u001b[39m \u001b[39m0\u001b[39m:\n\u001b[0;32m-> 2152\u001b[0m \u001b[39mraise\u001b[39;00m \u001b[39mRuntimeError\u001b[39;00m(\u001b[39m'\u001b[39m\u001b[39mError(s) in loading state_dict for \u001b[39m\u001b[39m{}\u001b[39;00m\u001b[39m:\u001b[39m\u001b[39m\\n\u001b[39;00m\u001b[39m\\t\u001b[39;00m\u001b[39m{}\u001b[39;00m\u001b[39m'\u001b[39m\u001b[39m.\u001b[39mformat(\n\u001b[1;32m 2153\u001b[0m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m\u001b[39m__class__\u001b[39m\u001b[39m.\u001b[39m\u001b[39m__name__\u001b[39m, \u001b[39m\"\u001b[39m\u001b[39m\\n\u001b[39;00m\u001b[39m\\t\u001b[39;00m\u001b[39m\"\u001b[39m\u001b[39m.\u001b[39mjoin(error_msgs)))\n\u001b[1;32m 2154\u001b[0m \u001b[39mreturn\u001b[39;00m _IncompatibleKeys(missing_keys, unexpected_keys)\n", + "\u001b[0;31mRuntimeError\u001b[0m: Error(s) in loading state_dict for QuantConv2d:\n\tMissing key(s) in state_dict: \"weight_quant.tensor_quant.scaling_impl.value\". " ] } ], @@ -916,39 +884,39 @@ }, { "cell_type": "code", - "execution_count": 25, + "execution_count": 24, "metadata": {}, "outputs": [ { "data": { "text/plain": [ - "QuantTensor(value=tensor([[[[ 0.1897, -0.1897, 0.1897],\n", - " [-0.1897, 0.1897, -0.1897],\n", - " [-0.1897, 0.1897, -0.1897]],\n", + "QuantTensor(value=tensor([[[[ 0.1876, 0.1876, -0.1876],\n", + " [ 0.1876, -0.1876, 0.1876],\n", + " [-0.1876, 0.1876, -0.1876]],\n", "\n", - " [[-0.1897, 0.1897, 0.1897],\n", - " [ 0.1897, -0.1897, -0.1897],\n", - " [ 0.1897, -0.1897, 0.1897]],\n", + " [[-0.1876, 0.1876, 0.1876],\n", + " [-0.1876, 0.1876, -0.1876],\n", + " [ 0.1876, -0.1876, -0.1876]],\n", "\n", - " [[-0.1897, 0.1897, -0.1897],\n", - " [-0.1897, 0.1897, 0.1897],\n", - " [-0.1897, 0.1897, 0.1897]]],\n", + " [[-0.1876, -0.1876, -0.1876],\n", + " [ 0.1876, 0.1876, 0.1876],\n", + " [ 0.1876, -0.1876, -0.1876]]],\n", "\n", "\n", - " [[[ 0.1897, 0.1897, 0.1897],\n", - " [-0.1897, 0.1897, -0.1897],\n", - " [ 0.1897, 0.1897, -0.1897]],\n", + " [[[-0.1876, -0.1876, -0.1876],\n", + " [-0.1876, 0.1876, -0.1876],\n", + " [-0.1876, -0.1876, -0.1876]],\n", "\n", - " [[ 0.1897, -0.1897, -0.1897],\n", - " [ 0.1897, 0.1897, -0.1897],\n", - " [ 0.1897, 0.1897, 0.1897]],\n", + " [[-0.1876, -0.1876, -0.1876],\n", + " [-0.1876, 0.1876, 0.1876],\n", + " [-0.1876, -0.1876, -0.1876]],\n", "\n", - " [[-0.1897, 0.1897, -0.1897],\n", - " [-0.1897, 0.1897, -0.1897],\n", - " [ 0.1897, 0.1897, 0.1897]]]], grad_fn=), scale=tensor(0.1897, grad_fn=), zero_point=tensor(0.), bit_width=tensor(1.), signed_t=tensor(True), training_t=tensor(True))" + " [[ 0.1876, 0.1876, -0.1876],\n", + " [-0.1876, -0.1876, 0.1876],\n", + " [ 0.1876, 0.1876, 0.1876]]]], grad_fn=), scale=tensor(0.1876, grad_fn=), zero_point=tensor(0.), bit_width=tensor(1.), signed_t=tensor(True), training_t=tensor(True))" ] }, - "execution_count": 25, + "execution_count": 24, "metadata": {}, "output_type": "execute_result" } @@ -979,7 +947,7 @@ }, { "cell_type": "code", - "execution_count": 26, + "execution_count": 25, "metadata": {}, "outputs": [ { @@ -1013,18 +981,7 @@ "cell_type": "code", "execution_count": 26, "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "True" - ] - }, - "execution_count": 26, - "metadata": {}, - "output_type": "execute_result" - } - ], + "outputs": [], "source": [ "quant_conv1 = QuantConv2d(3, 2, (3, 3), weight_quant=MySignedBinaryWeightQuantizer)\n", "quant_conv2 = QuantConv2d(3, 2, (3, 3), weight_quant=quant_conv1.weight_quant)\n", @@ -1036,19 +993,7 @@ "cell_type": "code", "execution_count": 27, "metadata": {}, - "outputs": [ - { - "ename": "AssertionError", - "evalue": "", - "output_type": "error", - "traceback": [ - "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", - "\u001b[0;31mAssertionError\u001b[0m Traceback (most recent call last)", - "\u001b[0;32m/tmp/ipykernel_58415/1066539094.py\u001b[0m in \u001b[0;36m\u001b[0;34m\u001b[0m\n\u001b[0;32m----> 1\u001b[0;31m \u001b[0;32massert\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0;34m(\u001b[0m\u001b[0mquant_conv1\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mquant_weight_scale\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m==\u001b[0m \u001b[0mquant_conv2\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mquant_weight_scale\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mitem\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m", - "\u001b[0;31mAssertionError\u001b[0m: " - ] - } - ], + "outputs": [], "source": [ "assert (quant_conv1.quant_weight_scale() == quant_conv2.quant_weight_scale()).item()" ] @@ -1065,18 +1010,7 @@ "cell_type": "code", "execution_count": 28, "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "False" - ] - }, - "execution_count": 28, - "metadata": {}, - "output_type": "execute_result" - } - ], + "outputs": [], "source": [ "class SharedParamFromMeanWeightQuantizer(MySignedBinaryWeightQuantizer):\n", " \n", @@ -1097,7 +1031,7 @@ }, { "cell_type": "code", - "execution_count": 31, + "execution_count": 29, "metadata": {}, "outputs": [], "source": [ @@ -1140,7 +1074,7 @@ }, { "cell_type": "code", - "execution_count": 32, + "execution_count": 30, "metadata": {}, "outputs": [], "source": [ @@ -1159,7 +1093,7 @@ }, { "cell_type": "code", - "execution_count": 33, + "execution_count": 31, "metadata": {}, "outputs": [], "source": [ @@ -1260,42 +1194,42 @@ }, { "cell_type": "code", - "execution_count": 35, + "execution_count": 33, "metadata": {}, "outputs": [ { "data": { "text/plain": [ - "QuantTensor(value=tensor([[[[-0.1842, 0.1842, -0.1842],\n", - " [-0.1842, -0.1842, 0.1842],\n", - " [-0.1842, -0.1842, 0.1842]],\n", + "QuantTensor(value=tensor([[[[-0.1903, 0.1903, -0.1903],\n", + " [ 0.1903, 0.1903, -0.1903],\n", + " [-0.1903, -0.1903, -0.1903]],\n", "\n", - " [[-0.1842, -0.1842, 0.1842],\n", - " [ 0.1842, -0.1842, 0.1842],\n", - " [ 0.1842, 0.1842, -0.1842]],\n", + " [[ 0.1903, -0.1903, -0.1903],\n", + " [ 0.1903, 0.1903, -0.1903],\n", + " [ 0.1903, -0.1903, 0.1903]],\n", "\n", - " [[-0.1842, -0.1842, 0.1842],\n", - " [ 0.1842, 0.1842, 0.1842],\n", - " [-0.1842, 0.1842, -0.1842]]],\n", + " [[-0.1903, -0.1903, -0.1903],\n", + " [-0.1903, -0.1903, 0.1903],\n", + " [-0.1903, 0.1903, -0.1903]]],\n", "\n", "\n", - " [[[ 0.1838, 0.1838, 0.1838],\n", - " [-0.1838, -0.1838, -0.1838],\n", - " [ 0.1838, 0.1838, -0.1838]],\n", + " [[[ 0.1870, 0.1870, -0.1870],\n", + " [ 0.1870, 0.1870, -0.1870],\n", + " [-0.1870, 0.1870, -0.1870]],\n", "\n", - " [[ 0.1838, -0.1838, 0.1838],\n", - " [ 0.1838, 0.1838, 0.1838],\n", - " [-0.1838, 0.1838, -0.1838]],\n", + " [[-0.1870, 0.1870, 0.1870],\n", + " [ 0.1870, 0.1870, 0.1870],\n", + " [ 0.1870, 0.1870, 0.1870]],\n", "\n", - " [[-0.1838, 0.1838, -0.1838],\n", - " [ 0.1838, -0.1838, -0.1838],\n", - " [ 0.1838, -0.1838, 0.1838]]]], grad_fn=), scale=tensor([[[[0.1842]]],\n", + " [[-0.1870, -0.1870, -0.1870],\n", + " [ 0.1870, -0.1870, -0.1870],\n", + " [-0.1870, -0.1870, 0.1870]]]], grad_fn=), scale=tensor([[[[0.1903]]],\n", "\n", "\n", - " [[[0.1838]]]], grad_fn=), zero_point=tensor(0.), bit_width=tensor(1.), signed_t=tensor(True), training_t=tensor(True))" + " [[[0.1870]]]], grad_fn=), zero_point=tensor(0.), bit_width=tensor(1.), signed_t=tensor(True), training_t=tensor(True))" ] }, - "execution_count": 35, + "execution_count": 33, "metadata": {}, "output_type": "execute_result" } @@ -1318,42 +1252,42 @@ }, { "cell_type": "code", - "execution_count": 36, + "execution_count": 34, "metadata": {}, "outputs": [ { "data": { "text/plain": [ - "QuantTensor(value=tensor([[[[ 0.1875, -0.1875, 0.1875],\n", - " [-0.1875, 0.1875, -0.1875],\n", - " [-0.1875, 0.1875, -0.1875]],\n", + "QuantTensor(value=tensor([[[[ 0.1873, 0.1873, -0.1873],\n", + " [ 0.1873, -0.1873, 0.1873],\n", + " [-0.1873, 0.1873, -0.1873]],\n", "\n", - " [[-0.1875, 0.1875, 0.1875],\n", - " [ 0.1875, -0.1875, -0.1875],\n", - " [ 0.1875, -0.1875, 0.1875]],\n", + " [[-0.1873, 0.1873, 0.1873],\n", + " [-0.1873, 0.1873, -0.1873],\n", + " [ 0.1873, -0.1873, -0.1873]],\n", "\n", - " [[-0.1875, 0.1875, -0.1875],\n", - " [-0.1875, 0.1875, 0.1875],\n", - " [-0.1875, 0.1875, 0.1875]]],\n", + " [[-0.1873, -0.1873, -0.1873],\n", + " [ 0.1873, 0.1873, 0.1873],\n", + " [ 0.1873, -0.1873, -0.1873]]],\n", "\n", "\n", - " [[[ 0.1897, 0.1897, 0.1897],\n", - " [-0.1897, 0.1897, -0.1897],\n", - " [ 0.1897, 0.1897, -0.1897]],\n", + " [[[-0.1876, -0.1876, -0.1876],\n", + " [-0.1876, 0.1876, -0.1876],\n", + " [-0.1876, -0.1876, -0.1876]],\n", "\n", - " [[ 0.1897, -0.1897, -0.1897],\n", - " [ 0.1897, 0.1897, -0.1897],\n", - " [ 0.1897, 0.1897, 0.1897]],\n", + " [[-0.1876, -0.1876, -0.1876],\n", + " [-0.1876, 0.1876, 0.1876],\n", + " [-0.1876, -0.1876, -0.1876]],\n", "\n", - " [[-0.1897, 0.1897, -0.1897],\n", - " [-0.1897, 0.1897, -0.1897],\n", - " [ 0.1897, 0.1897, 0.1897]]]], grad_fn=), scale=tensor([[[[0.1875]]],\n", + " [[ 0.1876, 0.1876, -0.1876],\n", + " [-0.1876, -0.1876, 0.1876],\n", + " [ 0.1876, 0.1876, 0.1876]]]], grad_fn=), scale=tensor([[[[0.1873]]],\n", "\n", "\n", - " [[[0.1897]]]], grad_fn=), zero_point=tensor(0.), bit_width=tensor(1.), signed_t=tensor(True), training_t=tensor(True))" + " [[[0.1876]]]], grad_fn=), zero_point=tensor(0.), bit_width=tensor(1.), signed_t=tensor(True), training_t=tensor(True))" ] }, - "execution_count": 36, + "execution_count": 34, "metadata": {}, "output_type": "execute_result" } @@ -1374,19 +1308,19 @@ }, { "cell_type": "code", - "execution_count": 37, + "execution_count": 35, "metadata": {}, "outputs": [ { "data": { "text/plain": [ - "tensor([[-0.0100, -0.0100, 0.0100, -0.0100],\n", - " [-0.0100, -0.0100, -0.0100, 0.0100],\n", - " [-0.0100, 0.0100, 0.0100, 0.0100],\n", - " [-0.0100, 0.0100, 0.0100, 0.0100]], grad_fn=)" + "tensor([[ 0.0100, 0.0100, -0.0100, -0.0100],\n", + " [ 0.0100, -0.0100, 0.0100, -0.0100],\n", + " [ 0.0100, -0.0100, 0.0100, -0.0100],\n", + " [-0.0100, -0.0100, 0.0100, -0.0100]], grad_fn=)" ] }, - "execution_count": 37, + "execution_count": 35, "metadata": {}, "output_type": "execute_result" } @@ -1421,21 +1355,21 @@ "evalue": "'AdvancedActQuantizer' can not resolve attribute 'per_channel_broadcastable_shape'", "output_type": "error", "traceback": [ - "\u001b[1;31m---------------------------------------------------------------------------\u001b[0m", - "\u001b[1;31mDependencyError\u001b[0m Traceback (most recent call last)", - "\u001b[1;32m\u001b[0m in \u001b[0;36m\u001b[1;34m\u001b[0m\n\u001b[0;32m 2\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 3\u001b[0m quant_identity = QuantIdentity(\n\u001b[1;32m----> 4\u001b[1;33m act_quant=AdvancedActQuantizer, is_clamped=True, scaling_per_output_channel=True)\n\u001b[0m", - "\u001b[1;32mc:\\brevitas_fx\\src\\brevitas\\nn\\quant_activation.py\u001b[0m in \u001b[0;36m__init__\u001b[1;34m(self, act_quant, return_quant_tensor, **kwargs)\u001b[0m\n\u001b[0;32m 134\u001b[0m \u001b[0mact_quant\u001b[0m\u001b[1;33m=\u001b[0m\u001b[0mact_quant\u001b[0m\u001b[1;33m,\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 135\u001b[0m \u001b[0mreturn_quant_tensor\u001b[0m\u001b[1;33m=\u001b[0m\u001b[0mreturn_quant_tensor\u001b[0m\u001b[1;33m,\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m--> 136\u001b[1;33m **kwargs)\n\u001b[0m\u001b[0;32m 137\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n", - "\u001b[1;32mc:\\brevitas_fx\\src\\brevitas\\nn\\quant_layer.py\u001b[0m in \u001b[0;36m__init__\u001b[1;34m(self, act_impl, passthrough_act, input_quant, act_quant, return_quant_tensor, **kwargs)\u001b[0m\n\u001b[0;32m 77\u001b[0m \u001b[0mpassthrough_act\u001b[0m\u001b[1;33m,\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 78\u001b[0m \u001b[0mact_quant\u001b[0m\u001b[1;33m,\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m---> 79\u001b[1;33m **kwargs)\n\u001b[0m\u001b[0;32m 80\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 81\u001b[0m \u001b[1;33m@\u001b[0m\u001b[0mproperty\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n", - "\u001b[1;32mc:\\brevitas_fx\\src\\brevitas\\nn\\mixin\\act.py\u001b[0m in \u001b[0;36m__init__\u001b[1;34m(self, act_impl, passthrough_act, act_quant, **kwargs)\u001b[0m\n\u001b[0;32m 157\u001b[0m \u001b[0mproxy_prefix\u001b[0m\u001b[1;33m=\u001b[0m\u001b[1;34m'act_'\u001b[0m\u001b[1;33m,\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 158\u001b[0m \u001b[0mkwargs_prefix\u001b[0m\u001b[1;33m=\u001b[0m\u001b[1;34m''\u001b[0m\u001b[1;33m,\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m--> 159\u001b[1;33m **kwargs)\n\u001b[0m\u001b[0;32m 160\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 161\u001b[0m \u001b[1;33m@\u001b[0m\u001b[0mproperty\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n", - "\u001b[1;32mc:\\brevitas_fx\\src\\brevitas\\nn\\mixin\\base.py\u001b[0m in \u001b[0;36m__init__\u001b[1;34m(self, quant, proxy_protocol, none_quant_injector, proxy_prefix, kwargs_prefix, **kwargs)\u001b[0m\n\u001b[0;32m 98\u001b[0m \u001b[0mquant_injector\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mquant\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 99\u001b[0m \u001b[0mquant_injector\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mquant_injector\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mlet\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m**\u001b[0m\u001b[0mfilter_kwargs\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mkwargs_prefix\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mkwargs\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m--> 100\u001b[1;33m \u001b[0mquant\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mquant_injector\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mproxy_class\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mself\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mquant_injector\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 101\u001b[0m \u001b[1;32melse\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 102\u001b[0m \u001b[1;32mif\u001b[0m \u001b[1;32mnot\u001b[0m \u001b[0misinstance\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mquant\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mproxy_protocol\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n", - "\u001b[1;32mc:\\brevitas_fx\\src\\brevitas\\proxy\\runtime_quant.py\u001b[0m in \u001b[0;36m__init__\u001b[1;34m(self, quant_layer, quant_injector)\u001b[0m\n\u001b[0;32m 108\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 109\u001b[0m \u001b[1;32mdef\u001b[0m \u001b[0m__init__\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mself\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mquant_layer\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mquant_injector\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m--> 110\u001b[1;33m \u001b[0msuper\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mActQuantProxyFromInjector\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0m__init__\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mquant_layer\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mquant_injector\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 111\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mis_passthrough_act\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0m_is_passthrough_act\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mquant_injector\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 112\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n", - "\u001b[1;32mc:\\brevitas_fx\\src\\brevitas\\proxy\\quant_proxy.py\u001b[0m in \u001b[0;36m__init__\u001b[1;34m(self, quant_layer, quant_injector, export_mode, export_handler)\u001b[0m\n\u001b[0;32m 74\u001b[0m \u001b[1;31m# Use a normal list and not a ModuleList since this is a pointer to parent modules\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 75\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mtracked_module_list\u001b[0m \u001b[1;33m=\u001b[0m \u001b[1;33m[\u001b[0m\u001b[1;33m]\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m---> 76\u001b[1;33m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0madd_tracked_module\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mquant_layer\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 77\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mexport_handler\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mexport_handler\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 78\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mexport_mode\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mexport_mode\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n", - "\u001b[1;32mc:\\brevitas_fx\\src\\brevitas\\proxy\\quant_proxy.py\u001b[0m in \u001b[0;36madd_tracked_module\u001b[1;34m(self, module)\u001b[0m\n\u001b[0;32m 130\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mtracked_module_list\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mappend\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mmodule\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 131\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mupdate_tracked_modules\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m--> 132\u001b[1;33m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0minit_tensor_quant\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 133\u001b[0m \u001b[1;32melse\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 134\u001b[0m \u001b[1;32mraise\u001b[0m \u001b[0mRuntimeError\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;34m\"Trying to add None as a parent module.\"\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n", - "\u001b[1;32mc:\\brevitas_fx\\src\\brevitas\\proxy\\runtime_quant.py\u001b[0m in \u001b[0;36minit_tensor_quant\u001b[1;34m(self)\u001b[0m\n\u001b[0;32m 120\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 121\u001b[0m \u001b[1;32mdef\u001b[0m \u001b[0minit_tensor_quant\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mself\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m--> 122\u001b[1;33m \u001b[0mtensor_quant\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mquant_injector\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mtensor_quant\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 123\u001b[0m \u001b[0mact_impl\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mquant_injector\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mact_impl\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 124\u001b[0m \u001b[0mis_act_enabled\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0m_is_act_enabled\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mact_impl\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mtensor_quant\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n", - " \u001b[1;31m[... skipping hidden 1 frame]\u001b[0m\n", - "\u001b[1;32mC:\\ProgramData\\Miniconda3\\envs\\pytorch\\lib\\site-packages\\_dependencies\\this.py\u001b[0m in \u001b[0;36m__call__\u001b[1;34m(self, __self__)\u001b[0m\n\u001b[0;32m 49\u001b[0m \u001b[1;32mif\u001b[0m \u001b[0mkind\u001b[0m \u001b[1;33m==\u001b[0m \u001b[1;34m\".\"\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 50\u001b[0m \u001b[1;32mtry\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m---> 51\u001b[1;33m \u001b[0mresult\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mgetattr\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mresult\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0msymbol\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 52\u001b[0m \u001b[1;32mexcept\u001b[0m \u001b[0mDependencyError\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 53\u001b[0m message = (\n", - " \u001b[1;31m[... skipping hidden 1 frame]\u001b[0m\n", - "\u001b[1;31mDependencyError\u001b[0m: 'AdvancedActQuantizer' can not resolve attribute 'per_channel_broadcastable_shape'" + "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[0;31mDependencyError\u001b[0m Traceback (most recent call last)", + "\u001b[1;32m/home/giuseppe/Documents/git/brevitas/notebooks/03_anatomy_of_a_quantizer.ipynb Cell 75\u001b[0m line \u001b[0;36m3\n\u001b[1;32m 1\u001b[0m \u001b[39mfrom\u001b[39;00m \u001b[39mbrevitas\u001b[39;00m\u001b[39m.\u001b[39;00m\u001b[39mnn\u001b[39;00m \u001b[39mimport\u001b[39;00m QuantIdentity\n\u001b[0;32m----> 3\u001b[0m quant_identity \u001b[39m=\u001b[39m QuantIdentity(\n\u001b[1;32m 4\u001b[0m act_quant\u001b[39m=\u001b[39;49mAdvancedActQuantizer, is_clamped\u001b[39m=\u001b[39;49m\u001b[39mTrue\u001b[39;49;00m, scaling_per_output_channel\u001b[39m=\u001b[39;49m\u001b[39mTrue\u001b[39;49;00m)\n", + "File \u001b[0;32m~/Documents/git/brevitas/src/brevitas/nn/quant_activation.py:113\u001b[0m, in \u001b[0;36mQuantIdentity.__init__\u001b[0;34m(self, act_quant, return_quant_tensor, **kwargs)\u001b[0m\n\u001b[1;32m 108\u001b[0m \u001b[39mdef\u001b[39;00m \u001b[39m__init__\u001b[39m(\n\u001b[1;32m 109\u001b[0m \u001b[39mself\u001b[39m,\n\u001b[1;32m 110\u001b[0m act_quant: Optional[ActQuantType] \u001b[39m=\u001b[39m Int8ActPerTensorFloat,\n\u001b[1;32m 111\u001b[0m return_quant_tensor: \u001b[39mbool\u001b[39m \u001b[39m=\u001b[39m \u001b[39mFalse\u001b[39;00m,\n\u001b[1;32m 112\u001b[0m \u001b[39m*\u001b[39m\u001b[39m*\u001b[39mkwargs):\n\u001b[0;32m--> 113\u001b[0m QuantNLAL\u001b[39m.\u001b[39;49m\u001b[39m__init__\u001b[39;49m(\n\u001b[1;32m 114\u001b[0m \u001b[39mself\u001b[39;49m,\n\u001b[1;32m 115\u001b[0m input_quant\u001b[39m=\u001b[39;49m\u001b[39mNone\u001b[39;49;00m,\n\u001b[1;32m 116\u001b[0m act_impl\u001b[39m=\u001b[39;49m\u001b[39mNone\u001b[39;49;00m,\n\u001b[1;32m 117\u001b[0m passthrough_act\u001b[39m=\u001b[39;49m\u001b[39mTrue\u001b[39;49;00m,\n\u001b[1;32m 118\u001b[0m act_quant\u001b[39m=\u001b[39;49mact_quant,\n\u001b[1;32m 119\u001b[0m return_quant_tensor\u001b[39m=\u001b[39;49mreturn_quant_tensor,\n\u001b[1;32m 120\u001b[0m \u001b[39m*\u001b[39;49m\u001b[39m*\u001b[39;49mkwargs)\n", + "File \u001b[0;32m~/Documents/git/brevitas/src/brevitas/nn/quant_layer.py:37\u001b[0m, in \u001b[0;36mQuantNonLinearActLayer.__init__\u001b[0;34m(self, act_impl, passthrough_act, input_quant, act_quant, return_quant_tensor, **kwargs)\u001b[0m\n\u001b[1;32m 35\u001b[0m QuantLayerMixin\u001b[39m.\u001b[39m\u001b[39m__init__\u001b[39m(\u001b[39mself\u001b[39m, return_quant_tensor)\n\u001b[1;32m 36\u001b[0m QuantInputMixin\u001b[39m.\u001b[39m\u001b[39m__init__\u001b[39m(\u001b[39mself\u001b[39m, input_quant, \u001b[39m*\u001b[39m\u001b[39m*\u001b[39mkwargs)\n\u001b[0;32m---> 37\u001b[0m QuantNonLinearActMixin\u001b[39m.\u001b[39;49m\u001b[39m__init__\u001b[39;49m(\u001b[39mself\u001b[39;49m, act_impl, passthrough_act, act_quant, \u001b[39m*\u001b[39;49m\u001b[39m*\u001b[39;49mkwargs)\n", + "File \u001b[0;32m~/Documents/git/brevitas/src/brevitas/nn/mixin/act.py:118\u001b[0m, in \u001b[0;36mQuantNonLinearActMixin.__init__\u001b[0;34m(self, act_impl, passthrough_act, act_quant, act_proxy_prefix, act_kwargs_prefix, **kwargs)\u001b[0m\n\u001b[1;32m 107\u001b[0m \u001b[39mdef\u001b[39;00m \u001b[39m__init__\u001b[39m(\n\u001b[1;32m 108\u001b[0m \u001b[39mself\u001b[39m,\n\u001b[1;32m 109\u001b[0m act_impl: Optional[Type[Module]],\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 113\u001b[0m act_kwargs_prefix\u001b[39m=\u001b[39m\u001b[39m'\u001b[39m\u001b[39m'\u001b[39m,\n\u001b[1;32m 114\u001b[0m \u001b[39m*\u001b[39m\u001b[39m*\u001b[39mkwargs):\n\u001b[1;32m 115\u001b[0m prefixed_kwargs \u001b[39m=\u001b[39m {\n\u001b[1;32m 116\u001b[0m act_kwargs_prefix \u001b[39m+\u001b[39m \u001b[39m'\u001b[39m\u001b[39mact_impl\u001b[39m\u001b[39m'\u001b[39m: act_impl,\n\u001b[1;32m 117\u001b[0m act_kwargs_prefix \u001b[39m+\u001b[39m \u001b[39m'\u001b[39m\u001b[39mpassthrough_act\u001b[39m\u001b[39m'\u001b[39m: passthrough_act}\n\u001b[0;32m--> 118\u001b[0m QuantProxyMixin\u001b[39m.\u001b[39;49m\u001b[39m__init__\u001b[39;49m(\n\u001b[1;32m 119\u001b[0m \u001b[39mself\u001b[39;49m,\n\u001b[1;32m 120\u001b[0m quant\u001b[39m=\u001b[39;49mact_quant,\n\u001b[1;32m 121\u001b[0m proxy_prefix\u001b[39m=\u001b[39;49mact_proxy_prefix,\n\u001b[1;32m 122\u001b[0m kwargs_prefix\u001b[39m=\u001b[39;49mact_kwargs_prefix,\n\u001b[1;32m 123\u001b[0m proxy_protocol\u001b[39m=\u001b[39;49mActQuantProxyProtocol,\n\u001b[1;32m 124\u001b[0m none_quant_injector\u001b[39m=\u001b[39;49mNoneActQuant,\n\u001b[1;32m 125\u001b[0m \u001b[39m*\u001b[39;49m\u001b[39m*\u001b[39;49mprefixed_kwargs,\n\u001b[1;32m 126\u001b[0m \u001b[39m*\u001b[39;49m\u001b[39m*\u001b[39;49mkwargs)\n", + "File \u001b[0;32m~/Documents/git/brevitas/src/brevitas/nn/mixin/base.py:70\u001b[0m, in \u001b[0;36mQuantProxyMixin.__init__\u001b[0;34m(self, quant, proxy_protocol, none_quant_injector, proxy_prefix, kwargs_prefix, **kwargs)\u001b[0m\n\u001b[1;32m 68\u001b[0m quant_injector \u001b[39m=\u001b[39m quant\n\u001b[1;32m 69\u001b[0m quant_injector \u001b[39m=\u001b[39m quant_injector\u001b[39m.\u001b[39mlet(\u001b[39m*\u001b[39m\u001b[39m*\u001b[39mfilter_kwargs(kwargs_prefix, kwargs))\n\u001b[0;32m---> 70\u001b[0m quant \u001b[39m=\u001b[39m quant_injector\u001b[39m.\u001b[39;49mproxy_class(\u001b[39mself\u001b[39;49m, quant_injector)\n\u001b[1;32m 71\u001b[0m \u001b[39melse\u001b[39;00m:\n\u001b[1;32m 72\u001b[0m \u001b[39mif\u001b[39;00m \u001b[39mnot\u001b[39;00m \u001b[39misinstance\u001b[39m(quant, proxy_protocol):\n", + "File \u001b[0;32m~/Documents/git/brevitas/src/brevitas/proxy/runtime_quant.py:89\u001b[0m, in \u001b[0;36mActQuantProxyFromInjector.__init__\u001b[0;34m(self, quant_layer, quant_injector)\u001b[0m\n\u001b[1;32m 88\u001b[0m \u001b[39mdef\u001b[39;00m \u001b[39m__init__\u001b[39m(\u001b[39mself\u001b[39m, quant_layer, quant_injector):\n\u001b[0;32m---> 89\u001b[0m QuantProxyFromInjector\u001b[39m.\u001b[39;49m\u001b[39m__init__\u001b[39;49m(\u001b[39mself\u001b[39;49m, quant_layer, quant_injector)\n\u001b[1;32m 90\u001b[0m ActQuantProxyProtocol\u001b[39m.\u001b[39m\u001b[39m__init__\u001b[39m(\u001b[39mself\u001b[39m)\n\u001b[1;32m 91\u001b[0m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mis_passthrough_act \u001b[39m=\u001b[39m _is_passthrough_act(quant_injector)\n", + "File \u001b[0;32m~/Documents/git/brevitas/src/brevitas/proxy/quant_proxy.py:82\u001b[0m, in \u001b[0;36mQuantProxyFromInjector.__init__\u001b[0;34m(self, quant_layer, quant_injector)\u001b[0m\n\u001b[1;32m 80\u001b[0m \u001b[39m# Use a normal list and not a ModuleList since this is a pointer to parent modules\u001b[39;00m\n\u001b[1;32m 81\u001b[0m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mtracked_module_list \u001b[39m=\u001b[39m []\n\u001b[0;32m---> 82\u001b[0m \u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49madd_tracked_module(quant_layer)\n\u001b[1;32m 83\u001b[0m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mdisable_quant \u001b[39m=\u001b[39m \u001b[39mFalse\u001b[39;00m\n", + "File \u001b[0;32m~/Documents/git/brevitas/src/brevitas/proxy/quant_proxy.py:120\u001b[0m, in \u001b[0;36mQuantProxyFromInjector.add_tracked_module\u001b[0;34m(self, module)\u001b[0m\n\u001b[1;32m 118\u001b[0m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mtracked_module_list\u001b[39m.\u001b[39mappend(module)\n\u001b[1;32m 119\u001b[0m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mupdate_tracked_modules()\n\u001b[0;32m--> 120\u001b[0m \u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49minit_tensor_quant()\n\u001b[1;32m 121\u001b[0m \u001b[39melse\u001b[39;00m:\n\u001b[1;32m 122\u001b[0m \u001b[39mraise\u001b[39;00m \u001b[39mRuntimeError\u001b[39;00m(\u001b[39m\"\u001b[39m\u001b[39mTrying to add None as a parent module.\u001b[39m\u001b[39m\"\u001b[39m)\n", + "File \u001b[0;32m~/Documents/git/brevitas/src/brevitas/proxy/runtime_quant.py:102\u001b[0m, in \u001b[0;36mActQuantProxyFromInjector.init_tensor_quant\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 101\u001b[0m \u001b[39mdef\u001b[39;00m \u001b[39minit_tensor_quant\u001b[39m(\u001b[39mself\u001b[39m):\n\u001b[0;32m--> 102\u001b[0m tensor_quant \u001b[39m=\u001b[39m \u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49mquant_injector\u001b[39m.\u001b[39;49mtensor_quant\n\u001b[1;32m 103\u001b[0m \u001b[39mif\u001b[39;00m \u001b[39m'\u001b[39m\u001b[39mact_impl\u001b[39m\u001b[39m'\u001b[39m \u001b[39min\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mquant_injector:\n\u001b[1;32m 104\u001b[0m act_impl \u001b[39m=\u001b[39m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mquant_injector\u001b[39m.\u001b[39mact_impl\n", + " \u001b[0;31m[... skipping hidden 1 frame]\u001b[0m\n", + "File \u001b[0;32m~/miniconda3/envs/torch_2.1/lib/python3.11/site-packages/_dependencies/this.py:51\u001b[0m, in \u001b[0;36m_ThisSpec.__call__\u001b[0;34m(self, __self__)\u001b[0m\n\u001b[1;32m 49\u001b[0m \u001b[39mif\u001b[39;00m kind \u001b[39m==\u001b[39m \u001b[39m\"\u001b[39m\u001b[39m.\u001b[39m\u001b[39m\"\u001b[39m:\n\u001b[1;32m 50\u001b[0m \u001b[39mtry\u001b[39;00m:\n\u001b[0;32m---> 51\u001b[0m result \u001b[39m=\u001b[39m \u001b[39mgetattr\u001b[39;49m(result, symbol)\n\u001b[1;32m 52\u001b[0m \u001b[39mexcept\u001b[39;00m DependencyError:\n\u001b[1;32m 53\u001b[0m message \u001b[39m=\u001b[39m (\n\u001b[1;32m 54\u001b[0m \u001b[39m\"\u001b[39m\u001b[39mYou tried to shift this more times than Injector has levels\u001b[39m\u001b[39m\"\u001b[39m\n\u001b[1;32m 55\u001b[0m )\n", + "File \u001b[0;32m~/Documents/git/brevitas/src/brevitas/inject/__init__.py:129\u001b[0m, in \u001b[0;36m_ExtendedInjectorType.__getattr__\u001b[0;34m(cls, attrname)\u001b[0m\n\u001b[1;32m 126\u001b[0m \u001b[39melse\u001b[39;00m:\n\u001b[1;32m 127\u001b[0m message \u001b[39m=\u001b[39m \u001b[39m\"\u001b[39m\u001b[39m{!r}\u001b[39;00m\u001b[39m can not resolve attribute \u001b[39m\u001b[39m{!r}\u001b[39;00m\u001b[39m\"\u001b[39m\u001b[39m.\u001b[39mformat(\n\u001b[1;32m 128\u001b[0m \u001b[39mcls\u001b[39m\u001b[39m.\u001b[39m\u001b[39m__name__\u001b[39m, current_attr)\n\u001b[0;32m--> 129\u001b[0m \u001b[39mraise\u001b[39;00m DependencyError(message)\n\u001b[1;32m 131\u001b[0m marker, attribute, args, have_defaults \u001b[39m=\u001b[39m spec\n\u001b[1;32m 133\u001b[0m \u001b[39mif\u001b[39;00m \u001b[39mset\u001b[39m(args)\u001b[39m.\u001b[39missubset(cached):\n", + "\u001b[0;31mDependencyError\u001b[0m: 'AdvancedActQuantizer' can not resolve attribute 'per_channel_broadcastable_shape'" ] } ], @@ -1455,22 +1389,22 @@ }, { "cell_type": "code", - "execution_count": 39, + "execution_count": 37, "metadata": {}, "outputs": [ { "data": { "text/plain": [ - "QuantTensor(value=tensor([[-0.0100, 0.0100, -0.0100, -0.0100],\n", - " [-0.0100, 0.0100, -0.0100, -0.0100],\n", + "QuantTensor(value=tensor([[-0.0100, 0.0100, 0.0100, -0.0100],\n", + " [ 0.0100, -0.0100, -0.0100, -0.0100],\n", " [ 0.0100, -0.0100, 0.0100, -0.0100],\n", - " [ 0.0100, -0.0100, -0.0100, -0.0100]], grad_fn=), scale=tensor([[0.0100],\n", + " [ 0.0100, -0.0100, -0.0100, 0.0100]], grad_fn=), scale=tensor([[0.0100],\n", " [0.0100],\n", " [0.0100],\n", " [0.0100]], grad_fn=), zero_point=tensor(0.), bit_width=tensor(1.), signed_t=tensor(True), training_t=tensor(True))" ] }, - "execution_count": 39, + "execution_count": 37, "metadata": {}, "output_type": "execute_result" } @@ -1488,6 +1422,11 @@ "source": [ "We have seen how powerful dependency injection is. In a way, it's even too expressive. For users that are not interesting in building completely custom quantizers, it can be hard to make sense of how the various components available under `brevitas.core` can be assembled together according to best practices." ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [] } ], "metadata": { @@ -1506,7 +1445,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.7.7" + "version": "3.11.5" } }, "nbformat": 4, diff --git a/notebooks/Brevitas_TVMCon2021.ipynb b/notebooks/Brevitas_TVMCon2021.ipynb index efd9421f0..e39b7301d 100644 --- a/notebooks/Brevitas_TVMCon2021.ipynb +++ b/notebooks/Brevitas_TVMCon2021.ipynb @@ -45,8 +45,10 @@ " input_quant: Optional[ActQuantType] = None,\n", " output_quant: Optional[ActQuantType] = None,\n", " return_quant_tensor: bool = False,\n", + " device: Optional[torch.device] = None,\n", + " dtype: Optional[torch.dtype] = None,\n", " **kwargs) -> None:\n", - " Linear.__init__(self, in_features, out_features, bias)\n", + " Linear.__init__(self, in_features, out_features, bias, device=device, dtype=dtype)\n", " QuantWBIOL.__init__(\n", " self,\n", " weight_quant=weight_quant,\n", @@ -118,7 +120,7 @@ " QuantTensor(value=tensor([[-0.0046, 0.3803],\n", " [-0.5820, -0.5224],\n", " [-0.2704, 0.1879],\n", - " [-0.0137, 0.5591]], grad_fn=), scale=tensor(0.0046, grad_fn=), zero_point=tensor(0.), bit_width=tensor(8.), signed_t=tensor(True), training_t=tensor(True)) \n", + " [-0.0137, 0.5591]], grad_fn=), scale=0.004582525696605444, zero_point=0.0, bit_width=8.0, signed_t=True, training_t=True) \n", "\n" ] } @@ -161,7 +163,7 @@ " tensor([[ -1, 83],\n", " [-127, -114],\n", " [ -59, 41],\n", - " [ -3, 122]], dtype=torch.int32)\n" + " [ -3, 122]], dtype=torch.int8)\n" ] } ], @@ -194,7 +196,15 @@ "Float output:\n", " tensor([[-0.9036, -0.4586, 0.3096, -0.6472],\n", " [ 1.2058, 0.6525, -0.3723, 0.8677],\n", - " [ 1.3873, 0.2801, -0.9009, 0.9507]], grad_fn=)\n" + " [ 1.3873, 0.2801, -0.9009, 0.9507]], grad_fn=)\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/home/giuseppe/miniconda3/envs/torch_2.1/lib/python3.11/site-packages/torch/_tensor.py:1362: UserWarning: Named tensors and all their associated APIs are an experimental feature and subject to change. Please do not use them for anything important until they are released as stable. (Triggered internally at /opt/conda/conda-bld/pytorch_1695392026823/work/c10/core/TensorImpl.h:1900.)\n", + " return super().rename(names)\n" ] } ], @@ -238,7 +248,7 @@ " QuantTensor(value=tensor([[-0.0078, 0.3828],\n", " [-0.5781, -0.5234],\n", " [-0.2734, 0.1875],\n", - " [-0.0156, 0.5625]], grad_fn=), scale=tensor(0.0078, grad_fn=), zero_point=tensor(0.), bit_width=tensor(8.), signed_t=tensor(True), training_t=tensor(True))\n", + " [-0.0156, 0.5625]], grad_fn=), scale=0.0078125, zero_point=0.0, bit_width=8.0, signed_t=True, training_t=True)\n", "Weight fix point: 7.0\n" ] } @@ -277,7 +287,7 @@ " QuantTensor(value=tensor([[-0.1000, 0.1000],\n", " [-0.1000, -0.1000],\n", " [-0.1000, 0.1000],\n", - " [-0.1000, 0.1000]], grad_fn=), scale=tensor(0.1000), zero_point=tensor(0.), bit_width=tensor(1.), signed_t=tensor(True), training_t=tensor(True))\n" + " [-0.1000, 0.1000]], grad_fn=), scale=0.10000000149011612, zero_point=0.0, bit_width=1.0, signed_t=True, training_t=True)\n" ] } ], @@ -372,7 +382,7 @@ "Quant output:\n", " tensor([[-0.9109, -0.4609, 0.3135, -0.6523],\n", " [ 1.2089, 0.6524, -0.3752, 0.8697],\n", - " [ 1.3893, 0.2816, -0.9011, 0.9521]], grad_fn=)\n" + " [ 1.3893, 0.2816, -0.9011, 0.9521]], grad_fn=)\n" ] } ], @@ -409,7 +419,7 @@ "Quant output:\n", " QuantTensor(value=tensor([[-0.9109, -0.4609, 0.3135, -0.6523],\n", " [ 1.2089, 0.6524, -0.3752, 0.8697],\n", - " [ 1.3893, 0.2816, -0.9011, 0.9521]], grad_fn=), scale=tensor([[9.0542e-05]], grad_fn=), zero_point=tensor(0.), bit_width=tensor(17.), signed_t=tensor(True), training_t=tensor(True))\n" + " [ 1.3893, 0.2816, -0.9011, 0.9521]], grad_fn=), scale=tensor([[9.0542e-05]], grad_fn=), zero_point=0.0, bit_width=17.0, signed_t=True, training_t=True)\n" ] } ], @@ -452,12 +462,12 @@ "Quant input:\n", " QuantTensor(value=tensor([[ 1.5490, -0.2894],\n", " [-2.1788, 0.5617],\n", - " [-1.0894, -1.3958]], grad_fn=), scale=tensor(0.0170, grad_fn=), zero_point=tensor(0.), bit_width=tensor(8.), signed_t=tensor(True), training_t=tensor(True)) \n", + " [-1.0894, -1.3958]], grad_fn=), scale=0.017021792009472847, zero_point=0.0, bit_width=8.0, signed_t=True, training_t=True) \n", "\n", "Quant output:\n", " QuantTensor(value=tensor([[-0.9109, -0.4609, 0.3135, -0.6523],\n", " [ 1.2089, 0.6524, -0.3752, 0.8697],\n", - " [ 1.3893, 0.2816, -0.9011, 0.9521]], grad_fn=), scale=tensor([[9.0542e-05]], grad_fn=), zero_point=tensor(0.), bit_width=tensor(17.), signed_t=tensor(True), training_t=tensor(True))\n" + " [ 1.3893, 0.2816, -0.9011, 0.9521]], grad_fn=), scale=tensor([[9.0542e-05]], grad_fn=), zero_point=0.0, bit_width=17.0, signed_t=True, training_t=True)\n" ] } ], @@ -511,7 +521,7 @@ "Quant output:\n", " QuantTensor(value=tensor([[1.5410, 0.0000],\n", " [0.0000, 0.5681],\n", - " [0.0000, 0.0000]], grad_fn=), scale=tensor(0.0060, grad_fn=), zero_point=tensor(0.), bit_width=tensor(8.), signed_t=tensor(False), training_t=tensor(True))\n" + " [0.0000, 0.0000]], grad_fn=), scale=0.006043121684342623, zero_point=0.0, bit_width=8.0, signed_t=False, training_t=True)\n" ] } ], @@ -555,11 +565,11 @@ "Quant output after QuantIdentity:\n", " QuantTensor(value=tensor([[ 1.5490, -0.2894],\n", " [-2.1788, 0.5617],\n", - " [-1.0894, -1.3958]], grad_fn=), scale=tensor(0.0170, grad_fn=), zero_point=tensor(0.), bit_width=tensor(8.), signed_t=tensor(True), training_t=tensor(True))\n", + " [-1.0894, -1.3958]], grad_fn=), scale=0.017021792009472847, zero_point=0.0, bit_width=8.0, signed_t=True, training_t=True)\n", "Quant output after QuantReLU:\n", " QuantTensor(value=tensor([[1.5490, 0.0000],\n", " [0.0000, 0.5588],\n", - " [0.0000, 0.0000]], grad_fn=), scale=tensor(0.0061, grad_fn=), zero_point=tensor(0.), bit_width=tensor(8.), signed_t=tensor(False), training_t=tensor(True))\n" + " [0.0000, 0.0000]], grad_fn=), scale=0.006074443459510803, zero_point=0.0, bit_width=8.0, signed_t=False, training_t=True)\n" ] } ], @@ -611,18 +621,17 @@ "outputs": [ { "ename": "RuntimeError", - "evalue": "Input scale required", + "evalue": "QuantLayer is not correctly configured", "output_type": "error", "traceback": [ - "\u001b[1;31m---------------------------------------------------------------------------\u001b[0m", - "\u001b[1;31mRuntimeError\u001b[0m Traceback (most recent call last)", - "\u001b[1;32mC:\\Users\\ALESSA~1\\AppData\\Local\\Temp/ipykernel_18920/2660651517.py\u001b[0m in \u001b[0;36m\u001b[1;34m\u001b[0m\n\u001b[0;32m 6\u001b[0m \u001b[0mquant_linear\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mQuantLinear\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;36m2\u001b[0m\u001b[1;33m,\u001b[0m \u001b[1;36m4\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mbias\u001b[0m\u001b[1;33m=\u001b[0m\u001b[1;32mTrue\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mbias_quant\u001b[0m\u001b[1;33m=\u001b[0m\u001b[0mInt16Bias\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mreturn_quant_tensor\u001b[0m\u001b[1;33m=\u001b[0m\u001b[1;32mTrue\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 7\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m----> 8\u001b[1;33m \u001b[0mquant_output\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mquant_linear\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mfloat_input\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m", - "\u001b[1;32m~\\miniconda3\\envs\\pt190\\lib\\site-packages\\torch\\nn\\modules\\module.py\u001b[0m in \u001b[0;36m_call_impl\u001b[1;34m(self, *input, **kwargs)\u001b[0m\n\u001b[0;32m 1049\u001b[0m if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks\n\u001b[0;32m 1050\u001b[0m or _global_forward_hooks or _global_forward_pre_hooks):\n\u001b[1;32m-> 1051\u001b[1;33m \u001b[1;32mreturn\u001b[0m \u001b[0mforward_call\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m*\u001b[0m\u001b[0minput\u001b[0m\u001b[1;33m,\u001b[0m \u001b[1;33m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 1052\u001b[0m \u001b[1;31m# Do not call functions when jit is used\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 1053\u001b[0m \u001b[0mfull_backward_hooks\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mnon_full_backward_hooks\u001b[0m \u001b[1;33m=\u001b[0m \u001b[1;33m[\u001b[0m\u001b[1;33m]\u001b[0m\u001b[1;33m,\u001b[0m \u001b[1;33m[\u001b[0m\u001b[1;33m]\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n", - "\u001b[1;32mc:\\users\\alessandro\\documenti\\brevitas_tvmcon\\src\\brevitas\\nn\\quant_linear.py\u001b[0m in \u001b[0;36mforward\u001b[1;34m(self, input)\u001b[0m\n\u001b[0;32m 96\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 97\u001b[0m \u001b[1;32mdef\u001b[0m \u001b[0mforward\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mself\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0minput\u001b[0m\u001b[1;33m:\u001b[0m \u001b[0mUnion\u001b[0m\u001b[1;33m[\u001b[0m\u001b[0mTensor\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mQuantTensor\u001b[0m\u001b[1;33m]\u001b[0m\u001b[1;33m)\u001b[0m \u001b[1;33m->\u001b[0m \u001b[0mUnion\u001b[0m\u001b[1;33m[\u001b[0m\u001b[0mTensor\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mQuantTensor\u001b[0m\u001b[1;33m]\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m---> 98\u001b[1;33m \u001b[1;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mforward_impl\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0minput\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 99\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 100\u001b[0m \u001b[1;32mdef\u001b[0m \u001b[0minner_forward_impl\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mself\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mx\u001b[0m\u001b[1;33m:\u001b[0m \u001b[0mTensor\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mquant_weight\u001b[0m\u001b[1;33m:\u001b[0m \u001b[0mTensor\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mquant_bias\u001b[0m\u001b[1;33m:\u001b[0m \u001b[0mOptional\u001b[0m\u001b[1;33m[\u001b[0m\u001b[0mTensor\u001b[0m\u001b[1;33m]\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n", - "\u001b[1;32mc:\\users\\alessandro\\documenti\\brevitas_tvmcon\\src\\brevitas\\nn\\quant_layer.py\u001b[0m in \u001b[0;36mforward_impl\u001b[1;34m(self, inp)\u001b[0m\n\u001b[0;32m 355\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 356\u001b[0m \u001b[1;32mif\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mbias\u001b[0m \u001b[1;32mis\u001b[0m \u001b[1;32mnot\u001b[0m \u001b[1;32mNone\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m--> 357\u001b[1;33m \u001b[0mquant_bias\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mbias_quant\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mbias\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0moutput_scale\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0moutput_bit_width\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 358\u001b[0m \u001b[1;32mif\u001b[0m \u001b[1;32mnot\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mtraining\u001b[0m \u001b[1;32mand\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mcache_inference_quant_bias\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 359\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0m_cached_bias\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0m_CachedIO\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mquant_bias\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mdetach\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mmetadata_only\u001b[0m\u001b[1;33m=\u001b[0m\u001b[1;32mFalse\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n", - "\u001b[1;32m~\\miniconda3\\envs\\pt190\\lib\\site-packages\\torch\\nn\\modules\\module.py\u001b[0m in \u001b[0;36m_call_impl\u001b[1;34m(self, *input, **kwargs)\u001b[0m\n\u001b[0;32m 1049\u001b[0m if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks\n\u001b[0;32m 1050\u001b[0m or _global_forward_hooks or _global_forward_pre_hooks):\n\u001b[1;32m-> 1051\u001b[1;33m \u001b[1;32mreturn\u001b[0m \u001b[0mforward_call\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m*\u001b[0m\u001b[0minput\u001b[0m\u001b[1;33m,\u001b[0m \u001b[1;33m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 1052\u001b[0m \u001b[1;31m# Do not call functions when jit is used\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 1053\u001b[0m \u001b[0mfull_backward_hooks\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mnon_full_backward_hooks\u001b[0m \u001b[1;33m=\u001b[0m \u001b[1;33m[\u001b[0m\u001b[1;33m]\u001b[0m\u001b[1;33m,\u001b[0m \u001b[1;33m[\u001b[0m\u001b[1;33m]\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n", - "\u001b[1;32mc:\\users\\alessandro\\documenti\\brevitas_tvmcon\\src\\brevitas\\proxy\\parameter_quant.py\u001b[0m in \u001b[0;36mforward\u001b[1;34m(self, x, input_scale, input_bit_width)\u001b[0m\n\u001b[0;32m 194\u001b[0m \u001b[0mimpl\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mexport_handler\u001b[0m \u001b[1;32mif\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mexport_mode\u001b[0m \u001b[1;32melse\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mtensor_quant\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 195\u001b[0m \u001b[1;32mif\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mrequires_input_scale\u001b[0m \u001b[1;32mand\u001b[0m \u001b[0minput_scale\u001b[0m \u001b[1;32mis\u001b[0m \u001b[1;32mNone\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m--> 196\u001b[1;33m \u001b[1;32mraise\u001b[0m \u001b[0mRuntimeError\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;34m\"Input scale required\"\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 197\u001b[0m \u001b[1;32mif\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mrequires_input_bit_width\u001b[0m \u001b[1;32mand\u001b[0m \u001b[0minput_bit_width\u001b[0m \u001b[1;32mis\u001b[0m \u001b[1;32mNone\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 198\u001b[0m \u001b[1;32mraise\u001b[0m \u001b[0mRuntimeError\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;34m\"Input bit-width required\"\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n", - "\u001b[1;31mRuntimeError\u001b[0m: Input scale required" + "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[0;31mRuntimeError\u001b[0m Traceback (most recent call last)", + "\u001b[1;32m/home/giuseppe/Documents/git/brevitas/notebooks/Brevitas_TVMCon2021.ipynb Cell 35\u001b[0m line \u001b[0;36m8\n\u001b[1;32m 5\u001b[0m float_input \u001b[39m=\u001b[39m torch\u001b[39m.\u001b[39mrandn(\u001b[39m3\u001b[39m, \u001b[39m2\u001b[39m)\n\u001b[1;32m 6\u001b[0m quant_linear \u001b[39m=\u001b[39m QuantLinear(\u001b[39m2\u001b[39m, \u001b[39m4\u001b[39m, bias\u001b[39m=\u001b[39m\u001b[39mTrue\u001b[39;00m, bias_quant\u001b[39m=\u001b[39mInt16Bias, return_quant_tensor\u001b[39m=\u001b[39m\u001b[39mTrue\u001b[39;00m)\n\u001b[0;32m----> 8\u001b[0m quant_output \u001b[39m=\u001b[39m quant_linear(float_input)\n", + "File \u001b[0;32m~/miniconda3/envs/torch_2.1/lib/python3.11/site-packages/torch/nn/modules/module.py:1518\u001b[0m, in \u001b[0;36mModule._wrapped_call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1516\u001b[0m \u001b[39mreturn\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_compiled_call_impl(\u001b[39m*\u001b[39margs, \u001b[39m*\u001b[39m\u001b[39m*\u001b[39mkwargs) \u001b[39m# type: ignore[misc]\u001b[39;00m\n\u001b[1;32m 1517\u001b[0m \u001b[39melse\u001b[39;00m:\n\u001b[0;32m-> 1518\u001b[0m \u001b[39mreturn\u001b[39;00m \u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49m_call_impl(\u001b[39m*\u001b[39;49margs, \u001b[39m*\u001b[39;49m\u001b[39m*\u001b[39;49mkwargs)\n", + "File \u001b[0;32m~/miniconda3/envs/torch_2.1/lib/python3.11/site-packages/torch/nn/modules/module.py:1527\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1522\u001b[0m \u001b[39m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[1;32m 1523\u001b[0m \u001b[39m# this function, and just call forward.\u001b[39;00m\n\u001b[1;32m 1524\u001b[0m \u001b[39mif\u001b[39;00m \u001b[39mnot\u001b[39;00m (\u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_backward_hooks \u001b[39mor\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_backward_pre_hooks \u001b[39mor\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_forward_hooks \u001b[39mor\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_forward_pre_hooks\n\u001b[1;32m 1525\u001b[0m \u001b[39mor\u001b[39;00m _global_backward_pre_hooks \u001b[39mor\u001b[39;00m _global_backward_hooks\n\u001b[1;32m 1526\u001b[0m \u001b[39mor\u001b[39;00m _global_forward_hooks \u001b[39mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[0;32m-> 1527\u001b[0m \u001b[39mreturn\u001b[39;00m forward_call(\u001b[39m*\u001b[39;49margs, \u001b[39m*\u001b[39;49m\u001b[39m*\u001b[39;49mkwargs)\n\u001b[1;32m 1529\u001b[0m \u001b[39mtry\u001b[39;00m:\n\u001b[1;32m 1530\u001b[0m result \u001b[39m=\u001b[39m \u001b[39mNone\u001b[39;00m\n", + "File \u001b[0;32m~/Documents/git/brevitas/src/brevitas/nn/quant_linear.py:66\u001b[0m, in \u001b[0;36mQuantLinear.forward\u001b[0;34m(self, input)\u001b[0m\n\u001b[1;32m 65\u001b[0m \u001b[39mdef\u001b[39;00m \u001b[39mforward\u001b[39m(\u001b[39mself\u001b[39m, \u001b[39minput\u001b[39m: Union[Tensor, QuantTensor]) \u001b[39m-\u001b[39m\u001b[39m>\u001b[39m Union[Tensor, QuantTensor]:\n\u001b[0;32m---> 66\u001b[0m \u001b[39mreturn\u001b[39;00m \u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49mforward_impl(\u001b[39minput\u001b[39;49m)\n", + "File \u001b[0;32m~/Documents/git/brevitas/src/brevitas/nn/quant_layer.py:329\u001b[0m, in \u001b[0;36mQuantWeightBiasInputOutputLayer.forward_impl\u001b[0;34m(self, inp)\u001b[0m\n\u001b[1;32m 325\u001b[0m compute_output_quant_tensor \u001b[39m=\u001b[39m \u001b[39misinstance\u001b[39m(quant_input, QuantTensor) \u001b[39mand\u001b[39;00m \u001b[39misinstance\u001b[39m(\n\u001b[1;32m 326\u001b[0m quant_weight, QuantTensor)\n\u001b[1;32m 327\u001b[0m \u001b[39mif\u001b[39;00m \u001b[39mnot\u001b[39;00m (compute_output_quant_tensor \u001b[39mor\u001b[39;00m\n\u001b[1;32m 328\u001b[0m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mis_output_quant_enabled) \u001b[39mand\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mreturn_quant_tensor:\n\u001b[0;32m--> 329\u001b[0m \u001b[39mraise\u001b[39;00m \u001b[39mRuntimeError\u001b[39;00m(\u001b[39m\"\u001b[39m\u001b[39mQuantLayer is not correctly configured\u001b[39m\u001b[39m\"\u001b[39m)\n\u001b[1;32m 331\u001b[0m \u001b[39mif\u001b[39;00m (\u001b[39mself\u001b[39m\u001b[39m.\u001b[39mreturn_quant_tensor \u001b[39mor\u001b[39;00m\n\u001b[1;32m 332\u001b[0m (\u001b[39mself\u001b[39m\u001b[39m.\u001b[39mis_bias_quant_enabled \u001b[39mand\u001b[39;00m\n\u001b[1;32m 333\u001b[0m (\u001b[39mself\u001b[39m\u001b[39m.\u001b[39mbias_quant\u001b[39m.\u001b[39mrequires_input_scale \u001b[39mor\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mbias_quant\u001b[39m.\u001b[39mrequires_input_bit_width))):\n\u001b[1;32m 334\u001b[0m \u001b[39mif\u001b[39;00m \u001b[39misinstance\u001b[39m(quant_input, QuantTensor) \u001b[39mand\u001b[39;00m \u001b[39misinstance\u001b[39m(quant_weight, QuantTensor):\n", + "\u001b[0;31mRuntimeError\u001b[0m: QuantLayer is not correctly configured" ] } ], @@ -654,7 +663,7 @@ "text/plain": [ "QuantTensor(value=tensor([[-0.6541, 0.1263, 0.1680, -0.1231],\n", " [ 1.4658, 1.2395, -0.5207, 1.3989],\n", - " [ 1.6461, 0.8687, -1.0466, 1.4813]], grad_fn=), scale=tensor([[9.0542e-05]], grad_fn=), zero_point=tensor(0.), bit_width=tensor(18.), signed_t=tensor(True), training_t=tensor(True))" + " [ 1.6461, 0.8687, -1.0466, 1.4813]], grad_fn=), scale=tensor([[9.0542e-05]], grad_fn=), zero_point=tensor(0.), bit_width=tensor(18.), signed_t=tensor(True), training_t=tensor(True))" ] }, "execution_count": 14, @@ -713,15 +722,15 @@ "Eval mode add quant inputs:\n", " QuantTensor(value=tensor([[ 1.5335, -0.2875],\n", " [-2.0447, 0.5751],\n", - " [-1.0863, -1.4057]]), scale=tensor(0.0160), zero_point=tensor(0.), bit_width=tensor(8.), signed_t=tensor(True), training_t=tensor(False)) \n", + " [-1.0863, -1.4057]]), scale=0.015974320471286774, zero_point=0.0, bit_width=8.0, signed_t=True, training_t=False) \n", " QuantTensor(value=tensor([[ 0.3994, 0.8307],\n", " [-0.7188, -0.3994],\n", - " [-0.5910, 0.1757]]), scale=tensor(0.0160), zero_point=tensor(0.), bit_width=tensor(8.), signed_t=tensor(True), training_t=tensor(False)) \n", + " [-0.5910, 0.1757]]), scale=0.015974320471286774, zero_point=0.0, bit_width=8.0, signed_t=True, training_t=False) \n", "\n", "Eval mode add quant output:\n", " QuantTensor(value=tensor([[ 1.9329, 0.5431],\n", " [-2.7636, 0.1757],\n", - " [-1.6773, -1.2300]]), scale=tensor(0.0160), zero_point=tensor(0.), bit_width=tensor(9.), signed_t=tensor(True), training_t=tensor(False))\n" + " [-1.6773, -1.2300]]), scale=0.015974320471286774, zero_point=0.0, bit_width=9.0, signed_t=True, training_t=False)\n" ] } ], @@ -784,7 +793,7 @@ " [ 0.1628, 0.8685, -0.1448, -0.1086]],\n", "\n", " [[ 0.9228, 1.2666, 2.0084, 0.0543],\n", - " [ 0.6152, -0.4162, -0.8323, -2.3160]]], grad_fn=), scale=tensor(0.0181, grad_fn=), zero_point=tensor(0.), bit_width=tensor(8.), signed_t=tensor(True), training_t=tensor(True)) \n", + " [ 0.6152, -0.4162, -0.8323, -2.3160]]], grad_fn=), scale=0.018094077706336975, zero_point=0.0, bit_width=8.0, signed_t=True, training_t=True) \n", "\n", "Quant output:\n", " QuantTensor(value=tensor([[[-1.1218, -0.2533],\n", @@ -794,15 +803,7 @@ " [ 0.8685, -0.1086]],\n", "\n", " [[ 1.2666, 2.0084],\n", - " [ 0.6152, -0.8323]]], grad_fn=), scale=tensor(0.0181, grad_fn=), zero_point=tensor(0.), bit_width=tensor(8.), signed_t=tensor(True), training_t=tensor(True))\n" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "C:\\Users\\Alessandro\\miniconda3\\envs\\pt190\\lib\\site-packages\\torch\\nn\\functional.py:652: UserWarning: Named tensors and all their associated APIs are an experimental feature and subject to change. Please do not use them for anything important until they are released as stable. (Triggered internally at ..\\c10/core/TensorImpl.h:1156.)\n", - " return torch.max_pool1d(input, kernel_size, stride, padding, dilation, ceil_mode)\n" + " [ 0.6152, -0.8323]]], grad_fn=), scale=0.018094077706336975, zero_point=0.0, bit_width=8.0, signed_t=True, training_t=True)\n" ] } ], @@ -845,7 +846,7 @@ " [ 0.1628, 0.8685, -0.1448, -0.1086]],\n", "\n", " [[ 0.9228, 1.2666, 2.0084, 0.0543],\n", - " [ 0.6152, -0.4162, -0.8323, -2.3160]]], grad_fn=), scale=tensor(0.0181, grad_fn=), zero_point=tensor(0.), bit_width=tensor(8.), signed_t=tensor(True), training_t=tensor(True)) \n", + " [ 0.6152, -0.4162, -0.8323, -2.3160]]], grad_fn=), scale=0.018094077706336975, zero_point=0.0, bit_width=8.0, signed_t=True, training_t=True) \n", "\n", "Quant output:\n", " tensor([[[-0.8082, -0.8204, -0.2480, -0.4089],\n", @@ -855,7 +856,15 @@ " [ 0.1614, 0.7006, -0.1438, -0.1081]],\n", "\n", " [[ 0.7272, 0.8529, 0.9646, 0.0542],\n", - " [ 0.5478, -0.3937, -0.6817, -0.9807]]], grad_fn=)\n" + " [ 0.5478, -0.3937, -0.6817, -0.9807]]], grad_fn=)\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/tmp/ipykernel_1735865/661358273.py:7: UserWarning: Defining your `__torch_function__` as a plain method is deprecated and will be an error in future, please define it as a classmethod. (Triggered internally at /opt/conda/conda-bld/pytorch_1695392026823/work/torch/csrc/utils/python_arg_parser.cpp:368.)\n", + " quant_output = torch.tanh(quant_input)\n" ] } ], @@ -893,14 +902,24 @@ "Eval mode concat quant inputs:\n", " QuantTensor(value=tensor([[ 1.5335, -0.2875],\n", " [-2.0447, 0.5751],\n", - " [-1.0863, -1.4057]]), scale=tensor(0.0160), zero_point=tensor(0.), bit_width=tensor(8.), signed_t=tensor(True), training_t=tensor(False)) QuantTensor(value=tensor([[ 0.3994, 0.8307],\n", + " [-1.0863, -1.4057]]), scale=0.015974320471286774, zero_point=0.0, bit_width=8.0, signed_t=True, training_t=False) QuantTensor(value=tensor([[ 0.3994, 0.8307],\n", " [-0.7188, -0.3994],\n", - " [-0.5910, 0.1757]]), scale=tensor(0.0160), zero_point=tensor(0.), bit_width=tensor(8.), signed_t=tensor(True), training_t=tensor(False)) \n", + " [-0.5910, 0.1757]]), scale=0.015974320471286774, zero_point=0.0, bit_width=8.0, signed_t=True, training_t=False) \n", "\n", "Eval mode concat quant output:\n", " QuantTensor(value=tensor([[ 1.5335, -0.2875, 0.3994, 0.8307],\n", " [-2.0447, 0.5751, -0.7188, -0.3994],\n", - " [-1.0863, -1.4057, -0.5910, 0.1757]]), scale=tensor(0.0160), zero_point=tensor(0.), bit_width=tensor(8.), signed_t=tensor(True), training_t=tensor(False))\n" + " [-1.0863, -1.4057, -0.5910, 0.1757]]), scale=0.015974320471286774, zero_point=0.0, bit_width=8.0, signed_t=True, training_t=False)\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/tmp/ipykernel_1735865/3932472163.py:8: UserWarning: Defining your `__torch_function__` as a plain method is deprecated and will be an error in future, please define it as a classmethod. (Triggered internally at /opt/conda/conda-bld/pytorch_1695392026823/work/torch/csrc/utils/python_arg_parser.cpp:368.)\n", + " train_mode_cat = torch.cat([quant_identity(float_inp1), quant_identity(float_inp2)], dim=1)\n", + "/tmp/ipykernel_1735865/3932472163.py:14: UserWarning: Defining your `__torch_function__` as a plain method is deprecated and will be an error in future, please define it as a classmethod. (Triggered internally at /opt/conda/conda-bld/pytorch_1695392026823/work/torch/csrc/utils/python_arg_parser.cpp:368.)\n", + " eval_mode_cat = torch.cat([eval_quant_inp1, eval_quant_inp2], dim=1)\n" ] } ], @@ -957,7 +976,7 @@ " QuantTensor(value=tensor([[-0.0000, 0.3880],\n", " [-0.5820, -0.5044],\n", " [-0.2716, 0.1940],\n", - " [-0.0000, 0.5432]], grad_fn=), scale=tensor(0.0388, grad_fn=), zero_point=tensor(0.), bit_width=tensor(5.), signed_t=tensor(True), training_t=tensor(True))\n" + " [-0.0000, 0.5432]], grad_fn=), scale=0.03879871591925621, zero_point=0.0, bit_width=5.0, signed_t=True, training_t=True)\n" ] } ], @@ -994,7 +1013,7 @@ " [-0.0000, 0.5607]], grad_fn=), scale=tensor([[0.0253],\n", " [0.0388],\n", " [0.0182],\n", - " [0.0374]], grad_fn=), zero_point=tensor(0.), bit_width=tensor(5.), signed_t=tensor(True), training_t=tensor(True))\n" + " [0.0374]], grad_fn=), zero_point=0.0, bit_width=5.0, signed_t=True, training_t=True)\n" ] } ], @@ -1027,7 +1046,7 @@ "QuantTensor:\n", " QuantTensor(value=tensor([[ 1.6341, -0.5447],\n", " [-2.1788, 0.5447],\n", - " [-1.0894, -1.6341]], grad_fn=), scale=tensor(0.5447, grad_fn=), zero_point=tensor(0.), bit_width=tensor(3.), signed_t=tensor(True), training_t=tensor(True))\n" + " [-1.0894, -1.6341]], grad_fn=), scale=0.5446973443031311, zero_point=0.0, bit_width=3.0, signed_t=True, training_t=True)\n" ] } ], @@ -1100,8 +1119,8 @@ "\n", "Per-channel quant output:\n", " QuantTensor(value=tensor([[[ 0.8616, -0.7012, 0.4503],\n", - " [-1.1285, -0.4937, -0.1901]]], grad_fn=), scale=tensor([[[0.0021],\n", - " [0.0013]]], grad_fn=), zero_point=tensor(0.), bit_width=tensor(17.), signed_t=tensor(True), training_t=tensor(True))\n" + " [-1.1285, -0.4937, -0.1901]]], grad_fn=), scale=tensor([[[0.0021],\n", + " [0.0013]]], grad_fn=), zero_point=0.0, bit_width=17.0, signed_t=True, training_t=True)\n" ] } ], @@ -1158,8 +1177,8 @@ "\n", "Per-channel quant output:\n", " QuantTensor(value=tensor([[[ 0.8616, -0.7012, 0.4503],\n", - " [-1.1285, -0.4937, -0.1901]]], grad_fn=), scale=tensor([[[0.0021],\n", - " [0.0013]]], grad_fn=), zero_point=tensor(0.), bit_width=tensor(17.), signed_t=tensor(True), training_t=tensor(True))\n" + " [-1.1285, -0.4937, -0.1901]]], grad_fn=), scale=tensor([[[0.0021],\n", + " [0.0013]]], grad_fn=), zero_point=0.0, bit_width=17.0, signed_t=True, training_t=True)\n" ] } ], @@ -1307,7 +1326,7 @@ " [-0.0132, 0.5607]], grad_fn=), scale=tensor([[0.0030],\n", " [0.0046],\n", " [0.0021],\n", - " [0.0044]], grad_fn=), zero_point=tensor(0.), bit_width=tensor(8., grad_fn=), signed_t=tensor(True), training_t=tensor(True))\n" + " [0.0044]], grad_fn=), zero_point=0.0, bit_width=8.0, signed_t=True, training_t=True)\n" ] } ], @@ -1345,8 +1364,8 @@ "text/plain": [ "QuantTensor(value=tensor([[-0.9109, -0.4588, 0.3119, -0.6530],\n", " [ 1.2089, 0.6493, -0.3731, 0.8706],\n", - " [ 1.3893, 0.2823, -0.8979, 0.9543]], grad_fn=), scale=tensor([[9.0542e-05, 3.9068e-05, 5.6866e-05, 6.4251e-05]],\n", - " grad_fn=), zero_point=tensor(0.), bit_width=tensor(17., grad_fn=), signed_t=tensor(True), training_t=tensor(True))" + " [ 1.3893, 0.2823, -0.8979, 0.9543]], grad_fn=), scale=tensor([[9.0542e-05, 3.9068e-05, 5.6866e-05, 6.4251e-05]],\n", + " grad_fn=), zero_point=tensor(0.), bit_width=tensor(17., grad_fn=), signed_t=tensor(True), training_t=tensor(True))" ] }, "execution_count": 28, @@ -1406,11 +1425,11 @@ "evalue": "Error(s) in loading state_dict for QuantLinear:\n\tMissing key(s) in state_dict: \"input_quant.fused_activation_quant_proxy.tensor_quant.scaling_impl.value\", \"input_quant.fused_activation_quant_proxy.tensor_quant.msb_clamp_bit_width_impl.bit_width_offset\", \"weight_quant.tensor_quant.scaling_impl.value\", \"weight_quant.tensor_quant.msb_clamp_bit_width_impl.bit_width_offset\". ", "output_type": "error", "traceback": [ - "\u001b[1;31m---------------------------------------------------------------------------\u001b[0m", - "\u001b[1;31mRuntimeError\u001b[0m Traceback (most recent call last)", - "\u001b[1;32mC:\\Users\\ALESSA~1\\AppData\\Local\\Temp/ipykernel_18920/1653109852.py\u001b[0m in \u001b[0;36m\u001b[1;34m\u001b[0m\n\u001b[0;32m 10\u001b[0m return_quant_tensor=True, bias=False)\n\u001b[0;32m 11\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m---> 12\u001b[1;33m \u001b[0mquant_linear\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mload_state_dict\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mfloat_linear\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mstate_dict\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m", - "\u001b[1;32m~\\miniconda3\\envs\\pt190\\lib\\site-packages\\torch\\nn\\modules\\module.py\u001b[0m in \u001b[0;36mload_state_dict\u001b[1;34m(self, state_dict, strict)\u001b[0m\n\u001b[0;32m 1405\u001b[0m \u001b[1;32mif\u001b[0m \u001b[0mlen\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0merror_msgs\u001b[0m\u001b[1;33m)\u001b[0m \u001b[1;33m>\u001b[0m \u001b[1;36m0\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 1406\u001b[0m raise RuntimeError('Error(s) in loading state_dict for {}:\\n\\t{}'.format(\n\u001b[1;32m-> 1407\u001b[1;33m self.__class__.__name__, \"\\n\\t\".join(error_msgs)))\n\u001b[0m\u001b[0;32m 1408\u001b[0m \u001b[1;32mreturn\u001b[0m \u001b[0m_IncompatibleKeys\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mmissing_keys\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0munexpected_keys\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 1409\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n", - "\u001b[1;31mRuntimeError\u001b[0m: Error(s) in loading state_dict for QuantLinear:\n\tMissing key(s) in state_dict: \"input_quant.fused_activation_quant_proxy.tensor_quant.scaling_impl.value\", \"input_quant.fused_activation_quant_proxy.tensor_quant.msb_clamp_bit_width_impl.bit_width_offset\", \"weight_quant.tensor_quant.scaling_impl.value\", \"weight_quant.tensor_quant.msb_clamp_bit_width_impl.bit_width_offset\". " + "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[0;31mRuntimeError\u001b[0m Traceback (most recent call last)", + "\u001b[1;32m/home/giuseppe/Documents/git/brevitas/notebooks/Brevitas_TVMCon2021.ipynb Cell 75\u001b[0m line \u001b[0;36m1\n\u001b[1;32m 5\u001b[0m float_linear \u001b[39m=\u001b[39m nn\u001b[39m.\u001b[39mLinear(\u001b[39m2\u001b[39m, \u001b[39m4\u001b[39m, bias\u001b[39m=\u001b[39m\u001b[39mFalse\u001b[39;00m)\n\u001b[1;32m 6\u001b[0m quant_linear \u001b[39m=\u001b[39m QuantLinear(\n\u001b[1;32m 7\u001b[0m \u001b[39m2\u001b[39m, \u001b[39m4\u001b[39m, \n\u001b[1;32m 8\u001b[0m input_quant\u001b[39m=\u001b[39mLearnedIntActPerTensorFloat,\n\u001b[1;32m 9\u001b[0m weight_quant\u001b[39m=\u001b[39mLearnedIntWeightPerChannelFloat, \n\u001b[1;32m 10\u001b[0m return_quant_tensor\u001b[39m=\u001b[39m\u001b[39mTrue\u001b[39;00m, bias\u001b[39m=\u001b[39m\u001b[39mFalse\u001b[39;00m)\n\u001b[0;32m---> 12\u001b[0m quant_linear\u001b[39m.\u001b[39;49mload_state_dict(float_linear\u001b[39m.\u001b[39;49mstate_dict())\n", + "File \u001b[0;32m~/miniconda3/envs/torch_2.1/lib/python3.11/site-packages/torch/nn/modules/module.py:2152\u001b[0m, in \u001b[0;36mModule.load_state_dict\u001b[0;34m(self, state_dict, strict, assign)\u001b[0m\n\u001b[1;32m 2147\u001b[0m error_msgs\u001b[39m.\u001b[39minsert(\n\u001b[1;32m 2148\u001b[0m \u001b[39m0\u001b[39m, \u001b[39m'\u001b[39m\u001b[39mMissing key(s) in state_dict: \u001b[39m\u001b[39m{}\u001b[39;00m\u001b[39m. \u001b[39m\u001b[39m'\u001b[39m\u001b[39m.\u001b[39mformat(\n\u001b[1;32m 2149\u001b[0m \u001b[39m'\u001b[39m\u001b[39m, \u001b[39m\u001b[39m'\u001b[39m\u001b[39m.\u001b[39mjoin(\u001b[39mf\u001b[39m\u001b[39m'\u001b[39m\u001b[39m\"\u001b[39m\u001b[39m{\u001b[39;00mk\u001b[39m}\u001b[39;00m\u001b[39m\"\u001b[39m\u001b[39m'\u001b[39m \u001b[39mfor\u001b[39;00m k \u001b[39min\u001b[39;00m missing_keys)))\n\u001b[1;32m 2151\u001b[0m \u001b[39mif\u001b[39;00m \u001b[39mlen\u001b[39m(error_msgs) \u001b[39m>\u001b[39m \u001b[39m0\u001b[39m:\n\u001b[0;32m-> 2152\u001b[0m \u001b[39mraise\u001b[39;00m \u001b[39mRuntimeError\u001b[39;00m(\u001b[39m'\u001b[39m\u001b[39mError(s) in loading state_dict for \u001b[39m\u001b[39m{}\u001b[39;00m\u001b[39m:\u001b[39m\u001b[39m\\n\u001b[39;00m\u001b[39m\\t\u001b[39;00m\u001b[39m{}\u001b[39;00m\u001b[39m'\u001b[39m\u001b[39m.\u001b[39mformat(\n\u001b[1;32m 2153\u001b[0m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m\u001b[39m__class__\u001b[39m\u001b[39m.\u001b[39m\u001b[39m__name__\u001b[39m, \u001b[39m\"\u001b[39m\u001b[39m\\n\u001b[39;00m\u001b[39m\\t\u001b[39;00m\u001b[39m\"\u001b[39m\u001b[39m.\u001b[39mjoin(error_msgs)))\n\u001b[1;32m 2154\u001b[0m \u001b[39mreturn\u001b[39;00m _IncompatibleKeys(missing_keys, unexpected_keys)\n", + "\u001b[0;31mRuntimeError\u001b[0m: Error(s) in loading state_dict for QuantLinear:\n\tMissing key(s) in state_dict: \"input_quant.fused_activation_quant_proxy.tensor_quant.scaling_impl.value\", \"input_quant.fused_activation_quant_proxy.tensor_quant.msb_clamp_bit_width_impl.bit_width_offset\", \"weight_quant.tensor_quant.scaling_impl.value\", \"weight_quant.tensor_quant.msb_clamp_bit_width_impl.bit_width_offset\". " ] } ], @@ -1575,10 +1594,12 @@ " (stats): _Stats(\n", " (stats_impl): AbsPercentile()\n", " )\n", - " (restrict_clamp_scaling): _RestrictClampValue(\n", - " (clamp_min_ste): Identity()\n", + " (restrict_scaling): _RestrictValue(\n", " (restrict_value_impl): FloatRestrictValue()\n", " )\n", + " (clamp_scaling): _ClampValue(\n", + " (clamp_min_ste): ScalarClampMinSte()\n", + " )\n", " (restrict_inplace_preprocess): Identity()\n", " (restrict_preprocess): Identity()\n", " )\n", @@ -1852,13 +1873,14 @@ "name": "stdout", "output_type": "stream", "text": [ - "Requirement already satisfied: netron in c:\\users\\alessandro\\miniconda3\\envs\\pt190\\lib\\site-packages (5.3.9)\n", - "Requirement already satisfied: onnx in c:\\users\\alessandro\\miniconda3\\envs\\pt190\\lib\\site-packages (1.10.2)\n", - "Requirement already satisfied: onnxoptimizer in c:\\users\\alessandro\\miniconda3\\envs\\pt190\\lib\\site-packages (0.2.6)\n", - "Requirement already satisfied: numpy>=1.16.6 in c:\\users\\alessandro\\miniconda3\\envs\\pt190\\lib\\site-packages (from onnx) (1.21.2)\n", - "Requirement already satisfied: typing-extensions>=3.6.2.1 in c:\\users\\alessandro\\miniconda3\\envs\\pt190\\lib\\site-packages (from onnx) (3.10.0.2)\n", - "Requirement already satisfied: protobuf in c:\\users\\alessandro\\miniconda3\\envs\\pt190\\lib\\site-packages (from onnx) (3.19.1)\n", - "Requirement already satisfied: six in c:\\users\\alessandro\\miniconda3\\envs\\pt190\\lib\\site-packages (from onnx) (1.16.0)\n" + "\u001b[33mWARNING: Ignoring invalid distribution ~nnx-weekly (/home/giuseppe/miniconda3/envs/torch_2.1/lib/python3.11/site-packages)\u001b[0m\u001b[33m\n", + "\u001b[0mRequirement already satisfied: netron in /home/giuseppe/miniconda3/envs/torch_2.1/lib/python3.11/site-packages (7.2.9)\n", + "Requirement already satisfied: onnx in /home/giuseppe/miniconda3/envs/torch_2.1/lib/python3.11/site-packages (1.15.0)\n", + "Requirement already satisfied: onnxoptimizer in /home/giuseppe/miniconda3/envs/torch_2.1/lib/python3.11/site-packages (0.3.13)\n", + "Requirement already satisfied: numpy in /home/giuseppe/miniconda3/envs/torch_2.1/lib/python3.11/site-packages (from onnx) (1.26.0)\n", + "Requirement already satisfied: protobuf>=3.20.2 in /home/giuseppe/miniconda3/envs/torch_2.1/lib/python3.11/site-packages (from onnx) (3.20.3)\n", + "\u001b[33mWARNING: Ignoring invalid distribution ~nnx-weekly (/home/giuseppe/miniconda3/envs/torch_2.1/lib/python3.11/site-packages)\u001b[0m\u001b[33m\n", + "\u001b[0m" ] } ], @@ -1894,9 +1916,202 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 39, "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/home/giuseppe/miniconda3/envs/torch_2.1/lib/python3.11/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", + " from .autonotebook import tqdm as notebook_tqdm\n", + "/home/giuseppe/Documents/git/brevitas/src/brevitas/export/onnx/standard/manager.py:26: UserWarning: ONNX opset version set to 13, override with opset_version=\n", + " warnings.warn(f\"ONNX opset version set to {DEFAULT_OPSET}, override with {ka}=\")\n" + ] + }, + { + "data": { + "text/plain": [ + "ir_version: 7\n", + "producer_name: \"pytorch\"\n", + "producer_version: \"2.1.0\"\n", + "graph {\n", + " node {\n", + " output: \"/export_handler/Constant_output_0\"\n", + " name: \"/export_handler/Constant\"\n", + " op_type: \"Constant\"\n", + " attribute {\n", + " name: \"value\"\n", + " t {\n", + " data_type: 1\n", + " raw_data: \"\\000\\000\\000<\"\n", + " }\n", + " type: TENSOR\n", + " }\n", + " }\n", + " node {\n", + " output: \"/export_handler/Constant_1_output_0\"\n", + " name: \"/export_handler/Constant_1\"\n", + " op_type: \"Constant\"\n", + " attribute {\n", + " name: \"value\"\n", + " t {\n", + " data_type: 3\n", + " raw_data: \"\\000\"\n", + " }\n", + " type: TENSOR\n", + " }\n", + " }\n", + " node {\n", + " input: \"inp.1\"\n", + " input: \"/export_handler/Constant_output_0\"\n", + " input: \"/export_handler/Constant_1_output_0\"\n", + " output: \"/export_handler/QuantizeLinear_output_0\"\n", + " name: \"/export_handler/QuantizeLinear\"\n", + " op_type: \"QuantizeLinear\"\n", + " }\n", + " node {\n", + " output: \"/export_handler/Constant_2_output_0\"\n", + " name: \"/export_handler/Constant_2\"\n", + " op_type: \"Constant\"\n", + " attribute {\n", + " name: \"value\"\n", + " t {\n", + " dims: 4\n", + " dims: 2\n", + " dims: 3\n", + " data_type: 3\n", + " raw_data: \"\\003\\006\\376\\006\\377\\001\\007\\371\\373\\376\\375\\006\\373\\375\\373\\371\\374\\006\\003\\004\\000\\374\\001\\371\"\n", + " }\n", + " type: TENSOR\n", + " }\n", + " }\n", + " node {\n", + " output: \"/export_handler/Constant_3_output_0\"\n", + " name: \"/export_handler/Constant_3\"\n", + " op_type: \"Constant\"\n", + " attribute {\n", + " name: \"value\"\n", + " t {\n", + " data_type: 1\n", + " raw_data: \"\\242\\272_=\"\n", + " }\n", + " type: TENSOR\n", + " }\n", + " }\n", + " node {\n", + " output: \"/export_handler/Constant_4_output_0\"\n", + " name: \"/export_handler/Constant_4\"\n", + " op_type: \"Constant\"\n", + " attribute {\n", + " name: \"value\"\n", + " t {\n", + " dims: 4\n", + " data_type: 6\n", + " raw_data: \"M\\375\\377\\377\\023\\376\\377\\377\\\\\\002\\000\\0001\\002\\000\\000\"\n", + " }\n", + " type: TENSOR\n", + " }\n", + " }\n", + " node {\n", + " input: \"/export_handler/QuantizeLinear_output_0\"\n", + " input: \"/export_handler/Constant_output_0\"\n", + " input: \"/export_handler/Constant_1_output_0\"\n", + " input: \"/export_handler/Constant_2_output_0\"\n", + " input: \"/export_handler/Constant_3_output_0\"\n", + " input: \"/export_handler/Constant_1_output_0\"\n", + " input: \"/export_handler/Constant_output_0\"\n", + " input: \"/export_handler/Constant_1_output_0\"\n", + " input: \"/export_handler/Constant_4_output_0\"\n", + " output: \"/export_handler/QLinearConv_output_0\"\n", + " name: \"/export_handler/QLinearConv\"\n", + " op_type: \"QLinearConv\"\n", + " attribute {\n", + " name: \"dilations\"\n", + " ints: 1\n", + " type: INTS\n", + " }\n", + " attribute {\n", + " name: \"group\"\n", + " i: 1\n", + " type: INT\n", + " }\n", + " attribute {\n", + " name: \"kernel_shape\"\n", + " ints: 3\n", + " type: INTS\n", + " }\n", + " attribute {\n", + " name: \"pads\"\n", + " ints: 0\n", + " ints: 0\n", + " type: INTS\n", + " }\n", + " attribute {\n", + " name: \"strides\"\n", + " ints: 1\n", + " type: INTS\n", + " }\n", + " }\n", + " node {\n", + " input: \"/export_handler/QLinearConv_output_0\"\n", + " input: \"/export_handler/Constant_output_0\"\n", + " input: \"/export_handler/Constant_1_output_0\"\n", + " output: \"10\"\n", + " name: \"/export_handler/DequantizeLinear\"\n", + " op_type: \"DequantizeLinear\"\n", + " }\n", + " name: \"main_graph\"\n", + " input {\n", + " name: \"inp.1\"\n", + " type {\n", + " tensor_type {\n", + " elem_type: 1\n", + " shape {\n", + " dim {\n", + " dim_value: 1\n", + " }\n", + " dim {\n", + " dim_value: 2\n", + " }\n", + " dim {\n", + " dim_value: 5\n", + " }\n", + " }\n", + " }\n", + " }\n", + " }\n", + " output {\n", + " name: \"10\"\n", + " type {\n", + " tensor_type {\n", + " elem_type: 1\n", + " shape {\n", + " dim {\n", + " dim_value: 1\n", + " }\n", + " dim {\n", + " dim_value: 4\n", + " }\n", + " dim {\n", + " dim_value: 3\n", + " }\n", + " }\n", + " }\n", + " }\n", + " }\n", + "}\n", + "opset_import {\n", + " domain: \"\"\n", + " version: 13\n", + "}" + ] + }, + "execution_count": 39, + "metadata": {}, + "output_type": "execute_result" + } + ], "source": [ "torch.manual_seed(0)\n", "\n", @@ -1918,7 +2133,7 @@ }, { "cell_type": "code", - "execution_count": 39, + "execution_count": 40, "metadata": { "tags": [ "skip-execution" @@ -1926,33 +2141,22 @@ }, "outputs": [ { - "name": "stdout", - "output_type": "stream", - "text": [ - "Serving 'qop_onnx_conv_4b8b.onnx' at http://localhost:8082\n" + "ename": "OSError", + "evalue": "[Errno 98] Address already in use", + "output_type": "error", + "traceback": [ + "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[0;31mOSError\u001b[0m Traceback (most recent call last)", + "\u001b[1;32m/home/giuseppe/Documents/git/brevitas/notebooks/Brevitas_TVMCon2021.ipynb Cell 103\u001b[0m line \u001b[0;36m1\n\u001b[0;32m----> 1\u001b[0m show_netron(output_path, \u001b[39m8082\u001b[39;49m)\n", + "\u001b[1;32m/home/giuseppe/Documents/git/brevitas/notebooks/Brevitas_TVMCon2021.ipynb Cell 103\u001b[0m line \u001b[0;36m7\n\u001b[1;32m 5\u001b[0m \u001b[39mdef\u001b[39;00m \u001b[39mshow_netron\u001b[39m(model_path, port):\n\u001b[1;32m 6\u001b[0m time\u001b[39m.\u001b[39msleep(\u001b[39m3.\u001b[39m)\n\u001b[0;32m----> 7\u001b[0m netron\u001b[39m.\u001b[39;49mstart(model_path, address\u001b[39m=\u001b[39;49m(\u001b[39m\"\u001b[39;49m\u001b[39mlocalhost\u001b[39;49m\u001b[39m\"\u001b[39;49m, port), browse\u001b[39m=\u001b[39;49m\u001b[39mFalse\u001b[39;49;00m)\n\u001b[1;32m 8\u001b[0m \u001b[39mreturn\u001b[39;00m IFrame(src\u001b[39m=\u001b[39m\u001b[39mf\u001b[39m\u001b[39m\"\u001b[39m\u001b[39mhttp://localhost:\u001b[39m\u001b[39m{\u001b[39;00mport\u001b[39m}\u001b[39;00m\u001b[39m/\u001b[39m\u001b[39m\"\u001b[39m, width\u001b[39m=\u001b[39m\u001b[39m\"\u001b[39m\u001b[39m100\u001b[39m\u001b[39m%\u001b[39m\u001b[39m\"\u001b[39m, height\u001b[39m=\u001b[39m\u001b[39m400\u001b[39m)\n", + "File \u001b[0;32m~/miniconda3/envs/torch_2.1/lib/python3.11/site-packages/netron/server.py:321\u001b[0m, in \u001b[0;36mstart\u001b[0;34m(file, address, browse, verbosity)\u001b[0m\n\u001b[1;32m 309\u001b[0m \u001b[39mdef\u001b[39;00m \u001b[39mstart\u001b[39m(file\u001b[39m=\u001b[39m\u001b[39mNone\u001b[39;00m, address\u001b[39m=\u001b[39m\u001b[39mNone\u001b[39;00m, browse\u001b[39m=\u001b[39m\u001b[39mTrue\u001b[39;00m, verbosity\u001b[39m=\u001b[39m\u001b[39m1\u001b[39m):\n\u001b[1;32m 310\u001b[0m \u001b[39m \u001b[39m\u001b[39m'''Start serving model file at address and open in web browser.\u001b[39;00m\n\u001b[1;32m 311\u001b[0m \n\u001b[1;32m 312\u001b[0m \u001b[39m Args:\u001b[39;00m\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 319\u001b[0m \u001b[39m A (host, port) address tuple.\u001b[39;00m\n\u001b[1;32m 320\u001b[0m \u001b[39m '''\u001b[39;00m\n\u001b[0;32m--> 321\u001b[0m \u001b[39mreturn\u001b[39;00m serve(file, \u001b[39mNone\u001b[39;49;00m, browse\u001b[39m=\u001b[39;49mbrowse, address\u001b[39m=\u001b[39;49maddress, verbosity\u001b[39m=\u001b[39;49mverbosity)\n", + "File \u001b[0;32m~/miniconda3/envs/torch_2.1/lib/python3.11/site-packages/netron/server.py:298\u001b[0m, in \u001b[0;36mserve\u001b[0;34m(file, data, address, browse, verbosity)\u001b[0m\n\u001b[1;32m 295\u001b[0m \u001b[39melse\u001b[39;00m:\n\u001b[1;32m 296\u001b[0m address \u001b[39m=\u001b[39m _make_port(address)\n\u001b[0;32m--> 298\u001b[0m thread \u001b[39m=\u001b[39m _HTTPServerThread(content, address, verbosity)\n\u001b[1;32m 299\u001b[0m thread\u001b[39m.\u001b[39mstart()\n\u001b[1;32m 300\u001b[0m \u001b[39mwhile\u001b[39;00m \u001b[39mnot\u001b[39;00m thread\u001b[39m.\u001b[39malive():\n", + "File \u001b[0;32m~/miniconda3/envs/torch_2.1/lib/python3.11/site-packages/netron/server.py:129\u001b[0m, in \u001b[0;36m_HTTPServerThread.__init__\u001b[0;34m(self, content, address, verbosity)\u001b[0m\n\u001b[1;32m 127\u001b[0m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39maddress \u001b[39m=\u001b[39m address\n\u001b[1;32m 128\u001b[0m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39murl \u001b[39m=\u001b[39m \u001b[39m'\u001b[39m\u001b[39mhttp://\u001b[39m\u001b[39m'\u001b[39m \u001b[39m+\u001b[39m address[\u001b[39m0\u001b[39m] \u001b[39m+\u001b[39m \u001b[39m'\u001b[39m\u001b[39m:\u001b[39m\u001b[39m'\u001b[39m \u001b[39m+\u001b[39m \u001b[39mstr\u001b[39m(address[\u001b[39m1\u001b[39m])\n\u001b[0;32m--> 129\u001b[0m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mserver \u001b[39m=\u001b[39m _ThreadedHTTPServer(address, _HTTPRequestHandler)\n\u001b[1;32m 130\u001b[0m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mserver\u001b[39m.\u001b[39mtimeout \u001b[39m=\u001b[39m \u001b[39m0.25\u001b[39m\n\u001b[1;32m 131\u001b[0m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mserver\u001b[39m.\u001b[39mblock_on_close \u001b[39m=\u001b[39m \u001b[39mFalse\u001b[39;00m\n", + "File \u001b[0;32m~/miniconda3/envs/torch_2.1/lib/python3.11/socketserver.py:456\u001b[0m, in \u001b[0;36mTCPServer.__init__\u001b[0;34m(self, server_address, RequestHandlerClass, bind_and_activate)\u001b[0m\n\u001b[1;32m 454\u001b[0m \u001b[39mif\u001b[39;00m bind_and_activate:\n\u001b[1;32m 455\u001b[0m \u001b[39mtry\u001b[39;00m:\n\u001b[0;32m--> 456\u001b[0m \u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49mserver_bind()\n\u001b[1;32m 457\u001b[0m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mserver_activate()\n\u001b[1;32m 458\u001b[0m \u001b[39mexcept\u001b[39;00m:\n", + "File \u001b[0;32m~/miniconda3/envs/torch_2.1/lib/python3.11/http/server.py:136\u001b[0m, in \u001b[0;36mHTTPServer.server_bind\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 134\u001b[0m \u001b[39mdef\u001b[39;00m \u001b[39mserver_bind\u001b[39m(\u001b[39mself\u001b[39m):\n\u001b[1;32m 135\u001b[0m \u001b[39m \u001b[39m\u001b[39m\"\"\"Override server_bind to store the server name.\"\"\"\u001b[39;00m\n\u001b[0;32m--> 136\u001b[0m socketserver\u001b[39m.\u001b[39;49mTCPServer\u001b[39m.\u001b[39;49mserver_bind(\u001b[39mself\u001b[39;49m)\n\u001b[1;32m 137\u001b[0m host, port \u001b[39m=\u001b[39m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mserver_address[:\u001b[39m2\u001b[39m]\n\u001b[1;32m 138\u001b[0m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mserver_name \u001b[39m=\u001b[39m socket\u001b[39m.\u001b[39mgetfqdn(host)\n", + "File \u001b[0;32m~/miniconda3/envs/torch_2.1/lib/python3.11/socketserver.py:472\u001b[0m, in \u001b[0;36mTCPServer.server_bind\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 470\u001b[0m \u001b[39mif\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mallow_reuse_port \u001b[39mand\u001b[39;00m \u001b[39mhasattr\u001b[39m(socket, \u001b[39m\"\u001b[39m\u001b[39mSO_REUSEPORT\u001b[39m\u001b[39m\"\u001b[39m):\n\u001b[1;32m 471\u001b[0m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39msocket\u001b[39m.\u001b[39msetsockopt(socket\u001b[39m.\u001b[39mSOL_SOCKET, socket\u001b[39m.\u001b[39mSO_REUSEPORT, \u001b[39m1\u001b[39m)\n\u001b[0;32m--> 472\u001b[0m \u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49msocket\u001b[39m.\u001b[39;49mbind(\u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49mserver_address)\n\u001b[1;32m 473\u001b[0m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mserver_address \u001b[39m=\u001b[39m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39msocket\u001b[39m.\u001b[39mgetsockname()\n", + "\u001b[0;31mOSError\u001b[0m: [Errno 98] Address already in use" ] - }, - { - "data": { - "text/html": [ - "\n", - " \n", - " " - ], - "text/plain": [ - "" - ] - }, - "execution_count": 39, - "metadata": {}, - "output_type": "execute_result" } ], "source": [ @@ -1982,7 +2186,317 @@ "cell_type": "code", "execution_count": null, "metadata": {}, - "outputs": [], + "outputs": [ + { + "data": { + "text/plain": [ + "ir_version: 8\n", + "producer_name: \"pytorch\"\n", + "producer_version: \"2.1.0\"\n", + "graph {\n", + " node {\n", + " input: \"x.87\"\n", + " input: \"/input_quant/export_handler/Constant_1_output_0\"\n", + " input: \"/input_quant/export_handler/Constant_2_output_0\"\n", + " input: \"/input_quant/export_handler/Constant_output_0\"\n", + " output: \"/input_quant/export_handler/Quant_output_0\"\n", + " name: \"/input_quant/export_handler/Quant\"\n", + " op_type: \"Quant\"\n", + " attribute {\n", + " name: \"narrow\"\n", + " i: 0\n", + " type: INT\n", + " }\n", + " attribute {\n", + " name: \"rounding_mode\"\n", + " s: \"ROUND\"\n", + " type: STRING\n", + " }\n", + " attribute {\n", + " name: \"signed\"\n", + " i: 1\n", + " type: INT\n", + " }\n", + " domain: \"onnx.brevitas\"\n", + " }\n", + " node {\n", + " input: \"/weight_quant/export_handler/Constant_1_output_0\"\n", + " input: \"/weight_quant/export_handler/Constant_2_output_0\"\n", + " input: \"/input_quant/export_handler/Constant_2_output_0\"\n", + " input: \"/weight_quant/export_handler/Constant_output_0\"\n", + " output: \"/weight_quant/export_handler/Quant_output_0\"\n", + " name: \"/weight_quant/export_handler/Quant\"\n", + " op_type: \"Quant\"\n", + " attribute {\n", + " name: \"narrow\"\n", + " i: 1\n", + " type: INT\n", + " }\n", + " attribute {\n", + " name: \"rounding_mode\"\n", + " s: \"ROUND\"\n", + " type: STRING\n", + " }\n", + " attribute {\n", + " name: \"signed\"\n", + " i: 1\n", + " type: INT\n", + " }\n", + " domain: \"onnx.brevitas\"\n", + " }\n", + " node {\n", + " input: \"bias\"\n", + " input: \"onnx.brevitas::Quant_11\"\n", + " input: \"/input_quant/export_handler/Constant_2_output_0\"\n", + " input: \"/bias_quant/export_handler/Constant_output_0\"\n", + " output: \"/bias_quant/export_handler/Quant_output_0\"\n", + " name: \"/bias_quant/export_handler/Quant\"\n", + " op_type: \"Quant\"\n", + " attribute {\n", + " name: \"narrow\"\n", + " i: 0\n", + " type: INT\n", + " }\n", + " attribute {\n", + " name: \"rounding_mode\"\n", + " s: \"ROUND\"\n", + " type: STRING\n", + " }\n", + " attribute {\n", + " name: \"signed\"\n", + " i: 1\n", + " type: INT\n", + " }\n", + " domain: \"onnx.brevitas\"\n", + " }\n", + " node {\n", + " input: \"/input_quant/export_handler/Quant_output_0\"\n", + " input: \"/weight_quant/export_handler/Quant_output_0\"\n", + " input: \"/bias_quant/export_handler/Quant_output_0\"\n", + " output: \"/Conv_output_0\"\n", + " name: \"/Conv\"\n", + " op_type: \"Conv\"\n", + " attribute {\n", + " name: \"dilations\"\n", + " ints: 1\n", + " type: INTS\n", + " }\n", + " attribute {\n", + " name: \"group\"\n", + " i: 1\n", + " type: INT\n", + " }\n", + " attribute {\n", + " name: \"kernel_shape\"\n", + " ints: 3\n", + " type: INTS\n", + " }\n", + " attribute {\n", + " name: \"pads\"\n", + " ints: 0\n", + " ints: 0\n", + " type: INTS\n", + " }\n", + " attribute {\n", + " name: \"strides\"\n", + " ints: 1\n", + " type: INTS\n", + " }\n", + " }\n", + " node {\n", + " input: \"/Conv_output_0\"\n", + " input: \"/input_quant/export_handler/Constant_1_output_0\"\n", + " input: \"/input_quant/export_handler/Constant_2_output_0\"\n", + " input: \"/input_quant/export_handler/Constant_output_0\"\n", + " output: \"15\"\n", + " name: \"/output_quant/export_handler/Quant\"\n", + " op_type: \"Quant\"\n", + " attribute {\n", + " name: \"narrow\"\n", + " i: 0\n", + " type: INT\n", + " }\n", + " attribute {\n", + " name: \"rounding_mode\"\n", + " s: \"ROUND\"\n", + " type: STRING\n", + " }\n", + " attribute {\n", + " name: \"signed\"\n", + " i: 1\n", + " type: INT\n", + " }\n", + " domain: \"onnx.brevitas\"\n", + " }\n", + " name: \"main_graph\"\n", + " initializer {\n", + " dims: 4\n", + " data_type: 1\n", + " name: \"bias\"\n", + " raw_data: \"w\\010\\227\\276\\360\\203W\\276q\\341\\203>\\002\\034u>\"\n", + " }\n", + " initializer {\n", + " data_type: 1\n", + " name: \"/input_quant/export_handler/Constant_output_0\"\n", + " raw_data: \"\\000\\000\\000A\"\n", + " }\n", + " initializer {\n", + " data_type: 1\n", + " name: \"/input_quant/export_handler/Constant_1_output_0\"\n", + " raw_data: \"\\000\\000\\000<\"\n", + " }\n", + " initializer {\n", + " data_type: 1\n", + " name: \"/input_quant/export_handler/Constant_2_output_0\"\n", + " raw_data: \"\\000\\000\\000\\000\"\n", + " }\n", + " initializer {\n", + " data_type: 1\n", + " name: \"/weight_quant/export_handler/Constant_output_0\"\n", + " raw_data: \"\\000\\000\\200@\"\n", + " }\n", + " initializer {\n", + " dims: 4\n", + " dims: 2\n", + " dims: 3\n", + " data_type: 1\n", + " name: \"/weight_quant/export_handler/Constant_1_output_0\"\n", + " raw_data: \"\\372\\313\\'>\\372\\313\\247>\\242\\272\\337\\275\\372\\313\\247>\\242\\272_\\275\\242\\272_=N\\303\\303>N\\303\\303\\276\\245\\324\\213\\276\\242\\272\\337\\275\\372\\313\\'\\276\\372\\313\\247>\\245\\324\\213\\276\\372\\313\\'\\276\\245\\324\\213\\276N\\303\\303\\276\\242\\272_\\276\\372\\313\\247>\\372\\313\\'>\\242\\272_>\\000\\000\\000\\000\\242\\272_\\276\\242\\272_=N\\303\\303\\276\"\n", + " }\n", + " initializer {\n", + " data_type: 1\n", + " name: \"/weight_quant/export_handler/Constant_2_output_0\"\n", + " raw_data: \"\\242\\272_=\"\n", + " }\n", + " initializer {\n", + " dims: 1\n", + " data_type: 1\n", + " name: \"onnx.brevitas::Quant_11\"\n", + " raw_data: \"\\242\\272\\3379\"\n", + " }\n", + " initializer {\n", + " data_type: 1\n", + " name: \"/bias_quant/export_handler/Constant_output_0\"\n", + " raw_data: \"\\000\\000\\200A\"\n", + " }\n", + " input {\n", + " name: \"x.87\"\n", + " type {\n", + " tensor_type {\n", + " elem_type: 1\n", + " shape {\n", + " dim {\n", + " dim_value: 1\n", + " }\n", + " dim {\n", + " dim_value: 2\n", + " }\n", + " dim {\n", + " dim_value: 5\n", + " }\n", + " }\n", + " }\n", + " }\n", + " }\n", + " input {\n", + " name: \"bias\"\n", + " type {\n", + " tensor_type {\n", + " elem_type: 1\n", + " shape {\n", + " dim {\n", + " dim_value: 4\n", + " }\n", + " }\n", + " }\n", + " }\n", + " }\n", + " output {\n", + " name: \"15\"\n", + " type {\n", + " tensor_type {\n", + " elem_type: 1\n", + " shape {\n", + " dim {\n", + " dim_value: 1\n", + " }\n", + " dim {\n", + " dim_value: 4\n", + " }\n", + " dim {\n", + " dim_value: 3\n", + " }\n", + " }\n", + " }\n", + " }\n", + " }\n", + " value_info {\n", + " name: \"/input_quant/export_handler/Quant_output_0\"\n", + " type {\n", + " tensor_type {\n", + " elem_type: 1\n", + " shape {\n", + " dim {\n", + " dim_value: 1\n", + " }\n", + " dim {\n", + " dim_value: 2\n", + " }\n", + " dim {\n", + " dim_value: 5\n", + " }\n", + " }\n", + " }\n", + " }\n", + " }\n", + " value_info {\n", + " name: \"/weight_quant/export_handler/Quant_output_0\"\n", + " type {\n", + " tensor_type {\n", + " elem_type: 1\n", + " shape {\n", + " dim {\n", + " dim_value: 4\n", + " }\n", + " dim {\n", + " dim_value: 2\n", + " }\n", + " dim {\n", + " dim_value: 3\n", + " }\n", + " }\n", + " }\n", + " }\n", + " }\n", + " value_info {\n", + " name: \"/bias_quant/export_handler/Quant_output_0\"\n", + " type {\n", + " tensor_type {\n", + " elem_type: 1\n", + " shape {\n", + " dim {\n", + " dim_value: 4\n", + " }\n", + " }\n", + " }\n", + " }\n", + " }\n", + "}\n", + "opset_import {\n", + " domain: \"\"\n", + " version: 17\n", + "}\n", + "opset_import {\n", + " domain: \"onnx.brevitas\"\n", + " version: 1\n", + "}" + ] + }, + "execution_count": 41, + "metadata": {}, + "output_type": "execute_result" + } + ], "source": [ "torch.manual_seed(0)\n", "\n", @@ -2003,7 +2517,7 @@ }, { "cell_type": "code", - "execution_count": 40, + "execution_count": null, "metadata": { "tags": [ "skip-execution" @@ -2055,7 +2569,189 @@ "cell_type": "code", "execution_count": null, "metadata": {}, - "outputs": [], + "outputs": [ + { + "data": { + "text/plain": [ + "ir_version: 8\n", + "producer_name: \"pytorch\"\n", + "producer_version: \"2.1.0\"\n", + "graph {\n", + " node {\n", + " input: \"/weight_quant/export_handler/Constant_1_output_0\"\n", + " input: \"/weight_quant/export_handler/Constant_2_output_0\"\n", + " input: \"/weight_quant/export_handler/Constant_3_output_0\"\n", + " input: \"/weight_quant/export_handler/Constant_output_0\"\n", + " output: \"/weight_quant/export_handler/Quant_output_0\"\n", + " name: \"/weight_quant/export_handler/Quant\"\n", + " op_type: \"Quant\"\n", + " attribute {\n", + " name: \"narrow\"\n", + " i: 1\n", + " type: INT\n", + " }\n", + " attribute {\n", + " name: \"rounding_mode\"\n", + " s: \"ROUND\"\n", + " type: STRING\n", + " }\n", + " attribute {\n", + " name: \"signed\"\n", + " i: 1\n", + " type: INT\n", + " }\n", + " domain: \"onnx.brevitas\"\n", + " }\n", + " node {\n", + " input: \"x.27\"\n", + " input: \"/weight_quant/export_handler/Quant_output_0\"\n", + " input: \"bias\"\n", + " output: \"8\"\n", + " name: \"/Conv\"\n", + " op_type: \"Conv\"\n", + " attribute {\n", + " name: \"dilations\"\n", + " ints: 1\n", + " type: INTS\n", + " }\n", + " attribute {\n", + " name: \"group\"\n", + " i: 1\n", + " type: INT\n", + " }\n", + " attribute {\n", + " name: \"kernel_shape\"\n", + " ints: 3\n", + " type: INTS\n", + " }\n", + " attribute {\n", + " name: \"pads\"\n", + " ints: 0\n", + " ints: 0\n", + " type: INTS\n", + " }\n", + " attribute {\n", + " name: \"strides\"\n", + " ints: 1\n", + " type: INTS\n", + " }\n", + " }\n", + " name: \"main_graph\"\n", + " initializer {\n", + " dims: 4\n", + " data_type: 1\n", + " name: \"bias\"\n", + " raw_data: \"\\243\\303\\206\\275\\325\\3600=\\366C\\275>\\222\\347\\301\\276\"\n", + " }\n", + " initializer {\n", + " data_type: 1\n", + " name: \"/weight_quant/export_handler/Constant_output_0\"\n", + " raw_data: \"\\000\\000\\200@\"\n", + " }\n", + " initializer {\n", + " dims: 4\n", + " dims: 2\n", + " dims: 3\n", + " data_type: 1\n", + " name: \"/weight_quant/export_handler/Constant_1_output_0\"\n", + " raw_data: \"\\000\\000\\000\\200\\2227d>\\256)\\253\\276\\273\\242\\216\\276\\256)+\\276\\2227\\344=\\000\\000\\000\\200\\256)\\253>\\2227d\\275\\2227\\344=\\2227\\344\\275\\2227d\\275\\240\\260\\307\\276\\273\\242\\216\\276\\256)+\\276\\000\\000\\000\\000\\256)+>\\2227d>\\273\\242\\216\\276\\256)+\\276\\256)+>\\256)\\253>\\2227\\344\\275\\273\\242\\216>\"\n", + " }\n", + " initializer {\n", + " data_type: 1\n", + " name: \"/weight_quant/export_handler/Constant_2_output_0\"\n", + " raw_data: \"\\2227d=\"\n", + " }\n", + " initializer {\n", + " data_type: 1\n", + " name: \"/weight_quant/export_handler/Constant_3_output_0\"\n", + " raw_data: \"\\000\\000\\000\\000\"\n", + " }\n", + " input {\n", + " name: \"x.27\"\n", + " type {\n", + " tensor_type {\n", + " elem_type: 1\n", + " shape {\n", + " dim {\n", + " dim_value: 1\n", + " }\n", + " dim {\n", + " dim_value: 2\n", + " }\n", + " dim {\n", + " dim_value: 5\n", + " }\n", + " }\n", + " }\n", + " }\n", + " }\n", + " input {\n", + " name: \"bias\"\n", + " type {\n", + " tensor_type {\n", + " elem_type: 1\n", + " shape {\n", + " dim {\n", + " dim_value: 4\n", + " }\n", + " }\n", + " }\n", + " }\n", + " }\n", + " output {\n", + " name: \"8\"\n", + " type {\n", + " tensor_type {\n", + " elem_type: 1\n", + " shape {\n", + " dim {\n", + " dim_value: 1\n", + " }\n", + " dim {\n", + " dim_value: 4\n", + " }\n", + " dim {\n", + " dim_value: 3\n", + " }\n", + " }\n", + " }\n", + " }\n", + " }\n", + " value_info {\n", + " name: \"/weight_quant/export_handler/Quant_output_0\"\n", + " type {\n", + " tensor_type {\n", + " elem_type: 1\n", + " shape {\n", + " dim {\n", + " dim_value: 4\n", + " }\n", + " dim {\n", + " dim_value: 2\n", + " }\n", + " dim {\n", + " dim_value: 3\n", + " }\n", + " }\n", + " }\n", + " }\n", + " }\n", + "}\n", + "opset_import {\n", + " domain: \"\"\n", + " version: 17\n", + "}\n", + "opset_import {\n", + " domain: \"onnx.brevitas\"\n", + " version: 1\n", + "}" + ] + }, + "execution_count": 42, + "metadata": {}, + "output_type": "execute_result" + } + ], "source": [ "torch.manual_seed(0)\n", "\n", @@ -2067,7 +2763,7 @@ }, { "cell_type": "code", - "execution_count": 41, + "execution_count": null, "metadata": { "tags": [ "skip-execution" @@ -2096,10 +2792,10 @@ " " ], "text/plain": [ - "" + "" ] }, - "execution_count": 41, + "execution_count": 43, "metadata": {}, "output_type": "execute_result" } @@ -2123,7 +2819,18 @@ "cell_type": "code", "execution_count": null, "metadata": {}, - "outputs": [], + "outputs": [ + { + "data": { + "text/plain": [ + "RecursiveScriptModule(original_name=_JitTraceExportWrapper)" + ] + }, + "execution_count": 44, + "metadata": {}, + "output_type": "execute_result" + } + ], "source": [ "from brevitas.quant import ShiftedUint8ActPerTensorFloat\n", "from brevitas.export import export_torch_qop\n", @@ -2142,21 +2849,13 @@ }, { "cell_type": "code", - "execution_count": 42, + "execution_count": null, "metadata": { "tags": [ "skip-execution" ] }, "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "c:\\users\\alessandro\\documenti\\brevitas_tvmcon\\src\\brevitas\\quant_tensor\\__init__.py:74: TracerWarning: torch.tensor results are registered as constants in the trace. You can safely ignore this warning if you use this function to create tensors out of constant variables that would be the same every time you call this function. In any other case, this might cause the trace to be incorrect.\n", - " training = torch.tensor(training, dtype=torch.bool)\n" - ] - }, { "name": "stdout", "output_type": "stream", @@ -2179,10 +2878,10 @@ " " ], "text/plain": [ - "" + "" ] }, - "execution_count": 42, + "execution_count": 45, "metadata": {}, "output_type": "execute_result" } @@ -2239,9 +2938,24 @@ }, { "cell_type": "code", - "execution_count": 44, + "execution_count": 41, "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/home/giuseppe/miniconda3/envs/torch_2.1/lib/python3.11/site-packages/torch/overrides.py:110: UserWarning: 'has_cuda' is deprecated, please use 'torch.backends.cuda.is_built()'\n", + " torch.has_cuda,\n", + "/home/giuseppe/miniconda3/envs/torch_2.1/lib/python3.11/site-packages/torch/overrides.py:111: UserWarning: 'has_cudnn' is deprecated, please use 'torch.backends.cudnn.is_available()'\n", + " torch.has_cudnn,\n", + "/home/giuseppe/miniconda3/envs/torch_2.1/lib/python3.11/site-packages/torch/overrides.py:117: UserWarning: 'has_mps' is deprecated, please use 'torch.backends.mps.is_built()'\n", + " torch.has_mps,\n", + "/home/giuseppe/miniconda3/envs/torch_2.1/lib/python3.11/site-packages/torch/overrides.py:118: UserWarning: 'has_mkldnn' is deprecated, please use 'torch.backends.mkldnn.is_available()'\n", + " torch.has_mkldnn,\n" + ] + } + ], "source": [ "from brevitas.graph.calibrate import bias_correction_mode\n", "from brevitas.graph.calibrate import calibration_mode\n", @@ -2280,7 +2994,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.8.15 | packaged by conda-forge | (default, Nov 22 2022, 08:42:03) [MSC v.1929 64 bit (AMD64)]" + "version": "3.11.5" }, "vscode": { "interpreter": { diff --git a/notebooks/ONNX_export_tutorial.ipynb b/notebooks/ONNX_export_tutorial.ipynb index 304161fce..e7a6659d2 100644 --- a/notebooks/ONNX_export_tutorial.ipynb +++ b/notebooks/ONNX_export_tutorial.ipynb @@ -22,9 +22,20 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 13, "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\u001b[33mWARNING: Ignoring invalid distribution ~nnx-weekly (/home/giuseppe/miniconda3/envs/torch_2.1/lib/python3.11/site-packages)\u001b[0m\u001b[33m\n", + "\u001b[0mRequirement already satisfied: netron in /home/giuseppe/miniconda3/envs/torch_2.1/lib/python3.11/site-packages (7.2.9)\n", + "\u001b[33mWARNING: Ignoring invalid distribution ~nnx-weekly (/home/giuseppe/miniconda3/envs/torch_2.1/lib/python3.11/site-packages)\u001b[0m\u001b[33m\n", + "\u001b[0mNote: you may need to restart the kernel to use updated packages.\n" + ] + } + ], "source": [ "%pip install netron" ] @@ -95,7 +106,7 @@ }, { "cell_type": "code", - "execution_count": 1, + "execution_count": 14, "metadata": { "collapsed": false, "pycharm": { @@ -116,7 +127,7 @@ }, { "cell_type": "code", - "execution_count": 2, + "execution_count": 15, "metadata": { "collapsed": false, "pycharm": { @@ -142,7 +153,7 @@ }, { "cell_type": "code", - "execution_count": 3, + "execution_count": 16, "metadata": { "collapsed": false, "pycharm": { @@ -157,6 +168,7 @@ "name": "stdout", "output_type": "stream", "text": [ + "Stopping http://localhost:8082\n", "Serving 'quant_linear_qcdq.onnx' at http://localhost:8082\n" ] }, @@ -175,10 +187,10 @@ " " ], "text/plain": [ - "" + "" ] }, - "execution_count": 3, + "execution_count": 16, "metadata": {}, "output_type": "execute_result" } @@ -219,7 +231,7 @@ }, { "cell_type": "code", - "execution_count": 4, + "execution_count": 17, "metadata": { "collapsed": false, "pycharm": { @@ -248,7 +260,7 @@ }, { "cell_type": "code", - "execution_count": 5, + "execution_count": 18, "metadata": { "collapsed": false, "pycharm": { @@ -263,6 +275,7 @@ "name": "stdout", "output_type": "stream", "text": [ + "Stopping http://localhost:8083\n", "Serving 'quant_model_qcdq.onnx' at http://localhost:8083\n" ] }, @@ -281,10 +294,10 @@ " " ], "text/plain": [ - "" + "" ] }, - "execution_count": 5, + "execution_count": 18, "metadata": {}, "output_type": "execute_result" } @@ -334,7 +347,7 @@ }, { "cell_type": "code", - "execution_count": 6, + "execution_count": 19, "metadata": { "collapsed": false, "pycharm": { @@ -365,7 +378,7 @@ }, { "cell_type": "code", - "execution_count": 7, + "execution_count": 20, "metadata": { "collapsed": false, "pycharm": { @@ -398,10 +411,10 @@ " " ], "text/plain": [ - "" + "" ] }, - "execution_count": 7, + "execution_count": 20, "metadata": {}, "output_type": "execute_result" } @@ -446,7 +459,7 @@ }, { "cell_type": "code", - "execution_count": 8, + "execution_count": 21, "metadata": { "collapsed": false, "pycharm": { @@ -458,7 +471,7 @@ "name": "stderr", "output_type": "stream", "text": [ - "c:\\users\\alessand\\documents\\brevitas\\src\\brevitas\\export\\onnx\\standard\\manager.py:23: UserWarning: ONNX opset version set to 13, override with opset_version=\n", + "/home/giuseppe/Documents/git/brevitas/src/brevitas/export/onnx/standard/manager.py:26: UserWarning: ONNX opset version set to 13, override with opset_version=\n", " warnings.warn(f\"ONNX opset version set to {DEFAULT_OPSET}, override with {ka}=\")\n" ] }, @@ -467,7 +480,7 @@ "text/plain": [ "ir_version: 7\n", "producer_name: \"pytorch\"\n", - "producer_version: \"1.13.1\"\n", + "producer_version: \"2.1.0\"\n", "graph {\n", " node {\n", " output: \"/input_quant/export_handler/Constant_output_0\"\n", @@ -496,7 +509,7 @@ " }\n", " }\n", " node {\n", - " input: \"inp.1\"\n", + " input: \"out.1\"\n", " input: \"/input_quant/export_handler/Constant_output_0\"\n", " input: \"/input_quant/export_handler/Constant_1_output_0\"\n", " output: \"/input_quant/export_handler/QuantizeLinear_output_0\"\n", @@ -515,7 +528,7 @@ " dims: 3\n", " dims: 3\n", " data_type: 3\n", - " raw_data: \"\\374\\372\\376\\374\\005\\000\\375\\374\\004\\375\\373\\373\\375\\007\\376\\374\\377\\000\\000\\373\\373\\004\\005\\371\\003\\375\\004\\373\\004\\374\\000\\006\\002\\003\\003\\005\\004\\377\\005\\000\\373\\376\\375\\376\\002\\376\\004\\377\\003\\005\\375\\371\\006\\373\\003\\007\\377\\374\\005\\375\\375\\006\\375\\377\\374\\001\\005\\371\\006\\005\\007\\376\\376\\372\\376\\004\\001\\374\\002\\373\\373\\376\\002\\376\\375\\377\\001\\376\\006\\371\\002\\000\\004\\005\\005\\000\\004\\373\\004\\002\\003\\000\\374\\376\\005\\000\\004\\372\\004\\000\\373\\000\\006\\377\\002\\005\\004\\005\\374\\000\\007\\377\\374\\371\\373\\007\\004\\376\\372\\001\\005\\001\\372\\377\\003\\001\\375\\006\\372\\377\\006\\003\\006\\004\\001\\004\\372\\005\\006\\003\\376\\373\\374\\375\\376\\005\\000\\004\\377\\372\\373\\000\\007\\377\\373\\003\\373\\376\\374\\374\\377\\375\\377\\003\\372\\005\\004\\007\\003\\375\\377\\001\\007\\377\\373\\374\\000\\377\\376\\374\\373\\377\\373\\375\\003\\004\\004\\376\\004\\377\\375\\003\\003\\377\\004\\000\\005\\004\\000\\372\\005\\007\\003\\004\\377\\373\\003\\371\\373\\002\\377\\006\\006\\007\\377\\376\\375\\002\\006\\005\\004\\374\\002\\000\\373\\004\\002\\002\\374\\371\\372\\371\\375\\001\\004\\000\\006\\376\\377\\002\\000\\372\\001\\001\\375\\007\\376\\005\\001\\373\\003\\374\\005\\003\\007\\005\\372\\004\\006\\375\\005\\003\\001\\373\\376\\374\\002\\376\\377\\376\\000\\006\\001\\375\\376\\377\\374\\000\\005\\002\\005\\006\\371\\375\\005\\375\\376\\374\\004\\001\\003\\001\\372\\005\\007\\371\\005\\000\\372\\001\\001\\371\\007\\374\\372\\373\\373\\372\\376\\004\\000\\002\\375\\376\\000\\004\\003\\003\\375\\003\\001\\376\\006\\001\\000\\372\\374\\376\\373\\002\\002\\004\\372\\377\\374\\005\\000\\001\\005\\005\\374\\007\\003\\377\\377\\000\\007\\002\\377\\377\\377\\374\\001\\001\\376\\000\\377\\373\\001\\004\\376\\003\\000\\007\\005\\000\\374\\372\\376\\005\\003\\003\\004\\372\\375\\372\\377\\006\\376\\374\\007\\373\\002\\374\\003\\377\\374\\002\\007\\373\\004\\376\\004\\004\\003\\005\\373\\003\\005\\376\\001\\000\\002\\371\\376\\000\\374\\377\\372\\375\\005\\373\\002\\373\\373\\377\\004\\375\\006\\377\\005\\005\\002\\375\\375\\003\\376\\376\\006\\002\\371\\000\\002\\373\\000\\006\\002\\372\\372\\006\\374\\372\\004\\006\\004\\000\\003\\001\\377\\371\\376\\003\\003\\373\\005\\000\\001\\003\\004\\001\\005\\001\\004\\373\\373\\372\\002\\371\\375\\372\\004\\377\\005\\375\\376\\374\\375\\003\\372\\001\\373\\372\\376\\005\\003\\372\\004\\373\\004\\374\\374\\376\\376\\377\\371\\375\\004\\375\\377\\376\\007\\004\\372\\000\\007\\372\\006\\002\\006\\001\\006\\372\\004\\004\\003\\002\\375\\006\\374\\002\\001\\001\\000\\376\\376\\006\\373\\374\\002\\372\\005\\374\\004\\004\\001\\374\\004\\377\\373\\002\\376\\001\\377\\003\\377\\007\\004\\372\\371\\002\\375\\377\\373\\002\\376\\375\\377\\006\\001\\001\\000\\374\\001\\006\\004\\371\\377\\375\\374\\377\\376\\003\\372\\373\\002\\005\\374\\000\\002\\004\\372\\004\\372\\003\\006\\375\\003\\377\\376\\000\\377\\374\\006\\377\\374\\375\\377\\373\\376\\372\\375\\006\\004\\371\\372\\374\\375\\004\\002\\372\\376\\001\\001\\002\\373\\000\\003\\000\\371\\001\\003\\377\\376\\371\\376\\004\\000\\003\\376\\002\\006\\004\\372\\007\\005\\004\\376\\000\\007\\372\\003\\002\\005\\005\\004\\372\\002\\377\\006\\002\\371\\375\\375\\372\\376\\005\\003\\000\\002\\371\\005\\372\\373\\377\\371\\376\\005\\374\\377\\007\\003\\001\\376\\006\\376\\001\\374\\374\\001\\373\\006\\376\\376\\001\\372\\377\\003\\006\\372\\373\\003\\377\\376\\000\\377\\373\\004\\372\\371\\376\\002\\004\\004\\006\\001\\372\\001\\376\\005\\001\\000\\000\\007\\002\\375\\002\\375\\375\\006\\007\\375\\375\\002\\006\\371\\375\\002\\377\\002\\377\\000\\373\\001\\372\\372\\001\\377\\372\\001\\002\\000\\375\\373\\377\\372\\001\\371\\372\\007\\372\\001\\377\\372\\004\\376\\376\\374\\375\\373\\373\\005\\371\\375\\006\\005\\007\\374\\373\\005\\372\\000\\001\\374\\005\\000\\002\\373\\004\\001\\004\\006\\002\\003\\373\\376\\372\\374\\003\\375\\005\\000\\005\\373\\001\\375\\374\\002\\002\\000\\373\\374\\003\\005\\376\\003\\374\\374\\373\\374\\000\\004\\371\\375\\372\\003\\375\\005\\005\\006\\007\\371\\003\\372\\003\\375\\004\\374\\001\\376\\373\\000\\004\\003\\001\\003\\372\\377\\003\\004\\374\\000\\376\\002\\377\\001\\374\\376\\002\\002\\001\\005\\375\\373\\001\\372\\000\\007\\004\\007\\006\\006\\000\\004\\004\\006\\000\\377\\375\\000\\002\\374\\376\\374\\006\\373\\377\\000\\374\\006\\373\\005\\001\\001\\006\\005\\373\\373\\001\\003\\371\\006\\372\\003\\005\\372\\003\\005\\006\\005\\006\\001\\001\\377\\372\\001\\003\\372\\005\\002\\376\\377\\373\\005\\376\\375\\373\\005\\004\\007\\001\\000\\002\\001\\374\\004\\003\\377\\004\\372\\373\\373\\007\\375\\002\\002\\377\\373\\007\\001\\004\\374\\007\\376\\000\\003\\376\\006\\371\\377\\003\\376\\003\\004\\375\\006\\376\\371\\373\\373\\004\\000\\005\\377\\372\\372\\377\\004\\002\\001\\000\\005\\372\\004\\377\\376\\375\\001\\005\\375\\375\\000\\003\\006\\374\\004\\377\\004\\006\\000\\374\\003\\000\\005\\376\\372\\371\\371\\000\\374\\372\\006\\004\\006\\376\\377\\001\\377\\376\\373\\374\\000\\003\\004\\372\\375\\000\\006\\002\\374\\377\\004\\372\\371\\373\\001\\006\\377\\003\\007\\377\\373\\000\\371\\002\\376\\003\\002\\377\\006\\006\\006\\371\\006\\373\\377\\006\\000\\374\\375\\376\\001\\376\\007\\003\\007\\376\\004\\001\\005\\003\\375\\372\\003\\004\\376\\374\\005\\372\\372\\000\\006\\377\\003\\000\\002\\001\\003\\375\\000\\004\\375\\372\\000\\001\\000\\000\\002\\000\\004\\005\\377\\005\\007\\376\\372\\001\\374\\006\\002\\376\\002\\005\\374\\372\\000\\375\\372\\372\\000\\001\\000\\377\\007\\376\\000\\374\\375\\000\\373\\003\\001\\006\\003\\376\\007\\374\\376\\374\\005\\371\\372\\001\\374\\374\\002\\375\\004\\001\\002\\002\\376\\003\\373\\000\\375\\375\\005\\373\\002\\376\\371\\006\\004\\001\\001\\371\\376\\005\\377\\375\\005\\003\\374\\375\\002\\373\\376\\001\\002\\001\\007\\002\\004\\376\\375\\377\\376\\004\\373\\000\\001\\375\\377\\372\\376\\002\\001\\375\\006\\005\\006\\004\\376\\004\\004\\001\\001\\377\\004\\006\\003\\001\\005\\006\\001\\377\\000\\000\\000\\372\\004\\375\\004\\377\\377\\006\\377\\373\\003\\375\\373\\004\\005\\377\\006\\376\\374\\374\\371\\376\\003\\376\\374\\001\\373\\001\\375\\001\\376\\376\\000\\376\\371\\376\\377\\372\\373\\374\\374\\375\\376\\003\\376\\002\\372\\375\\375\\007\\377\\373\\377\\006\\376\\377\\373\\002\\001\\000\\005\\004\\006\\376\\001\\373\\372\\371\\001\\371\\001\\373\\374\\001\\375\\373\\003\\375\\373\\005\\373\\004\\377\\002\\000\\002\\006\\001\\373\\375\\005\\376\\004\\000\\376\\003\\007\\000\\377\\003\\004\\005\\376\\004\\003\\004\\006\\006\\006\\371\\002\\374\\375\\003\\375\\000\\375\\377\\004\\003\\374\\373\\004\\005\\375\\003\\376\\001\\001\\374\\003\\377\\004\\006\\003\\377\\001\\003\\377\\377\\371\\000\\374\\003\\373\\374\\006\\372\\372\\006\\004\\375\\375\\373\\004\\005\\001\\373\\371\\377\\376\\004\\005\\373\\374\\005\\000\\376\\001\\002\\003\\006\\006\\374\\375\\374\\377\\001\\373\\003\\004\\372\\004\\375\\001\\371\\004\\002\\001\\376\\377\\005\\000\\376\\376\\372\\005\\000\\376\\004\\371\\000\\377\\377\\377\\373\\377\\001\\004\\002\\374\\373\\000\\374\\377\\373\\373\\374\\005\\006\\374\\003\\373\\000\\006\\001\\003\\371\\373\\006\\374\\005\\005\\006\\371\\002\\005\\373\\000\\377\\377\\003\\005\\003\\004\\376\\372\\000\\005\\004\\371\\372\\376\\371\\005\\375\\000\\001\\001\\000\\006\\005\\006\\002\\002\\000\\003\\006\\374\\005\\000\\373\\372\\376\\002\\372\\006\\003\\375\\373\\373\\375\\002\\004\\001\\007\\373\\377\\004\\005\\004\\375\\005\\376\\376\\004\\003\\000\\004\\376\\006\\001\\376\\003\\376\\007\\006\\002\\376\\001\\376\\006\\371\\006\\375\\375\\004\\003\\006\\377\\374\\004\\003\\375\\372\\374\\375\\006\\377\\000\\004\\373\\002\\006\\373\\377\\374\\372\\000\\000\\376\\006\\373\\372\\004\\001\\006\\003\\377\\006\\371\\006\\006\\004\\004\\005\\371\\376\\001\\003\\372\\005\\001\\002\\373\\001\\372\\375\\004\\372\\006\\373\\375\\001\\003\\375\\377\\003\\372\\374\\374\\373\\006\\005\\373\\002\\000\\004\\376\\377\\004\\374\\006\\374\\006\\373\\004\\375\\373\\006\\376\\006\\002\\002\\377\\372\\372\\005\\004\\375\\000\\002\\374\\002\\376\\007\\373\\376\\371\\377\\005\\376\\006\\002\\006\\376\\004\\372\\000\\005\\002\\002\\003\\006\\004\\377\\007\\374\\372\\372\\002\\375\\377\\001\\375\\005\\374\\377\\003\\007\\002\\005\\006\\006\\000\\001\\004\\000\\376\\371\\001\\000\\005\\004\\375\\372\\375\\004\\007\\371\\374\\002\\005\\000\\002\\002\\004\\004\\007\\005\\006\\373\\006\\004\\002\\005\\004\\376\\375\\000\\372\\004\\377\\003\\374\\003\\376\\006\\376\\006\\006\\005\\002\\006\\007\\002\\372\\372\\377\\373\\004\\373\\375\\004\\004\\003\\006\\002\\000\\002\\376\\000\\000\\005\\006\\005\\372\\003\\372\\006\\001\\007\\372\\002\\372\\004\\001\\005\\002\\005\\374\\372\\372\\002\\372\\001\\377\\002\\006\\005\\000\\005\\372\\375\\007\\377\\375\\004\\005\\003\\372\\004\\005\\376\\373\\001\\372\\003\\371\\371\\374\\005\\002\\005\\374\\377\\004\\002\\376\\004\\373\\377\\377\\377\\001\\005\\372\\003\\373\\375\\006\\374\\007\\376\\372\\006\\005\\371\\377\\005\\001\\003\\005\\002\\006\\003\\001\\377\\374\\004\\376\\374\\375\\376\\001\\001\\001\\004\\007\\007\\000\\005\\001\\376\\003\\376\\000\\000\\001\\001\\375\\371\\006\\002\\001\\373\\000\\377\\007\\004\\002\\374\\000\\001\\377\\003\\374\\003\\007\\373\\371\\373\\001\\005\\373\\372\\005\\373\\375\\005\\006\\372\\000\\005\\007\\003\\003\\377\\005\\006\\004\\374\\372\\375\\003\\004\\000\\005\\376\\374\\374\\375\\375\\377\\372\\000\\004\\002\\005\\002\\000\\374\\376\\373\\373\\376\\002\\374\\000\\376\\373\\000\\371\\373\\372\\006\\000\\376\\002\\375\\376\\005\\372\\004\\376\\375\\005\\006\\006\\004\\003\\002\\002\\002\\002\\375\\006\\377\\000\\004\\375\\004\\007\\004\\005\\372\\374\\004\\377\\003\\377\\000\\375\\372\\374\\372\\000\\004\\007\\002\\007\\372\\376\\004\\371\\375\\001\\001\\007\\003\\000\\004\\373\\001\\001\\376\\002\\377\\377\\006\\002\\003\\373\\373\\004\\372\\372\\376\\372\\002\\002\\002\\373\\001\\375\\374\\000\\004\\003\\376\\003\\376\\002\\373\\374\\003\\372\\371\\001\\375\\004\\371\\374\\004\\005\\002\\374\\371\\001\\373\\377\\374\\006\\373\\006\\000\\005\\005\\006\\006\\002\\375\\002\\001\\001\\005\\375\\000\\372\\371\\003\\004\\375\\376\\003\\377\\374\\005\\007\\007\\377\\374\\375\\374\\376\\373\\003\\002\\002\\374\\377\\373\\004\\375\\372\\374\\003\\374\\005\\376\\002\\373\\376\\006\\005\\374\\002\\371\\005\\004\\001\\373\\000\\377\\374\\003\\000\\001\\001\\003\\372\\005\\001\\371\\371\\000\\375\\001\\375\\372\\374\\003\\373\\376\\001\\371\\006\\005\\004\\377\\004\\376\\377\\377\\003\\373\\001\\372\\376\\006\\372\\372\\005\\374\\001\\374\\004\\001\\004\\375\\002\\002\\373\\006\\000\\001\\002\\377\\371\\005\\005\\374\\374\\006\\003\\001\\002\\001\\374\\377\\372\\000\\377\\374\\373\\371\\007\\003\\375\\373\\374\\373\\374\\005\\004\\005\\006\\002\\374\\000\\372\\001\\376\\002\\373\\371\\372\\374\\374\\377\\005\\375\\371\\002\\374\\374\\005\\377\\007\\004\\376\\007\\373\\372\\007\\007\\377\\004\\002\\002\\007\\377\\375\\002\\005\\006\\003\\002\\006\\376\\003\\004\\003\\000\\371\\002\\002\\374\\006\\373\\005\\003\\003\\002\\003\\376\\002\\004\\377\\377\\371\\007\\001\\373\\376\\003\\002\\007\\376\\002\\005\\004\\374\\003\\377\\374\\003\\007\\004\\377\\002\\001\\003\\005\\373\\377\\374\\002\\377\\004\\000\\000\\005\\007\\002\\003\\376\\371\\377\\006\\372\\372\\002\\372\\371\\375\\000\\376\\005\\372\\000\\373\\372\\007\\002\\001\\372\\374\\375\\005\\005\\004\\001\\002\\002\\006\\372\\001\\007\\373\\375\\000\\372\\005\\003\\000\\375\\377\\001\\003\\006\\000\\376\\374\\002\\375\\375\\003\\001\\007\\376\\377\\003\\000\\005\\376\\374\\005\\373\\004\\377\\000\\375\\002\\005\\001\\001\\000\\001\\375\\374\\001\\006\\372\\375\\376\\372\\371\\001\\372\\005\\004\\376\\373\\006\\005\\375\\006\\377\\001\\001\\000\\006\\000\\006\\007\\003\\372\\004\\375\\373\\372\\372\\000\\374\\001\\006\\007\\376\\374\\371\\373\\372\\375\\003\\377\\372\\377\\005\\002\\006\\372\\006\\004\\005\\000\\376\\007\\003\\372\\004\\377\\006\\001\\373\\375\\374\\373\\373\\004\\004\\375\\373\\005\\376\\000\\001\\375\\371\\372\\005\\375\\000\\002\\372\\003\\004\\372\\003\\374\\005\\002\\374\\377\\001\\005\\376\\377\\374\\376\\005\\376\\372\\003\\373\\372\\006\\372\\377\\373\\006\\372\\004\\006\\373\\005\\375\\375\\007\\374\\005\\002\\374\\374\\002\\002\\377\\375\\376\\372\\005\\375\\371\\003\\005\\003\\372\\377\\375\\372\\002\\005\\000\\006\\372\\005\\371\\376\\000\\001\\377\\004\\004\\006\\000\\377\\007\\002\\006\\000\\371\\375\\374\\374\\001\\373\\371\\002\\376\\002\\000\\374\\006\\001\\374\\006\\005\\001\\003\\376\\003\\374\\003\\374\\002\\007\\373\\002\\004\\007\\005\\374\\376\\372\\372\\001\\371\\002\\005\\373\\376\\006\\375\\372\\376\\004\\003\\001\\004\\376\\002\\373\\006\\006\\371\\372\\003\\004\\006\\375\\004\\007\\371\\000\\000\\001\\000\\374\\001\\006\\002\\006\\002\\000\\002\\373\\372\\372\\000\\372\\005\\006\\004\\000\\376\\372\\373\\006\\007\\373\\006\\373\\377\\003\\375\\373\\001\\377\\001\\002\\376\\003\\373\\002\\376\\007\\371\\371\\374\\006\\377\\001\\002\\005\\001\\376\\375\\000\\377\\371\\005\\372\\002\\377\\375\\375\\002\\375\\376\\003\\003\\373\\373\\005\\004\\004\\373\\000\\000\\007\\003\\372\\375\\004\\003\\376\\377\\373\\376\\004\\372\\004\\377\\376\\007\\002\\005\\003\\001\\006\\006\\002\\005\\373\\000\\004\\000\\004\\374\\372\\376\\007\\002\\003\\006\\002\\000\\372\\001\\374\\005\\376\\006\\007\\373\\001\\375\\004\\377\\374\\375\\377\\001\\377\\003\\375\\005\\000\\003\\376\\375\\003\\377\\372\\002\\006\\003\\007\\005\\374\\003\\006\\003\\000\\375\\000\\001\\000\\001\\002\\374\\377\\372\\004\\372\\377\\377\\003\\377\\007\\006\\371\\003\\005\\004\\007\\006\\371\\006\\001\\375\\001\\001\\376\\002\\374\\006\\375\\375\\376\\377\\002\\002\\007\\373\\373\\374\\373\\377\\001\\006\\375\\375\\001\\375\\373\\375\\373\\372\\376\\003\\371\\006\\376\\376\\375\\007\\377\\374\\376\\377\\006\\377\\001\\371\\377\\007\\375\\371\\005\\002\\373\\003\\005\\002\\371\\375\\003\\003\\003\\374\\000\\377\\375\\003\\002\\006\\006\\375\\006\\002\\000\\374\\373\\374\\002\\003\\373\\002\\375\\377\\004\\006\\003\\006\\000\\377\\372\\375\\375\\002\\002\\003\\006\\003\\003\\377\\373\\003\\003\\003\\003\\377\\004\\004\\372\\377\\000\\374\\375\\005\\004\\005\\003\\002\\375\\376\\001\\376\\003\\374\\002\\007\\002\\376\\377\\007\\006\\376\\372\\374\\004\\371\\004\\006\\006\\374\\374\\377\\374\\003\\006\\371\\377\\007\\372\\375\\006\\374\\374\\005\\372\\006\\372\\371\\001\\000\\375\\372\\374\\373\\374\\374\\374\\005\\004\\002\\375\\004\\007\\004\\006\\002\\005\\005\\372\\375\\000\\004\\000\\377\\004\\004\\001\\374\\377\\006\\003\\377\\374\\000\\376\\372\\376\\373\\377\\006\\377\\376\\002\\005\\005\\372\\004\\000\\001\\004\\005\\373\\005\\003\\371\\374\\373\\000\\375\\002\\375\\006\\003\\001\\004\\377\\374\\372\\005\\006\\005\\005\\005\\005\\007\\372\\006\\004\\006\\372\\372\\002\\373\\371\\001\\004\\006\\374\\005\\373\\004\\006\\001\\005\\006\\377\\006\\373\\001\\373\\373\\376\\375\\007\\372\\374\\372\\377\\004\\006\\004\\375\\374\\000\\007\\005\\000\\002\\377\\002\\372\\002\\001\\377\\372\\006\\002\\001\\000\\376\\375\\374\\003\\376\\371\\005\\001\\000\\002\\372\\373\\375\\004\\376\\371\\374\\376\\000\\004\\004\\376\\375\\007\\374\\377\\375\\377\\001\\003\\005\\372\\002\\376\\003\\003\\375\\001\\004\\001\\001\\000\\002\\004\\375\\375\\372\\003\\003\\372\\002\\375\\372\\377\\373\\000\\002\\371\\005\\003\\001\\001\\376\\372\\374\\001\\001\\376\\000\\001\\376\\001\\376\\005\\002\\374\\002\\004\\004\\000\\374\\007\\000\\000\\006\\003\\371\\376\\371\\006\\005\\006\\007\\002\\371\\373\\005\\372\\375\\006\\003\\373\\005\\375\\375\\373\\002\\000\\375\\005\\001\\372\\377\\377\\373\\375\\375\\374\\000\\376\\372\\000\\374\\001\\001\\372\\375\\373\\004\\374\\000\\006\\375\\004\\001\\006\\000\\373\\001\\375\\003\\372\\000\\373\\376\\003\\374\\005\\007\\377\\373\\007\\006\\002\\371\\373\\377\\004\\373\\001\\374\\000\\001\\004\\001\\005\\375\\372\\002\\376\\377\\371\\374\\375\\371\\373\\005\\376\\374\\001\\377\\376\\371\\375\\371\\000\\375\\373\\377\\006\\002\\003\\005\\372\\003\\004\\005\\005\\004\\000\\376\\372\\371\\006\\000\\377\\373\\003\\376\\005\\007\\006\\372\\004\\007\\374\\375\\376\\374\\000\\001\\001\\375\\003\\371\\001\\006\\374\\376\\006\\377\\000\\001\\375\\006\\004\\372\\371\\001\\377\\377\\377\\376\\006\\375\\372\\000\\371\\376\\002\\374\\372\\006\\372\\002\\006\\005\\001\\376\\004\\374\\002\\376\\000\\004\\376\\375\\000\\376\\004\\000\\006\\372\\005\\007\\006\\002\\004\\373\\373\\006\\003\\007\\001\\375\\007\\007\\372\\004\\005\\376\\005\\376\\007\\002\\376\\004\\373\\373\\376\\004\\372\\375\\373\\374\\001\\000\\375\\004\\375\\375\\377\\004\\001\\377\\002\\376\\004\\377\\001\\001\\374\\376\\374\\377\\377\\001\\000\\000\\377\\373\\374\\002\\006\\001\\375\\376\\000\\000\\374\\006\\004\\004\\004\\375\\001\\376\\001\\002\\373\\006\\006\\376\\002\\005\\005\\374\\373\\377\\376\\004\\005\\374\\000\\376\\002\\375\\376\\004\\373\\001\\377\\377\\002\\377\\373\\372\\371\\003\\003\\372\\006\\000\\002\\003\\005\\375\\371\\375\\004\\376\\374\\007\\375\\371\\002\\374\\000\\375\\005\\006\\374\\373\\004\\371\\000\\007\\376\\001\\375\\377\\372\\372\\373\\005\\005\\001\\372\\377\\371\\377\\375\"\n", + " raw_data: \"\\377\\006\\003\\006\\005\\002\\373\\006\\000\\374\\004\\377\\374\\005\\006\\374\\004\\376\\002\\005\\005\\006\\374\\371\\373\\002\\371\\374\\001\\002\\377\\002\\006\\376\\006\\373\\004\\005\\003\\007\\376\\372\\007\\374\\377\\005\\002\\375\\001\\374\\372\\375\\373\\003\\002\\372\\000\\377\\003\\006\\002\\377\\004\\373\\374\\371\\000\\373\\376\\372\\002\\006\\005\\005\\374\\003\\001\\006\\006\\001\\003\\375\\000\\006\\376\\000\\004\\004\\373\\372\\002\\002\\000\\002\\007\\001\\374\\376\\376\\377\\375\\377\\375\\006\\372\\371\\005\\004\\005\\372\\377\\004\\377\\373\\004\\373\\004\\007\\377\\000\\003\\373\\005\\003\\004\\376\\372\\371\\376\\377\\005\\005\\006\\001\\005\\376\\002\\006\\000\\000\\371\\005\\001\\003\\003\\003\\372\\003\\372\\002\\377\\374\\007\\000\\377\\005\\004\\006\\006\\374\\000\\006\\375\\005\\376\\374\\000\\000\\004\\002\\374\\006\\374\\005\\004\\006\\376\\001\\372\\376\\006\\374\\371\\001\\005\\006\\375\\372\\373\\377\\377\\376\\000\\005\\377\\005\\006\\374\\003\\376\\003\\000\\376\\377\\374\\004\\375\\001\\376\\374\\374\\373\\000\\371\\377\\006\\002\\377\\375\\001\\003\\006\\372\\001\\002\\000\\373\\374\\000\\005\\003\\004\\003\\377\\004\\373\\376\\005\\376\\377\\375\\376\\003\\005\\005\\004\\004\\006\\004\\005\\375\\376\\373\\372\\006\\002\\375\\003\\001\\005\\003\\006\\372\\004\\375\\371\\377\\003\\374\\006\\376\\376\\373\\374\\377\\006\\001\\377\\372\\003\\001\\372\\006\\000\\377\\372\\372\\001\\000\\377\\371\\002\\373\\006\\373\\372\\002\\375\\374\\373\\376\\377\\377\\375\\374\\006\\003\\372\\003\\000\\007\\005\\371\\001\\372\\006\\373\\376\\006\\376\\000\\372\\007\\000\\000\\006\\006\\373\\006\\372\\005\\374\\376\\376\\373\\372\\006\\004\\372\\003\\377\\000\\005\\002\\374\\375\\373\\000\\000\\006\\004\\376\\001\\001\\003\\374\\003\\004\\004\\377\\007\\377\\376\\006\\005\\000\\371\\006\\002\\003\\377\\374\\377\\372\\377\\005\\377\\003\\000\\006\\375\\004\\372\\376\\005\\372\\001\\005\\006\\005\\002\\004\\001\\001\\007\\375\\004\\002\\007\\374\\006\\005\\000\\002\\372\\001\\377\\002\\372\\005\\007\\373\\377\\374\\001\\004\\005\\372\\001\\001\\377\\002\\007\\003\\373\\376\\377\\005\\006\\374\\002\\375\\002\\373\\001\\004\\377\\373\\003\\377\\006\\006\\376\\004\\373\\003\\375\\376\\374\\375\\376\\007\\000\\377\\371\\004\\373\\000\\374\\002\\377\\375\\005\\006\\372\\003\\002\\376\\007\\002\\003\\001\\000\\006\\004\\000\\001\\000\\376\\377\\000\\004\\006\\001\\000\\373\\006\\374\\000\\375\\377\\004\\006\\373\\006\\003\\006\\373\\373\\376\\007\\375\\377\\004\\374\\003\\001\\376\\374\\372\\001\\375\\004\\003\\002\\003\\373\\001\\003\\374\\001\\372\\003\\003\\004\\372\\007\\005\\004\\373\\372\\002\\377\\007\\003\\001\\001\\373\\375\\373\\373\\372\\375\\376\\375\\005\\376\\373\\374\\374\\000\\002\\373\\006\\003\\000\\005\\005\\000\\007\\004\\377\\373\\372\\004\\375\\375\\002\\007\\376\\000\\006\\376\\373\\372\\001\\373\\000\\377\\006\\002\\006\\375\\376\\002\\004\\006\\373\\001\\002\\372\\005\\376\\002\\001\\373\\377\\004\\001\\374\\373\\002\\004\\002\\377\\004\\372\\377\\373\\004\\375\\001\\372\\006\\376\\007\\371\\006\\003\\006\\373\\000\\377\\375\\005\\374\\005\\374\\005\\373\\372\\000\\372\\371\\002\\372\\375\\372\\377\\005\\371\\004\\375\\000\\006\\002\\006\\377\\375\\006\\006\\004\\005\\374\\372\\372\\372\\004\\377\\005\\377\\372\\375\\374\\371\\376\\000\\004\\005\\005\\003\\373\\371\\000\\375\\001\\376\\372\\006\\376\\374\\005\\005\\372\\372\\005\\003\\000\\001\\376\\372\\377\\004\\376\\001\\377\\375\\005\\005\\371\\377\\371\\377\\374\\007\\004\\007\\000\\377\\000\\376\\001\\376\\004\\375\\006\\003\\001\\005\\373\\004\\376\\005\\003\\377\\377\\001\\004\\375\\375\\376\\377\\373\\376\\000\\000\\005\\374\\372\\375\\000\\376\\000\\002\\376\\005\\004\\377\\004\\372\\006\\375\\377\\372\\376\\004\\374\\003\\004\\006\\375\\376\\003\\371\\374\\374\\000\\000\\371\\006\\002\\003\\376\\374\\374\\001\\375\\004\\003\\372\\007\\004\\005\\006\\004\\007\\372\\376\\371\\000\\007\\005\\005\\005\\001\\374\\374\\377\\006\\003\\000\\001\\004\\372\\375\\005\\003\\002\\374\\004\\005\\371\\373\\373\\377\\374\\372\\376\\002\\372\\377\\004\\000\\003\\002\\001\\004\\377\\374\\002\\004\\374\\377\\376\\006\\005\\002\\004\\005\\003\\004\\373\\377\\004\\373\\004\\003\\004\\001\\375\\005\\004\\001\\376\\005\\005\\000\\375\\374\\001\\373\\006\\000\\376\\005\\377\\001\\002\\374\\007\\005\\002\\000\\371\\375\\007\\000\\000\\005\\372\\002\\373\\004\\000\\374\\375\\006\\371\\007\\004\\007\\374\\001\\000\\006\\376\\375\\371\\375\\372\\002\\003\\004\\375\\002\\005\\373\\372\\377\\004\\373\\001\\004\\003\\007\\002\\373\\372\\373\\376\\374\\004\\003\\002\\000\\376\\375\\006\\376\\373\\006\\371\\006\\005\\005\\006\\000\\000\\377\\001\\372\\005\\377\\005\\376\\001\\373\\376\\001\\375\\371\\375\\372\\373\\002\\374\\002\\000\\006\\377\\003\\004\\371\\001\\375\\002\\004\\373\\001\\000\\371\\375\\372\\004\\003\\372\\002\\002\\374\\002\\001\\004\\371\\006\\007\\003\\373\\004\\003\\376\\005\\005\\003\\373\\374\\003\\004\\376\\000\\007\\002\\005\\376\\006\\004\\001\\375\\004\\377\\375\\006\\375\\001\\005\\376\\373\\377\\003\\374\\000\\371\\000\\006\\000\\007\\374\\376\\377\\004\\374\\004\\374\\374\\374\\004\\376\\001\\002\\376\\000\\000\\001\\002\\373\\377\\004\\007\\376\\373\\374\\375\\006\\376\\004\\007\\001\\001\\373\\372\\003\\001\\002\\003\\375\\000\\373\\004\\376\\000\\373\\377\\376\\376\\377\\001\\005\\000\\006\\372\\002\\006\\377\\376\\004\\003\\002\\376\\004\\006\\001\\002\\374\\000\\374\\005\\377\\375\\372\\006\\003\\373\\376\\004\\372\\002\\003\\377\\006\\376\\375\\377\\375\\001\\374\\004\\005\\000\\001\\372\\376\\376\\006\\377\\374\\001\\006\\375\\374\\373\\372\\002\\001\\001\\373\\006\\374\\375\\001\\377\\004\\001\\007\\371\\001\\000\\376\\006\\376\\375\\003\\374\\371\\376\\377\\376\\000\\001\\373\\006\\376\\375\\373\\002\\374\\375\\375\\001\\000\\003\\007\\004\\373\\003\\377\\372\\000\\376\\376\\371\\006\\373\\001\\377\\002\\003\\001\\377\\004\\007\\002\\375\\376\\004\\002\\005\\002\\001\\376\\001\\006\\002\\002\\375\\372\\003\\377\\000\\375\\004\\375\\377\\373\\374\\376\\001\\002\\373\\377\\377\\003\\005\\004\\373\\006\\001\\001\\003\\000\\005\\003\\377\\376\\377\\002\\004\\373\\003\\006\\004\\372\\004\\003\\002\\371\\376\\377\\377\\371\\373\\371\\377\\374\\006\\373\\005\\007\\372\\373\\377\\003\\003\\374\\377\\007\\004\\376\\000\\000\\003\\372\\007\\001\\372\\004\\000\\001\\003\\375\\005\\007\\001\\376\\001\\377\\371\\377\\004\\007\\374\\373\\373\\006\\007\\005\\001\\376\\376\\005\\373\\001\\005\\006\\004\\005\\372\\373\\002\\004\\006\\377\\375\\005\\376\\000\\373\\005\\006\\003\\002\\000\\372\\001\\001\\000\\000\\007\\372\\001\\374\\006\\003\\005\\376\\003\\002\\377\\373\\372\\375\\371\\377\\001\\374\\377\\001\\371\\006\\001\\376\\374\\006\\375\\001\\000\\007\\000\\375\\376\\376\\377\\001\\374\\000\\371\\373\\374\\003\\006\\006\\371\\001\\001\\376\\006\\377\\001\\375\\002\\376\\000\\377\\006\\000\\004\\372\\000\\000\\375\\000\\003\\002\\004\\372\\000\\001\\372\\002\\004\\003\\374\\373\\005\\006\\376\\007\\000\\000\\373\\003\\000\\007\\377\\376\\372\\007\\376\\003\\001\\374\\001\\006\\006\\001\\372\\002\\371\\006\\005\\374\\005\\005\\377\\373\\373\\003\\006\\002\\376\\371\\007\\374\\006\\372\\377\\375\\002\\000\\006\\006\\377\\373\\001\\372\\375\\004\\377\\372\\372\\001\\375\\003\\000\\373\\000\\373\\001\\004\\371\\377\\377\\372\\005\\372\\004\\005\\007\\002\\372\\001\\371\\002\\003\\006\\376\\372\\006\\373\\375\\376\\000\\373\\376\\007\\377\\000\\375\\000\\371\\006\\373\\007\\002\\004\\376\\372\\004\\002\\003\\000\\373\\005\\376\\377\\001\\004\\372\\377\\000\\003\\000\\373\\004\\005\\375\\006\\374\\004\\376\\003\\375\\374\\372\\001\\003\\374\\000\\002\\001\\004\\002\\374\\003\\001\\006\\374\\372\\003\\006\\375\\377\\374\\006\\001\\375\\005\\375\\002\\373\\007\\004\\373\\003\\372\\004\\374\\004\\373\\007\\000\\007\\377\\376\\374\\371\\004\\001\\375\\373\\005\\007\\377\\371\\372\\005\\372\\004\\377\\374\\372\\001\\002\\000\\001\\375\\003\\374\\375\\375\\376\\003\\006\\372\\006\\002\\006\\377\\375\\377\\005\\006\\005\\374\\377\\372\\373\\004\\003\\376\\006\\373\\006\\374\\002\\006\\005\\006\\371\\005\\004\\372\\001\\004\\371\\003\\005\\004\\374\\003\\373\\376\\374\\005\\003\\000\\006\\373\\006\\376\\001\\376\\006\\372\\371\\005\\372\\375\\374\\002\\003\\375\\372\\000\\001\\001\\006\\000\\002\\374\\373\\377\\373\\001\\375\\000\\000\\001\\003\\006\\374\\002\\375\\000\\375\\002\\005\\001\\004\\000\\377\\376\\005\\371\\377\\000\\002\\376\\372\\004\\376\\372\\372\\003\\004\\003\\375\\001\\376\\002\\003\\371\\372\\377\\375\\005\\004\\376\\005\\004\\004\\376\\002\\372\\001\\373\\002\\000\\006\\376\\375\\007\\001\\000\\002\\374\\000\\377\\005\\372\\003\\000\\000\\000\\005\\006\\002\\001\\004\\000\\376\\375\\006\\004\\374\\376\\006\\002\\007\\006\\377\\006\\006\\376\\000\\002\\004\\374\\005\\373\\004\\375\\371\\376\\006\\000\\373\\376\\376\\003\\007\\371\\377\\005\\376\\000\\005\\001\\375\\371\\001\\376\\373\\006\\005\\000\\376\\005\\001\\001\\376\\002\\002\\001\\375\\375\\372\\373\\004\\372\\000\\000\\006\\005\\375\\003\\005\\006\\372\\003\\001\\006\\377\\003\\003\\002\\001\\377\\004\\374\\006\\003\\374\\004\\373\\374\\006\\376\\005\\003\\374\\377\\376\\003\\006\\000\\375\\007\\003\\375\\371\\373\\374\\006\\004\\004\\373\\373\\374\\005\\001\\006\\005\\373\\000\\372\\371\\000\\004\\002\\375\\374\\006\\375\\373\\005\\376\\004\\007\\002\\002\\374\\375\\004\\371\\007\\007\\006\\003\\377\\004\\006\\007\\372\\006\\371\\371\\374\\002\\001\\371\\007\\377\\377\\005\\002\\001\\004\\002\\377\\377\\000\\373\\000\\004\\005\\004\\372\\377\\376\\373\\007\\007\\000\\000\\000\\001\\006\\006\\375\\002\\006\\372\\005\\000\\005\\003\\371\\371\\006\\001\\375\\002\\001\\377\\006\\376\\372\\373\\375\\001\\002\\004\\376\\001\\374\\373\\005\\374\\376\\006\\002\\377\\006\\373\\007\\002\\004\\374\\373\\374\\002\\004\\372\\006\\005\\375\\000\\371\\003\\376\\376\\002\\374\\001\\002\\004\\001\\003\\006\\002\\002\\371\\376\\006\\000\\371\\372\\007\\002\\005\\002\\372\\006\\000\\373\\000\\375\\001\\002\\004\\007\\374\\376\\000\\372\\003\\375\\377\\001\\375\\003\\372\\372\\000\\001\\002\\376\\373\\376\\004\\372\\004\\372\\377\\375\\001\\004\\375\\002\\371\\376\\006\\005\\374\\001\\372\\006\\000\\005\\000\\373\\377\\001\\007\\375\\374\\002\\007\\373\\373\\000\\376\\004\\006\\000\\372\\003\\002\\376\\007\\002\\002\\001\\374\\000\\373\\374\\005\\007\\375\\003\\004\\006\\371\\002\\006\\372\\372\\376\\371\\002\\002\\000\\006\\373\\374\\003\\374\\372\\004\\372\\000\\002\\374\\374\\007\\000\\001\\002\\004\\376\\001\\004\\375\\377\\003\\376\\004\\001\\376\\374\\377\\372\\374\\000\\003\\002\\371\\372\\377\\373\\005\\371\\373\\003\\372\\373\\004\\371\\006\\002\\006\\376\\002\\377\\375\\376\\371\\005\\006\\006\\005\\374\\004\\372\\006\\372\\004\\002\\006\\001\\007\\007\\002\\005\\001\\005\\005\\004\\004\\007\\372\\373\\004\\374\\373\\004\\374\\376\\376\\375\\372\\005\\002\\375\\005\\007\\375\\007\\006\\376\\374\\003\\377\\377\\000\\373\\372\\003\\371\\006\\000\\373\\374\\001\\003\\005\\372\\376\\374\\002\\003\\373\\377\\006\\376\\374\\004\\005\\375\\002\\004\\371\\004\\371\\377\\006\\001\\375\\005\\004\\374\\006\\376\\004\\000\\001\\372\\007\\374\\373\\373\\005\\372\\004\\001\\006\\374\\374\\001\\003\\000\\375\\371\\000\\004\\005\\003\\376\\377\\004\\004\\007\\000\\004\\007\\004\\376\\376\\003\\376\\001\\373\\377\\373\\002\\374\\003\\374\\373\\374\\376\\373\\003\\006\\375\\002\\373\\375\\374\\376\\373\\001\\375\\001\\375\\371\\001\\003\\002\\006\\374\\004\\371\\373\\004\\374\\377\\003\\374\\371\\000\\006\\003\\377\\374\\006\\372\\373\\376\\375\\002\\003\\001\\000\\005\\004\\374\\006\\377\\006\\371\\002\\377\\376\\000\\374\\376\\005\\373\\376\\004\\001\\377\\006\\001\\372\\001\\002\\375\\373\\000\\374\\007\\376\\006\\375\\375\\377\\004\\004\\002\\374\\000\\376\\002\\006\\376\\006\\003\\000\\376\\371\\005\\004\\373\\004\\005\\376\\000\\375\\375\\003\\371\\375\\002\\371\\007\\374\\377\\005\\006\\372\\000\\375\\373\\001\\375\\005\\374\\374\\001\\374\\004\\374\\002\\371\\000\\006\\005\\376\\373\\373\\004\\374\\004\\005\\004\\374\\005\\372\\373\\003\\377\\000\\005\\376\\375\\373\\376\\007\\003\\000\\004\\007\\374\\376\\000\\377\\375\\377\\000\\376\\373\\007\\005\\374\\004\\007\\006\\004\\001\\373\\377\\376\\372\\005\\007\\007\\004\\377\\374\\373\\004\\376\\373\\003\\004\\372\\004\\376\\373\\000\\000\\002\\374\\002\\000\\001\\003\\376\\007\\375\\374\\001\\003\\000\\374\\375\\001\\374\\002\\375\\003\\375\\001\\376\\007\\374\\003\\003\\000\\374\\001\\373\\000\\000\\003\\374\\375\\005\\377\\001\\374\\000\\375\\372\\372\\000\\376\\372\\001\\373\\007\\372\\373\\375\\375\\373\\001\\003\\372\\005\\376\\374\\374\\375\\002\\372\\376\\376\\374\\004\\374\\005\\004\\007\\375\\372\\004\\001\\374\\002\\001\\372\\006\\373\\003\\003\\000\\375\\373\\374\\004\\373\\374\\372\\372\\005\\002\\005\\003\\377\\376\\002\\005\\006\\374\\374\\003\\003\\377\\000\\371\\006\\007\\003\\001\\005\\005\\377\\373\\372\\005\\377\\005\\002\\373\\006\\001\\007\\006\\005\\373\\006\\003\\002\\000\\372\\002\\005\\373\\377\\001\\375\\372\\003\\374\\375\\004\\372\\372\\371\\376\\377\\374\\372\\004\\376\\001\\375\\000\\374\\000\\375\\376\\377\\372\\371\\003\\373\\005\\371\\001\\372\\373\\003\\374\\003\\376\\375\\003\\004\\372\\374\\002\\372\\006\\377\\373\\000\\373\\002\\375\\374\\005\\004\\003\\006\\377\\372\\375\\005\\376\\374\\001\\004\\371\\373\\377\\001\\372\\003\\372\\002\\372\\001\\001\\007\\000\\004\\002\\000\\375\\372\\371\\001\\001\\375\\371\\005\\000\\001\\377\\002\\376\\002\\000\\376\\373\\371\\373\\376\\000\\001\\375\\373\\372\\005\\005\\006\\001\\001\\373\\003\\006\\373\\006\\005\\003\\374\\006\\375\\007\\005\\374\\007\\007\\371\\376\\375\\374\\001\\376\\372\\372\\373\\000\\004\\376\\005\\372\\376\\000\\004\\375\\001\\000\\376\\376\\376\\004\\375\\002\\374\\371\\373\\371\\006\\000\\006\\005\\374\\005\\377\\373\\001\\375\\375\\000\\376\\373\\377\\372\\377\\375\\006\\005\\002\\001\\377\\374\\004\\001\\002\\006\\004\\375\\374\\000\\003\\003\\000\\000\\003\\373\\006\\374\\376\\007\\376\\003\\003\\373\\376\\003\\003\\000\\002\\004\\000\\375\\006\\373\\003\\001\\377\\372\\006\\005\\003\\376\\374\\002\\001\\001\\006\\005\\376\\377\\374\\003\\006\\372\\004\\002\\374\\004\\374\\374\\005\\004\\375\\377\\377\\001\\377\\005\\006\\372\\377\\377\\377\\002\\000\\377\\006\\373\\376\\376\\001\\376\\007\\004\\004\\371\\373\\371\\374\\376\\000\\000\\372\\003\\377\\004\\005\\004\\001\\376\\374\\003\\376\\007\\373\\006\\002\\006\\004\\004\\371\\372\\373\\006\\376\\005\\373\\373\\376\\000\\006\\006\\007\\377\\005\\001\\376\\374\\373\\005\\001\\373\\001\\004\\003\\000\\375\\005\\377\\003\\373\\377\\376\\375\\376\\000\\000\\000\\375\\375\\372\\000\\000\\003\\373\\000\\373\\001\\007\\006\\001\\374\\007\\002\\006\\004\\372\\377\\000\\375\\375\\000\\377\\002\\005\\001\\376\\371\\001\\374\\374\\373\\374\\372\\376\\005\\372\\003\\373\\007\\374\\005\\003\\000\\006\\005\\372\\004\\372\\003\\005\\373\\376\\003\\374\\377\\003\\373\\003\\005\\374\\001\\374\\375\\002\\374\\001\\000\\002\\374\\003\\007\\374\\373\\004\\000\\004\\003\\002\\000\\377\\371\\002\\377\\003\\006\\001\\000\\371\\377\\377\\376\\002\\000\\004\\003\\007\\005\\005\\375\\376\\375\\005\\375\\376\\007\\002\\004\\001\\003\\001\\001\\004\\375\\003\\000\\374\\004\\376\\001\\006\\376\\003\\374\\374\\373\\375\\003\\374\\000\\376\\001\\004\\374\\004\\003\\000\\005\\003\\374\\006\\375\\373\\376\\374\\373\\002\\004\\004\\006\\374\\004\\001\\001\\002\\373\\005\\004\\373\\375\\377\\377\\002\\005\\001\\375\\375\\006\\001\\373\\003\\377\\004\\003\\003\\006\\001\\376\\375\\375\\377\\373\\004\\373\\007\\375\\001\\376\\374\\002\\004\\003\\377\\376\\374\\005\\007\\000\\006\\006\\377\\007\\374\\376\\375\\371\\375\\003\\005\\005\\373\\373\\376\\002\\005\\375\\000\\375\\371\\000\\004\\006\\373\\372\\005\\372\\377\\375\\372\\006\\005\\375\\372\\377\\374\\375\\006\\002\\377\\374\\374\\006\\374\\004\\375\\373\\005\\006\\377\\000\\001\\377\\003\\375\\006\\376\\004\\002\\372\\372\\377\\005\\371\\376\\374\\002\\377\\373\\001\\006\\006\\372\\002\\004\\001\\005\\001\\002\\003\\372\\001\\377\\004\\003\\005\\003\\006\\372\\002\\376\\000\\000\\000\\376\\377\\373\\004\\002\\371\\373\\003\\374\\372\\005\\005\\373\\376\\003\\375\\372\\373\\375\\374\\006\\377\\004\\005\\004\\377\\000\\005\\375\\004\\005\\377\\003\\004\\002\\006\\374\\377\\005\\003\\376\\372\\003\\373\\374\\377\\004\\372\\002\\000\\002\\002\\375\\006\\001\\377\\374\\001\\374\\375\\372\\372\\004\\004\\001\\377\\003\\002\\375\\006\\007\\374\\376\\375\\372\\003\\376\\372\\373\\374\\004\\374\\002\\376\\003\\003\\372\\000\\375\\002\\007\\005\\375\\373\\000\\373\\373\\002\\372\\376\\005\\000\\004\\006\\375\\374\\006\\372\\377\\006\\000\\005\\004\\002\\375\\376\\000\\005\\003\\374\\375\\001\\372\\373\\005\\002\\376\\374\\007\\000\\003\\002\\005\\006\\374\\374\\006\\371\\375\\002\\005\\005\\372\\003\\372\\001\\000\\376\\377\\372\\372\\000\\004\\002\\002\\373\\376\\374\\373\\003\\373\\376\\002\\007\\007\\004\\003\\376\\373\\002\\003\\372\\001\\001\\001\\375\\376\\377\\372\\004\\002\\000\\003\\371\\002\\003\\377\\375\\001\\372\\372\\003\\005\\376\\007\\374\\374\\000\\374\\376\\004\\374\\004\\373\\004\\375\\000\\376\\001\\377\\004\\007\\373\\003\\371\\001\\375\\007\\002\\000\\001\\003\\006\\004\"\n", " }\n", " type: TENSOR\n", " }\n", @@ -528,7 +541,7 @@ " name: \"value\"\n", " t {\n", " data_type: 1\n", - " raw_data: \"\\263-\\341<\"\n", + " raw_data: \"\\2556\\341<\"\n", " }\n", " type: TENSOR\n", " }\n", @@ -542,7 +555,7 @@ " t {\n", " dims: 128\n", " data_type: 6\n", - " raw_data: \"\\271\\377\\377\\377\\032\\003\\000\\0009\\001\\000\\000\\302\\002\\000\\000;\\375\\377\\377\\031\\000\\000\\000\\024\\003\\000\\000d\\003\\000\\000\\327\\374\\377\\377\\363\\377\\377\\377u\\003\\000\\000\\374\\000\\000\\000t\\000\\000\\000\\321\\002\\000\\000\\236\\377\\377\\377\\241\\377\\377\\377\\237\\375\\377\\377\\010\\000\\000\\000\\350\\002\\000\\000}\\376\\377\\377\\267\\377\\377\\377\\374\\000\\000\\000\\355\\001\\000\\000N\\375\\377\\377\\\\\\002\\000\\000\\346\\002\\000\\000\\317\\000\\000\\000\\207\\001\\000\\000?\\000\\000\\000\\302\\002\\000\\000Y\\377\\377\\377\\326\\376\\377\\377\\\\\\003\\000\\000\\374\\376\\377\\377\\334\\000\\000\\000\\200\\001\\000\\000\\362\\377\\377\\377+\\000\\000\\000\\304\\375\\377\\377u\\000\\000\\000\\340\\000\\000\\000\\275\\001\\000\\000\\324\\377\\377\\377\\332\\000\\000\\000\\026\\001\\000\\000\\333\\001\\000\\000\\371\\375\\377\\377\\363\\000\\000\\000|\\002\\000\\000\\335\\376\\377\\377\\226\\375\\377\\377\\335\\002\\000\\0002\\001\\000\\000F\\377\\377\\377\\006\\003\\000\\000\\310\\375\\377\\377\\344\\377\\377\\377\\177\\376\\377\\377>\\001\\000\\000\\033\\002\\000\\000I\\003\\000\\000\\006\\376\\377\\377\\315\\375\\377\\377\\033\\003\\000\\000\\236\\000\\000\\000@\\376\\377\\377\\031\\002\\000\\000\\321\\002\\000\\000;\\000\\000\\000\\035\\377\\377\\377\\354\\377\\377\\377Z\\001\\000\\000N\\375\\377\\377I\\001\\000\\000\\030\\001\\000\\000w\\377\\377\\377\\303\\002\\000\\000\\022\\000\\000\\000\\377\\001\\000\\000!\\000\\000\\000\\035\\001\\000\\000\\003\\375\\377\\377^\\377\\377\\377\\336\\374\\377\\377p\\377\\377\\377\\351\\002\\000\\000X\\376\\377\\377\\247\\000\\000\\000H\\376\\377\\377}\\000\\000\\000\\225\\374\\377\\3776\\001\\000\\000\\301\\001\\000\\000\\210\\001\\000\\000\\374\\376\\377\\377\\307\\377\\377\\377\\320\\374\\377\\377\\267\\377\\377\\377F\\375\\377\\377\\352\\377\\377\\377=\\377\\377\\3770\\376\\377\\377#\\000\\000\\000\\313\\376\\377\\377\\334\\000\\000\\000\\261\\001\\000\\000\\363\\001\\000\\000\\037\\001\\000\\000\\220\\377\\377\\377\\202\\000\\000\\000d\\377\\377\\377\\013\\002\\000\\000\\266\\002\\000\\000\\347\\374\\377\\377+\\001\\000\\000\\301\\376\\377\\377\\341\\377\\377\\377O\\003\\000\\000\\037\\375\\377\\377\\244\\375\\377\\377\\352\\000\\000\\000\\302\\001\\000\\000I\\002\\000\\000~\\377\\377\\377*\\376\\377\\377\\333\\000\\000\\000\\214\\000\\000\\000\\014\\002\\000\\000\"\n", + " raw_data: \"\\016\\003\\000\\000\\240\\375\\377\\377\\344\\002\\000\\000\\341\\002\\000\\000\\207\\000\\000\\000C\\377\\377\\377\\255\\375\\377\\377,\\376\\377\\377\\\"\\001\\000\\000\\237\\001\\000\\000\\'\\003\\000\\000\\220\\377\\377\\377{\\003\\000\\000\\252\\002\\000\\000A\\003\\000\\000\\233\\002\\000\\000\\375\\377\\377\\377\\302\\001\\000\\000\\365\\374\\377\\377\\025\\003\\000\\000w\\003\\000\\000\\231\\375\\377\\377\\030\\377\\377\\377K\\000\\000\\000h\\002\\000\\0002\\001\\000\\000X\\003\\000\\000\\241\\001\\000\\000W\\000\\000\\000\\010\\002\\000\\000R\\002\\000\\000v\\003\\000\\000\\353\\001\\000\\000J\\001\\000\\000\\312\\377\\377\\377\\007\\002\\000\\000\\345\\376\\377\\377\\316\\001\\000\\000\\352\\000\\000\\000\\357\\375\\377\\377\\004\\001\\000\\000\\353\\002\\000\\000\\342\\376\\377\\377#\\003\\000\\000\\252\\001\\000\\000\\354\\377\\377\\377Y\\003\\000\\000x\\000\\000\\000\\251\\377\\377\\377f\\000\\000\\000}\\003\\000\\000\\317\\374\\377\\377\\300\\376\\377\\377\\230\\000\\000\\000c\\003\\000\\000\\204\\377\\377\\377n\\376\\377\\377(\\375\\377\\377\\314\\001\\000\\000\\304\\000\\000\\000\\357\\374\\377\\377\\241\\376\\377\\377\\217\\375\\377\\377\\r\\001\\000\\000;\\001\\000\\000\\240\\377\\377\\377Q\\377\\377\\377U\\375\\377\\377\\'\\377\\377\\377h\\002\\000\\000f\\002\\000\\000\\307\\002\\000\\000\\364\\001\\000\\000\\303\\000\\000\\000W\\000\\000\\000\\001\\375\\377\\377!\\002\\000\\000\\210\\000\\000\\000:\\377\\377\\377\\242\\000\\000\\000o\\377\\377\\377\\327\\001\\000\\000\\263\\377\\377\\377X\\003\\000\\000\\303\\000\\000\\0003\\000\\000\\000\\337\\375\\377\\377=\\375\\377\\377{\\000\\000\\000\\336\\375\\377\\377I\\375\\377\\377\\036\\377\\377\\377\\016\\002\\000\\000\\017\\003\\000\\000\\240\\374\\377\\377f\\001\\000\\000\\003\\375\\377\\377\\020\\375\\377\\377\\224\\000\\000\\000S\\375\\377\\377\\266\\001\\000\\000\\337\\002\\000\\000\\356\\375\\377\\377\\027\\000\\000\\000\\340\\000\\000\\000p\\377\\377\\377\\371\\000\\000\\000\\253\\000\\000\\000\\322\\001\\000\\0008\\001\\000\\000\\346\\377\\377\\377\\271\\374\\377\\377\\003\\376\\377\\377\\032\\001\\000\\000\\207\\374\\377\\377\\265\\374\\377\\377;\\003\\000\\000\\353\\374\\377\\377\\275\\002\\000\\000:\\002\\000\\000}\\377\\377\\377\\000\\376\\377\\377W\\376\\377\\377\\235\\000\\000\\000\\333\\376\\377\\377x\\375\\377\\377+\\376\\377\\377\\025\\003\\000\\000\"\n", " }\n", " type: TENSOR\n", " }\n", @@ -600,9 +613,9 @@ " name: \"/linear/export_handler/DequantizeLinear\"\n", " op_type: \"DequantizeLinear\"\n", " }\n", - " name: \"torch_jit\"\n", + " name: \"main_graph\"\n", " input {\n", - " name: \"inp.1\"\n", + " name: \"out.1\"\n", " type {\n", " tensor_type {\n", " elem_type: 1\n", @@ -652,7 +665,7 @@ "}" ] }, - "execution_count": 8, + "execution_count": 21, "metadata": {}, "output_type": "execute_result" } @@ -691,7 +704,7 @@ }, { "cell_type": "code", - "execution_count": 9, + "execution_count": 22, "metadata": { "collapsed": false, "pycharm": { @@ -724,10 +737,10 @@ " " ], "text/plain": [ - "" + "" ] }, - "execution_count": 9, + "execution_count": 22, "metadata": {}, "output_type": "execute_result" } @@ -792,14 +805,22 @@ }, { "cell_type": "code", - "execution_count": 10, + "execution_count": 23, "metadata": { "collapsed": false, "pycharm": { "name": "#%%\n" } }, - "outputs": [], + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "2024-02-02 09:51:08.764823876 [W:onnxruntime:, graph.cc:1283 Graph] Initializer linear.bias appears in graph inputs and will not be treated as constant value/weight. This may prevent some of the graph optimizations, like const folding. Move it out of graph inputs if there is no need to override it, by either re-generating the model with latest exporter/converter or with the tool onnxruntime/tools/python/remove_initializer_from_input.py.\n" + ] + } + ], "source": [ "import onnxruntime as ort\n", "\n", @@ -862,7 +883,7 @@ }, { "cell_type": "code", - "execution_count": 11, + "execution_count": 24, "metadata": { "collapsed": false, "pycharm": { @@ -910,7 +931,7 @@ }, { "cell_type": "code", - "execution_count": 13, + "execution_count": 25, "metadata": { "collapsed": false, "pycharm": { @@ -922,7 +943,7 @@ "name": "stderr", "output_type": "stream", "text": [ - "c:\\users\\alessand\\documents\\brevitas\\src\\brevitas\\export\\onnx\\standard\\manager.py:23: UserWarning: ONNX opset version set to 13, override with opset_version=\n", + "/home/giuseppe/Documents/git/brevitas/src/brevitas/export/onnx/standard/manager.py:26: UserWarning: ONNX opset version set to 13, override with opset_version=\n", " warnings.warn(f\"ONNX opset version set to {DEFAULT_OPSET}, override with {ka}=\")\n" ] } @@ -997,7 +1018,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.8.15" + "version": "3.11.5" }, "orig_nbformat": 4, "vscode": { diff --git a/notebooks/quantized_recurrent.ipynb b/notebooks/quantized_recurrent.ipynb index 766e82745..5bb95a465 100644 --- a/notebooks/quantized_recurrent.ipynb +++ b/notebooks/quantized_recurrent.ipynb @@ -38,12 +38,14 @@ " bias: bool = True,\n", " batch_first: bool = False,\n", " bidirectional: bool = False,\n", - " weight_quant = Int8WeightPerTensorFloat,\n", - " bias_quant = Int32Bias,\n", - " io_quant = Int8ActPerTensorFloat,\n", - " gate_acc_quant = Int8ActPerTensorFloat,\n", - " shared_input_hidden_weights = False,\n", + " weight_quant=Int8WeightPerTensorFloat,\n", + " bias_quant=Int32Bias,\n", + " io_quant=Int8ActPerTensorFloat,\n", + " gate_acc_quant=Int8ActPerTensorFloat,\n", + " shared_input_hidden_weights=False,\n", " return_quant_tensor: bool = False,\n", + " dtype: Optional[torch.dtype] = None,\n", + " device: Optional[torch.device] = None,\n", " **kwargs):\n", " super(QuantRNN, self).__init__(\n", " layer_impl=_QuantRNNLayer,\n", @@ -60,6 +62,8 @@ " gate_acc_quant=gate_acc_quant,\n", " shared_input_hidden_weights=shared_input_hidden_weights,\n", " return_quant_tensor=return_quant_tensor,\n", + " dtype=dtype,\n", + " device=device,\n", " **kwargs)\n", "\n", "```" @@ -107,7 +111,7 @@ "name": "stderr", "output_type": "stream", "text": [ - "c:\\users\\alessand\\documents\\brevitas\\src\\brevitas\\nn\\mixin\\base.py:112: UserWarning: Keyword arguments are being passed but they not being used.\n", + "/home/giuseppe/Documents/git/brevitas/src/brevitas/nn/mixin/base.py:78: UserWarning: Keyword arguments are being passed but they not being used.\n", " warn('Keyword arguments are being passed but they not being used.')\n" ] } @@ -278,46 +282,46 @@ "Input-hidden weight bit-width: 4.0\n", "Hidden-hidden weight bit-width: 4.0\n", "I/O quant bit-width: 6.0\n", - "Input-hidden weight scale: tensor([[0.0316],\n", - " [0.0317],\n", - " [0.0319],\n", - " [0.0318],\n", - " [0.0314],\n", + "Input-hidden weight scale: tensor([[0.0297],\n", + " [0.0311],\n", " [0.0298],\n", + " [0.0295],\n", + " [0.0316],\n", + " [0.0311],\n", + " [0.0318],\n", + " [0.0309],\n", " [0.0317],\n", - " [0.0285],\n", - " [0.0306],\n", - " [0.0312],\n", + " [0.0309],\n", + " [0.0316],\n", + " [0.0319],\n", + " [0.0319],\n", " [0.0318],\n", " [0.0315],\n", - " [0.0298],\n", - " [0.0314],\n", - " [0.0293],\n", " [0.0310],\n", - " [0.0306],\n", - " [0.0310],\n", - " [0.0309],\n", - " [0.0317]], grad_fn=)\n", - "Hidden-hidden weight scale: tensor([[0.0316],\n", - " [0.0317],\n", + " [0.0319],\n", " [0.0319],\n", " [0.0318],\n", - " [0.0314],\n", + " [0.0312]], grad_fn=)\n", + "Hidden-hidden weight scale: tensor([[0.0297],\n", + " [0.0311],\n", " [0.0298],\n", + " [0.0295],\n", + " [0.0316],\n", + " [0.0311],\n", + " [0.0318],\n", + " [0.0309],\n", " [0.0317],\n", - " [0.0285],\n", - " [0.0306],\n", - " [0.0312],\n", + " [0.0309],\n", + " [0.0316],\n", + " [0.0319],\n", + " [0.0319],\n", " [0.0318],\n", " [0.0315],\n", - " [0.0298],\n", - " [0.0314],\n", - " [0.0293],\n", - " [0.0310],\n", - " [0.0306],\n", " [0.0310],\n", - " [0.0309],\n", - " [0.0317]], grad_fn=)\n" + " [0.0319],\n", + " [0.0319],\n", + " [0.0318],\n", + " [0.0312]], grad_fn=)\n" ] } ], @@ -387,52 +391,54 @@ "name": "stderr", "output_type": "stream", "text": [ - "c:\\users\\alessand\\documents\\brevitas\\src\\brevitas\\nn\\mixin\\base.py:343: UserWarning: Defining your `__torch_function__` as a plain method is deprecated and will be an error in future, please define it as a classmethod. (Triggered internally at C:\\cb\\pytorch_1000000000000\\work\\torch\\csrc\\utils\\python_arg_parser.cpp:354.)\n", + "/home/giuseppe/Documents/git/brevitas/src/brevitas/quant_tensor/__init__.py:84: UserWarning: Empty QuantTensor are deprecated and will be removed in a future version\n", + " warnings.warn(\n", + "/home/giuseppe/Documents/git/brevitas/src/brevitas/nn/mixin/base.py:320: UserWarning: Defining your `__torch_function__` as a plain method is deprecated and will be an error in future, please define it as a classmethod. (Triggered internally at /opt/conda/conda-bld/pytorch_1695392026823/work/torch/csrc/utils/python_arg_parser.cpp:368.)\n", " return torch.cat(outputs, dim=seq_dim)\n" ] }, { "data": { "text/plain": [ - "(QuantTensor(value=tensor([[[-0.4458, -0.1651, -0.7045, -0.5889, -0.2532, -0.0330, -0.1651,\n", - " 0.1706, 0.1376, 0.4348, 0.5834, -0.3577, -0.2807, 0.1046,\n", - " 0.2532, 0.2807, 0.2532, -0.4293, 0.1376, -0.1486],\n", - " [-0.1569, 0.3530, -0.6995, -0.0458, -0.5295, -0.3007, -0.7257,\n", - " 0.2877, -0.1308, 0.6603, 0.0196, -0.8237, 0.0065, -0.4380,\n", - " -0.2615, 0.3138, -0.0850, 0.0065, 0.0458, -0.1961],\n", - " [ 0.1929, -0.5981, -0.2508, -0.2251, -0.5917, 0.2251, 0.0257,\n", - " 0.2508, -0.3023, 0.2830, 0.3344, -0.4309, -0.0836, 0.2701,\n", - " 0.3666, -0.1351, 0.1736, -0.0257, 0.1286, -0.6174],\n", - " [ 0.4682, -0.1804, 0.2780, 0.4974, 0.4389, -0.0585, -0.6242,\n", - " -0.0098, 0.2341, 0.3511, -0.2926, -0.4925, 0.1414, -0.4633,\n", - " -0.0683, 0.2633, 0.3804, 0.3024, 0.1951, 0.1707],\n", - " [-0.0852, 0.0965, -0.4656, -0.3180, -0.3464, -0.2782, -0.1931,\n", - " -0.6360, -0.3180, -0.3293, 0.7211, 0.4316, 0.4145, -0.3066,\n", - " -0.5224, -0.3066, -0.5849, -0.7211, 0.3293, 0.1420]],\n", + "(QuantTensor(value=tensor([[[-0.0062, -0.2872, 0.7931, 0.4309, 0.5495, -0.4558, 0.2373,\n", + " 0.6807, 0.4621, 0.6120, -0.1124, 0.3872, 0.3060, 0.7681,\n", + " -0.3684, 0.0437, -0.7369, -0.3247, 0.7743, 0.3372],\n", + " [ 0.5450, 0.2962, -0.3969, 0.3555, -0.5628, 0.2429, -0.4976,\n", + " 0.1777, -0.1244, 0.0296, -0.2607, 0.0948, 0.5036, -0.3673,\n", + " 0.5213, -0.2962, 0.7524, 0.0770, -0.0948, -0.0948],\n", + " [ 0.2691, -0.6624, -0.5434, 0.4968, -0.6624, 0.0983, 0.1345,\n", + " 0.1242, -0.0517, -0.3726, 0.3053, 0.1604, 0.3208, 0.0983,\n", + " 0.3105, 0.4243, 0.2794, 0.1604, 0.1035, -0.0724],\n", + " [ 0.1284, -0.3337, -0.5263, -0.0449, -0.5263, 0.3081, -0.1733,\n", + " 0.5648, 0.4942, -0.1412, 0.1733, 0.3337, 0.6225, 0.3401,\n", + " 0.5070, -0.1412, 0.0642, -0.3722, 0.2888, 0.1155],\n", + " [ 0.0579, -0.0058, -0.4054, -0.1564, -0.5560, -0.3301, 0.3533,\n", + " 0.0058, -0.1622, -0.3765, 0.1216, 0.0695, -0.4054, 0.0927,\n", + " 0.6139, -0.1390, 0.7066, 0.1274, 0.1622, -0.2896]],\n", " \n", - " [[ 0.5669, 0.2367, -0.3027, -0.3137, -0.3632, -0.1651, -0.5999,\n", - " 0.2036, 0.4293, 0.2201, -0.2862, -0.3908, -0.2091, -0.2532,\n", - " -0.2532, -0.5834, -0.2697, 0.0055, 0.2532, 0.1761],\n", - " [ 0.1242, 0.4184, -0.6472, -0.0196, -0.4707, -0.5034, -0.8368,\n", - " 0.3530, 0.1504, 0.0458, -0.0654, -0.7714, -0.1961, -0.4903,\n", - " -0.6015, -0.3596, -0.2484, -0.4380, -0.0458, 0.2942],\n", - " [ 0.3409, 0.8168, -0.7396, 0.2958, 0.2508, -0.1286, -0.1286,\n", - " 0.7782, -0.1994, 0.7846, -0.3087, -0.3666, 0.1029, 0.1479,\n", - " -0.3216, -0.1479, -0.2315, 0.4566, 0.5209, -0.3344],\n", - " [-0.0878, 0.0390, -0.1707, -0.1365, -0.2243, -0.2390, -0.3706,\n", - " 0.1609, -0.5511, -0.4096, 0.5121, -0.5901, 0.2633, -0.3609,\n", - " -0.5511, 0.3755, -0.4925, -0.0293, -0.0780, -0.2829],\n", - " [ 0.0965, -0.1987, 0.0057, 0.1306, 0.3861, 0.2839, -0.3861,\n", - " 0.5962, -0.1987, 0.3180, -0.1647, -0.3066, -0.0227, 0.4372,\n", - " 0.0852, 0.3748, 0.0852, -0.0057, -0.1703, -0.0738]]],\n", - " grad_fn=), scale=tensor(0.0058, grad_fn=), zero_point=tensor(0.), bit_width=tensor(8.), signed_t=tensor(True), training_t=tensor(True)),\n", - " QuantTensor(value=tensor([[[-0.0852, 0.0965, -0.4656, -0.3180, -0.3464, -0.2782, -0.1931,\n", - " -0.6360, -0.3180, -0.3293, 0.7211, 0.4316, 0.4145, -0.3066,\n", - " -0.5224, -0.3066, -0.5849, -0.7211, 0.3293, 0.1420],\n", - " [ 0.0965, -0.1987, 0.0057, 0.1306, 0.3861, 0.2839, -0.3861,\n", - " 0.5962, -0.1987, 0.3180, -0.1647, -0.3066, -0.0227, 0.4372,\n", - " 0.0852, 0.3748, 0.0852, -0.0057, -0.1703, -0.0738]]],\n", - " grad_fn=), scale=tensor(0.0057, grad_fn=), zero_point=tensor(0.), bit_width=tensor(8.), signed_t=tensor(True), training_t=tensor(True)))" + " [[ 0.1374, 0.5745, 0.0624, -0.2373, 0.3060, 0.3310, -0.5183,\n", + " 0.1186, 0.1124, 0.2997, 0.0375, 0.6369, -0.5308, 0.6307,\n", + " -0.5683, 0.7556, 0.2997, -0.4933, 0.3934, -0.4871],\n", + " [ 0.1066, -0.1244, -0.1718, 0.4266, 0.5569, 0.0178, 0.1185,\n", + " -0.3910, 0.2133, 0.0178, -0.1066, -0.2903, 0.1837, -0.2547,\n", + " -0.2903, 0.0770, 0.3495, 0.2547, 0.2311, -0.6161],\n", + " [-0.0880, -0.1966, 0.3001, -0.0569, 0.4140, -0.1552, -0.1345,\n", + " 0.4554, 0.5175, 0.1242, -0.2898, 0.1966, -0.0414, 0.3985,\n", + " -0.1708, -0.0621, -0.1708, 0.0828, 0.2225, 0.0517],\n", + " [ 0.2118, 0.5648, -0.2824, -0.0449, 0.5840, 0.3209, -0.5648,\n", + " 0.3530, 0.4043, -0.4942, -0.3786, 0.0257, 0.5327, -0.1990,\n", + " -0.1348, -0.8215, 0.3016, 0.5327, 0.5648, -0.1155],\n", + " [-0.0290, -0.1738, 0.0695, 0.3765, 0.1738, 0.0579, -0.4054,\n", + " -0.2664, 0.4923, 0.2143, -0.4170, 0.4112, 0.5502, 0.7066,\n", + " -0.6024, 0.7356, 0.0348, 0.1043, -0.1911, -0.4518]]],\n", + " grad_fn=), scale=tensor(0.0059, grad_fn=), zero_point=tensor(0.), bit_width=tensor(8.), signed_t=tensor(True), training_t=tensor(True)),\n", + " QuantTensor(value=tensor([[[ 0.0579, -0.0058, -0.4054, -0.1564, -0.5560, -0.3301, 0.3533,\n", + " 0.0058, -0.1622, -0.3765, 0.1216, 0.0695, -0.4054, 0.0927,\n", + " 0.6139, -0.1390, 0.7066, 0.1274, 0.1622, -0.2896],\n", + " [-0.0290, -0.1738, 0.0695, 0.3765, 0.1738, 0.0579, -0.4054,\n", + " -0.2664, 0.4923, 0.2143, -0.4170, 0.4112, 0.5502, 0.7066,\n", + " -0.6024, 0.7356, 0.0348, 0.1043, -0.1911, -0.4518]]],\n", + " grad_fn=), scale=tensor(0.0058, grad_fn=), zero_point=tensor(0.), bit_width=tensor(8.), signed_t=tensor(True), training_t=tensor(True)))" ] }, "execution_count": 10, @@ -461,48 +467,56 @@ "execution_count": 11, "metadata": {}, "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/home/giuseppe/miniconda3/envs/torch_2.1/lib/python3.11/site-packages/torch/_tensor.py:1362: UserWarning: Named tensors and all their associated APIs are an experimental feature and subject to change. Please do not use them for anything important until they are released as stable. (Triggered internally at /opt/conda/conda-bld/pytorch_1695392026823/work/c10/core/TensorImpl.h:1900.)\n", + " return super().rename(names)\n" + ] + }, { "data": { "text/plain": [ - "(QuantTensor(value=tensor([[[ 0.1760, 0.2670, -0.1214, -0.3702, 0.3884, 0.4127, 0.0243,\n", - " 0.0425, -0.2246, -0.0910, -0.2670, 0.4734, 0.0971, -0.3824,\n", - " 0.1396, 0.6858, 0.0061, 0.3702, 0.1275, 0.5037],\n", - " [ 0.2831, 0.0566, -0.2831, -0.2661, -0.0793, 0.3511, -0.4926,\n", - " 0.0510, -0.6455, 0.7191, -0.1812, -0.6172, 0.1529, 0.4077,\n", - " -0.7078, -0.0453, -0.0963, 0.4926, -0.4983, -0.4077],\n", - " [ 0.0000, -0.3977, 0.0947, 0.1894, -0.3725, -0.2589, -0.3914,\n", - " 0.3409, -0.0063, 0.2652, -0.5177, -0.4230, -0.0821, -0.0631,\n", - " 0.0505, -0.0189, 0.0253, -0.1578, -0.4988, 0.5556],\n", - " [ 0.4809, 0.8144, -0.6925, 0.4360, 0.0256, -0.4360, -0.5130,\n", - " 0.2501, -0.1347, 0.7631, -0.5386, -0.2437, 0.4296, -0.1988,\n", - " -0.7246, -0.1154, -0.2437, 0.3655, 0.0641, 0.3142],\n", - " [ 0.0706, -0.0192, -0.7185, -0.8211, -0.5709, 0.1155, 0.4683,\n", - " 0.3400, -0.3015, 0.3528, 0.3143, -0.1155, -0.3143, -0.0257,\n", - " 0.1411, -0.2309, 0.5132, 0.3721, 0.5196, -0.5453]],\n", + "(QuantTensor(value=tensor([[[ 0.2111, 0.1267, 0.0060, 0.6153, -0.7721, -0.3740, -0.5188,\n", + " 0.6273, 0.4162, 0.2051, 0.2292, 0.7239, 0.6032, 0.2533,\n", + " 0.5067, 0.6635, 0.1206, -0.5730, 0.0483, 0.3318],\n", + " [ 0.5742, 0.0194, -0.3807, -0.0710, -0.6000, 0.1807, 0.1355,\n", + " 0.4129, 0.3807, 0.3936, -0.0903, 0.1549, 0.1032, 0.0645,\n", + " 0.4775, -0.0645, 0.1161, -0.0065, 0.0194, -0.1097],\n", + " [ 0.0453, -0.4533, 0.1036, -0.0194, -0.2979, 0.3432, 0.0777,\n", + " 0.6346, -0.0842, 0.3302, 0.4727, 0.4856, -0.4144, 0.7382,\n", + " -0.0453, 0.5439, 0.2266, -0.4792, 0.4403, -0.1036],\n", + " [ 0.3198, 0.2741, -0.6395, 0.0971, -0.6052, -0.5196, 0.1770,\n", + " -0.5025, -0.1256, 0.2056, 0.2684, -0.6395, -0.0285, -0.7309,\n", + " 0.7194, -0.7194, 0.1542, -0.3426, -0.6509, 0.0343],\n", + " [ 0.0000, -0.4004, 0.3151, -0.0263, -0.5842, -0.1641, -0.3939,\n", + " 0.0263, -0.2429, 0.6499, -0.5186, 0.1247, -0.2101, 0.8337,\n", + " -0.1444, 0.6762, -0.1641, -0.5317, -0.1707, -0.0197]],\n", " \n", - " [[ 0.4066, -0.7768, 0.6008, 0.0546, 0.0182, 0.1821, 0.0971,\n", - " -0.3763, 0.3520, -0.5037, -0.0061, 0.2246, -0.0486, 0.2124,\n", - " 0.3641, -0.6433, 0.4248, 0.0789, 0.1275, -0.1214],\n", - " [ 0.2321, 0.1982, -0.1302, 0.1529, -0.0736, -0.3567, -0.4360,\n", - " -0.0283, 0.4869, 0.5379, -0.6964, -0.0340, -0.2944, -0.1529,\n", - " -0.2152, -0.4643, 0.3454, 0.3284, -0.3341, 0.5945],\n", - " [-0.2020, 0.0379, -0.8081, -0.7260, -0.0821, 0.0631, 0.4988,\n", - " 0.0694, 0.0253, 0.5430, 0.8018, 0.2273, -0.3472, -0.0505,\n", - " 0.4924, -0.4735, 0.5745, -0.5619, 0.6313, -0.1768],\n", - " [ 0.2501, -0.4360, 0.6541, 0.0385, 0.5835, -0.3078, -0.0449,\n", - " 0.3270, 0.7951, -0.3591, -0.4809, -0.2757, -0.3591, -0.7567,\n", - " 0.5194, 0.2757, 0.7438, 0.7695, 0.5451, 0.4296],\n", - " [ 0.2630, -0.4747, 0.1347, -0.0641, -0.2245, -0.3336, -0.4490,\n", - " -0.4619, -0.1796, -0.5517, 0.3913, 0.0257, -0.2053, -0.2823,\n", - " -0.6992, -0.6607, 0.1989, -0.6928, -0.5581, 0.5966]]],\n", + " [[ 0.2111, -0.2111, -0.3197, -0.0241, -0.5067, -0.0241, -0.2895,\n", + " 0.1749, -0.4283, 0.0000, -0.3680, 0.5308, -0.1267, 0.5248,\n", + " 0.1206, 0.2654, 0.6394, -0.1327, 0.2292, -0.3800],\n", + " [ 0.6775, -0.3355, -0.1807, 0.2774, -0.8259, -0.2000, -0.0065,\n", + " 0.5678, 0.4000, 0.2258, 0.4387, 0.2710, 0.5355, 0.1290,\n", + " 0.6710, -0.0645, -0.2710, -0.3613, 0.6388, 0.5226],\n", + " [-0.0065, -0.0777, -0.6475, -0.1684, -0.3820, 0.3885, 0.0065,\n", + " 0.1943, -0.3238, -0.2525, -0.1230, -0.0453, -0.0777, 0.3432,\n", + " 0.4921, -0.1101, 0.8224, 0.2396, 0.1554, -0.3885],\n", + " [-0.0514, -0.4111, -0.4625, -0.1713, -0.3369, 0.2512, -0.2969,\n", + " -0.4111, -0.2341, 0.3597, -0.1998, 0.0000, 0.2741, 0.7137,\n", + " -0.1256, 0.1370, -0.0742, -0.5938, -0.5424, -0.4168],\n", + " [ 0.3479, 0.5974, -0.3939, 0.1444, -0.6762, 0.1969, -0.6499,\n", + " 0.4136, 0.5383, -0.3085, 0.4070, 0.4070, 0.6630, -0.0263,\n", + " 0.2823, -0.1510, 0.1313, -0.5186, 0.4464, -0.0066]]],\n", " grad_fn=), scale=tensor(0.0062, grad_fn=), zero_point=tensor(0.), bit_width=tensor(8.), signed_t=tensor(True), training_t=tensor(True)),\n", - " QuantTensor(value=tensor([[[ 0.0706, -0.0192, -0.7185, -0.8211, -0.5709, 0.1155, 0.4683,\n", - " 0.3400, -0.3015, 0.3528, 0.3143, -0.1155, -0.3143, -0.0257,\n", - " 0.1411, -0.2309, 0.5132, 0.3721, 0.5196, -0.5453],\n", - " [ 0.2630, -0.4747, 0.1347, -0.0641, -0.2245, -0.3336, -0.4490,\n", - " -0.4619, -0.1796, -0.5517, 0.3913, 0.0257, -0.2053, -0.2823,\n", - " -0.6992, -0.6607, 0.1989, -0.6928, -0.5581, 0.5966]]],\n", - " grad_fn=), scale=tensor(0.0064, grad_fn=), zero_point=tensor(0.), bit_width=tensor(8.), signed_t=tensor(True), training_t=tensor(True)))" + " QuantTensor(value=tensor([[[ 0.0000, -0.4004, 0.3151, -0.0263, -0.5842, -0.1641, -0.3939,\n", + " 0.0263, -0.2429, 0.6499, -0.5186, 0.1247, -0.2101, 0.8337,\n", + " -0.1444, 0.6762, -0.1641, -0.5317, -0.1707, -0.0197],\n", + " [ 0.3479, 0.5974, -0.3939, 0.1444, -0.6762, 0.1969, -0.6499,\n", + " 0.4136, 0.5383, -0.3085, 0.4070, 0.4070, 0.6630, -0.0263,\n", + " 0.2823, -0.1510, 0.1313, -0.5186, 0.4464, -0.0066]]],\n", + " grad_fn=), scale=tensor(0.0066, grad_fn=), zero_point=tensor(0.), bit_width=tensor(8.), signed_t=tensor(True), training_t=tensor(True)))" ] }, "execution_count": 11, @@ -533,45 +547,45 @@ { "data": { "text/plain": [ - "(QuantTensor(value=tensor([[[-0.1984, 0.2499, -0.1102, 0.2499, -0.0955, -0.4630, -0.8672,\n", - " 0.1911, -0.4851, 0.8085, 0.6982, -0.5806, 0.0000, -0.4189,\n", - " -0.7423, -0.4851, -0.9260, -0.0147, 0.0514, -0.1984],\n", - " [-0.2167, 0.5092, -0.3846, 0.0650, 0.6717, -0.2492, -0.0867,\n", - " 0.3142, -0.3900, 0.3521, 0.4767, -0.1137, 0.6879, 0.1733,\n", - " -0.0596, 0.4279, -0.5471, -0.2762, 0.5904, -0.3737],\n", - " [-0.1335, -0.0140, -0.2810, -0.5339, -0.5339, 0.0562, 0.7236,\n", - " -0.1264, -0.0211, -0.3021, -0.1124, 0.4777, 0.3793, 0.2388,\n", - " -0.0702, 0.4847, -0.4988, 0.7236, 0.5901, -0.4847],\n", - " [ 0.3340, -0.5225, -0.1242, 0.1499, 0.3083, -0.1756, -0.1713,\n", - " 0.0000, 0.3512, -0.3041, 0.3126, -0.5482, 0.4882, 0.1028,\n", - " -0.4796, 0.1028, -0.2527, -0.3640, 0.1713, 0.0471],\n", - " [-0.4438, -0.2686, -0.3095, -0.2978, -0.0993, 0.0584, 0.4846,\n", - " -0.0526, 0.3737, -0.4496, 0.1109, 0.7416, -0.0526, 0.3445,\n", - " 0.4963, 0.2803, 0.1927, 0.0000, 0.6131, 0.1109]],\n", + "(QuantTensor(value=tensor([[[-0.3777, -0.2074, 0.7184, 0.9110, 0.0148, -0.1926, -0.7110,\n", + " 0.1926, -0.4222, -0.9480, 0.2592, 0.2222, -0.2370, -0.5407,\n", + " 0.5851, -0.2370, 0.3555, 0.1703, 0.4444, -0.2222],\n", + " [ 0.4814, -0.7355, -0.1605, 0.3878, -0.5282, 0.2073, 0.0000,\n", + " 0.3677, 0.1805, -0.1204, -0.4614, 0.2474, 0.7021, 0.0401,\n", + " 0.4346, 0.4480, -0.3143, 0.0401, 0.6887, 0.6753],\n", + " [ 0.5038, -0.3650, -0.6936, 0.0146, -0.9345, 0.0000, 0.1679,\n", + " -0.3066, 0.1825, 0.4089, 0.0949, -0.2555, 0.3870, -0.2482,\n", + " 0.5914, -0.0803, 0.1314, -0.4235, -0.3797, 0.1168],\n", + " [ 0.1795, 0.1795, 0.0449, 0.0449, 0.2308, 0.0898, -0.1282,\n", + " 0.5579, 0.1731, -0.1795, 0.1603, 0.3142, 0.1090, 0.5835,\n", + " -0.1475, 0.0449, 0.1795, -0.0256, 0.8143, -0.2437],\n", + " [-0.0066, 0.4804, 0.0066, -0.1184, 0.6843, -0.0197, 0.1448,\n", + " 0.1842, 0.6383, -0.1908, -0.0066, -0.1053, -0.1316, 0.0461,\n", + " -0.0066, -0.2764, 0.3751, 0.3619, 0.5001, -0.1316]],\n", " \n", - " [[ 0.1102, -0.8085, 0.5806, -0.0661, 0.3013, 0.2646, 0.2499,\n", - " -0.6321, 0.4557, 0.4777, 0.6321, 0.0294, -0.2646, -0.9407,\n", - " 0.7350, -0.6027, 0.6174, -0.4116, 0.6835, 0.0514],\n", - " [ 0.1787, 0.0271, 0.1354, -0.3033, 0.6229, -0.3250, -0.3846,\n", - " 0.0812, 0.5633, 0.6879, -0.0325, -0.2383, -0.3521, -0.5850,\n", - " 0.3033, -0.3900, 0.6771, 0.3196, 0.5633, 0.2383],\n", - " [-0.1264, 0.5901, -0.3934, 0.3231, 0.0492, -0.5128, -0.8149,\n", - " 0.1124, -0.7517, 0.8711, 0.4004, -0.8992, 0.0702, -0.2178,\n", - " -0.8851, -0.5760, -0.1054, -0.0702, -0.3512, -0.5198],\n", - " [ 0.2612, 0.2570, 0.1542, -0.1071, -0.0300, 0.0257, -0.3854,\n", - " -0.0685, -0.2570, 0.0728, -0.4240, -0.3083, 0.1627, -0.3383,\n", - " -0.0428, 0.0300, -0.1199, 0.3683, 0.3298, -0.3340],\n", - " [ 0.4204, -0.2452, -0.0934, 0.2336, 0.1285, -0.1285, 0.2044,\n", - " -0.0701, 0.0058, 0.3971, 0.0175, -0.3270, 0.2803, 0.1810,\n", - " -0.4963, -0.5547, 0.0467, 0.0175, 0.1927, -0.2452]]],\n", - " grad_fn=), scale=tensor(0.0060, grad_fn=), zero_point=tensor(0.), bit_width=tensor(8.), signed_t=tensor(True), training_t=tensor(True)),\n", - " QuantTensor(value=tensor([[[-0.4438, -0.2686, -0.3095, -0.2978, -0.0993, 0.0584, 0.4846,\n", - " -0.0526, 0.3737, -0.4496, 0.1109, 0.7416, -0.0526, 0.3445,\n", - " 0.4963, 0.2803, 0.1927, 0.0000, 0.6131, 0.1109],\n", - " [ 0.4204, -0.2452, -0.0934, 0.2336, 0.1285, -0.1285, 0.2044,\n", - " -0.0701, 0.0058, 0.3971, 0.0175, -0.3270, 0.2803, 0.1810,\n", - " -0.4963, -0.5547, 0.0467, 0.0175, 0.1927, -0.2452]]],\n", - " grad_fn=), scale=tensor(0.0058, grad_fn=), zero_point=tensor(0.), bit_width=tensor(8.), signed_t=tensor(True), training_t=tensor(True)))" + " [[ 0.5110, -0.3555, 0.6443, -0.8221, 0.4888, -0.2074, 0.0444,\n", + " 0.4888, 0.5999, 0.4370, 0.0000, 0.5036, -0.7628, 0.9332,\n", + " -0.6147, 0.7332, 0.3629, 0.9184, 0.7702, -0.8887],\n", + " [ 0.8492, -0.3410, -0.3878, 0.1404, -0.3410, 0.3143, -0.1204,\n", + " 0.5817, 0.4413, 0.5550, 0.6486, -0.1070, 0.6285, -0.4948,\n", + " 0.2006, 0.1605, 0.0535, -0.4079, 0.3811, 0.4948],\n", + " [ 0.6060, 0.7666, -0.8688, -0.6863, -0.5111, -0.0803, -0.6425,\n", + " -0.0146, -0.3577, 0.3431, -0.6571, 0.5622, 0.0000, 0.7374,\n", + " -0.1314, -0.3650, 0.7520, 0.2336, -0.2847, -0.8250],\n", + " [ 0.3014, 0.2950, -0.0898, -0.3142, 0.4040, 0.4681, -0.0705,\n", + " -0.2052, 0.8143, -0.1603, 0.3334, -0.6733, 0.0834, 0.0898,\n", + " -0.4937, 0.1924, 0.0064, 0.4104, 0.6348, -0.3527],\n", + " [-0.6449, 0.5856, -0.0263, -0.0197, 0.8357, -0.5856, 0.0395,\n", + " -0.3422, 0.8028, 0.0855, -0.7238, -0.6317, 0.2764, -0.0461,\n", + " -0.4211, -0.5988, 0.2632, 0.4014, -0.7501, -0.5659]]],\n", + " grad_fn=), scale=tensor(0.0069, grad_fn=), zero_point=tensor(0.), bit_width=tensor(8.), signed_t=tensor(True), training_t=tensor(True)),\n", + " QuantTensor(value=tensor([[[-0.0066, 0.4804, 0.0066, -0.1184, 0.6843, -0.0197, 0.1448,\n", + " 0.1842, 0.6383, -0.1908, -0.0066, -0.1053, -0.1316, 0.0461,\n", + " -0.0066, -0.2764, 0.3751, 0.3619, 0.5001, -0.1316],\n", + " [-0.6449, 0.5856, -0.0263, -0.0197, 0.8357, -0.5856, 0.0395,\n", + " -0.3422, 0.8028, 0.0855, -0.7238, -0.6317, 0.2764, -0.0461,\n", + " -0.4211, -0.5988, 0.2632, 0.4014, -0.7501, -0.5659]]],\n", + " grad_fn=), scale=tensor(0.0066, grad_fn=), zero_point=tensor(0.), bit_width=tensor(8.), signed_t=tensor(True), training_t=tensor(True)))" ] }, "execution_count": 12, @@ -754,20 +768,22 @@ " bias: bool = True,\n", " batch_first: bool = False,\n", " bidirectional: bool = False,\n", - " weight_quant = Int8WeightPerTensorFloat,\n", - " bias_quant = Int32Bias,\n", - " io_quant = Int8ActPerTensorFloat,\n", - " gate_acc_quant = Int8ActPerTensorFloat,\n", - " sigmoid_quant = Uint8ActPerTensorFloat,\n", - " tanh_quant = Int8ActPerTensorFloat,\n", - " cell_state_quant = Int8ActPerTensorFloat,\n", + " weight_quant=Int8WeightPerTensorFloat,\n", + " bias_quant=Int32Bias,\n", + " io_quant=Int8ActPerTensorFloat,\n", + " gate_acc_quant=Int8ActPerTensorFloat,\n", + " sigmoid_quant=Uint8ActPerTensorFloat,\n", + " tanh_quant=Int8ActPerTensorFloat,\n", + " cell_state_quant=Int8ActPerTensorFloat,\n", " coupled_input_forget_gates: bool = False,\n", - " cat_output_cell_states = True,\n", - " shared_input_hidden_weights = False,\n", - " shared_intra_layer_weight_quant = False,\n", - " shared_intra_layer_gate_acc_quant = False,\n", - " shared_cell_state_quant = True,\n", + " cat_output_cell_states=True,\n", + " shared_input_hidden_weights=False,\n", + " shared_intra_layer_weight_quant=False,\n", + " shared_intra_layer_gate_acc_quant=False,\n", + " shared_cell_state_quant=True,\n", " return_quant_tensor: bool = False,\n", + " device: Optional[torch.device] = None,\n", + " dtype: Optional[torch.dtype] = None,\n", " **kwargs):\n", " super(QuantLSTM, self).__init__(\n", " layer_impl=_QuantLSTMLayer,\n", @@ -790,6 +806,8 @@ " shared_intra_layer_gate_acc_quant=shared_intra_layer_gate_acc_quant,\n", " shared_cell_state_quant=shared_cell_state_quant,\n", " return_quant_tensor=return_quant_tensor,\n", + " dtype=dtype,\n", + " device=device,\n", " **kwargs)\n", " if cat_output_cell_states and cell_state_quant is not None and not shared_cell_state_quant:\n", " raise RuntimeError(\"Concatenating cell states requires shared cell quantizers.\")\n", @@ -894,7 +912,16 @@ "cell_type": "code", "execution_count": 18, "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/home/giuseppe/miniconda3/envs/torch_2.1/lib/python3.11/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", + " from .autonotebook import tqdm as notebook_tqdm\n" + ] + } + ], "source": [ "import torch\n", "from brevitas.nn import QuantLSTM\n", @@ -936,7 +963,7 @@ " " ], "text/plain": [ - "" + "" ] }, "execution_count": 19, @@ -958,9 +985,17 @@ }, { "cell_type": "code", - "execution_count": 19, + "execution_count": 20, "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "2024-02-02 10:22:46.461627098 [W:onnxruntime:, graph.cc:1283 Graph] Initializer onnx::LSTM_93 appears in graph inputs and will not be treated as constant value/weight. This may prevent some of the graph optimizations, like const folding. Move it out of graph inputs if there is no need to override it, by either re-generating the model with latest exporter/converter or with the tool onnxruntime/tools/python/remove_initializer_from_input.py.\n" + ] + } + ], "source": [ "import onnxruntime as ort\n", "import numpy as np\n", @@ -1004,37 +1039,7 @@ "skip-execution" ] }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Serving 'quant_lstm_weight_only_cifg_4b.onnx' at http://localhost:8082\n" - ] - }, - { - "data": { - "text/html": [ - "\n", - " \n", - " " - ], - "text/plain": [ - "" - ] - }, - "execution_count": 22, - "metadata": {}, - "output_type": "execute_result" - } - ], + "outputs": [], "source": [ "show_netron(export_path, 8082)" ] @@ -1049,9 +1054,17 @@ }, { "cell_type": "code", - "execution_count": 21, + "execution_count": 23, "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "2024-02-02 10:22:49.697482752 [W:onnxruntime:, graph.cc:1283 Graph] Initializer onnx::LSTM_87 appears in graph inputs and will not be treated as constant value/weight. This may prevent some of the graph optimizations, like const folding. Move it out of graph inputs if there is no need to override it, by either re-generating the model with latest exporter/converter or with the tool onnxruntime/tools/python/remove_initializer_from_input.py.\n" + ] + } + ], "source": [ "import onnxruntime as ort\n", "import numpy as np\n", @@ -1079,7 +1092,7 @@ "name": "stderr", "output_type": "stream", "text": [ - "/home/giuseppe/Documents/git/brevitas/src/brevitas/nn/mixin/base.py:77: UserWarning: Keyword arguments are being passed but they not being used.\n", + "/home/giuseppe/Documents/git/brevitas/src/brevitas/nn/mixin/base.py:78: UserWarning: Keyword arguments are being passed but they not being used.\n", " warn('Keyword arguments are being passed but they not being used.')\n" ] } @@ -1104,37 +1117,7 @@ "skip-execution" ] }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Serving 'quant_lstm_weight_only_bidirectional_2_layers.onnx' at http://localhost:8083\n" - ] - }, - { - "data": { - "text/html": [ - "\n", - " \n", - " " - ], - "text/plain": [ - "" - ] - }, - "execution_count": 25, - "metadata": {}, - "output_type": "execute_result" - } - ], + "outputs": [], "source": [ "show_netron(export_path, 8083)" ] @@ -1155,7 +1138,7 @@ "name": "stderr", "output_type": "stream", "text": [ - "/home/giuseppe/Documents/git/brevitas/src/brevitas/nn/mixin/base.py:77: UserWarning: Keyword arguments are being passed but they not being used.\n", + "/home/giuseppe/Documents/git/brevitas/src/brevitas/nn/mixin/base.py:78: UserWarning: Keyword arguments are being passed but they not being used.\n", " warn('Keyword arguments are being passed but they not being used.')\n" ] } @@ -1180,37 +1163,7 @@ "skip-execution" ] }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Serving 'quant_lstm_weight_only_bidirectional_2_layers_shared_ih.onnx' at http://localhost:8085\n" - ] - }, - { - "data": { - "text/html": [ - "\n", - " \n", - " " - ], - "text/plain": [ - "" - ] - }, - "execution_count": 27, - "metadata": {}, - "output_type": "execute_result" - } - ], + "outputs": [], "source": [ "show_netron(export_path, 8085)" ] @@ -1225,17 +1178,39 @@ }, { "cell_type": "code", - "execution_count": 24, + "execution_count": 28, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ - "c:\\users\\alessand\\documents\\brevitas\\src\\brevitas\\nn\\mixin\\base.py:112: UserWarning: Keyword arguments are being passed but they not being used.\n", + "/home/giuseppe/Documents/git/brevitas/src/brevitas/nn/mixin/base.py:78: UserWarning: Keyword arguments are being passed but they not being used.\n", " warn('Keyword arguments are being passed but they not being used.')\n" ] - }, + } + ], + "source": [ + "import torch\n", + "from brevitas.nn import QuantLSTM\n", + "from brevitas.export import export_onnx_qcdq\n", + "\n", + "quant_lstm_weight_only_bidirectional_2_layers = QuantLSTM(\n", + " input_size=10, hidden_size=20, bidirectional=True, num_layers=2, weight_bit_width=4, shared_intra_layer_weight_quant=True,\n", + " io_quant=None, bias_quant=None, gate_acc_quant=None, sigmoid_quant=None, tanh_quant=None, cell_state_quant=None)\n", + "export_path = 'quant_lstm_weight_only_bidirectional_2_layers_shared_q.onnx'\n", + "exported_model = export_onnx_qcdq(quant_lstm_weight_only_bidirectional_2_layers, (torch.randn(5, 1, 10)), opset_version=14, export_path=export_path)" + ] + }, + { + "cell_type": "code", + "execution_count": 29, + "metadata": { + "tags": [ + "skip-execution" + ] + }, + "outputs": [ { "name": "stdout", "output_type": "stream", @@ -1258,35 +1233,14 @@ " " ], "text/plain": [ - "" + "" ] }, - "execution_count": 24, + "execution_count": 29, "metadata": {}, "output_type": "execute_result" } ], - "source": [ - "import torch\n", - "from brevitas.nn import QuantLSTM\n", - "from brevitas.export import export_onnx_qcdq\n", - "\n", - "quant_lstm_weight_only_bidirectional_2_layers = QuantLSTM(\n", - " input_size=10, hidden_size=20, bidirectional=True, num_layers=2, weight_bit_width=4, shared_intra_layer_weight_quant=True,\n", - " io_quant=None, bias_quant=None, gate_acc_quant=None, sigmoid_quant=None, tanh_quant=None, cell_state_quant=None)\n", - "export_path = 'quant_lstm_weight_only_bidirectional_2_layers_shared_q.onnx'\n", - "exported_model = export_onnx_qcdq(quant_lstm_weight_only_bidirectional_2_layers, (torch.randn(5, 1, 10)), opset_version=14, export_path=export_path)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "tags": [ - "skip-execution" - ] - }, - "outputs": [], "source": [ "show_netron(export_path, 8086)" ] @@ -1301,17 +1255,40 @@ }, { "cell_type": "code", - "execution_count": 25, + "execution_count": 30, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ - "c:\\users\\alessand\\documents\\brevitas\\src\\brevitas\\nn\\mixin\\base.py:112: UserWarning: Keyword arguments are being passed but they not being used.\n", + "/home/giuseppe/Documents/git/brevitas/src/brevitas/nn/mixin/base.py:78: UserWarning: Keyword arguments are being passed but they not being used.\n", " warn('Keyword arguments are being passed but they not being used.')\n" ] - }, + } + ], + "source": [ + "import torch\n", + "from brevitas.nn import QuantLSTM\n", + "from brevitas.export import export_onnx_qcdq\n", + "\n", + "quant_lstm_weight_only_bidirectional_2_layers = QuantLSTM(\n", + " input_size=10, hidden_size=20, bidirectional=True, num_layers=2, weight_bit_width=4, \n", + " shared_input_hidden_weights=True, shared_intra_layer_weight_quant=True,\n", + " io_quant=None, bias_quant=None, gate_acc_quant=None, sigmoid_quant=None, tanh_quant=None, cell_state_quant=None)\n", + "export_path = 'quant_lstm_weight_only_bidirectional_2_layers_shared_q_ih.onnx'\n", + "exported_model = export_onnx_qcdq(quant_lstm_weight_only_bidirectional_2_layers, (torch.randn(5, 1, 10)), opset_version=14, export_path=export_path)" + ] + }, + { + "cell_type": "code", + "execution_count": 31, + "metadata": { + "tags": [ + "skip-execution" + ] + }, + "outputs": [ { "name": "stdout", "output_type": "stream", @@ -1334,36 +1311,14 @@ " " ], "text/plain": [ - "" + "" ] }, - "execution_count": 25, + "execution_count": 31, "metadata": {}, "output_type": "execute_result" } ], - "source": [ - "import torch\n", - "from brevitas.nn import QuantLSTM\n", - "from brevitas.export import export_onnx_qcdq\n", - "\n", - "quant_lstm_weight_only_bidirectional_2_layers = QuantLSTM(\n", - " input_size=10, hidden_size=20, bidirectional=True, num_layers=2, weight_bit_width=4, \n", - " shared_input_hidden_weights=True, shared_intra_layer_weight_quant=True,\n", - " io_quant=None, bias_quant=None, gate_acc_quant=None, sigmoid_quant=None, tanh_quant=None, cell_state_quant=None)\n", - "export_path = 'quant_lstm_weight_only_bidirectional_2_layers_shared_q_ih.onnx'\n", - "exported_model = export_onnx_qcdq(quant_lstm_weight_only_bidirectional_2_layers, (torch.randn(5, 1, 10)), opset_version=14, export_path=export_path)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "tags": [ - "skip-execution" - ] - }, - "outputs": [], "source": [ "show_netron(export_path, 8087)" ] @@ -1380,8 +1335,37 @@ }, { "cell_type": "code", - "execution_count": 26, + "execution_count": 32, "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "[W shape_type_inference.cpp:1974] Warning: The shape inference of onnx.brevitas::QuantLSTMCell type is missing, so it may result in wrong shape inference for the exported graph. Please consider adding it in symbolic function. (function UpdateReliable)\n", + "[W shape_type_inference.cpp:1974] Warning: The shape inference of onnx.brevitas::QuantLSTMCell type is missing, so it may result in wrong shape inference for the exported graph. Please consider adding it in symbolic function. (function UpdateReliable)\n", + "[W shape_type_inference.cpp:1974] Warning: The shape inference of onnx.brevitas::QuantLSTMCell type is missing, so it may result in wrong shape inference for the exported graph. Please consider adding it in symbolic function. (function UpdateReliable)\n" + ] + } + ], + "source": [ + "import torch\n", + "from brevitas.nn import QuantLSTM\n", + "from brevitas.export import export_qonnx\n", + "\n", + "quant_lstm = QuantLSTM(input_size=10, hidden_size=20)\n", + "export_path = 'quant_lstm.onnx'\n", + "exported_model = export_qonnx(quant_lstm, (torch.randn(5, 1, 10)), export_path=export_path)" + ] + }, + { + "cell_type": "code", + "execution_count": 33, + "metadata": { + "tags": [ + "skip-execution" + ] + }, "outputs": [ { "name": "stdout", @@ -1405,33 +1389,14 @@ " " ], "text/plain": [ - "" + "" ] }, - "execution_count": 26, + "execution_count": 33, "metadata": {}, "output_type": "execute_result" } ], - "source": [ - "import torch\n", - "from brevitas.nn import QuantLSTM\n", - "from brevitas.export import export_qonnx\n", - "\n", - "quant_lstm = QuantLSTM(input_size=10, hidden_size=20)\n", - "export_path = 'quant_lstm.onnx'\n", - "exported_model = export_qonnx(quant_lstm, (torch.randn(5, 1, 10)), export_path=export_path)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "tags": [ - "skip-execution" - ] - }, - "outputs": [], "source": [ "show_netron(export_path, 8088)" ] @@ -1518,7 +1483,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.7.5 (default, Oct 25 2019, 15:51:11) \n[GCC 7.3.0]" + "version": "3.11.5" }, "orig_nbformat": 4, "vscode": { diff --git a/src/brevitas/core/quant/binary.py b/src/brevitas/core/quant/binary.py index 392cdeb62..3a4b7346e 100644 --- a/src/brevitas/core/quant/binary.py +++ b/src/brevitas/core/quant/binary.py @@ -48,8 +48,9 @@ class BinaryQuant(brevitas.jit.ScriptModule): Set env variable BREVITAS_JIT=1 to enable TorchScript compilation of this module. """ - def __init__(self, scaling_impl: Module, quant_delay_steps: int = 0): + def __init__(self, scaling_impl: Module, signed: bool = True, quant_delay_steps: int = 0): super(BinaryQuant, self).__init__() + assert signed, "Unsigned binary quant not supported" self.scaling_impl = scaling_impl self.bit_width = BitWidthConst(1) self.zero_point = StatelessBuffer(torch.tensor(0.0)) diff --git a/src/brevitas/core/stats/stats_op.py b/src/brevitas/core/stats/stats_op.py index bfcfbb58f..fac729326 100644 --- a/src/brevitas/core/stats/stats_op.py +++ b/src/brevitas/core/stats/stats_op.py @@ -12,6 +12,7 @@ from brevitas import config from brevitas.core.utils import StatelessBuffer from brevitas.function.ops import max_int +from brevitas.quant_tensor import _unpack_quant_tensor # Use custom implementation of kthvalue as work around to (b)float16 kernel limitations from brevitas.utils.torch_utils import kthvalue @@ -478,8 +479,7 @@ def evaluate_loss(self, x, candidate): # Set to local_loss_mode before calling the proxy self.set_local_loss_mode(True) quant_value = self.proxy_forward(x) - if isinstance(quant_value, tuple): - quant_value = quant_value[0] + quant_value = _unpack_quant_tensor(quant_value) loss = self.mse_loss_fn(x, quant_value) self.set_local_loss_mode(False) return loss diff --git a/src/brevitas/export/manager.py b/src/brevitas/export/manager.py index f8f1189fd..bc62920f3 100644 --- a/src/brevitas/export/manager.py +++ b/src/brevitas/export/manager.py @@ -207,10 +207,9 @@ def _cache_fn_dispatcher(cls, fn, input, *args, **kwargs): if isinstance(input, QuantTensor): inp_cache = None out_cache = None - if input.is_not_none: - inp_cache = _CachedIO(input, metadata_only=True) + inp_cache = _CachedIO(input, metadata_only=True) output = fn(input, *args, **kwargs) - if isinstance(output, QuantTensor) and output.is_not_none: + if isinstance(output, QuantTensor): out_cache = _CachedIO(output, metadata_only=True) cached_io = (inp_cache, out_cache) if fn in cls._cached_io_handler_map: diff --git a/src/brevitas/export/onnx/standard/qoperator/handler/base.py b/src/brevitas/export/onnx/standard/qoperator/handler/base.py index 2dfcf6037..e614d2ed5 100644 --- a/src/brevitas/export/onnx/standard/qoperator/handler/base.py +++ b/src/brevitas/export/onnx/standard/qoperator/handler/base.py @@ -104,7 +104,7 @@ def input_quant_symbolic_kwargs(cls, module): @classmethod def input_dequant_symbolic_kwargs(cls, module): - if module._cached_inp.scale is not None: + if module._cached_inp is not None: return cls.dequant_symbolic_kwargs_from_cached_io(module._cached_inp) else: return None diff --git a/src/brevitas/export/onnx/standard/qoperator/manager.py b/src/brevitas/export/onnx/standard/qoperator/manager.py index 12f16cba3..464d5941a 100644 --- a/src/brevitas/export/onnx/standard/qoperator/manager.py +++ b/src/brevitas/export/onnx/standard/qoperator/manager.py @@ -1,18 +1,11 @@ # Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. # SPDX-License-Identifier: BSD-3-Clause -from typing import Optional, Tuple, Union - -from packaging import version -from torch import Tensor from torch.nn import functional as F from torch.nn import Module -from brevitas import torch_version from brevitas.export.manager import _set_layer_export_handler from brevitas.export.manager import _set_layer_export_mode -from brevitas.export.onnx.manager import ONNXBaseManager -from brevitas.quant_tensor import QuantTensor from ..function import DequantizeLinearFn from ..function import IntClipFn diff --git a/src/brevitas/graph/calibrate.py b/src/brevitas/graph/calibrate.py index bb435b7ef..8f690fc9b 100644 --- a/src/brevitas/graph/calibrate.py +++ b/src/brevitas/graph/calibrate.py @@ -48,6 +48,21 @@ BN_LAYERS = (nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d) +def disable_return_quant_tensor(model): + previous_state = {} + for module in model.modules(): + if hasattr(module, 'return_quant_tensor'): + previous_state[module] = module.return_quant_tensor + module.return_quant_tensor = False + return previous_state + + +def restore_return_quant_tensor(model, previous_state): + for module in model.modules(): + if hasattr(module, 'return_quant_tensor'): + module.return_quant_tensor = previous_state[module] + + def extend_collect_stats_steps(module): if hasattr(module, 'collect_stats_steps'): # We extend the collect steps in PTQ to match potentially long calibrations @@ -75,11 +90,13 @@ def __init__(self, model, enabled=True): self.previous_training_state = model.training self.disable_quant_inference = DisableEnableQuantization(call_act_quantizer_impl=True) self.enabled = enabled + self.return_quant_tensor_state = dict() def __enter__(self): if self.enabled: self.model.apply(extend_collect_stats_steps) self.model.apply(set_collect_stats_to_average) + self.return_quant_tensor_state = disable_return_quant_tensor(self.model) self.disable_quant_inference.apply( self.model, is_training=True, quantization_enabled=False) @@ -87,6 +104,7 @@ def __exit__(self, type, value, traceback): self.model.apply(finalize_collect_stats) self.disable_quant_inference.apply( self.model, is_training=self.previous_training_state, quantization_enabled=True) + restore_return_quant_tensor(self.model, self.return_quant_tensor_state) class load_quant_model: @@ -168,7 +186,7 @@ def disable_act_quant_hook(self, module, inp, output): if isinstance(module.tracked_module_list[0], QuantHardTanh): inp = F.hardtanh( inp, min_val=module.quant_injector.min_val, max_val=module.quant_injector.max_val) - return QuantTensor(value=inp, training=module.training) + return inp def disable_act_quantization(self, model, is_training): # If self.call_act_quantizer_impl is set to True, the quantization will be performed but the output diff --git a/src/brevitas/graph/gpxq.py b/src/brevitas/graph/gpxq.py index 149e8ec03..e9641a5a8 100644 --- a/src/brevitas/graph/gpxq.py +++ b/src/brevitas/graph/gpxq.py @@ -234,19 +234,14 @@ def process_input(self, inp): inp_training = self.layer.training # If using quantized activations, inp could be QuantTensor. In - # this case, we overwrite the metadata if it is specified. + # this case, we overwrite the metadata. if isinstance(inp, QuantTensor): if self.layer_requires_input_quant and (self.quant_input is None): - if inp.scale is not None: - inp_scale = inp.scale - if inp.zero_point is not None: - inp_zero_point = inp.zero_point - if inp.bit_width is not None: - inp_bit_width = inp.bit_width - if inp.signed is not None: - inp_signed = inp.signed - if inp.training is not None: - inp_training = inp.training + inp_scale = inp.scale + inp_zero_point = inp.zero_point + inp_bit_width = inp.bit_width + inp_signed = inp.signed + inp_training = inp.training inp = inp.value # if the layer requires an input quant and the quant input cache has diff --git a/src/brevitas/nn/hadamard_classifier.py b/src/brevitas/nn/hadamard_classifier.py index d3f22f679..e78163321 100644 --- a/src/brevitas/nn/hadamard_classifier.py +++ b/src/brevitas/nn/hadamard_classifier.py @@ -49,15 +49,13 @@ def forward(self, inp): out = inp.value / norm out = nn.functional.linear(out, self.proj[:self.out_channels, :self.in_channels]) out = -self.scale * out - if inp.scale is not None: + if isinstance(inp, QuantTensor): output_scale = inp.scale * self.scale / norm - if inp.bit_width is not None: output_bit_width = self.max_output_bit_width(inp.bit_width) - if (self.return_quant_tensor and inp.zero_point is not None and - (inp.zero_point != 0.0).any()): - raise RuntimeError("Computing zero point of output accumulator not supported yet.") - else: - output_zp = inp.zero_point + if (self.return_quant_tensor and inp.zero_point != 0.0).any(): + raise RuntimeError("Computing zero point of output accumulator not supported yet.") + else: + output_zp = inp.zero_point out = QuantTensor( value=out, scale=output_scale, diff --git a/src/brevitas/nn/mixin/base.py b/src/brevitas/nn/mixin/base.py index 2d4fa97ad..e54ad1ecc 100644 --- a/src/brevitas/nn/mixin/base.py +++ b/src/brevitas/nn/mixin/base.py @@ -18,6 +18,7 @@ from brevitas.inject import ExtendedInjector from brevitas.inject import Injector from brevitas.nn.utils import compute_channel_view_shape +from brevitas.quant_tensor import _unpack_quant_tensor from brevitas.quant_tensor import QuantTensor from .utils import filter_kwargs @@ -154,7 +155,7 @@ def quant_output_bit_width(self): else: return None - def unpack_input(self, inp: Union[Tensor, QuantTensor]): + def unpack_input(self, inp: Union[Tensor, QuantTensor]) -> Union[Tensor, QuantTensor]: self._set_global_is_quant_layer(True) # Hack to recognize a QuantTensor that has decayed to a tuple # when used as input to tracing (e.g. during ONNX export) @@ -166,25 +167,23 @@ def unpack_input(self, inp: Union[Tensor, QuantTensor]): if not self.training and not self._export_mode and self.cache_inference_quant_inp: cached_inp = _CachedIO(inp.detach(), self.cache_quant_io_metadata_only) self._cached_inp = cached_inp - else: - inp = QuantTensor(inp, training=self.training) - if not self.training and self.cache_inference_quant_inp: - cached_inp = _CachedIO(inp.detach(), self.cache_quant_io_metadata_only) - self._cached_inp = cached_inp - # Remove any naming metadata to avoid dowmstream errors - # Avoid inplace operations on the input in case of forward hooks if not torch._C._get_tracing_state(): - inp = inp.set(value=inp.value.rename(None)) + if isinstance(inp, QuantTensor): + inp = inp.set(value=inp.value.rename(None)) + else: + inp = inp.rename(None) return inp - def pack_output(self, quant_output: QuantTensor): - if not self.training and self.cache_inference_quant_out: + def pack_output(self, quant_output: Union[Tensor, QuantTensor]) -> Union[Tensor, QuantTensor]: + if not self.training and self.cache_inference_quant_out and isinstance(quant_output, + QuantTensor): self._cached_out = _CachedIO(quant_output.detach(), self.cache_quant_io_metadata_only) self._set_global_is_quant_layer(False) if self.return_quant_tensor: + assert isinstance(quant_output, QuantTensor) return quant_output else: - return quant_output.value + return _unpack_quant_tensor(quant_output) class QuantRecurrentLayerMixin(ExportMixin): @@ -246,9 +245,9 @@ def gate_params_fwd(gate, quant_input): acc_bit_width = None quant_weight_ih = gate.input_weight() quant_weight_hh = gate.hidden_weight() - if quant_input.bit_width is not None: + if isinstance(quant_input, QuantTensor): acc_bit_width = None # TODO - if quant_input.scale is not None and quant_weight_ih.scale is not None: + if isinstance(quant_input, QuantTensor) and isinstance(quant_weight_ih, QuantTensor): acc_scale_shape = compute_channel_view_shape(quant_input.value, channel_dim=1) acc_scale = quant_weight_ih.scale.view(acc_scale_shape) acc_scale = acc_scale * quant_input.scale.view(acc_scale_shape) @@ -267,8 +266,6 @@ def maybe_quantize_input(self, inp): quant_input = inp if not self.quantize_output_only: quant_input = self.io_quant(quant_input) - elif not isinstance(inp, QuantTensor): - quant_input = QuantTensor(quant_input) return quant_input def maybe_quantize_state(self, inp, state, quant): @@ -276,15 +273,16 @@ def maybe_quantize_state(self, inp, state, quant): batch_size = inp.size(0) if self.cell.batch_first else inp.size(1) quant_state = torch.zeros( int(batch_size), self.hidden_size, dtype=inp.dtype, device=inp.device) - quant_state = QuantTensor(quant_state) else: quant_state = quant(state) return quant_state def pack_quant_outputs(self, quant_outputs): # In export mode, quant_outputs has the shape of the output concatenated value + # Even though we check that return_quant_tensor can be enabled only with io_quant != None, + # inner layers in a deep network overrides it, so we check again. if self.export_mode: - if self.return_quant_tensor: + if self.return_quant_tensor and self.io_quant.is_quant_enabled: return QuantTensor( quant_outputs, self.io_quant.scale(), @@ -295,7 +293,7 @@ def pack_quant_outputs(self, quant_outputs): else: return quant_outputs seq_dim = 1 if self.cell.batch_first else 0 - if self.return_quant_tensor: + if self.return_quant_tensor and self.io_quant.is_quant_enabled: outputs = [ QuantTensor( torch.unsqueeze(quant_output[0], dim=seq_dim), @@ -312,8 +310,10 @@ def pack_quant_outputs(self, quant_outputs): return torch.cat(outputs, dim=seq_dim) def pack_quant_state(self, quant_state, quant): + # Even though we check that return_quant_tensor can be enabled only with quant != None, + # inner layers in a deep network overrides it, so we check again. if self.export_mode: - if self.return_quant_tensor: + if self.return_quant_tensor and quant.is_quant_enabled: quant_state = QuantTensor( torch.unsqueeze(quant_state, dim=0), quant.scale(), @@ -324,7 +324,7 @@ def pack_quant_state(self, quant_state, quant): else: quant_state = torch.unsqueeze(quant_state, dim=0) else: - if self.return_quant_tensor: + if self.return_quant_tensor and quant.is_quant_enabled: quant_state = QuantTensor( torch.unsqueeze(quant_state[0], dim=0), quant_state[1], diff --git a/src/brevitas/nn/mixin/parameter.py b/src/brevitas/nn/mixin/parameter.py index 095c981f1..3acfe7c95 100644 --- a/src/brevitas/nn/mixin/parameter.py +++ b/src/brevitas/nn/mixin/parameter.py @@ -198,7 +198,11 @@ def quant_bias_zero_point(self): if self.bias is None: return None if not self.bias_quant.requires_input_scale and not self.bias_quant.requires_input_bit_width: - return self.bias_quant(self.bias).zero_point + bias_quant = self.bias_quant(self.bias) + if isinstance(bias_quant, QuantTensor): + return bias_quant.zero_point + else: + return None else: if self._cached_bias is None: raise RuntimeError( diff --git a/src/brevitas/nn/quant_avg_pool.py b/src/brevitas/nn/quant_avg_pool.py index d8c83e3f2..5d567d0ca 100644 --- a/src/brevitas/nn/quant_avg_pool.py +++ b/src/brevitas/nn/quant_avg_pool.py @@ -13,6 +13,7 @@ from brevitas.function.ops import max_int from brevitas.function.ops_ste import ceil_ste from brevitas.inject.defaults import RoundTo8bit +from brevitas.quant_tensor import _unpack_quant_tensor from brevitas.quant_tensor import QuantTensor from .mixin.acc import AccQuantType @@ -55,16 +56,22 @@ def _avg_scaling(self): def forward(self, input: Union[Tensor, QuantTensor]): x = self.unpack_input(input) + if self.export_mode: - return self.export_handler(x.value) - x = x.set(value=super(TruncAvgPool2d, self).forward(x.value)) - if self.is_trunc_quant_enabled: - assert x.is_not_none # check input quant tensor is filled with values - # remove avg scaling - rescaled_value = x.value * self._avg_scaling - x = x.set(value=rescaled_value) - x = x.set(bit_width=self.max_acc_bit_width(x.bit_width)) - x = self.trunc_quant(x) + return self.export_handler(_unpack_quant_tensor(x)) + + if isinstance(x, QuantTensor): + x = x.set(value=super(TruncAvgPool2d, self).forward(x.value)) + if self.is_trunc_quant_enabled: + # remove avg scaling + rescaled_value = x.value * self._avg_scaling + x = x.set(value=rescaled_value) + x = x.set(bit_width=self.max_acc_bit_width(x.bit_width)) + x = self.trunc_quant(x) + else: + assert not self.is_trunc_quant_enabled + x = super(TruncAvgPool2d, self).forward(x) + return self.pack_output(x) def max_acc_bit_width(self, input_bit_width): @@ -127,23 +134,30 @@ def compute_kernel_size_stride(self, input_shape, output_shape): def forward(self, input: Union[Tensor, QuantTensor]): x = self.unpack_input(input) + # shortcut execution through the export impl during export if self.export_mode: - out = self.export_handler(x.value) + out = self.export_handler(_unpack_quant_tensor(x)) self._set_global_is_quant_layer(False) return out - y = x.set(value=super(TruncAdaptiveAvgPool2d, self).forward(x.value)) - k_size, stride = self.compute_kernel_size_stride(x.value.shape[2:], y.value.shape[2:]) + if self.cache_kernel_size_stride: self._cached_kernel_size = k_size self._cached_kernel_stride = stride - if self.is_trunc_quant_enabled: - assert y.is_not_none # check input quant tensor is filled with values - reduce_size = reduce(mul, k_size, 1) - rescaled_value = y.value * reduce_size # remove avg scaling - y = y.set(value=rescaled_value) - y = y.set(bit_width=self.max_acc_bit_width(y.bit_width, reduce_size)) - y = self.trunc_quant(y) + + if isinstance(x, QuantTensor): + y = x.set(value=super(TruncAdaptiveAvgPool2d, self).forward(x.value)) + k_size, stride = self.compute_kernel_size_stride(x.value.shape[2:], y.value.shape[2:]) + if self.is_trunc_quant_enabled: + reduce_size = reduce(mul, k_size, 1) + rescaled_value = y.value * reduce_size # remove avg scaling + y = y.set(value=rescaled_value) + y = y.set(bit_width=self.max_acc_bit_width(y.bit_width, reduce_size)) + y = self.trunc_quant(y) + else: + assert not self.is_trunc_quant_enabled + y = super(TruncAdaptiveAvgPool2d, self).forward(x) + return self.pack_output(y) def max_acc_bit_width(self, input_bit_width, reduce_size): diff --git a/src/brevitas/nn/quant_layer.py b/src/brevitas/nn/quant_layer.py index 7208aa8e3..f56ddd160 100644 --- a/src/brevitas/nn/quant_layer.py +++ b/src/brevitas/nn/quant_layer.py @@ -10,6 +10,7 @@ from torch.nn import Module from torch.nn import Parameter +from brevitas.quant_tensor import _unpack_quant_tensor from brevitas.quant_tensor import QuantTensor from .mixin import * @@ -135,7 +136,7 @@ def forward(self, input: Union[Tensor, QuantTensor]): quant_input = self.input_quant(input) # shortcut execution through the export impl during export if self.export_mode: - out = self.export_handler(quant_input.value) + out = self.export_handler(_unpack_quant_tensor(quant_input)) self._set_global_is_quant_layer(False) return out out = self.act_quant(quant_input) @@ -303,61 +304,73 @@ def forward_impl(self, inp: Union[Tensor, QuantTensor]) -> Union[Tensor, QuantTe # shortcut execution through the export impl during export if self.export_mode: - out = self.export_handler(inp.value) + out = self.export_handler(_unpack_quant_tensor(inp)) self._set_global_is_quant_layer(False) return out quant_input = self.input_quant(inp) quant_weight = self.quant_weight(quant_input) - if (self.return_quant_tensor or - (self.is_bias_quant_enabled and - (self.bias_quant.requires_input_scale or self.bias_quant.requires_input_bit_width))): - if quant_input.bit_width is not None and quant_weight.bit_width is not None: - output_bit_width = self.max_acc_bit_width( - quant_input.bit_width, quant_weight.bit_width) - if quant_input.scale is not None and quant_weight.scale is not None: - output_scale = self.quant_output_scale_impl( - inp, quant_input.scale, quant_weight.scale) - if quant_input.signed is not None: - output_signed = inp.signed or quant_weight.signed + compute_output_quant_tensor = isinstance(quant_input, QuantTensor) and isinstance( + quant_weight, QuantTensor) + if not (compute_output_quant_tensor or + self.is_output_quant_enabled) and self.return_quant_tensor: + raise RuntimeError("QuantLayer is not correctly configured") + + if isinstance(quant_input, QuantTensor) and isinstance(quant_weight, QuantTensor): + output_bit_width = self.max_acc_bit_width(quant_input.bit_width, quant_weight.bit_width) + output_scale = self.quant_output_scale_impl(inp, quant_input.scale, quant_weight.scale) + output_signed = quant_input.signed or quant_weight.signed if self.bias is not None: quant_bias = self.bias_quant(self.bias, output_scale, output_bit_width) - if not self.training and self.cache_inference_quant_bias: + if not self.training and self.cache_inference_quant_bias and isinstance(quant_bias, + QuantTensor): self._cached_bias = _CachedIO(quant_bias.detach(), metadata_only=False) - output_tensor = self.inner_forward_impl( - quant_input.value, quant_weight.value, quant_bias.value) - - if (self.return_quant_tensor and output_scale is not None and - (quant_bias.scale is None or - (quant_bias.scale is not None and - quant_bias.scale.data_ptr() != output_scale.data_ptr()))): - output_scale_broadcast_shape = compute_channel_view_shape(inp, channel_dim=1) - output_zero_point = -quant_bias.value.view( - output_scale_broadcast_shape) / output_scale - - if quant_bias.bit_width is not None and output_bit_width is not None: + _unpack_quant_tensor(quant_input), + _unpack_quant_tensor(quant_weight), + _unpack_quant_tensor(quant_bias)) + + if output_scale is not None: + if (isinstance(quant_bias, QuantTensor) and + quant_bias.scale.data_ptr() != output_scale.data_ptr()) or not isinstance( + quant_bias, QuantTensor): + channel_dim = -1 if isinstance(self, torch.nn.Linear) else 1 + output_scale_broadcast_shape = compute_channel_view_shape( + inp, channel_dim=channel_dim) + output_zero_point = -_unpack_quant_tensor(quant_bias).view( + output_scale_broadcast_shape) / output_scale + + if output_bit_width is not None and isinstance(quant_bias, QuantTensor): output_bit_width = torch.where( quant_bias.bit_width > output_bit_width, quant_bias.bit_width, output_bit_width) output_bit_width = output_bit_width + 1 else: - output_tensor = self.inner_forward_impl(quant_input.value, quant_weight.value, None) - - if self.return_quant_tensor and not self.is_output_quant_enabled: - if (quant_input.zero_point is not None and quant_weight.zero_point is not None and - ((quant_input.zero_point != 0.0).any() or (quant_weight.zero_point != 0.0).any())): - raise RuntimeError("Computing zero point of output accumulator not supported yet.") - elif quant_input.zero_point is not None and output_zero_point is None: - output_zero_point = quant_input.zero_point - - quant_output = QuantTensor( - value=output_tensor, - scale=output_scale, - zero_point=output_zero_point, - bit_width=output_bit_width, - signed=output_signed, - training=self.training) + output_tensor = self.inner_forward_impl( + _unpack_quant_tensor(quant_input), _unpack_quant_tensor(quant_weight), None) + + if not self.is_output_quant_enabled and self.return_quant_tensor: + if compute_output_quant_tensor: + if (quant_input.zero_point != 0.0).any() or (quant_weight.zero_point != 0.0).any(): + raise RuntimeError( + "Computing zero point of output accumulator not supported yet.") + elif output_zero_point is None: + output_zero_point = quant_input.zero_point + + elif output_zero_point is None: + output_zero_point = torch.zeros(1).type_as(output_tensor) + + if compute_output_quant_tensor: + quant_output = QuantTensor( + output_tensor, + scale=output_scale, + zero_point=output_zero_point, + bit_width=output_bit_width, + signed=output_signed, + training=self.training) + else: + quant_output = output_tensor + quant_output = self.output_quant(quant_output) return self.pack_output(quant_output) diff --git a/src/brevitas/nn/quant_rnn.py b/src/brevitas/nn/quant_rnn.py index 642c1b2d1..c27bf199b 100644 --- a/src/brevitas/nn/quant_rnn.py +++ b/src/brevitas/nn/quant_rnn.py @@ -23,6 +23,7 @@ from brevitas.quant import Int8WeightPerTensorFloat from brevitas.quant import Int32Bias from brevitas.quant import Uint8ActPerTensorFloat +from brevitas.quant_tensor import _unpack_quant_tensor QuantTupleShortEnabled = List[Tuple[Tensor, Tensor, Tensor, Tensor]] QuantTupleShortDisabled = List[Tuple[Tensor, Optional[Tensor], Optional[Tensor], Optional[Tensor]]] @@ -416,11 +417,12 @@ def forward(self, inp, state): quant_input = self.maybe_quantize_input(inp) quant_weight_ih, quant_weight_hh, quant_bias = self.gate_params_fwd( self.gate_params, quant_input) - if quant_bias.value is None: - quant_bias = torch.tensor(0., device=quant_input.value.device) + quant_input_value = _unpack_quant_tensor(quant_input) + if quant_bias is None: + quant_bias = torch.tensor(0., device=quant_input_value.device) else: - quant_bias = quant_bias.value - quant_state = self.maybe_quantize_state(quant_input.value, state, self.cell.output_quant) + quant_bias = _unpack_quant_tensor(quant_bias) + quant_state = self.maybe_quantize_state(quant_input_value, state, self.cell.output_quant) if self.export_mode: cell = self.export_handler elif self.fast_mode: @@ -428,10 +430,10 @@ def forward(self, inp, state): else: cell = self.cell quant_outputs = cell( - quant_input.value, - quant_state.value, - quant_weight_ih.value, - quant_weight_hh.value, + _unpack_quant_tensor(quant_input), + _unpack_quant_tensor(quant_state), + _unpack_quant_tensor(quant_weight_ih), + _unpack_quant_tensor(quant_weight_hh), quant_bias) quant_output = self.pack_quant_outputs(quant_outputs) quant_state = self.pack_quant_state(quant_outputs[-1], self.cell.output_quant) @@ -666,6 +668,7 @@ def fast_cell(self): def forward(self, inp, hidden_state, cell_state): quant_input = self.maybe_quantize_input(inp) + quant_input_value = _unpack_quant_tensor(quant_input) quant_weight_ii, quant_weight_hi, quant_bias_input = self.gate_params_fwd( self.input_gate_params, quant_input) quant_weight_ic, quant_weight_hc, quant_bias_cell = self.gate_params_fwd( @@ -680,26 +683,26 @@ def forward(self, inp, hidden_state, cell_state): quant_weight_if, quant_weight_hf, quant_bias_forget = self.gate_params_fwd( self.forget_gate_params, quant_input) # Handle None bias by setting it 0. - if quant_bias_input.value is None: - quant_bias_input = torch.tensor(0., device=quant_input.value.device) + if quant_bias_input is None: + quant_bias_input = torch.tensor(0., device=quant_input_value.device) else: - quant_bias_input = quant_bias_input.value - if quant_bias_forget.value is None: - quant_bias_forget = torch.tensor(0., device=quant_input.value.device) + quant_bias_input = _unpack_quant_tensor(quant_bias_input) + if quant_bias_forget is None: + quant_bias_forget = torch.tensor(0., device=quant_input_value.device) else: - quant_bias_forget = quant_bias_forget.value - if quant_bias_cell.value is None: - quant_bias_cell = torch.tensor(0., device=quant_input.value.device) + quant_bias_forget = _unpack_quant_tensor(quant_bias_forget) + if quant_bias_cell is None: + quant_bias_cell = torch.tensor(0., device=quant_input_value.device) else: - quant_bias_cell = quant_bias_cell.value - if quant_bias_output.value is None: - quant_bias_output = torch.tensor(0., device=quant_input.value.device) + quant_bias_cell = _unpack_quant_tensor(quant_bias_cell) + if quant_bias_output is None: + quant_bias_output = torch.tensor(0., device=quant_input_value.device) else: - quant_bias_output = quant_bias_output.value + quant_bias_output = _unpack_quant_tensor(quant_bias_output) quant_hidden_state = self.maybe_quantize_state( - quant_input.value, hidden_state, self.cell.output_quant) + quant_input_value, hidden_state, self.cell.output_quant) quant_cell_state = self.maybe_quantize_state( - quant_input.value, cell_state, self.cell.cell_state_quant) + quant_input_value, cell_state, self.cell.cell_state_quant) # Pick cell impl if self.export_mode: cell = self.export_handler @@ -708,17 +711,17 @@ def forward(self, inp, hidden_state, cell_state): else: cell = self.cell quant_outputs, quant_hidden_state, quant_cell_state = cell( - quant_input.value, - quant_hidden_state.value, - quant_cell_state.value, - quant_weight_ii=quant_weight_ii.value, - quant_weight_if=quant_weight_if.value, - quant_weight_ic=quant_weight_ic.value, - quant_weight_io=quant_weight_io.value, - quant_weight_hi=quant_weight_hi.value, - quant_weight_hf=quant_weight_hf.value, - quant_weight_hc=quant_weight_hc.value, - quant_weight_ho=quant_weight_ho.value, + quant_input_value, + _unpack_quant_tensor(quant_hidden_state), + _unpack_quant_tensor(quant_cell_state), + quant_weight_ii=_unpack_quant_tensor(quant_weight_ii), + quant_weight_if=_unpack_quant_tensor(quant_weight_if), + quant_weight_ic=_unpack_quant_tensor(quant_weight_ic), + quant_weight_io=_unpack_quant_tensor(quant_weight_io), + quant_weight_hi=_unpack_quant_tensor(quant_weight_hi), + quant_weight_hf=_unpack_quant_tensor(quant_weight_hf), + quant_weight_hc=_unpack_quant_tensor(quant_weight_hc), + quant_weight_ho=_unpack_quant_tensor(quant_weight_ho), quant_bias_input=quant_bias_input, quant_bias_forget=quant_bias_forget, quant_bias_cell=quant_bias_cell, @@ -967,6 +970,8 @@ def __init__( **kwargs) if cat_output_cell_states and cell_state_quant is not None and not shared_cell_state_quant: raise RuntimeError("Concatenating cell states requires shared cell quantizers.") + if return_quant_tensor and cell_state_quant is None: + raise RuntimeError("return_quant_tensor=True requires cell_state_quant != None.") self.cat_output_cell_states = cat_output_cell_states def forward(self, inp, hx=None, cx=None): diff --git a/src/brevitas/nn/quant_upsample.py b/src/brevitas/nn/quant_upsample.py index 10727cec5..f2735abf5 100644 --- a/src/brevitas/nn/quant_upsample.py +++ b/src/brevitas/nn/quant_upsample.py @@ -45,7 +45,7 @@ def forward(self, input: Union[Tensor, QuantTensor]): y_value = interpolate(x.value, self.size, self.scale_factor, self.mode, self.align_corners) if self.mode != 'nearest': # round interpolated values to scale - assert x.scale is not None, 'Input scale factor required to interpolate correctly' + assert isinstance(x, QuantTensor), 'Input scale factor required to interpolate correctly' y_value = round_ste(y_value / x.scale) * x.scale y = x.set(value=y_value) return self.pack_output(y) @@ -73,7 +73,7 @@ def forward(self, input: Union[Tensor, QuantTensor]): return out y_value = interpolate(x.value, self.size, self.scale_factor, self.mode, self.align_corners) # round interpolated values to scale - assert x.scale is not None, 'Input scale factor required to interpolate correctly' + assert isinstance(x, QuantTensor), 'Input scale factor required to interpolate correctly' y_value = round_ste(y_value / x.scale) * x.scale y = x.set(value=y_value) return self.pack_output(y) diff --git a/src/brevitas/nn/target/flexml.py b/src/brevitas/nn/target/flexml.py index e98438f75..66daa94cf 100644 --- a/src/brevitas/nn/target/flexml.py +++ b/src/brevitas/nn/target/flexml.py @@ -97,11 +97,13 @@ def _avg_scaling(self): def forward(self, input: Union[Tensor, QuantTensor]): x = self.unpack_input(input) - x = x.set(value=super(FlexMLQuantAvgPool2d, self).forward(x.value) * self.rescaling_const) - if x.scale is not None: + if isinstance(x, QuantTensor): + x = x.set( + value=super(FlexMLQuantAvgPool2d, self).forward(x.value) * self.rescaling_const) x = x.set(scale=x.scale * self.quantized_div_scale) - if x.bit_width is not None: x = x.set(bit_width=self.max_acc_bit_width(x.bit_width)) + else: + x = super(FlexMLQuantAvgPool2d, self).forward(x) * self.rescaling_const return self.pack_output(x) def max_acc_bit_width(self, input_bit_width): diff --git a/src/brevitas/proxy/parameter_quant.py b/src/brevitas/proxy/parameter_quant.py index 5a4b2ed55..1f6adf549 100644 --- a/src/brevitas/proxy/parameter_quant.py +++ b/src/brevitas/proxy/parameter_quant.py @@ -3,7 +3,7 @@ from abc import ABCMeta from abc import abstractmethod -from typing import List, Optional, Tuple +from typing import Optional, Union import torch from torch import Tensor @@ -94,13 +94,13 @@ def bit_width(self): bit_width_ = self.__call__(self.tracked_parameter_list[0]).bit_width return bit_width_ - def forward(self, x: torch.Tensor) -> QuantTensor: + def forward(self, x: torch.Tensor) -> Union[Tensor, QuantTensor]: if self.is_quant_enabled: impl = self.export_handler if self.export_mode else self.tensor_quant out, scale, zero_point, bit_width = impl(x) return QuantTensor(out, scale, zero_point, bit_width, self.is_signed, self.training) else: # quantization disabled - return QuantTensor(x, training=self.training) + return x class DecoupledWeightQuantProxyFromInjector(WeightQuantProxyFromInjector): @@ -115,13 +115,13 @@ def pre_zero_point(self): out, scale, zero_point, bit_width, pre_scale, pre_zero_point = output_tuple return pre_zero_point - def forward(self, x: torch.Tensor) -> QuantTensor: + def forward(self, x: torch.Tensor) -> Union[Tensor, QuantTensor]: if self.is_quant_enabled: impl = self.export_handler if self.export_mode else self.tensor_quant out, scale, zero_point, bit_width, pre_scale, pre_zero_point = impl(x) return QuantTensor(out, scale, zero_point, bit_width, self.is_signed, self.training) else: # quantization disabled - return QuantTensor(x, training=self.training) + return x class DecoupledWeightQuantWithInputProxyFromInjector(DecoupledWeightQuantProxyFromInjector): @@ -145,9 +145,8 @@ def pre_scale(self): def pre_zero_point(self): raise NotImplementedError - def forward( - self, x: torch.Tensor, input_bit_width: torch.Tensor, - input_is_signed: bool) -> QuantTensor: + def forward(self, x: torch.Tensor, input_bit_width: torch.Tensor, + input_is_signed: bool) -> Union[Tensor, QuantTensor]: if self.is_quant_enabled: impl = self.export_handler if self.export_mode else self.tensor_quant out, scale, zero_point, bit_width, pre_scale, pre_zero_point = impl(x, input_bit_width, input_is_signed) @@ -199,7 +198,7 @@ def forward( self, x: Tensor, input_scale: Optional[Tensor] = None, - input_bit_width: Optional[Tensor] = None) -> QuantTensor: + input_bit_width: Optional[Tensor] = None) -> Union[Tensor, QuantTensor]: if self.is_quant_enabled: impl = self.export_handler if self.export_mode else self.tensor_quant if self.requires_input_scale and input_scale is None: @@ -218,4 +217,4 @@ def forward( raise RuntimeError("Internally defined bit-width required") return QuantTensor(out, out_scale, out_zp, out_bit_width, self.is_signed, self.training) else: - return QuantTensor(x, training=self.training) + return x diff --git a/src/brevitas/proxy/runtime_quant.py b/src/brevitas/proxy/runtime_quant.py index 9d15f3bba..0324465c1 100644 --- a/src/brevitas/proxy/runtime_quant.py +++ b/src/brevitas/proxy/runtime_quant.py @@ -137,7 +137,7 @@ def bit_width(self): scale = self.__call__(self._zero_hw_sentinel()).bit_width return scale - def forward(self, x: Union[Tensor, QuantTensor]) -> QuantTensor: + def forward(self, x: Union[Tensor, QuantTensor]) -> Union[Tensor, QuantTensor]: if self.fused_activation_quant_proxy is not None: y = x if isinstance(y, QuantTensor): @@ -151,22 +151,24 @@ def forward(self, x: Union[Tensor, QuantTensor]) -> QuantTensor: else: y = self.fused_activation_quant_proxy(y) # If y is an empty QuantTensor, we need to check if this is a passthrough proxy, - # otherwise return an empty QuantTensor + # otherwise return a simple Tensor if isinstance(y, tuple) and not any(map(lambda f: f is None, y)): return QuantTensor(*y, signed=self.is_signed, training=self.training) elif self.is_passthrough_act: # preserve scale/zp/bit/sign even without output quant if isinstance(y, tuple): y = y[0] - return QuantTensor(y, x.scale, x.zero_point, x.bit_width, x.signed, self.training) + if isinstance(x, QuantTensor): + return QuantTensor( + y, x.scale, x.zero_point, x.bit_width, x.signed, self.training) + else: + return y else: if isinstance(y, tuple): y = y[0] - return QuantTensor(y, training=self.training) + return y else: - if isinstance(x, QuantTensor): # passthrough - return x - else: - return QuantTensor(x, training=self.training) + # If fused activation quant proxy is not enabled, return the input + return x class DynamicActQuantProxyFromInjector(ActQuantProxyFromInjector): @@ -184,7 +186,7 @@ def bit_width(self): class ClampQuantProxyFromInjector(QuantProxyFromInjector, AccQuantProxyProtocol): - def forward(self, x: QuantTensor): + def forward(self, x: QuantTensor) -> Union[Tensor, QuantTensor]: if self.is_quant_enabled: out_tuple = self.tensor_quant(x.value, x.scale, x.bit_width) out_value, out_scale, out_zp, out_bit_width = out_tuple @@ -197,11 +199,12 @@ class TruncQuantProxyFromInjector(QuantProxyFromInjector, AccQuantProxyProtocol) def bit_width(self): zhs = self._zero_hw_sentinel() - empty_imp = QuantTensor(zhs, zhs, zhs, zhs) + # Signed might or might not be defined. We just care about retrieving the bitwidth + empty_imp = QuantTensor(zhs, zhs, zhs, zhs, signed=True, training=self.training) bit_width = self.__call__(empty_imp).bit_width return bit_width - def forward(self, x: QuantTensor): + def forward(self, x: QuantTensor) -> Union[Tensor, QuantTensor]: if self.is_quant_enabled: if self.export_mode: out_tuple = self.export_handler( diff --git a/src/brevitas/quant_tensor/__init__.py b/src/brevitas/quant_tensor/__init__.py index bd1da8edd..c66690c50 100644 --- a/src/brevitas/quant_tensor/__init__.py +++ b/src/brevitas/quant_tensor/__init__.py @@ -3,10 +3,12 @@ from abc import ABC from typing import NamedTuple, Optional +import warnings import torch from torch import Tensor +import brevitas.config as config from brevitas.function.ops import max_int from brevitas.function.ops import min_int from brevitas.function.ops_ste import ceil_ste @@ -29,7 +31,7 @@ class QuantTensorBase(NamedTuple): def _unpack_quant_tensor(input_data): if isinstance(input_data, QuantTensor): - return input_data.tensor + return input_data.value elif isinstance(input_data, tuple): return tuple([_unpack_quant_tensor(v) for v in input_data]) elif isinstance(input_data, list): @@ -40,33 +42,22 @@ def _unpack_quant_tensor(input_data): return input_data -def _is_all_nested_not_none(input_data): - if isinstance(input_data, QuantTensor): - return input_data.is_not_none - elif isinstance(input_data, (tuple, list)): - return all([_is_all_nested_not_none(v) for v in input_data]) - elif isinstance(input_data, dict): - return all([_is_all_nested_not_none(v) for v in input_data.values()]) - else: - return True - - class QuantTensor(QuantTensorBase): - def __new__( - cls, value, scale=None, zero_point=None, bit_width=None, signed=None, training=None): + def __new__(cls, value, scale, zero_point, bit_width, signed, training): - if scale is not None and not isinstance(scale, torch.Tensor): + if not isinstance(scale, torch.Tensor): scale = torch.tensor(scale, dtype=torch.float) - if zero_point is not None and not isinstance(zero_point, torch.Tensor): + if not isinstance(zero_point, torch.Tensor): zero_point = torch.tensor(zero_point, dtype=torch.float) - if bit_width is not None and not isinstance(bit_width, torch.Tensor): + if not isinstance(bit_width, torch.Tensor): bit_width = torch.tensor(bit_width, dtype=torch.float) - if signed is not None and not isinstance(signed, torch.Tensor): + if not isinstance(signed, torch.Tensor): signed = torch.tensor(signed, dtype=torch.bool) - if training is not None and not isinstance(training, torch.Tensor): + if not isinstance(training, torch.Tensor): training = torch.tensor(training, dtype=torch.bool) - return super().__new__(cls, value, scale, zero_point, bit_width, signed, training) + quant_tensor = super().__new__(cls, value, scale, zero_point, bit_width, signed, training) + return quant_tensor @property def signed(self): @@ -86,8 +77,7 @@ def __torch_function__(self, func, types, args=(), kwargs=None): if kwargs is None: kwargs = {} if (func not in QUANT_TENSOR_FN_HANDLER or - not all(issubclass(t, QuantTensor) for t in types) or - not (_is_all_nested_not_none(args) and _is_all_nested_not_none(kwargs))): + not all(issubclass(t, QuantTensor) for t in types)): args = _unpack_quant_tensor(args) kwargs = _unpack_quant_tensor(kwargs) return func(*args, **kwargs) @@ -97,12 +87,6 @@ def __torch_function__(self, func, types, args=(), kwargs=None): def tensor(self): return self.value - @property - def is_not_none(self): - return ( - self.value is not None and self.scale is not None and self.zero_point is not None and - self.bit_width is not None and self.signed is not None) - @property def _pre_round_int_value(self): value = self.value @@ -118,30 +102,27 @@ def _pre_round_int_value(self): @property def is_valid(self): - if self.is_not_none: - with torch.no_grad(): - pre_round_int_value = self._pre_round_int_value - rounded_int_value = torch.round(pre_round_int_value) - max_abs_diff = torch.max(torch.abs(pre_round_int_value - rounded_int_value)) - atol = BFLOAT16_IS_VALID_ATOL if self.value.dtype == torch.bfloat16 else IS_VALID_ATOL - is_int = max_abs_diff < atol - if self.bit_width >= 2: - if self.signed: - is_upper_b = (2.0 ** (self.bit_width - 1) - 1 >= rounded_int_value).all() - is_lower_b = (-2.0 ** (self.bit_width - 1) <= rounded_int_value).all() - else: - is_upper_b = (2.0 ** self.bit_width - 1 >= rounded_int_value).all() - is_lower_b = (0. <= rounded_int_value).all() - return (is_int & is_upper_b & is_lower_b).item() - else: # binary case - unique_vals = rounded_int_value.unique( - sorted=False, return_counts=False, return_inverse=False) - is_binary = unique_vals.view(-1).size()[0] == 2 - is_signed = (unique_vals < 0.).any().item() - sign_match = is_signed == self.signed - return is_int.item() and is_binary and sign_match - else: - return False + with torch.no_grad(): + pre_round_int_value = self._pre_round_int_value + rounded_int_value = torch.round(pre_round_int_value) + max_abs_diff = torch.max(torch.abs(pre_round_int_value - rounded_int_value)) + atol = BFLOAT16_IS_VALID_ATOL if self.value.dtype == torch.bfloat16 else IS_VALID_ATOL + is_int = max_abs_diff < atol + if self.bit_width >= 2: + if self.signed: + is_upper_b = (2.0 ** (self.bit_width - 1) - 1 >= rounded_int_value).all() + is_lower_b = (-2.0 ** (self.bit_width - 1) <= rounded_int_value).all() + else: + is_upper_b = (2.0 ** self.bit_width - 1 >= rounded_int_value).all() + is_lower_b = (0. <= rounded_int_value).all() + return (is_int & is_upper_b & is_lower_b).item() + else: # binary case + unique_vals = rounded_int_value.unique( + sorted=False, return_counts=False, return_inverse=False) + is_binary = unique_vals.view(-1).size()[0] == 2 + is_signed = (unique_vals < 0.).any().item() + sign_match = is_signed == self.signed + return is_int.item() and is_binary and sign_match @property def device(self): @@ -166,18 +147,18 @@ def detach_(self): def detach(self): return QuantTensor( self.value.detach(), - self.scale.detach() if self.scale is not None else None, - self.zero_point.detach() if self.zero_point is not None else None, - self.bit_width.detach() if self.bit_width is not None else None, + self.scale.detach(), + self.zero_point.detach(), + self.bit_width.detach(), self.signed, self.training) def contiguous(self): return QuantTensor( self.value.contiguous(), - self.scale.contiguous() if self.scale is not None else None, - self.zero_point.contiguous() if self.zero_point is not None else None, - self.bit_width.contiguous() if self.bit_width is not None else None, + self.scale.contiguous(), + self.zero_point.contiguous(), + self.bit_width.contiguous(), self.signed, self.training) @@ -209,10 +190,7 @@ def check_input_type(tensor): @staticmethod def is_zero_zero_point(tensor): QuantTensor.check_input_type(tensor) - if tensor.zero_point is not None: - return (tensor.zero_point == 0.).all() - else: - return None + return (tensor.zero_point == 0.).all() def check_scaling_factors_same(self, other): if self.training is not None and self.training: @@ -282,7 +260,7 @@ def cat(tensors, dim, out=None): return tensors[0] else: first_qt = tensors[0] - if all([isinstance(qt, QuantTensor) and qt.is_not_none for qt in tensors]): + if all([isinstance(qt, QuantTensor) for qt in tensors]): for qt in tensors[1:]: first_qt.check_scaling_factors_same(qt) first_qt.check_zero_points_same(qt) @@ -337,32 +315,32 @@ def __neg__(self): def to(self, *args, **kwargs): return QuantTensor( self.value.to(*args, **kwargs), - self.scale.to(*args, **kwargs) if self.scale is not None else None, - self.zero_point.to(*args, **kwargs) if self.zero_point is not None else None, - self.bit_width.to(*args, **kwargs) if self.bit_width is not None else None, + self.scale.to(*args, **kwargs), + self.zero_point.to(*args, **kwargs), + self.bit_width.to(*args, **kwargs), self.signed, self.training) def cuda(self, *args, **kwargs): return QuantTensor( self.value.cuda(*args, **kwargs), - self.scale.cuda(*args, **kwargs) if self.scale is not None else None, - self.zero_point.cuda(*args, **kwargs) if self.zero_point is not None else None, - self.bit_width.cuda(*args, **kwargs) if self.bit_width is not None else None, + self.scale.cuda(*args, **kwargs), + self.zero_point.cuda(*args, **kwargs), + self.bit_width.cuda(*args, **kwargs), self.signed, self.training) def cpu(self, *args, **kwargs): return QuantTensor( self.value.cpu(*args, **kwargs), - self.scale.cpu(*args, **kwargs) if self.scale is not None else None, - self.zero_point.cpu(*args, **kwargs) if self.zero_point is not None else None, - self.bit_width.cpu(*args, **kwargs) if self.bit_width is not None else None, + self.scale.cpu(*args, **kwargs), + self.zero_point.cpu(*args, **kwargs), + self.bit_width.cpu(*args, **kwargs), self.signed, self.training) def __add__(self, other): - if isinstance(other, QuantTensor) and self.is_not_none and other.is_not_none: + if isinstance(other, QuantTensor): self.check_scaling_factors_same(other) output_value = self.value + other.value output_scale = (self.scale + other.scale) / 2 @@ -394,7 +372,7 @@ def __rmul__(self, other): return self.__mul__(other) def __mul__(self, other): - if isinstance(other, QuantTensor) and self.is_not_none and other.is_not_none: + if isinstance(other, QuantTensor): output_value = self.value * other.value output_scale = self.scale * other.scale output_bit_width = self.bit_width + other.bit_width @@ -420,8 +398,11 @@ def __mul__(self, other): def __sub__(self, other): return self.__add__(-other) + def __str__(self): + return f"QuantTensor(value={self.value}, scale={self.scale}, zero_point={self.zero_point}, bit_width={self.bit_width}, signed_t={self.signed_t}, training_t={self.training_t})" + def __truediv__(self, other): - if isinstance(other, QuantTensor) and self.is_not_none and other.is_not_none: + if isinstance(other, QuantTensor): output_tensor = self.value / other.value # Note, output tensor not guaranteed to pass self.is_valid() max_int_denominator = 2 ** (other.bit_width - int(other.signed)) output_scale = self.scale / (other.scale * max_int_denominator) diff --git a/src/brevitas/quant_tensor/torch_handler.py b/src/brevitas/quant_tensor/torch_handler.py index 3b64bca89..1b6e43a37 100644 --- a/src/brevitas/quant_tensor/torch_handler.py +++ b/src/brevitas/quant_tensor/torch_handler.py @@ -23,10 +23,7 @@ def decorator(func): def quant_invariant_handler(fn, inp, *args, **kwargs): out_value = fn(inp.value, *args, **kwargs) - if inp.is_not_none: - return inp.set(value=out_value) - else: - return out_value + return inp.set(value=out_value) @implements(torch.flatten) diff --git a/src/brevitas_examples/text_to_speech/melgan/res_stack_brevitas.py b/src/brevitas_examples/text_to_speech/melgan/res_stack_brevitas.py index 07dcfc8f4..6356cfea3 100644 --- a/src/brevitas_examples/text_to_speech/melgan/res_stack_brevitas.py +++ b/src/brevitas_examples/text_to_speech/melgan/res_stack_brevitas.py @@ -30,6 +30,7 @@ OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.""" +from brevitas.quant_tensor import _unpack_quant_tensor from brevitas.quant_tensor import QuantTensor from .common import * @@ -68,13 +69,13 @@ def forward(self, x): for layer in self.layers: x = self.scale_norm(x) if isinstance(x, QuantTensor): - x_unp, _, _ = x + x_unp = _unpack_quant_tensor(x) else: x_unp = x x_layer = self.scale_norm(layer(x_unp)) if isinstance(x_layer, QuantTensor): - x_layer_unp, _, _ = x_layer + x_layer_unp = _unpack_quant_tensor(x_layer) else: x_layer_unp = x_layer @@ -84,7 +85,7 @@ def forward(self, x): x = x + x_layer if isinstance(x, QuantTensor): - x, _, _ = x + x = _unpack_quant_tensor(x) return x diff --git a/tests/brevitas/fx/test_tracer.py b/tests/brevitas/fx/test_tracer.py index be5698d2c..d7ef1d3ac 100644 --- a/tests/brevitas/fx/test_tracer.py +++ b/tests/brevitas/fx/test_tracer.py @@ -232,8 +232,8 @@ def test_module(module): @pytest.mark.parametrize('module', QUANT_TENSOR_MODULES) def test_quant_module(module): mod = module() - x = QuantTensor(torch.randn(INPUT_SIZE)) - x_trace = QuantTensor(torch.randn(INPUT_SIZE)) + x = torch.randn(INPUT_SIZE) + x_trace = torch.randn(INPUT_SIZE) with torch.no_grad(): out = mod(x) graph_model = value_trace(mod, value_args={'x': x_trace}) diff --git a/tests/brevitas/nn/nn_quantizers_fixture.py b/tests/brevitas/nn/nn_quantizers_fixture.py index 538e836e8..98a4a0d7b 100644 --- a/tests/brevitas/nn/nn_quantizers_fixture.py +++ b/tests/brevitas/nn/nn_quantizers_fixture.py @@ -384,6 +384,8 @@ def case_quant_lstm_full( if return_quant_tensor and io_quantizer is None: pytest.skip("return_quant_tensor cannot be True if no io_quantizer is specified") + if return_quant_tensor and signed_act_quantizer is None: + pytest.skip("return_quant_tensor cannot be True if no cell_state_quant is specified") class Model(nn.Module): diff --git a/tests/brevitas/nn/test_linear.py b/tests/brevitas/nn/test_linear.py index 457b66e20..b9690ce63 100644 --- a/tests/brevitas/nn/test_linear.py +++ b/tests/brevitas/nn/test_linear.py @@ -56,5 +56,7 @@ def test_forward_bias_int(self): torch.rand(size=(3, INPUT_FEATURES)), torch.tensor(1.0), torch.tensor(0.0), - torch.tensor(3)) + torch.tensor(3), + signed=True, + training=False) assert mod(x) is not None diff --git a/tests/brevitas/nn/test_nn_quantizers.py b/tests/brevitas/nn/test_nn_quantizers.py index 55296ff35..bbee8daca 100644 --- a/tests/brevitas/nn/test_nn_quantizers.py +++ b/tests/brevitas/nn/test_nn_quantizers.py @@ -41,8 +41,13 @@ def test_quant_wbiol(model_input, current_cases): is_input_quanttensor = kwargs['io_quant'] is not None or kwargs['input_quantized'] - if (not is_input_quanttensor or - kwargs['weight_quant'] is None) and kwargs['bias_quant'] == 'quant_external': + if (not (is_input_quanttensor and kwargs['weight_quant'] is not None) and + kwargs['io_quant'] is None) and kwargs['return_quant_tensor']: + with pytest.raises(RuntimeError, match='QuantLayer is not correctly configured'): + output = model(input) + return + elif (not is_input_quanttensor or + kwargs['weight_quant'] is None) and kwargs['bias_quant'] == 'quant_external': with pytest.raises(RuntimeError, match='Input scale required'): output = model(input) return @@ -57,14 +62,6 @@ def test_quant_wbiol(model_input, current_cases): if kwargs['return_quant_tensor']: assert isinstance(output, QuantTensor) - # Empty QuantTensor - if ( not kwargs['input_quantized'] or kwargs['weight_quant'] is None) and \ - kwargs['io_quant'] is None: - assert output.scale is None - assert output.bit_width is None - else: # "Full" QuantTensor - assert output.scale is not None - assert output.bit_width is not None else: assert isinstance(output, torch.Tensor) @@ -72,7 +69,6 @@ def test_quant_wbiol(model_input, current_cases): @pytest_cases.parametrize_with_cases( 'model_input', cases=[case_quant_lstm_full, case_quant_rnn_full]) def test_quant_lstm_rnn_full(model_input, current_cases): - model, input = model_input cases_generator_func = current_cases['model_input'][1] case_id = get_case_id(cases_generator_func) @@ -80,7 +76,9 @@ def test_quant_lstm_rnn_full(model_input, current_cases): kwargs = parse_args(args) is_input_quanttensor = kwargs['io_quant'] is not None or kwargs['input_quantized'] + return_quant_tensor = kwargs['return_quant_tensor'] + model, input = model_input if (kwargs['bias_quant'] == 'quant_external') and ( \ (not is_input_quanttensor or kwargs['weight_quant'] is None) or \ (kwargs['num_layers']> 1 and (kwargs['weight_quant'] is None or kwargs['io_quant'] is None))): @@ -98,41 +96,23 @@ def test_quant_lstm_rnn_full(model_input, current_cases): else: output, h = output c = None - return_quant_tensor = kwargs['return_quant_tensor'] if return_quant_tensor: assert isinstance(output, QuantTensor) - # Empty QuantTensor - if ( not kwargs['input_quantized'] or kwargs['weight_quant'] is None) and \ - kwargs['io_quant'] is None: - assert output.scale is None - assert output.bit_width is None - else: # "Full" QuantTensor - assert output.scale is not None - assert output.bit_width is not None else: assert isinstance(output, torch.Tensor) if h is not None: - if return_quant_tensor and kwargs['io_quant'] is not None: + if return_quant_tensor: assert isinstance(h, QuantTensor) else: assert isinstance(h, torch.Tensor) if c is not None: - if kwargs['signed_act'] is None or not kwargs['return_quant_tensor']: - if not kwargs['bidirectional']: - if not kwargs['return_quant_tensor']: - assert isinstance(c, torch.Tensor) - elif kwargs['return_quant_tensor'] and kwargs['signed_act'] is None and kwargs[ - 'num_layers'] == 2: - assert isinstance(c, torch.Tensor) - else: - assert isinstance(c, QuantTensor) - else: - assert isinstance(c, torch.Tensor) - else: + if return_quant_tensor: assert isinstance(c, QuantTensor) + else: + assert isinstance(c, torch.Tensor) @pytest_cases.parametrize_with_cases('model_input', cases=[case_quant_lstm, case_quant_rnn]) @@ -163,29 +143,21 @@ def test_quant_lstm_rnn(model_input, current_cases): else: output, h = output c = None - return_quant_tensor = kwargs['return_quant_tensor'] and kwargs['io_quant'] is not None + return_quant_tensor = kwargs['return_quant_tensor'] if return_quant_tensor: assert isinstance(output, QuantTensor) - # Empty QuantTensor - if ( not kwargs['input_quantized'] or kwargs['weight_quant'] is None) and \ - kwargs['io_quant'] is None: - assert output.scale is None - assert output.bit_width is None - else: # "Full" QuantTensor - assert output.scale is not None - assert output.bit_width is not None else: assert isinstance(output, torch.Tensor) if h is not None: - if return_quant_tensor and kwargs['io_quant'] is not None: + if return_quant_tensor: assert isinstance(h, QuantTensor) else: assert isinstance(h, torch.Tensor) if c is not None: - if return_quant_tensor and kwargs['io_quant'] is not None: + if return_quant_tensor: assert isinstance(c, QuantTensor) else: assert isinstance(c, torch.Tensor) @@ -199,23 +171,30 @@ def test_quant_mha(model_input, current_cases): case_id = get_case_id(cases_generator_func) args = case_id.split('-')[1:] # Exclude first argument kwargs = parse_args(args) - - if (kwargs['io_quant'] is None or + is_input_quanttensor = kwargs['io_quant'] is not None or kwargs['input_quantized'] + if (not is_input_quanttensor or kwargs['weight_quant'] is None) and kwargs['bias_quant'] == 'quant_external': with pytest.raises(RuntimeError, match='Input scale required'): output, _ = model(inp, inp, inp) return - + elif kwargs['io_quant'] is None and kwargs['return_quant_tensor']: + with pytest.raises(RuntimeError, match='QuantLayer is not correctly configured'): + output, _ = model(inp, inp, inp) + return + elif kwargs['io_quant'] is None and kwargs['bias_quant'] == 'quant_external': + with pytest.raises(RuntimeError, match='Input scale required'): + output, _ = model(inp, inp, inp) + return + elif kwargs['weight_quant'] is not None and kwargs['io_quant'] is None: + if kwargs['weight_quant'] == 'quant_asym' and kwargs['return_quant_tensor']: + with pytest.raises( + RuntimeError, + match='Computing zero point of output accumulator not supported yet.'): + output, _ = model(inp, inp, inp) + return output, _ = model(inp, inp, inp) if kwargs['return_quant_tensor']: assert isinstance(output, QuantTensor) - # Empty QuantTensor - if kwargs['io_quant'] is None: - assert output.scale is None - assert output.bit_width is None - else: # "Full" QuantTensor - assert output.scale is not None - assert output.bit_width is not None else: assert isinstance(output, torch.Tensor) diff --git a/tests/brevitas_ort/common.py b/tests/brevitas_ort/common.py index a7f87cbef..fa324f0be 100644 --- a/tests/brevitas_ort/common.py +++ b/tests/brevitas_ort/common.py @@ -27,6 +27,7 @@ from brevitas.quant.shifted_scaled_int import ShiftedUint8ActPerTensorFloat from brevitas.quant.shifted_scaled_int import ShiftedUint8WeightPerChannelFloat from brevitas.quant.shifted_scaled_int import ShiftedUint8WeightPerTensorFloat +from brevitas.quant_tensor import QuantTensor from brevitas_examples.common.generative.quantizers import ShiftedUint8DynamicActPerTensorFloat SEED = 123456 @@ -116,9 +117,15 @@ def is_brevitas_ort_close( input_t = torch.from_numpy(np_input) with torch.no_grad(): brevitas_output = model(input_t) + if isinstance(brevitas_output, QuantTensor): + computed_out = brevitas_output.value + scale = brevitas_output.scale + else: + computed_out = brevitas_output + scale = 1. if tolerance is not None and export_type == 'qcdq': - tolerance = tolerance * brevitas_output.scale # Float Output, tolerance is +/- output scale + tolerance = tolerance * scale # Float Output, tolerance is +/- output scale if export_type == 'qonnx': exported_model = export_qonnx(model, input_t, export_path=export_name) @@ -130,7 +137,7 @@ def is_brevitas_ort_close( else: if export_type == 'qop': export_onnx_qop(model, input_t, export_path=export_name) - brevitas_output = brevitas_output.int(float_datatype=False) + computed_out = brevitas_output.int(float_datatype=False) elif export_type == 'qcdq': export_onnx_qcdq(model, input_t, export_path=export_name) elif export_type == 'qcdq_opset14': @@ -145,13 +152,13 @@ def is_brevitas_ort_close( if first_output_only: if isinstance(ort_output, (tuple, list)): ort_output = ort_output[0] - if isinstance(brevitas_output, tuple): - brevitas_output = brevitas_output[0] + if isinstance(computed_out, tuple): + computed_out = computed_out[0] # make sure we are not comparing 0s - if (ort_output == 0).all() and (brevitas_output == 0).all(): + if (ort_output == 0).all() and (computed_out == 0).all(): pytest.skip("Skip testing against all 0s.") - return recursive_allclose(ort_output, brevitas_output, tolerance) + return recursive_allclose(ort_output, computed_out, tolerance) def gen_linspaced_data(num_samples, min_val=-1.0, max_val=1.0):