Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added some preliminary unit tests to the CNNs 'quantize_model' #927

Merged
merged 9 commits into from
Jul 29, 2024

Conversation

OscarSavolainenDR
Copy link
Contributor

This PR is a work in progress, and I expect to add more tests.

As of commit 66e029b, we test some aspects of layerwise and fx quantization, as well as some invalid inputs, e.g. invalid strings and zero and negative valued bit widths.

@OscarSavolainenDR OscarSavolainenDR changed the title Added some prelininary unit tests to the CNNs 'quantize_model' Added some preliminary unit tests to the CNNs 'quantize_model' Mar 30, 2024
tests/brevitas_examples/test_quantize_model.py Outdated Show resolved Hide resolved
tests/brevitas_examples/test_quantize_model.py Outdated Show resolved Hide resolved
tests/brevitas_examples/test_quantize_model.py Outdated Show resolved Hide resolved
@Giuseppe5
Copy link
Collaborator

#934

I missed this comment on the other PR. The bias bitwidth can be None, which means that the bias is not quantized. Leaving this here if it could be useful for some tests.

@OscarSavolainenDR
Copy link
Contributor Author

#934

I missed this comment on the other PR. The bias bitwidth can be None, which means that the bias is not quantized. Leaving this here if it could be useful for some tests.

I'll incorporate it!

from brevitas.nn import QuantReLU
from brevitas.quant_tensor import QuantTensor
from brevitas_examples.imagenet_classification.ptq.ptq_common import quantize_model

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Remaining ToDos, but we can add to it!

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Save for the missing minifloat tests (see below), I think we can add the other two todos and review/merge this.

Great work!

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure, shall do!

)


def test_layerwise_valid_minifloat_bit_widths(minimal_model):
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This PR is getting there, but still need to work on this. I am testing if my explicit implementation of minifloat quantization matches the under-the-hood Brevitas one, but could use some guidance on whether this is correctly implemented. I'll be hacking at it either way!

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If it can be helpful, there are two PRs (#922 and #919) where we are expanding support to minifloat, to match the level of support we have for integer quantization. This means that a minifloat QuantTensor will have the correct metadata to properly characterize it, and I think that could be helpful in the writing of the tests.

If you agree with this, I'm happy to leave the minifloat tests to another PRs after those two have been merged, as not to block this one.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That sounds good!

@@ -557,5 +557,8 @@ def check_positive_int(*args):
We check that every inputted value is positive, and an integer.
"""
for arg in args:
if not arg:
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Redundant to PR: #934

Included here to make the tests pass.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The PR has been merged so you can just rebase now

Copy link
Contributor Author

@OscarSavolainenDR OscarSavolainenDR Apr 23, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Cool, have rebased!

@OscarSavolainenDR
Copy link
Contributor Author

I think all of the ToDos (for this PR) are done, subject to whatever changes are desired!

The two tests I added in the latest commit:

  • check that the percentiles in stats calibration work as expected, and that the quantization range becomes what we expect it to be.
  • check that the MSE calibration method minimizes MSE (perturbed the qparams slightly, and none registered smaller MSE than the original ones).

@OscarSavolainenDR
Copy link
Contributor Author

Some nox checks are failing because some Python/PyTorch versions are throwing an error: Input scale required when I try to feed data through the FX-quantized model.

Some debugging:
If I use those versions (e.g. Python 3.8, PyTorch 1.9.1), it particularly fails at this line:

_0 = getattr(self, "0")(input_1);  input_1 = None  

inside the Graph Mode forward call.

getattr(self, "0") returns:

QuantConv2d(
  10, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)
  (input_quant): ActQuantProxyFromInjector(
    (_zero_hw_sentinel): StatelessBuffer()
  )
  (output_quant): ActQuantProxyFromInjector(
    (_zero_hw_sentinel): StatelessBuffer()
  )
  (weight_quant): WeightQuantProxyFromInjector(
    (_zero_hw_sentinel): StatelessBuffer()
    (tensor_quant): RescalingIntQuant(
      (int_quant): IntQuant(
        (float_to_int_impl): RoundSte()
        (tensor_clamp_impl): TensorClampSte()
        (delay_wrapper): DelayWrapper(
          (delay_impl): _NoDelay()
        )
      )
      (scaling_impl): ParameterFromStatsFromParameterScaling(
        (parameter_list_stats): _ParameterListStats(
          (first_tracked_param): _ViewParameterWrapper(
            (view_shape_impl): OverTensorView()
          )
          (stats): _Stats(
            (stats_impl): AbsMax()
          )
        )
        (stats_scaling_impl): _StatsScaling(
          (affine_rescaling): Identity()
          (restrict_clamp_scaling): _RestrictClampValue(
            (clamp_min_ste): ScalarClampMinSte()
            (restrict_value_impl): FloatRestrictValue()
          )
          (restrict_scaling_pre): Identity()
        )
        (restrict_inplace_preprocess): Identity()
      )
      (int_scaling_impl): IntScaling()
      (zero_point_impl): ZeroZeroPoint(
        (zero_point): StatelessBuffer()
      )
      (msb_clamp_bit_width_impl): BitWidthConst(
        (bit_width): StatelessBuffer()
      )
    )
  )
  (bias_quant): BiasQuantProxyFromInjector(
    (_zero_hw_sentinel): StatelessBuffer()
    (tensor_quant): PrescaledRestrictIntQuant(
      (int_quant): IntQuant(
        (float_to_int_impl): RoundSte()
        (tensor_clamp_impl): TensorClamp()
        (delay_wrapper): DelayWrapper(
          (delay_impl): _NoDelay()
        )
      )
      (msb_clamp_bit_width_impl): BitWidthConst(
        (bit_width): StatelessBuffer()
      )
      (zero_point): StatelessBuffer()
    )
  )
)

I'm AFK for next week, but will pick this up when I get back! I'll look into why the input scale is missing, and solutions for it.

@OscarSavolainenDR
Copy link
Contributor Author

OscarSavolainenDR commented May 16, 2024

Ok, I've narrowed in somewhat on what the issue is with the Python/Torch versioning. I refer to the old version that isn't working as 1.9.1 after Torch 1.9.1. It could also be Python versioning, but I assume not.

In 1.9.1, we get an error when we try to feed data through the quantized model:

RuntimeError: Input scale required

where the quantized model is given by e.g.:

    quant_model = quantize_model(
        model=fx_model,
        backend="fx",
        weight_bit_width=weight_bit_width,
        act_bit_width=act_bit_width,
        bias_bit_width=bias_bit_width if bias_bit_width > 0 else None,
        weight_quant_granularity="per_tensor",
        act_quant_percentile=99.9,
        act_quant_type="sym",
        scale_factor_type="float_scale",
        quant_format="int",
        layerwise_first_last_bit_width=5,
    )

The issue is in compute_bias_scale for one of the first layers of the graph, it returns None for the scale for 1.9.1 because the input tensor is not of type QuantTensor.

Ultimately, this is because the graph of the model is not quantizing correctly.

ipdb> quant_model.graph.print_tabular()
opcode       name             target           args        kwargs
-----------  ---------------  ---------------  ----------  --------
placeholder  input_1          input            ()          {}
call_module  input_1_quant    input_1_quant    (input_1,)  {}
call_module  _0               0                (input_1,)  {}
call_module  _1               1                (_0,)       {}
call_module  _2               2                (_1,)       {}
call_module  _3               3                (_2,)       {}
call_module  _4               4                (_3,)       {}
call_module  _5               5                (_4,)       {}
call_module  _6_input_quant   _6_input_quant   (_5,)       {}
call_module  _6               6                (_5,)       {}
call_module  _6_output_quant  _6_output_quant  (_6,)       {}
output       output           output           (_6,)       {}

_0 is not taking in the quantized tensor.

Whereas if I print out the graph in a later version of PyTorch (not 1.9.1), it uses the quantized tensor correctly:

ipdb> quant_model.graph.print_tabular()
opcode       name             target           args                kwargs
-----------  ---------------  ---------------  ------------------  --------
placeholder  input_1          input            ()                  {}
call_module  input_1_quant    input_1_quant    (input_1,)          {}
call_module  _0               0                (input_1_quant,)    {}
call_module  _1               1                (_0,)               {}
call_module  _2               2                (_1,)               {}
call_module  _3               3                (_2,)               {}
call_module  _4               4                (_3,)               {}
call_module  _5               5                (_4,)               {}
call_module  _6_input_quant   _6_input_quant   (_5,)               {}
call_module  _6               6                (_6_input_quant,)   {}
call_module  _6_output_quant  _6_output_quant  (_6,)               {}
output       output           output           (_6_output_quant,)  {}
ipdb>

The issue particularly happens in src/brevtias/graph/quantize.py quantize > inp_placeholder_handler when we try to rewrite the model.

There is a red herring: unlike later versions, 1.9.1 throws this warning (from inside InsertModuleCallAfter):

 ✘ Brevitas-3.8  oscar   tests-quantize-model -  python temp.py
> /home/oscar/Coding/OpenSource/Brevitas/src/brevitas/graph/quantize_impl.py(74)inp_placeholder_handler()
     73         ipdb.set_trace()
---> 74         model = rewriter.apply(model)
     75     return model

ipdb>         model = rewriter.apply(model)

/home/oscar/miniconda3/envs/Brevitas-3.8/lib/python3.8/site-packages/torch/fx/graph.py:606: UserWarning: Attempted to insert a call_module Node with no underlying reference in the owning GraphModule! Call GraphModule.add_submodule to add the necessary submodule
  warnings.warn("Attempted to insert a call_module Node with "
ipdb>

However, that warning is because of a bug in PyTorch, and they've since fixed it. I.e. in 1.9.1 they had:

        if (self.owning_module and
                self.owning_module.get_submodule(module_name) is not None):
            warnings.warn("Attempted to insert a call_module Node with "
                          "no underlying reference in the owning "
                          "GraphModule! Call "
                          "GraphModule.add_submodule to add the "
                          "necessary submodule")

Instead of:

        if (self.owning_module and
                self.owning_module.get_submodule(module_name) is None):
            warnings.warn("Attempted to insert a call_module Node with "
                          "no underlying reference in the owning "
                          "GraphModule! Call "
                          "GraphModule.add_submodule to add the "
                          "necessary submodule")

I.e. is not None vs is None. However, this doesn't seem relevant for the issue at hand.

The difference in the graph actually manifests here (from src/brevtias/graph/quantize.py quantize > inp_placeholder_handler > InsertModuleCallAfter > replace_all_uses_except):

        replace_all_uses_except(
            self.node,
            quant_identity_node,
            [quant_identity_node] + list(self.node_to_exclude),
        )

For some reason, in 1.9.1 the graph doesn't change, but it does in late Torch versions. The issue may be in replace_all_uses_except, or upstream in the inputs to the function, I haven't figured it out yet.

Sorry for the piecemeal update, am working on it in the background, but I leave an update just so in case I end up being busy with something else there'll be some record. Hopefully I can figure this out tomorrow!

@Giuseppe5
Copy link
Collaborator

Thanks for the update.
I tried running this branch locally but I didn't manage to do that.

I will try again using these insights. Let me know if you manage to find out more, and I'll do the same!

@OscarSavolainenDR
Copy link
Contributor Author

I haven't really been able to spend a lot of time on this. The issue seems to be rooted in a specific Torch version, and I'm wondering if it's not just simpler for Brevitas to specify that that specific Torch version isn't supported: it might legitimately be a bug in an old version of PyTorch when FX Graph mode was new.

If that's not an option I can try and debug this again.

@Giuseppe5 Giuseppe5 requested review from Giuseppe5 and removed request for Giuseppe5 July 22, 2024 09:10
@Giuseppe5
Copy link
Collaborator

I think I pushed something to dev that broke some of your tests.
I will use this PR to push some commits to your branch to fix what I broke and merge it. I hope that's ok with you.

Thanks for implementing all these tests, it is an amazing work!

@OscarSavolainenDR
Copy link
Contributor Author

I think I pushed something to dev that broke some of your tests.

I will use this PR to push some commits to your branch to fix what I broke and merge it. I hope that's ok with you.

Thanks for implementing all these tests, it is an amazing work!

Sure, no problem!

Thanks, I appreciate it! 🙂

@Giuseppe5 Giuseppe5 requested review from Giuseppe5 and removed request for Giuseppe5 July 23, 2024 13:02
@Giuseppe5 Giuseppe5 merged commit 55fd0ea into Xilinx:dev Jul 29, 2024
337 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants