Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

RuntimeError: Module <class 'brevitas.proxy.float_runtime_quant.ActFloatQuantProxyFromInjector'> not supported for export #1091

Open
1 of 3 tasks
jcollyer-turing opened this issue Nov 13, 2024 · 12 comments
Labels
bug Something isn't working

Comments

@jcollyer-turing
Copy link

Describe the bug

Attempting to save PTQ TorchVision models using the ptq_benchmark_torchvision.py script after amending the script to save the model using export_torch_qcdq as a final step.

The traceback is below:

['Traceback (most recent call last):\n', '  File "/Users/jcollyer/Documents/git/brevitas-datagen/ptq_benchmark_torchvision.py", line 343, in ptq_torchvision_models\n    export_torch_qcdq(quant_model.to(\'cpu\'), torch.randn(1, 3, 244, 244).to(\'cpu\'), export_path = f"{folder}/{uuid}")\n', '  File "/Users/jcollyer/Documents/git/brevitas-datagen/.venv/lib/python3.10/site-packages/brevitas/export/__init__.py", line 29, in export_torch_qcdq\n    return TorchQCDQManager.export(*args, **kwargs)\n', '  File "/Users/jcollyer/Documents/git/brevitas-datagen/.venv/lib/python3.10/site-packages/brevitas/export/torch/qcdq/manager.py", line 56, in export\n    traced_module = cls.jit_inference_trace(module, args, export_path)\n', '  File "/Users/jcollyer/Documents/git/brevitas-datagen/.venv/lib/python3.10/site-packages/brevitas/export/manager.py", line 209, in jit_inference_trace\n    module.apply(cls.set_export_handler)\n', '  File "/Users/jcollyer/Documents/git/brevitas-datagen/.venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 895, in apply\n    module.apply(fn)\n', '  File "/Users/jcollyer/Documents/git/brevitas-datagen/.venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 895, in apply\n    module.apply(fn)\n', '  File "/Users/jcollyer/Documents/git/brevitas-datagen/.venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 896, in apply\n    fn(self)\n', '  File "/Users/jcollyer/Documents/git/brevitas-datagen/.venv/lib/python3.10/site-packages/brevitas/export/torch/qcdq/manager.py", line 39, in set_export_handler\n    _set_proxy_export_handler(cls, module)\n', '  File "/Users/jcollyer/Documents/git/brevitas-datagen/.venv/lib/python3.10/site-packages/brevitas/export/manager.py", line 127, in _set_proxy_export_handler\n    _set_export_handler(manager_cls, module, QuantProxyProtocol, no_inheritance=True)\n', '  File "/Users/jcollyer/Documents/git/brevitas-datagen/.venv/lib/python3.10/site-packages/brevitas/export/manager.py", line 115, in _set_export_handler\n    raise RuntimeError(f"Module {module.__class__} not supported for export.")\n', "RuntimeError: Module <class 'brevitas.proxy.float_runtime_quant.ActFloatQuantProxyFromInjector'> not supported for export.\n"]

Reproducibility

  • Can be reproduced consistently.
  • Difficult to reproduce.
  • Unable to reproduce.

To Reproduce

Steps to reproduce the behavior. For example:

  1. Add export_torch_qcdq() and associated args to the end of the ptq_benchmark_torchvision.py
  2. Execute the following command (with calibration and validation imagenet sets):
python ptq_benchmark_torchvision.py 0 --calibration-dir <path-to-calib> --validation-dir <path-to-val> \
--quant_format float \
--scale_factor_type float_scale \
--weight_bit_width 2 3 4 5 6 7 8 \
--act_bit_width 2 3 4 5 6 7 8 \
--weight_mantissa_bit_width 1 2 3 4 5 6 \
--weight_exponent_bit_width 1 2 3 4 5 6 \
--act_mantissa_bit_width 1 2 3 4 5 6 \
--act_exponent_bit_width 1 2 3 4 5 6 \
--bias_bit_width None \
--weight_quant_granularity per_channel per_tensor \
--act_quant_type sym \
--weight_param_method stats \
--act_param_method mse \
--bias_corr True \
--graph_eq_iterations 20 \
--graph_eq_merge_bias True \
--act_equalization layerwise \
--learned_round False \
--gptq False \
--gpxq_act_order False \
--gpfq False \
--gpfq_p None \
--gpfa2q False \
--accumulator_bit_width None \
--uint_sym_act_for_unsigned_values False \
--act_quant_percentile None

Expected behavior
The model should be saved.

please complete the following information:

If known:

  • Brevitas version: 0.11.0
  • PyTorch version: 2.4.1
  • Operating System / platform: MacOS M2 (running using CPU not MPS)

Additional context
I have tired torch.save() natively and it doesn't work either.

@jcollyer-turing jcollyer-turing added the bug Something isn't working label Nov 13, 2024
@Giuseppe5
Copy link
Collaborator

Hello,
Thanks for pointing this out.
At the moment we only support ONNX for FP8 export.

Unfortunately, torch quant/dequant op that we normally use to map quantization (see https://pytorch.org/docs/stable/generated/torch.quantize_per_tensor.html) don't support minifloat/fp8 quantization.
We could have a work around with custom ops but that requires torch 2.0+ to be done properly, and we will probably move in that direction in the future versions of Brevitas once we deprecate older pytorch versions.

With respect to ONNX, we can only export FP8 and not lower bit-width for similar reasons, meaning that ONNX only supports a few types of fp8 but it doesn't allow to define a custom minifloat data type with arbitrary mantissa and exponent bit-width (or other configurations).

If you share with us what is your goal with this export flow, maybe we can guide you towards a custom solution while we build a generic export flow.

@jcollyer-turing
Copy link
Author

Hey @Giuseppe5 - Thank you for the speedy reply.

That makes a lot of sense and sounds like something that is a bit tricky to do generically given the complexity of the different ops to support it correctly.

I am running a series of experiments to look at the effect quantisation has on adversarial robustness and want to quantise the models on a HPC style environment, save them somewhere sensible and then evaluate the robustness downstream. It doesn't have to be a valid torch/onnx model to load into another library necessarily. I tried to use dill which I have had success with in the past but it too seems to have it's own issues (Struggling with the multi-inhertiance/mixins).

I'd appreciate any ideas you may have to save the models to be loaded again later.

@Giuseppe5
Copy link
Collaborator

We recently merged in dev the possibility to export minifloat to QONNX (QONNX ref and QONNX Minifloat ref).

This representation allows us to represent explicitly all the various minifloat configuration that Brevitas can simulate.
With respect to QONNX, it generally provides reference implementations of its kernels, similar to what happens in ONNX, allowing you to execute your ONNX graph for numerical correctness. However, I think for the moment QONNX floatquant is only an interface and they're working on the reference implementation which means that you can't consider this a "valid" ONNX for the moment. I will let @maltanar comment on this in case I'm saying something that is not correct or precise.

Would this work for you? Do you need a torch-based export for this task or do you just need to represent (even with custom ops) the computational graph?

@jcollyer-turing
Copy link
Author

This sounds like something that could be very useful for my use case but does not cover everything. Ideally, it would be torch based for easy integration into torch based adversarial attack libraries (especially the attacks that require gradient information).

That being said, the features of qonnx and the ability to calculate inference statistics is actually another interesting set of data to collect. If this is the easiest way for me to save the model and conduct down stream inference, I can definitely work with this!

@Giuseppe5
Copy link
Collaborator

Another option is to just save the state dict with model.state_dict() (all quantization parameters would be saved as well), and then re-generate and quantize the model downstream, and then re-load the checkpoint. This would require you to carry Brevitas as dependency downstream as well.

Talking with @nickfraser, we might try a few ideas on how to make Brevitas compatible with serialization. It should be a quick experiment, and I'll keep the issue updated if it works out so that you can replicate it while we go through all the PR process.

@jcollyer-turing
Copy link
Author

That sounds wonderful. Thank you!

Re. carrying brevitas as a dependency, that is definitely an option and something I could do but introduces a different challenge (sorry!) - brevitas pins numpy<=1.26.4 which clashes with other req's elsewhere in my project.

@Giuseppe5
Copy link
Collaborator

Giuseppe5 commented Nov 15, 2024

To be fair, that pin might be outdated.

When numpy 2.0 was released, a lot of things broke and our decision was to wait until it stabilized a bit before we tried to unpin it. We don't use any specific numpy functions that are no longer available in 2.0 or things like that. I will open a PR with the unpinned version to see how many things break.

Fingers crossed : #1093

@jcollyer-turing
Copy link
Author

Sounds great! Thank you 👍

@Giuseppe5
Copy link
Collaborator

I think there might be a reasonable optimism that Brevitas doesn't have anymore any hard-requirement on numpy.

Having said that, I notice that torch keeps installing numpy 1.26 even if I don't specify any particular version.

I tried manually upgrading it (with torch 2.4) and everything seems to work fine. With an older version of torch there seems to be conflicts.

Let me know what you see from your side. I will still look into serialization but this could be the fastest way to get there IMHO.

@jcollyer-turing
Copy link
Author

What is the best branch to install/setup to test this? and any luck with the serialisation experimentation?

@Giuseppe5
Copy link
Collaborator

Giuseppe5 commented Nov 18, 2024

What is the best branch to install/setup to test this?

After install brevitas, just update numpy to whatever version you need and everything should work. Be sure to have
onnxruntime updated as well to the latest version.

and any luck with the serialisation experimentation?

I opened a draft PR with an example #1096

It seems to work locally but there could be side effects when using the pickled model downstream.
We always assume that the quant_injector is present, so we use it even though we could store the value at init time and then discard the quant_injector completely.

I believe all the issues can be fixed relatively easily.
Basically what could happen downstream is that you're trying to access an attribute in the quant_injector and it's not there. If that's the case, the solution is to modify the proxy to store that attribute at init time before you generate the pickle and then you're good to go. If you face any such issues, feel free to report here and I can add fixes in the draft PR.

Wouldn't you need to carry on Brevitas downstream also with pickle?

@jcollyer-turing
Copy link
Author

I would need to carry brevitas as a dependency if it was pickle but I would be able to save the model in a torch compatible way, allowing me to get gradient info which I think if I go to qonnx, that wouldn't be possible.

I will pull the draft PR version and have an experiment this week and report back! (Probably close of play Thursday!)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

2 participants