-
Notifications
You must be signed in to change notification settings - Fork 197
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add tutorial examples of per-channel quantization #867
Add tutorial examples of per-channel quantization #867
Conversation
The explanation is correct. I would just expand a bit on the inner workings behind |
80f5cec
to
4837b17
Compare
@Giuseppe5 I have added some details, let me know if I should add some more! A lot of the permutation code relies on boilerplate stuff, I wasn't sure how detailed of a tracing of the boilerplate I should include in the tutorial. I have all my notes so can get quite granular if we want to. |
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"We can see that the per-tensor scale parameter has calibrated itself to provide a full quantization range of 3, matching that of the most extreme channel. \n", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Most extreme -> highest? If I'm understanding the meaning correctly
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yep, I'll make it more clear!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I changed it to indicate the channel with largest dynamic range, instead of most "extreme".
"source": [ | ||
"Next, we initialise a new `QuantRelU` instance, but this time we specify that we desire per-channel quantization i.e. `scaling_per_output_channel=True`. This will implictly call `scaling_stats_input_view_shape_impl`, defined [here](https://github.com/Xilinx/brevitas/blob/200456825f3b4b8db414f2b25b64311f82d3991a/src/brevitas/quant/solver/common.py#L184), and will change the `QuantReLU` from using a per-tensor view when gathering stats to a per output channel view ([`OverOutputChannelView`](https://github.com/Xilinx/brevitas/blob/200456825f3b4b8db414f2b25b64311f82d3991a/src/brevitas/core/function_wrapper/shape.py#L52)). This simply permutes the tensor into a 2D tensor, with dim 0 equal to the number of output channels.\n", | ||
"\n", | ||
"To accomplish this, we also need to give it some extra information: `scaling_stats_permute_dims` and `per_channel_broadcastable_shape`. `scaling_stats_permute_dims` is responsible for defining how we do the permutation. `per_channel_broadcastable_shape` simply represents what the dimensions of the quantization parameters will be, i.e. there should be one parameter per output channel.\n", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
per_channel_broadcastable_shape
is necessary to understand along which dimensions the scale factor has to be broadcasted, so that the scale factor values are applied along the channel dimensions of the input.
By default, PyTorch will broadcast along the first rightmost dimension for which the shapes of the two tensors match. To make sure that we apply the scale factor in our desired dimension, we need to tell PyTorch how to correctly broadcast the scale factors, so the scale factor has to have as many dimensions as the input tensors, with all the shapes equal to 1 apart from the channel dimension.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Awesome, will add more context!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I took your description, as it was a very good explanation haha, and added it to the PR!
Could you please rebase this over dev? Thanks! |
I'm unusually busy until Wednesday, but will rebase and adjust the PR along the lines of the comments then! |
… per-channel quantization
9208318
to
44b0bc1
Compare
Purpose of this PR
To expand QuantAct notebook with scaling per output channel examples, as referenced in issue #862
Changes made in this PR
We add the example of a
QuantReLU
, and show how per-channel quantization can enable one to better match the dynamic range of any individual channel.Note: I am not sure my explanation of
per_channel_broadcastable_shape
andscaling_stats_permute_dims
are correct: