-
Notifications
You must be signed in to change notification settings - Fork 197
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Feat: Support for Groupwise (MX) quantization #971
Conversation
6c7ce63
to
ee3fbdd
Compare
3e8e2ca
to
69e5e98
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe worth double-checking all of the <Datatype>QuantTensor
instantiations where the order of the arguments really matter. I found a few errors that are likely due to some original issue + copy/paste, so they may also exist outside your specific changes.
I think it's worth checking double-checking as many as you can while you're fixing the ones I found.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM!
8f26c84
to
8ed14c6
Compare
This implements:
Missing:
Features
Dynamic expansion of contracted groupwise tensors
Compared to Int/Float QuantTensor, the main difference of their groupwise equivalent is that
value
,scale
, andzero_point
are not direct attributes anymore but properties. The new attributes arevalue_
,scale_
, andzero_point_
.The reason for this is shaping. When quantizing a tensor with shapes [O, I], where O is output channel and I is input channel, with groupsize k, groupwise quantization is normally represented as follow:
The alternative to this representation is to have all three tensors with shapes [O,I], with a massive increase in memory utilization, especially with QAT + gradients.
The underscored attributes will have the compressed shapes, while the properties (non-underscored naming) will dynamically compute the expanded version of the property. This means:
Internally, the quantization will happen with groupwise shaping (e.g., contracted). For this reason, there is a preliminary view applied to the tensors before everything goes in
tensor_quant
.Deprecation of scaling_per_output_channel
Another important change of this PR is the
deprecation
(i.e., still usable but not recommended anymore) of the flag scaling_per_output_channel, in favor of a ternary flagscaling_per_output
that can be TENSOR/CHANNEL/GROUP.Lots of work has gone into maintaining retro-compatibility with the existing binary flag, so that everything will still work as intended.
One thing that is still up for discussion is how to handle shared quantizers for groupwise quantization.
Automatic OCP definition
When instantiating a OCP FP quantizer, NaN/INF encoding are automatically defined through dep inj based on the bitwidths. This is to avoid to have manually to define all possible minifloat quantizers needed for MX (e.g. MX FP8, FP6, etc.)
Groupwise Default quantizers
Lots of possible quantizers could be defined by default. For integer (not much of a problem) we have groupwise with float scaling (defined in examples if I'm not mistaken) and MX (groupwise with Po2 scale) defined in the brevitas source.
For float it gets a bit more complicated with inf/nan. The solution adopted is the following:
I only created quantizers with bitwidth equal to 8, e4m3. Overriding the bit_width, mantissa_bit_width and exponent_bit_width will produce OCP compliant MX quantizers (thanks to the change in Automatic OCP definition)
Example changes
No longer separated flags for ocp/fnuz standard. Now the user should pass float_e4m3 for general fp8 no standard, float_ocp_e4m3 for OCP, float_fnuz_e4m3 for FNUZ.