diff --git a/notebooks/02_quant_activation_overview.ipynb b/notebooks/02_quant_activation_overview.ipynb index 4f58acc8d..2c41cdd94 100644 --- a/notebooks/02_quant_activation_overview.ipynb +++ b/notebooks/02_quant_activation_overview.ipynb @@ -648,7 +648,7 @@ }, { "cell_type": "code", - "execution_count": 19, + "execution_count": 30, "metadata": {}, "outputs": [], "source": [ @@ -664,20 +664,9 @@ }, { "cell_type": "code", - "execution_count": 20, + "execution_count": 31, "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "True" - ] - }, - "execution_count": 20, - "metadata": {}, - "output_type": "execute_result" - } - ], + "outputs": [], "source": [ "out1_train = quant_hard_tanh(inp1)\n", "quant_hard_tanh.eval()\n", @@ -694,7 +683,7 @@ }, { "cell_type": "code", - "execution_count": 161, + "execution_count": 32, "metadata": {}, "outputs": [ { @@ -703,7 +692,7 @@ "tensor(2.9998, grad_fn=)" ] }, - "execution_count": 161, + "execution_count": 32, "metadata": {}, "output_type": "execute_result" } @@ -722,29 +711,136 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "We can see that the per-tensor scale parameter has calibrated itself to provide a full quantization range of 3, matching that of the most extreme channel.\n", + "We can see that the per-tensor scale parameter has calibrated itself to provide a full quantization range of 3, matching that of the most extreme channel. \n", "\n", - "Next, we initialise a new `QuantRelU` instance, but this time we specify that we desire per-channel quantization i.e. `scaling_per_output_channel=True`. To accomplish this, we also need to give it some extra information on the dimensions of the inputted tensor, so that it knows which dimensions to interpret as the output channels. This is done via the `per_channel_broadcastable_shape` and `scaling_stats_permute_dims` attributes. \n", + "We can take a look at the `QuantReLU` object, and in particular look at what the ` scaling_impl` object is composed of. It is responsible for gathering statistics for determining the quantization parameters, and we can see that its ` stats_input_view_shape_impl` attribute is set to be an instance of `OverTensorView`. This is defined [here](https://github.com/Xilinx/brevitas/blob/200456825f3b4b8db414f2b25b64311f82d3991a/src/brevitas/core/function_wrapper/shape.py#L78), and serves to flatten out the observed tensor into a 1D tensor and, in this case, use the `AbsPercentile` observer to calculate the quantization parameters during the gathering statistics stage of QAT." + ] + }, + { + "cell_type": "code", + "execution_count": 33, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "QuantReLU(\n", + " (input_quant): ActQuantProxyFromInjector(\n", + " (_zero_hw_sentinel): StatelessBuffer()\n", + " )\n", + " (act_quant): ActQuantProxyFromInjector(\n", + " (_zero_hw_sentinel): StatelessBuffer()\n", + " (fused_activation_quant_proxy): FusedActivationQuantProxy(\n", + " (activation_impl): ReLU()\n", + " (tensor_quant): RescalingIntQuant(\n", + " (int_quant): IntQuant(\n", + " (float_to_int_impl): RoundSte()\n", + " (tensor_clamp_impl): TensorClamp()\n", + " (delay_wrapper): DelayWrapper(\n", + " (delay_impl): _NoDelay()\n", + " )\n", + " )\n", + " (scaling_impl): ParameterFromRuntimeStatsScaling(\n", + " (stats_input_view_shape_impl): OverTensorView()\n", + " (stats): _Stats(\n", + " (stats_impl): AbsPercentile()\n", + " )\n", + " (restrict_scaling): _RestrictValue(\n", + " (restrict_value_impl): FloatRestrictValue()\n", + " )\n", + " (clamp_scaling): _ClampValue(\n", + " (clamp_min_ste): ScalarClampMinSte()\n", + " )\n", + " (restrict_inplace_preprocess): Identity()\n", + " (restrict_preprocess): Identity()\n", + " )\n", + " (int_scaling_impl): IntScaling()\n", + " (zero_point_impl): ZeroZeroPoint(\n", + " (zero_point): StatelessBuffer()\n", + " )\n", + " (msb_clamp_bit_width_impl): BitWidthConst(\n", + " (bit_width): StatelessBuffer()\n", + " )\n", + " )\n", + " )\n", + " )\n", + ")" + ] + }, + "execution_count": 33, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "per_tensor_quant_relu" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Next, we initialise a new `QuantRelU` instance, but this time we specify that we desire per-channel quantization i.e. `scaling_per_output_channel=True`. This will implictly call `scaling_stats_input_view_shape_impl`, defined [here](https://github.com/Xilinx/brevitas/blob/200456825f3b4b8db414f2b25b64311f82d3991a/src/brevitas/quant/solver/common.py#L184), and will change the `QuantReLU` from using a per-tensor view when gathering stats to a per output channel view ([`OverOutputChannelView`](https://github.com/Xilinx/brevitas/blob/200456825f3b4b8db414f2b25b64311f82d3991a/src/brevitas/core/function_wrapper/shape.py#L52)). This simply permutes the tensor into a 2D tensor, with dim 0 equal to the number of output channels.\n", "\n", - "`per_channel_broadcastable_shape` represents what the dimensions of the quantization parameters will be, and should be laid out to match those of the output channels of the outputted tensor. We also need to specify the permutation dimensions via `scaling_stats_permute_dims` so as to shape the tensor into a standard format of output channels first. This is so that during the statistics gathering stage of QAT the correct stats will be gathered." + "To accomplish this, we also need to give it some extra information: `scaling_stats_permute_dims` and `per_channel_broadcastable_shape`. `scaling_stats_permute_dims` is responsible for defining how we do the permutation. `per_channel_broadcastable_shape` simply represents what the dimensions of the quantization parameters will be, i.e. there should be one parameter per output channel.\n", + "\n", + "Below, we can see that in the per-channel ` QuantReLU` instance, the `stats_input_view_shape_impl` is now ` OverOutputChannelView`, and uses a `PermuteDims` [instance](https://github.com/Xilinx/brevitas/blob/200456825f3b4b8db414f2b25b64311f82d3991a/src/brevitas/core/function_wrapper/shape.py#L21) to do the permutation of the tensor to, in this case, a 2D tensor. " ] }, { "cell_type": "code", - "execution_count": 160, + "execution_count": 35, "metadata": {}, "outputs": [ { "data": { "text/plain": [ - "tensor([[[[2.9999]],\n", - "\n", - " [[1.0000]],\n", - "\n", - " [[1.0000]]]], grad_fn=)" + "QuantReLU(\n", + " (input_quant): ActQuantProxyFromInjector(\n", + " (_zero_hw_sentinel): StatelessBuffer()\n", + " )\n", + " (act_quant): ActQuantProxyFromInjector(\n", + " (_zero_hw_sentinel): StatelessBuffer()\n", + " (fused_activation_quant_proxy): FusedActivationQuantProxy(\n", + " (activation_impl): ReLU()\n", + " (tensor_quant): RescalingIntQuant(\n", + " (int_quant): IntQuant(\n", + " (float_to_int_impl): RoundSte()\n", + " (tensor_clamp_impl): TensorClamp()\n", + " (delay_wrapper): DelayWrapper(\n", + " (delay_impl): _NoDelay()\n", + " )\n", + " )\n", + " (scaling_impl): ParameterFromRuntimeStatsScaling(\n", + " (stats_input_view_shape_impl): OverOutputChannelView(\n", + " (permute_impl): PermuteDims()\n", + " )\n", + " (stats): _Stats(\n", + " (stats_impl): AbsPercentile()\n", + " )\n", + " (restrict_scaling): _RestrictValue(\n", + " (restrict_value_impl): FloatRestrictValue()\n", + " )\n", + " (clamp_scaling): _ClampValue(\n", + " (clamp_min_ste): ScalarClampMinSte()\n", + " )\n", + " (restrict_inplace_preprocess): Identity()\n", + " (restrict_preprocess): Identity()\n", + " )\n", + " (int_scaling_impl): IntScaling()\n", + " (zero_point_impl): ZeroZeroPoint(\n", + " (zero_point): StatelessBuffer()\n", + " )\n", + " (msb_clamp_bit_width_impl): BitWidthConst(\n", + " (bit_width): StatelessBuffer()\n", + " )\n", + " )\n", + " )\n", + " )\n", + ")" ] }, - "execution_count": 160, + "execution_count": 35, "metadata": {}, "output_type": "execute_result" } @@ -755,6 +851,37 @@ " per_channel_broadcastable_shape=(1, out_channels, 1 , 1),\n", " scaling_stats_permute_dims=(1, 0, 2, 3),\n", " )\n", + "per_chan_quant_relu" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "We can also observe the effect on the quantization parameters:" + ] + }, + { + "cell_type": "code", + "execution_count": 34, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "tensor([[[[2.9999]],\n", + "\n", + " [[1.0000]],\n", + "\n", + " [[1.0000]]]], grad_fn=)" + ] + }, + "execution_count": 34, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ "out_channel = per_chan_quant_relu(inp3)\n", "out_channel.scale * ((2**8) -1)" ] @@ -763,7 +890,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "Above, we can see that the number of elements in the quantization scale of the outputted tensor is now 3, matching those of the 3-channel tensor! Furthermore, we see that each channel has an 8-bit quantization range that matches its data distribution, which is much more ideal in terms of reducing quantization mismatch. However, it's important to note that some hardware providers don't efficiently support per-channel quantization in production, so it's best to check if your targetted hardware will allow per-channel quantization." + "We can see that the number of elements in the quantization scale of the outputted tensor is now 3, matching those of the 3-channel tensor! Furthermore, we see that each channel has an 8-bit quantization range that matches its data distribution, which is much more ideal in terms of reducing quantization mismatch. However, it's important to note that some hardware providers don't efficiently support per-channel quantization in production, so it's best to check if your targetted hardware will allow per-channel quantization." ] }, {