diff --git a/docs/tutorials/onnx_export.ipynb b/docs/tutorials/onnx_export.ipynb index 5d9a34228..e0debd571 100644 --- a/docs/tutorials/onnx_export.ipynb +++ b/docs/tutorials/onnx_export.ipynb @@ -424,318 +424,6 @@ "Similarly, the output of the QuantReLU is clipped between 0 and 15, since in this case we are doing an unsigned 4 bit quantization." ] }, - { - "cell_type": "markdown", - "metadata": { - "collapsed": false, - "pycharm": { - "name": "#%% md\n" - } - }, - "source": [ - "## QOps Export\n", - "Another supported style for exporting quantized operation in ONNX is represented by QOps.\n", - "\n", - "Compared to QCDQ, where it is possible to re-use standard floating point nodes (e.g., GEMM or Conv2d) preceeded by QCDQ nodes, with QOps the entire layer is replaced by its quantized counterpart.\n", - "\n", - "Opposite to what happens with QCDQ, all elements of the computation in this case have to be quantized: Input, Weight, Bias (if present), and Output tensors.\n", - "\n", - "This introduces some contraints on how we define our quantized layers through Brevitas.\n", - "\n" - ] - }, - { - "cell_type": "code", - "execution_count": 8, - "metadata": { - "collapsed": false, - "pycharm": { - "name": "#%%\n" - } - }, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "c:\\users\\alessand\\documents\\brevitas\\src\\brevitas\\export\\onnx\\standard\\manager.py:23: UserWarning: ONNX opset version set to 13, override with opset_version=\n", - " warnings.warn(f\"ONNX opset version set to {DEFAULT_OPSET}, override with {ka}=\")\n" - ] - }, - { - "data": { - "text/plain": [ - "ir_version: 7\n", - "producer_name: \"pytorch\"\n", - "producer_version: \"1.13.1\"\n", - "graph {\n", - " node {\n", - " output: \"/input_quant/export_handler/Constant_output_0\"\n", - " name: \"/input_quant/export_handler/Constant\"\n", - " op_type: \"Constant\"\n", - " attribute {\n", - " name: \"value\"\n", - " t {\n", - " data_type: 1\n", - " raw_data: \"\\000\\000\\000<\"\n", - " }\n", - " type: TENSOR\n", - " }\n", - " }\n", - " node {\n", - " output: \"/input_quant/export_handler/Constant_1_output_0\"\n", - " name: \"/input_quant/export_handler/Constant_1\"\n", - " op_type: \"Constant\"\n", - " attribute {\n", - " name: \"value\"\n", - " t {\n", - " data_type: 3\n", - " raw_data: \"\\000\"\n", - " }\n", - " type: TENSOR\n", - " }\n", - " }\n", - " node {\n", - " input: \"inp.1\"\n", - " input: \"/input_quant/export_handler/Constant_output_0\"\n", - " input: \"/input_quant/export_handler/Constant_1_output_0\"\n", - " output: \"/input_quant/export_handler/QuantizeLinear_output_0\"\n", - " name: \"/input_quant/export_handler/QuantizeLinear\"\n", - " op_type: \"QuantizeLinear\"\n", - " }\n", - " node {\n", - " output: \"/linear/export_handler/Constant_output_0\"\n", - " name: \"/linear/export_handler/Constant\"\n", - " op_type: \"Constant\"\n", - " attribute {\n", - " name: \"value\"\n", - " t {\n", - " dims: 128\n", - " dims: 3\n", - " dims: 3\n", - " dims: 3\n", - " data_type: 3\n", - " raw_data: \"\\374\\372\\376\\374\\005\\000\\375\\374\\004\\375\\373\\373\\375\\007\\376\\374\\377\\000\\000\\373\\373\\004\\005\\371\\003\\375\\004\\373\\004\\374\\000\\006\\002\\003\\003\\005\\004\\377\\005\\000\\373\\376\\375\\376\\002\\376\\004\\377\\003\\005\\375\\371\\006\\373\\003\\007\\377\\374\\005\\375\\375\\006\\375\\377\\374\\001\\005\\371\\006\\005\\007\\376\\376\\372\\376\\004\\001\\374\\002\\373\\373\\376\\002\\376\\375\\377\\001\\376\\006\\371\\002\\000\\004\\005\\005\\000\\004\\373\\004\\002\\003\\000\\374\\376\\005\\000\\004\\372\\004\\000\\373\\000\\006\\377\\002\\005\\004\\005\\374\\000\\007\\377\\374\\371\\373\\007\\004\\376\\372\\001\\005\\001\\372\\377\\003\\001\\375\\006\\372\\377\\006\\003\\006\\004\\001\\004\\372\\005\\006\\003\\376\\373\\374\\375\\376\\005\\000\\004\\377\\372\\373\\000\\007\\377\\373\\003\\373\\376\\374\\374\\377\\375\\377\\003\\372\\005\\004\\007\\003\\375\\377\\001\\007\\377\\373\\374\\000\\377\\376\\374\\373\\377\\373\\375\\003\\004\\004\\376\\004\\377\\375\\003\\003\\377\\004\\000\\005\\004\\000\\372\\005\\007\\003\\004\\377\\373\\003\\371\\373\\002\\377\\006\\006\\007\\377\\376\\375\\002\\006\\005\\004\\374\\002\\000\\373\\004\\002\\002\\374\\371\\372\\371\\375\\001\\004\\000\\006\\376\\377\\002\\000\\372\\001\\001\\375\\007\\376\\005\\001\\373\\003\\374\\005\\003\\007\\005\\372\\004\\006\\375\\005\\003\\001\\373\\376\\374\\002\\376\\377\\376\\000\\006\\001\\375\\376\\377\\374\\000\\005\\002\\005\\006\\371\\375\\005\\375\\376\\374\\004\\001\\003\\001\\372\\005\\007\\371\\005\\000\\372\\001\\001\\371\\007\\374\\372\\373\\373\\372\\376\\004\\000\\002\\375\\376\\000\\004\\003\\003\\375\\003\\001\\376\\006\\001\\000\\372\\374\\376\\373\\002\\002\\004\\372\\377\\374\\005\\000\\001\\005\\005\\374\\007\\003\\377\\377\\000\\007\\002\\377\\377\\377\\374\\001\\001\\376\\000\\377\\373\\001\\004\\376\\003\\000\\007\\005\\000\\374\\372\\376\\005\\003\\003\\004\\372\\375\\372\\377\\006\\376\\374\\007\\373\\002\\374\\003\\377\\374\\002\\007\\373\\004\\376\\004\\004\\003\\005\\373\\003\\005\\376\\001\\000\\002\\371\\376\\000\\374\\377\\372\\375\\005\\373\\002\\373\\373\\377\\004\\375\\006\\377\\005\\005\\002\\375\\375\\003\\376\\376\\006\\002\\371\\000\\002\\373\\000\\006\\002\\372\\372\\006\\374\\372\\004\\006\\004\\000\\003\\001\\377\\371\\376\\003\\003\\373\\005\\000\\001\\003\\004\\001\\005\\001\\004\\373\\373\\372\\002\\371\\375\\372\\004\\377\\005\\375\\376\\374\\375\\003\\372\\001\\373\\372\\376\\005\\003\\372\\004\\373\\004\\374\\374\\376\\376\\377\\371\\375\\004\\375\\377\\376\\007\\004\\372\\000\\007\\372\\006\\002\\006\\001\\006\\372\\004\\004\\003\\002\\375\\006\\374\\002\\001\\001\\000\\376\\376\\006\\373\\374\\002\\372\\005\\374\\004\\004\\001\\374\\004\\377\\373\\002\\376\\001\\377\\003\\377\\007\\004\\372\\371\\002\\375\\377\\373\\002\\376\\375\\377\\006\\001\\001\\000\\374\\001\\006\\004\\371\\377\\375\\374\\377\\376\\003\\372\\373\\002\\005\\374\\000\\002\\004\\372\\004\\372\\003\\006\\375\\003\\377\\376\\000\\377\\374\\006\\377\\374\\375\\377\\373\\376\\372\\375\\006\\004\\371\\372\\374\\375\\004\\002\\372\\376\\001\\001\\002\\373\\000\\003\\000\\371\\001\\003\\377\\376\\371\\376\\004\\000\\003\\376\\002\\006\\004\\372\\007\\005\\004\\376\\000\\007\\372\\003\\002\\005\\005\\004\\372\\002\\377\\006\\002\\371\\375\\375\\372\\376\\005\\003\\000\\002\\371\\005\\372\\373\\377\\371\\376\\005\\374\\377\\007\\003\\001\\376\\006\\376\\001\\374\\374\\001\\373\\006\\376\\376\\001\\372\\377\\003\\006\\372\\373\\003\\377\\376\\000\\377\\373\\004\\372\\371\\376\\002\\004\\004\\006\\001\\372\\001\\376\\005\\001\\000\\000\\007\\002\\375\\002\\375\\375\\006\\007\\375\\375\\002\\006\\371\\375\\002\\377\\002\\377\\000\\373\\001\\372\\372\\001\\377\\372\\001\\002\\000\\375\\373\\377\\372\\001\\371\\372\\007\\372\\001\\377\\372\\004\\376\\376\\374\\375\\373\\373\\005\\371\\375\\006\\005\\007\\374\\373\\005\\372\\000\\001\\374\\005\\000\\002\\373\\004\\001\\004\\006\\002\\003\\373\\376\\372\\374\\003\\375\\005\\000\\005\\373\\001\\375\\374\\002\\002\\000\\373\\374\\003\\005\\376\\003\\374\\374\\373\\374\\000\\004\\371\\375\\372\\003\\375\\005\\005\\006\\007\\371\\003\\372\\003\\375\\004\\374\\001\\376\\373\\000\\004\\003\\001\\003\\372\\377\\003\\004\\374\\000\\376\\002\\377\\001\\374\\376\\002\\002\\001\\005\\375\\373\\001\\372\\000\\007\\004\\007\\006\\006\\000\\004\\004\\006\\000\\377\\375\\000\\002\\374\\376\\374\\006\\373\\377\\000\\374\\006\\373\\005\\001\\001\\006\\005\\373\\373\\001\\003\\371\\006\\372\\003\\005\\372\\003\\005\\006\\005\\006\\001\\001\\377\\372\\001\\003\\372\\005\\002\\376\\377\\373\\005\\376\\375\\373\\005\\004\\007\\001\\000\\002\\001\\374\\004\\003\\377\\004\\372\\373\\373\\007\\375\\002\\002\\377\\373\\007\\001\\004\\374\\007\\376\\000\\003\\376\\006\\371\\377\\003\\376\\003\\004\\375\\006\\376\\371\\373\\373\\004\\000\\005\\377\\372\\372\\377\\004\\002\\001\\000\\005\\372\\004\\377\\376\\375\\001\\005\\375\\375\\000\\003\\006\\374\\004\\377\\004\\006\\000\\374\\003\\000\\005\\376\\372\\371\\371\\000\\374\\372\\006\\004\\006\\376\\377\\001\\377\\376\\373\\374\\000\\003\\004\\372\\375\\000\\006\\002\\374\\377\\004\\372\\371\\373\\001\\006\\377\\003\\007\\377\\373\\000\\371\\002\\376\\003\\002\\377\\006\\006\\006\\371\\006\\373\\377\\006\\000\\374\\375\\376\\001\\376\\007\\003\\007\\376\\004\\001\\005\\003\\375\\372\\003\\004\\376\\374\\005\\372\\372\\000\\006\\377\\003\\000\\002\\001\\003\\375\\000\\004\\375\\372\\000\\001\\000\\000\\002\\000\\004\\005\\377\\005\\007\\376\\372\\001\\374\\006\\002\\376\\002\\005\\374\\372\\000\\375\\372\\372\\000\\001\\000\\377\\007\\376\\000\\374\\375\\000\\373\\003\\001\\006\\003\\376\\007\\374\\376\\374\\005\\371\\372\\001\\374\\374\\002\\375\\004\\001\\002\\002\\376\\003\\373\\000\\375\\375\\005\\373\\002\\376\\371\\006\\004\\001\\001\\371\\376\\005\\377\\375\\005\\003\\374\\375\\002\\373\\376\\001\\002\\001\\007\\002\\004\\376\\375\\377\\376\\004\\373\\000\\001\\375\\377\\372\\376\\002\\001\\375\\006\\005\\006\\004\\376\\004\\004\\001\\001\\377\\004\\006\\003\\001\\005\\006\\001\\377\\000\\000\\000\\372\\004\\375\\004\\377\\377\\006\\377\\373\\003\\375\\373\\004\\005\\377\\006\\376\\374\\374\\371\\376\\003\\376\\374\\001\\373\\001\\375\\001\\376\\376\\000\\376\\371\\376\\377\\372\\373\\374\\374\\375\\376\\003\\376\\002\\372\\375\\375\\007\\377\\373\\377\\006\\376\\377\\373\\002\\001\\000\\005\\004\\006\\376\\001\\373\\372\\371\\001\\371\\001\\373\\374\\001\\375\\373\\003\\375\\373\\005\\373\\004\\377\\002\\000\\002\\006\\001\\373\\375\\005\\376\\004\\000\\376\\003\\007\\000\\377\\003\\004\\005\\376\\004\\003\\004\\006\\006\\006\\371\\002\\374\\375\\003\\375\\000\\375\\377\\004\\003\\374\\373\\004\\005\\375\\003\\376\\001\\001\\374\\003\\377\\004\\006\\003\\377\\001\\003\\377\\377\\371\\000\\374\\003\\373\\374\\006\\372\\372\\006\\004\\375\\375\\373\\004\\005\\001\\373\\371\\377\\376\\004\\005\\373\\374\\005\\000\\376\\001\\002\\003\\006\\006\\374\\375\\374\\377\\001\\373\\003\\004\\372\\004\\375\\001\\371\\004\\002\\001\\376\\377\\005\\000\\376\\376\\372\\005\\000\\376\\004\\371\\000\\377\\377\\377\\373\\377\\001\\004\\002\\374\\373\\000\\374\\377\\373\\373\\374\\005\\006\\374\\003\\373\\000\\006\\001\\003\\371\\373\\006\\374\\005\\005\\006\\371\\002\\005\\373\\000\\377\\377\\003\\005\\003\\004\\376\\372\\000\\005\\004\\371\\372\\376\\371\\005\\375\\000\\001\\001\\000\\006\\005\\006\\002\\002\\000\\003\\006\\374\\005\\000\\373\\372\\376\\002\\372\\006\\003\\375\\373\\373\\375\\002\\004\\001\\007\\373\\377\\004\\005\\004\\375\\005\\376\\376\\004\\003\\000\\004\\376\\006\\001\\376\\003\\376\\007\\006\\002\\376\\001\\376\\006\\371\\006\\375\\375\\004\\003\\006\\377\\374\\004\\003\\375\\372\\374\\375\\006\\377\\000\\004\\373\\002\\006\\373\\377\\374\\372\\000\\000\\376\\006\\373\\372\\004\\001\\006\\003\\377\\006\\371\\006\\006\\004\\004\\005\\371\\376\\001\\003\\372\\005\\001\\002\\373\\001\\372\\375\\004\\372\\006\\373\\375\\001\\003\\375\\377\\003\\372\\374\\374\\373\\006\\005\\373\\002\\000\\004\\376\\377\\004\\374\\006\\374\\006\\373\\004\\375\\373\\006\\376\\006\\002\\002\\377\\372\\372\\005\\004\\375\\000\\002\\374\\002\\376\\007\\373\\376\\371\\377\\005\\376\\006\\002\\006\\376\\004\\372\\000\\005\\002\\002\\003\\006\\004\\377\\007\\374\\372\\372\\002\\375\\377\\001\\375\\005\\374\\377\\003\\007\\002\\005\\006\\006\\000\\001\\004\\000\\376\\371\\001\\000\\005\\004\\375\\372\\375\\004\\007\\371\\374\\002\\005\\000\\002\\002\\004\\004\\007\\005\\006\\373\\006\\004\\002\\005\\004\\376\\375\\000\\372\\004\\377\\003\\374\\003\\376\\006\\376\\006\\006\\005\\002\\006\\007\\002\\372\\372\\377\\373\\004\\373\\375\\004\\004\\003\\006\\002\\000\\002\\376\\000\\000\\005\\006\\005\\372\\003\\372\\006\\001\\007\\372\\002\\372\\004\\001\\005\\002\\005\\374\\372\\372\\002\\372\\001\\377\\002\\006\\005\\000\\005\\372\\375\\007\\377\\375\\004\\005\\003\\372\\004\\005\\376\\373\\001\\372\\003\\371\\371\\374\\005\\002\\005\\374\\377\\004\\002\\376\\004\\373\\377\\377\\377\\001\\005\\372\\003\\373\\375\\006\\374\\007\\376\\372\\006\\005\\371\\377\\005\\001\\003\\005\\002\\006\\003\\001\\377\\374\\004\\376\\374\\375\\376\\001\\001\\001\\004\\007\\007\\000\\005\\001\\376\\003\\376\\000\\000\\001\\001\\375\\371\\006\\002\\001\\373\\000\\377\\007\\004\\002\\374\\000\\001\\377\\003\\374\\003\\007\\373\\371\\373\\001\\005\\373\\372\\005\\373\\375\\005\\006\\372\\000\\005\\007\\003\\003\\377\\005\\006\\004\\374\\372\\375\\003\\004\\000\\005\\376\\374\\374\\375\\375\\377\\372\\000\\004\\002\\005\\002\\000\\374\\376\\373\\373\\376\\002\\374\\000\\376\\373\\000\\371\\373\\372\\006\\000\\376\\002\\375\\376\\005\\372\\004\\376\\375\\005\\006\\006\\004\\003\\002\\002\\002\\002\\375\\006\\377\\000\\004\\375\\004\\007\\004\\005\\372\\374\\004\\377\\003\\377\\000\\375\\372\\374\\372\\000\\004\\007\\002\\007\\372\\376\\004\\371\\375\\001\\001\\007\\003\\000\\004\\373\\001\\001\\376\\002\\377\\377\\006\\002\\003\\373\\373\\004\\372\\372\\376\\372\\002\\002\\002\\373\\001\\375\\374\\000\\004\\003\\376\\003\\376\\002\\373\\374\\003\\372\\371\\001\\375\\004\\371\\374\\004\\005\\002\\374\\371\\001\\373\\377\\374\\006\\373\\006\\000\\005\\005\\006\\006\\002\\375\\002\\001\\001\\005\\375\\000\\372\\371\\003\\004\\375\\376\\003\\377\\374\\005\\007\\007\\377\\374\\375\\374\\376\\373\\003\\002\\002\\374\\377\\373\\004\\375\\372\\374\\003\\374\\005\\376\\002\\373\\376\\006\\005\\374\\002\\371\\005\\004\\001\\373\\000\\377\\374\\003\\000\\001\\001\\003\\372\\005\\001\\371\\371\\000\\375\\001\\375\\372\\374\\003\\373\\376\\001\\371\\006\\005\\004\\377\\004\\376\\377\\377\\003\\373\\001\\372\\376\\006\\372\\372\\005\\374\\001\\374\\004\\001\\004\\375\\002\\002\\373\\006\\000\\001\\002\\377\\371\\005\\005\\374\\374\\006\\003\\001\\002\\001\\374\\377\\372\\000\\377\\374\\373\\371\\007\\003\\375\\373\\374\\373\\374\\005\\004\\005\\006\\002\\374\\000\\372\\001\\376\\002\\373\\371\\372\\374\\374\\377\\005\\375\\371\\002\\374\\374\\005\\377\\007\\004\\376\\007\\373\\372\\007\\007\\377\\004\\002\\002\\007\\377\\375\\002\\005\\006\\003\\002\\006\\376\\003\\004\\003\\000\\371\\002\\002\\374\\006\\373\\005\\003\\003\\002\\003\\376\\002\\004\\377\\377\\371\\007\\001\\373\\376\\003\\002\\007\\376\\002\\005\\004\\374\\003\\377\\374\\003\\007\\004\\377\\002\\001\\003\\005\\373\\377\\374\\002\\377\\004\\000\\000\\005\\007\\002\\003\\376\\371\\377\\006\\372\\372\\002\\372\\371\\375\\000\\376\\005\\372\\000\\373\\372\\007\\002\\001\\372\\374\\375\\005\\005\\004\\001\\002\\002\\006\\372\\001\\007\\373\\375\\000\\372\\005\\003\\000\\375\\377\\001\\003\\006\\000\\376\\374\\002\\375\\375\\003\\001\\007\\376\\377\\003\\000\\005\\376\\374\\005\\373\\004\\377\\000\\375\\002\\005\\001\\001\\000\\001\\375\\374\\001\\006\\372\\375\\376\\372\\371\\001\\372\\005\\004\\376\\373\\006\\005\\375\\006\\377\\001\\001\\000\\006\\000\\006\\007\\003\\372\\004\\375\\373\\372\\372\\000\\374\\001\\006\\007\\376\\374\\371\\373\\372\\375\\003\\377\\372\\377\\005\\002\\006\\372\\006\\004\\005\\000\\376\\007\\003\\372\\004\\377\\006\\001\\373\\375\\374\\373\\373\\004\\004\\375\\373\\005\\376\\000\\001\\375\\371\\372\\005\\375\\000\\002\\372\\003\\004\\372\\003\\374\\005\\002\\374\\377\\001\\005\\376\\377\\374\\376\\005\\376\\372\\003\\373\\372\\006\\372\\377\\373\\006\\372\\004\\006\\373\\005\\375\\375\\007\\374\\005\\002\\374\\374\\002\\002\\377\\375\\376\\372\\005\\375\\371\\003\\005\\003\\372\\377\\375\\372\\002\\005\\000\\006\\372\\005\\371\\376\\000\\001\\377\\004\\004\\006\\000\\377\\007\\002\\006\\000\\371\\375\\374\\374\\001\\373\\371\\002\\376\\002\\000\\374\\006\\001\\374\\006\\005\\001\\003\\376\\003\\374\\003\\374\\002\\007\\373\\002\\004\\007\\005\\374\\376\\372\\372\\001\\371\\002\\005\\373\\376\\006\\375\\372\\376\\004\\003\\001\\004\\376\\002\\373\\006\\006\\371\\372\\003\\004\\006\\375\\004\\007\\371\\000\\000\\001\\000\\374\\001\\006\\002\\006\\002\\000\\002\\373\\372\\372\\000\\372\\005\\006\\004\\000\\376\\372\\373\\006\\007\\373\\006\\373\\377\\003\\375\\373\\001\\377\\001\\002\\376\\003\\373\\002\\376\\007\\371\\371\\374\\006\\377\\001\\002\\005\\001\\376\\375\\000\\377\\371\\005\\372\\002\\377\\375\\375\\002\\375\\376\\003\\003\\373\\373\\005\\004\\004\\373\\000\\000\\007\\003\\372\\375\\004\\003\\376\\377\\373\\376\\004\\372\\004\\377\\376\\007\\002\\005\\003\\001\\006\\006\\002\\005\\373\\000\\004\\000\\004\\374\\372\\376\\007\\002\\003\\006\\002\\000\\372\\001\\374\\005\\376\\006\\007\\373\\001\\375\\004\\377\\374\\375\\377\\001\\377\\003\\375\\005\\000\\003\\376\\375\\003\\377\\372\\002\\006\\003\\007\\005\\374\\003\\006\\003\\000\\375\\000\\001\\000\\001\\002\\374\\377\\372\\004\\372\\377\\377\\003\\377\\007\\006\\371\\003\\005\\004\\007\\006\\371\\006\\001\\375\\001\\001\\376\\002\\374\\006\\375\\375\\376\\377\\002\\002\\007\\373\\373\\374\\373\\377\\001\\006\\375\\375\\001\\375\\373\\375\\373\\372\\376\\003\\371\\006\\376\\376\\375\\007\\377\\374\\376\\377\\006\\377\\001\\371\\377\\007\\375\\371\\005\\002\\373\\003\\005\\002\\371\\375\\003\\003\\003\\374\\000\\377\\375\\003\\002\\006\\006\\375\\006\\002\\000\\374\\373\\374\\002\\003\\373\\002\\375\\377\\004\\006\\003\\006\\000\\377\\372\\375\\375\\002\\002\\003\\006\\003\\003\\377\\373\\003\\003\\003\\003\\377\\004\\004\\372\\377\\000\\374\\375\\005\\004\\005\\003\\002\\375\\376\\001\\376\\003\\374\\002\\007\\002\\376\\377\\007\\006\\376\\372\\374\\004\\371\\004\\006\\006\\374\\374\\377\\374\\003\\006\\371\\377\\007\\372\\375\\006\\374\\374\\005\\372\\006\\372\\371\\001\\000\\375\\372\\374\\373\\374\\374\\374\\005\\004\\002\\375\\004\\007\\004\\006\\002\\005\\005\\372\\375\\000\\004\\000\\377\\004\\004\\001\\374\\377\\006\\003\\377\\374\\000\\376\\372\\376\\373\\377\\006\\377\\376\\002\\005\\005\\372\\004\\000\\001\\004\\005\\373\\005\\003\\371\\374\\373\\000\\375\\002\\375\\006\\003\\001\\004\\377\\374\\372\\005\\006\\005\\005\\005\\005\\007\\372\\006\\004\\006\\372\\372\\002\\373\\371\\001\\004\\006\\374\\005\\373\\004\\006\\001\\005\\006\\377\\006\\373\\001\\373\\373\\376\\375\\007\\372\\374\\372\\377\\004\\006\\004\\375\\374\\000\\007\\005\\000\\002\\377\\002\\372\\002\\001\\377\\372\\006\\002\\001\\000\\376\\375\\374\\003\\376\\371\\005\\001\\000\\002\\372\\373\\375\\004\\376\\371\\374\\376\\000\\004\\004\\376\\375\\007\\374\\377\\375\\377\\001\\003\\005\\372\\002\\376\\003\\003\\375\\001\\004\\001\\001\\000\\002\\004\\375\\375\\372\\003\\003\\372\\002\\375\\372\\377\\373\\000\\002\\371\\005\\003\\001\\001\\376\\372\\374\\001\\001\\376\\000\\001\\376\\001\\376\\005\\002\\374\\002\\004\\004\\000\\374\\007\\000\\000\\006\\003\\371\\376\\371\\006\\005\\006\\007\\002\\371\\373\\005\\372\\375\\006\\003\\373\\005\\375\\375\\373\\002\\000\\375\\005\\001\\372\\377\\377\\373\\375\\375\\374\\000\\376\\372\\000\\374\\001\\001\\372\\375\\373\\004\\374\\000\\006\\375\\004\\001\\006\\000\\373\\001\\375\\003\\372\\000\\373\\376\\003\\374\\005\\007\\377\\373\\007\\006\\002\\371\\373\\377\\004\\373\\001\\374\\000\\001\\004\\001\\005\\375\\372\\002\\376\\377\\371\\374\\375\\371\\373\\005\\376\\374\\001\\377\\376\\371\\375\\371\\000\\375\\373\\377\\006\\002\\003\\005\\372\\003\\004\\005\\005\\004\\000\\376\\372\\371\\006\\000\\377\\373\\003\\376\\005\\007\\006\\372\\004\\007\\374\\375\\376\\374\\000\\001\\001\\375\\003\\371\\001\\006\\374\\376\\006\\377\\000\\001\\375\\006\\004\\372\\371\\001\\377\\377\\377\\376\\006\\375\\372\\000\\371\\376\\002\\374\\372\\006\\372\\002\\006\\005\\001\\376\\004\\374\\002\\376\\000\\004\\376\\375\\000\\376\\004\\000\\006\\372\\005\\007\\006\\002\\004\\373\\373\\006\\003\\007\\001\\375\\007\\007\\372\\004\\005\\376\\005\\376\\007\\002\\376\\004\\373\\373\\376\\004\\372\\375\\373\\374\\001\\000\\375\\004\\375\\375\\377\\004\\001\\377\\002\\376\\004\\377\\001\\001\\374\\376\\374\\377\\377\\001\\000\\000\\377\\373\\374\\002\\006\\001\\375\\376\\000\\000\\374\\006\\004\\004\\004\\375\\001\\376\\001\\002\\373\\006\\006\\376\\002\\005\\005\\374\\373\\377\\376\\004\\005\\374\\000\\376\\002\\375\\376\\004\\373\\001\\377\\377\\002\\377\\373\\372\\371\\003\\003\\372\\006\\000\\002\\003\\005\\375\\371\\375\\004\\376\\374\\007\\375\\371\\002\\374\\000\\375\\005\\006\\374\\373\\004\\371\\000\\007\\376\\001\\375\\377\\372\\372\\373\\005\\005\\001\\372\\377\\371\\377\\375\"\n", - " }\n", - " type: TENSOR\n", - " }\n", - " }\n", - " node {\n", - " output: \"/linear/export_handler/Constant_1_output_0\"\n", - " name: \"/linear/export_handler/Constant_1\"\n", - " op_type: \"Constant\"\n", - " attribute {\n", - " name: \"value\"\n", - " t {\n", - " data_type: 1\n", - " raw_data: \"\\263-\\341<\"\n", - " }\n", - " type: TENSOR\n", - " }\n", - " }\n", - " node {\n", - " output: \"/linear/export_handler/Constant_2_output_0\"\n", - " name: \"/linear/export_handler/Constant_2\"\n", - " op_type: \"Constant\"\n", - " attribute {\n", - " name: \"value\"\n", - " t {\n", - " dims: 128\n", - " data_type: 6\n", - " raw_data: \"\\271\\377\\377\\377\\032\\003\\000\\0009\\001\\000\\000\\302\\002\\000\\000;\\375\\377\\377\\031\\000\\000\\000\\024\\003\\000\\000d\\003\\000\\000\\327\\374\\377\\377\\363\\377\\377\\377u\\003\\000\\000\\374\\000\\000\\000t\\000\\000\\000\\321\\002\\000\\000\\236\\377\\377\\377\\241\\377\\377\\377\\237\\375\\377\\377\\010\\000\\000\\000\\350\\002\\000\\000}\\376\\377\\377\\267\\377\\377\\377\\374\\000\\000\\000\\355\\001\\000\\000N\\375\\377\\377\\\\\\002\\000\\000\\346\\002\\000\\000\\317\\000\\000\\000\\207\\001\\000\\000?\\000\\000\\000\\302\\002\\000\\000Y\\377\\377\\377\\326\\376\\377\\377\\\\\\003\\000\\000\\374\\376\\377\\377\\334\\000\\000\\000\\200\\001\\000\\000\\362\\377\\377\\377+\\000\\000\\000\\304\\375\\377\\377u\\000\\000\\000\\340\\000\\000\\000\\275\\001\\000\\000\\324\\377\\377\\377\\332\\000\\000\\000\\026\\001\\000\\000\\333\\001\\000\\000\\371\\375\\377\\377\\363\\000\\000\\000|\\002\\000\\000\\335\\376\\377\\377\\226\\375\\377\\377\\335\\002\\000\\0002\\001\\000\\000F\\377\\377\\377\\006\\003\\000\\000\\310\\375\\377\\377\\344\\377\\377\\377\\177\\376\\377\\377>\\001\\000\\000\\033\\002\\000\\000I\\003\\000\\000\\006\\376\\377\\377\\315\\375\\377\\377\\033\\003\\000\\000\\236\\000\\000\\000@\\376\\377\\377\\031\\002\\000\\000\\321\\002\\000\\000;\\000\\000\\000\\035\\377\\377\\377\\354\\377\\377\\377Z\\001\\000\\000N\\375\\377\\377I\\001\\000\\000\\030\\001\\000\\000w\\377\\377\\377\\303\\002\\000\\000\\022\\000\\000\\000\\377\\001\\000\\000!\\000\\000\\000\\035\\001\\000\\000\\003\\375\\377\\377^\\377\\377\\377\\336\\374\\377\\377p\\377\\377\\377\\351\\002\\000\\000X\\376\\377\\377\\247\\000\\000\\000H\\376\\377\\377}\\000\\000\\000\\225\\374\\377\\3776\\001\\000\\000\\301\\001\\000\\000\\210\\001\\000\\000\\374\\376\\377\\377\\307\\377\\377\\377\\320\\374\\377\\377\\267\\377\\377\\377F\\375\\377\\377\\352\\377\\377\\377=\\377\\377\\3770\\376\\377\\377#\\000\\000\\000\\313\\376\\377\\377\\334\\000\\000\\000\\261\\001\\000\\000\\363\\001\\000\\000\\037\\001\\000\\000\\220\\377\\377\\377\\202\\000\\000\\000d\\377\\377\\377\\013\\002\\000\\000\\266\\002\\000\\000\\347\\374\\377\\377+\\001\\000\\000\\301\\376\\377\\377\\341\\377\\377\\377O\\003\\000\\000\\037\\375\\377\\377\\244\\375\\377\\377\\352\\000\\000\\000\\302\\001\\000\\000I\\002\\000\\000~\\377\\377\\377*\\376\\377\\377\\333\\000\\000\\000\\214\\000\\000\\000\\014\\002\\000\\000\"\n", - " }\n", - " type: TENSOR\n", - " }\n", - " }\n", - " node {\n", - " input: \"/input_quant/export_handler/QuantizeLinear_output_0\"\n", - " input: \"/input_quant/export_handler/Constant_output_0\"\n", - " input: \"/input_quant/export_handler/Constant_1_output_0\"\n", - " input: \"/linear/export_handler/Constant_output_0\"\n", - " input: \"/linear/export_handler/Constant_1_output_0\"\n", - " input: \"/input_quant/export_handler/Constant_1_output_0\"\n", - " input: \"/input_quant/export_handler/Constant_output_0\"\n", - " input: \"/input_quant/export_handler/Constant_1_output_0\"\n", - " input: \"/linear/export_handler/Constant_2_output_0\"\n", - " output: \"/linear/export_handler/QLinearConv_output_0\"\n", - " name: \"/linear/export_handler/QLinearConv\"\n", - " op_type: \"QLinearConv\"\n", - " attribute {\n", - " name: \"dilations\"\n", - " ints: 1\n", - " ints: 1\n", - " type: INTS\n", - " }\n", - " attribute {\n", - " name: \"group\"\n", - " i: 1\n", - " type: INT\n", - " }\n", - " attribute {\n", - " name: \"kernel_shape\"\n", - " ints: 3\n", - " ints: 3\n", - " type: INTS\n", - " }\n", - " attribute {\n", - " name: \"pads\"\n", - " ints: 0\n", - " ints: 0\n", - " ints: 0\n", - " ints: 0\n", - " type: INTS\n", - " }\n", - " attribute {\n", - " name: \"strides\"\n", - " ints: 1\n", - " ints: 1\n", - " type: INTS\n", - " }\n", - " }\n", - " node {\n", - " input: \"/linear/export_handler/QLinearConv_output_0\"\n", - " input: \"/input_quant/export_handler/Constant_output_0\"\n", - " input: \"/input_quant/export_handler/Constant_1_output_0\"\n", - " output: \"10\"\n", - " name: \"/linear/export_handler/DequantizeLinear\"\n", - " op_type: \"DequantizeLinear\"\n", - " }\n", - " name: \"torch_jit\"\n", - " input {\n", - " name: \"inp.1\"\n", - " type {\n", - " tensor_type {\n", - " elem_type: 1\n", - " shape {\n", - " dim {\n", - " dim_value: 1\n", - " }\n", - " dim {\n", - " dim_value: 3\n", - " }\n", - " dim {\n", - " dim_value: 128\n", - " }\n", - " dim {\n", - " dim_value: 128\n", - " }\n", - " }\n", - " }\n", - " }\n", - " }\n", - " output {\n", - " name: \"10\"\n", - " type {\n", - " tensor_type {\n", - " elem_type: 1\n", - " shape {\n", - " dim {\n", - " dim_value: 1\n", - " }\n", - " dim {\n", - " dim_value: 128\n", - " }\n", - " dim {\n", - " dim_value: 126\n", - " }\n", - " dim {\n", - " dim_value: 126\n", - " }\n", - " }\n", - " }\n", - " }\n", - " }\n", - "}\n", - "opset_import {\n", - " domain: \"\"\n", - " version: 13\n", - "}" - ] - }, - "execution_count": 8, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "from brevitas.quant.scaled_int import Int8ActPerTensorFloat, Int32Bias\n", - "from brevitas.export import export_onnx_qop\n", - "\n", - "IN_CH = 3\n", - "IMG_SIZE = 128\n", - "OUT_CH = 128\n", - "BATCH_SIZE = 1\n", - "\n", - "class Model(torch.nn.Module):\n", - " def __init__(self) -> None:\n", - " super().__init__()\n", - " self.input_quant = qnn.QuantIdentity(return_quant_tensor=True)\n", - " self.linear = qnn.QuantConv2d(\n", - " IN_CH, OUT_CH, kernel_size=3, bias=True,\n", - " weight_bit_width=4, bias_quant=Int32Bias,\n", - " output_quant=Int8ActPerTensorFloat)\n", - " \n", - " def forward(self, inp):\n", - " inp = self.input_quant(inp)\n", - " inp = self.linear(inp)\n", - " return inp\n", - "\n", - "inp = torch.randn(BATCH_SIZE, IN_CH, IMG_SIZE, IMG_SIZE)\n", - "model = Model() \n", - "model.eval()\n", - "\n", - "\n", - "export_onnx_qop(\n", - " model, args=inp, export_path=\"quant_model_qop.onnx\")" - ] - }, - { - "cell_type": "code", - "execution_count": 9, - "metadata": { - "collapsed": false, - "pycharm": { - "name": "#%%\n" - }, - "tags": [ - "skip-execution" - ] - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Serving 'quant_model_qop.onnx' at http://localhost:8085\n" - ] - }, - { - "data": { - "text/html": [ - "\n", - " \n", - " " - ], - "text/plain": [ - "" - ] - }, - "execution_count": 9, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "show_netron(\"quant_model_qop.onnx\", 8085)" - ] - }, { "cell_type": "markdown", "metadata": { @@ -893,92 +581,6 @@ "source": [ "Unfortunately ONNX Runtime does not provide a built-in way to log whether execution goes through unoptimized floating-point GEMM, or int8 QGEMM." ] - }, - { - "cell_type": "markdown", - "metadata": { - "collapsed": false, - "pycharm": { - "name": "#%% md\n" - } - }, - "source": [ - "### QOps\n", - "\n", - "As for the QCDQ case, also in this case we are using only standard ONNX operations, thus we can use ONNX Runtime for executing our exported models." - ] - }, - { - "cell_type": "code", - "execution_count": 13, - "metadata": { - "collapsed": false, - "pycharm": { - "name": "#%%\n" - } - }, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "c:\\users\\alessand\\documents\\brevitas\\src\\brevitas\\export\\onnx\\standard\\manager.py:23: UserWarning: ONNX opset version set to 13, override with opset_version=\n", - " warnings.warn(f\"ONNX opset version set to {DEFAULT_OPSET}, override with {ka}=\")\n" - ] - } - ], - "source": [ - "import onnxruntime as ort\n", - "\n", - "class Model(torch.nn.Module):\n", - " def __init__(self) -> None:\n", - " super().__init__()\n", - " self.input_quant = qnn.QuantIdentity(return_quant_tensor=True)\n", - " self.conv = qnn.QuantConv2d(\n", - " IN_CH, OUT_CH, kernel_size=3, bias=False,\n", - " weight_bit_width=4,\n", - " output_quant=Int8ActPerTensorFloat, \n", - " output_bit_width=4,\n", - " return_quant_tensor=True)\n", - " \n", - " def forward(self, inp):\n", - " inp = self.input_quant(inp)\n", - " inp = self.conv(inp)\n", - " return inp\n", - "\n", - "model = Model()\n", - "model.eval()\n", - "inp = torch.randn(BATCH_SIZE, IN_CH, IMG_SIZE, IMG_SIZE)\n", - "path = 'quant_model_qops_4b_4b.onnx'\n", - "\n", - "exported_model = export_onnx_qop(model, args=inp, export_path=path)\n", - "\n", - "sess_opt = ort.SessionOptions()\n", - "sess = ort.InferenceSession(path, sess_opt)\n", - "input_name = sess.get_inputs()[0].name\n", - "pred_onx = sess.run(None, {input_name: inp.numpy()})[0]\n", - "\n", - "\n", - "out_brevitas = model(inp).int()\n", - "out_ort = torch.tensor(pred_onx, dtype=torch.int8)\n", - "\n", - "assert torch.allclose(out_brevitas, out_ort, atol=1)" - ] - }, - { - "attachments": {}, - "cell_type": "markdown", - "metadata": { - "collapsed": false, - "pycharm": { - "name": "#%% md\n" - } - }, - "source": [ - "Note a few things. `QuantConv2d` defines `return_quant_tensor=True` so that the exported ONNX model doesn't have a dequantize node at the end. Vecause we set `output_bit_width=4` (overriding the 8 bit bit-width in the `output_quant` quantizer), we have a `Clip` node at the end. At the same time, within Brevitas, `return_quant_tensor=True` means the PyTorch model returns a Brevitas `QuantTensor`, from which we are taking the `int` representation.\n", - "\n", - "Due to differences in how the computation is performed between Brevitas and ONNX Runtime, it might happen the two results are slightly different (since Brevitas uses a style closer to QCDQ, rather than operating between integers), thus we added a tolerance for off-by-1 errors." - ] } ], "metadata": { diff --git a/docs/tutorials/tvmcon2021.ipynb b/docs/tutorials/tvmcon2021.ipynb index efd9421f0..cdaaed038 100644 --- a/docs/tutorials/tvmcon2021.ipynb +++ b/docs/tutorials/tvmcon2021.ipynb @@ -1882,93 +1882,6 @@ " return IFrame(src=f\"http://localhost:{port}/\", width=\"100%\", height=400)" ] }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Export to ONNX QOps\n", - "\n", - "Say we want to export a QuantConv1d with 4b symmetric weights, 8b symmetric inputs and outputs, and 16 biases. \n", - "We can export it to a ONNX's `QLinearConv`, but some information will be lost. In particular, weights will be represented as 8b and bias as 32b, even though they are respectively 4b and 16b. This is because ONNX does not provide a standardized way to represent them as such:" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "torch.manual_seed(0)\n", - "\n", - "from brevitas.nn import QuantConv1d\n", - "from brevitas.quant import Int8WeightPerTensorFloat, Int8ActPerTensorFloat, Int16Bias\n", - "from brevitas.export import export_onnx_qop\n", - "\n", - "float_inp = torch.randn(1, 2, 5)\n", - "\n", - "quant_conv_4b8b = QuantConv1d(\n", - " 2, 4, 3, bias=True, weight_bit_width=4,\n", - " input_quant=Int8ActPerTensorFloat,\n", - " output_quant=Int8ActPerTensorFloat,\n", - " bias_quant=Int16Bias)\n", - "\n", - "output_path = 'qop_onnx_conv_4b8b.onnx'\n", - "export_onnx_qop(quant_conv_4b8b, input_t=float_inp, export_path=output_path)" - ] - }, - { - "cell_type": "code", - "execution_count": 39, - "metadata": { - "tags": [ - "skip-execution" - ] - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Serving 'qop_onnx_conv_4b8b.onnx' at http://localhost:8082\n" - ] - }, - { - "data": { - "text/html": [ - "\n", - " \n", - " " - ], - "text/plain": [ - "" - ] - }, - "execution_count": 39, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "show_netron(output_path, 8082)" - ] - }, - { - "attachments": {}, - "cell_type": "markdown", - "metadata": {}, - "source": [ - "In general the standard ONNX opset doesn't support representing quantization below 8b. Additionally, ONNX QOp representation requires an output quantizer to be set at part of of the layer. \n", - "\n", - "The constraint of always having an output quantizer is relaxed in the more recently introduced QDQ style of representation (for which there is support in Brevitas starting from version 0.8), which uses only `QuantizeLinear` and `DequantizeLinear` to represent quantization, but even with that support is still limited to 8b quantization." - ] - }, { "cell_type": "markdown", "metadata": {}, @@ -2112,93 +2025,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "The custom format shown above can integrated into ONNX-based toolchains, e.g. it's supported by our own FINN toolchain for low-precision dataflow style custom FPGAs implementations, and would be a starting point for direct integration with TVM.\n", - "\n", - "## Export to TorchScript quantization backend\n", - "\n", - "It's also possible to export to TorchScript own quantized functional operators, which come with their own set of restrictions. In particular, weights should be 7b and unsigned, which requires a zero-point. We can model that with appropriate quantizers:" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "from brevitas.quant import ShiftedUint8ActPerTensorFloat\n", - "from brevitas.export import export_torch_qop\n", - "\n", - "\n", - "quant_conv_8b7b = QuantConv1d(\n", - " 2, 4, 3, bias=True,\n", - " input_quant=ShiftedUint8ActPerTensorFloat,\n", - " output_quant=ShiftedUint8ActPerTensorFloat,\n", - " weight_bit_width=7,\n", - " bias_quant=Int16Bias)\n", - "\n", - "output_path = 'pytorch_qf_conv_8b7b.pt'\n", - "export_torch_qop(quant_conv_8b7b, input_t=float_inp, export_path=output_path)" - ] - }, - { - "cell_type": "code", - "execution_count": 42, - "metadata": { - "tags": [ - "skip-execution" - ] - }, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "c:\\users\\alessandro\\documenti\\brevitas_tvmcon\\src\\brevitas\\quant_tensor\\__init__.py:74: TracerWarning: torch.tensor results are registered as constants in the trace. You can safely ignore this warning if you use this function to create tensors out of constant variables that would be the same every time you call this function. In any other case, this might cause the trace to be incorrect.\n", - " training = torch.tensor(training, dtype=torch.bool)\n" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Serving 'pytorch_qf_conv_8b7b.pt' at http://localhost:8085\n" - ] - }, - { - "data": { - "text/html": [ - "\n", - " \n", - " " - ], - "text/plain": [ - "" - ] - }, - "execution_count": 42, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "show_netron(output_path, 8085)" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "As we can see though information on the fact that activations are 7b is lost, and they simply marked as 8b.\n", - "\n", - "Additionally, because bias quantization is not represented explicitly (although it is performed implicitly at 32b at runtime in the backend), any information around that is lost.\n", - "As with standard ONNX, representing precisions below 8b is not possible." + "The custom format shown above can integrated into ONNX-based toolchains, e.g. it's supported by our own FINN toolchain for low-precision dataflow style custom FPGAs implementations, and would be a starting point for direct integration with TVM." ] }, { diff --git a/notebooks/Brevitas_TVMCon2021.ipynb b/notebooks/Brevitas_TVMCon2021.ipynb index 425822b02..a6f24ea72 100644 --- a/notebooks/Brevitas_TVMCon2021.ipynb +++ b/notebooks/Brevitas_TVMCon2021.ipynb @@ -1903,102 +1903,6 @@ " return IFrame(src=f\"http://localhost:{port}/\", width=\"100%\", height=400)" ] }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Export to ONNX QOps\n", - "\n", - "Say we want to export a QuantConv1d with 4b symmetric weights, 8b symmetric inputs and outputs, and 16 biases. \n", - "We can export it to a ONNX's `QLinearConv`, but some information will be lost. In particular, weights will be represented as 8b and bias as 32b, even though they are respectively 4b and 16b. This is because ONNX does not provide a standardized way to represent them as such:" - ] - }, - { - "cell_type": "code", - "execution_count": 39, - "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "/scratch/fabian/brevitas/src/brevitas/export/onnx/standard/manager.py:26: UserWarning: ONNX opset version set to 13, override with opset_version=\n", - " warnings.warn(f\"ONNX opset version set to {DEFAULT_OPSET}, override with {ka}=\")\n" - ] - } - ], - "source": [ - "torch.manual_seed(0)\n", - "\n", - "from brevitas.nn import QuantConv1d\n", - "from brevitas.quant import Int8WeightPerTensorFloat, Int8ActPerTensorFloat, Int16Bias\n", - "from brevitas.export import export_onnx_qop\n", - "\n", - "float_inp = torch.randn(1, 2, 5)\n", - "\n", - "quant_conv_4b8b = QuantConv1d(\n", - " 2, 4, 3, bias=True, weight_bit_width=4,\n", - " input_quant=Int8ActPerTensorFloat,\n", - " output_quant=Int8ActPerTensorFloat,\n", - " bias_quant=Int16Bias)\n", - "\n", - "output_path = 'qop_onnx_conv_4b8b.onnx'\n", - "exported_model = export_onnx_qop(quant_conv_4b8b, input_t=float_inp, export_path=output_path)" - ] - }, - { - "cell_type": "code", - "execution_count": 40, - "metadata": { - "tags": [ - "skip-execution" - ] - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Serving 'qop_onnx_conv_4b8b.onnx' at http://localhost:8082\n" - ] - }, - { - "data": { - "text/html": [ - "\n", - " \n", - " " - ], - "text/plain": [ - "" - ] - }, - "execution_count": 40, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "show_netron(output_path, 8082)" - ] - }, - { - "attachments": {}, - "cell_type": "markdown", - "metadata": {}, - "source": [ - "In general the standard ONNX opset doesn't support representing quantization below 8b. Additionally, ONNX QOp representation requires an output quantizer to be set at part of of the layer. \n", - "\n", - "The constraint of always having an output quantizer is relaxed in the more recently introduced QDQ style of representation (for which there is support in Brevitas starting from version 0.8), which uses only `QuantizeLinear` and `DequantizeLinear` to represent quantization, but even with that support is still limited to 8b quantization." - ] - }, { "cell_type": "markdown", "metadata": {}, @@ -2142,85 +2046,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "The custom format shown above can integrated into ONNX-based toolchains, e.g. it's supported by our own FINN toolchain for low-precision dataflow style custom FPGAs implementations, and would be a starting point for direct integration with TVM.\n", - "\n", - "## Export to TorchScript quantization backend\n", - "\n", - "It's also possible to export to TorchScript own quantized functional operators, which come with their own set of restrictions. In particular, weights should be 7b and unsigned, which requires a zero-point. We can model that with appropriate quantizers:" - ] - }, - { - "cell_type": "code", - "execution_count": 45, - "metadata": {}, - "outputs": [], - "source": [ - "from brevitas.quant import ShiftedUint8ActPerTensorFloat\n", - "from brevitas.export import export_torch_qop\n", - "\n", - "\n", - "quant_conv_8b7b = QuantConv1d(\n", - " 2, 4, 3, bias=True,\n", - " input_quant=ShiftedUint8ActPerTensorFloat,\n", - " output_quant=ShiftedUint8ActPerTensorFloat,\n", - " weight_bit_width=7,\n", - " bias_quant=Int16Bias)\n", - "\n", - "output_path = 'pytorch_qf_conv_8b7b.pt'\n", - "exported_model = export_torch_qop(quant_conv_8b7b, input_t=float_inp, export_path=output_path)" - ] - }, - { - "cell_type": "code", - "execution_count": 46, - "metadata": { - "tags": [ - "skip-execution" - ] - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Serving 'pytorch_qf_conv_8b7b.pt' at http://localhost:8085\n" - ] - }, - { - "data": { - "text/html": [ - "\n", - " \n", - " " - ], - "text/plain": [ - "" - ] - }, - "execution_count": 46, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "show_netron(output_path, 8085)" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "As we can see though information on the fact that activations are 7b is lost, and they simply marked as 8b.\n", - "\n", - "Additionally, because bias quantization is not represented explicitly (although it is performed implicitly at 32b at runtime in the backend), any information around that is lost.\n", - "As with standard ONNX, representing precisions below 8b is not possible." + "The custom format shown above can integrated into ONNX-based toolchains, e.g. it's supported by our own FINN toolchain for low-precision dataflow style custom FPGAs implementations, and would be a starting point for direct integration with TVM." ] }, { diff --git a/notebooks/ONNX_export_tutorial.ipynb b/notebooks/ONNX_export_tutorial.ipynb index 1417946a3..65ef2ea58 100644 --- a/notebooks/ONNX_export_tutorial.ipynb +++ b/notebooks/ONNX_export_tutorial.ipynb @@ -489,155 +489,6 @@ "Similarly, the output of the QuantReLU is clipped between 0 and 15, since in this case we are doing an unsigned 4 bit quantization." ] }, - { - "cell_type": "markdown", - "metadata": { - "collapsed": false, - "jupyter": { - "outputs_hidden": false - }, - "pycharm": { - "name": "#%% md\n" - } - }, - "source": [ - "## QOps Export\n", - "Another supported style for exporting quantized operation in ONNX is represented by QOps.\n", - "\n", - "Compared to QCDQ, where it is possible to re-use standard floating point nodes (e.g., GEMM or Conv2d) preceeded by QCDQ nodes, with QOps the entire layer is replaced by its quantized counterpart.\n", - "\n", - "Opposite to what happens with QCDQ, all elements of the computation in this case have to be quantized: Input, Weight, Bias (if present), and Output tensors.\n", - "\n", - "This introduces some contraints on how we define our quantized layers through Brevitas.\n", - "\n" - ] - }, - { - "cell_type": "code", - "execution_count": 9, - "metadata": { - "collapsed": false, - "jupyter": { - "outputs_hidden": false - }, - "pycharm": { - "name": "#%%\n" - } - }, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "/scratch/fabian/brevitas/src/brevitas/export/onnx/standard/manager.py:26: UserWarning: ONNX opset version set to 13, override with opset_version=\n", - " warnings.warn(f\"ONNX opset version set to {DEFAULT_OPSET}, override with {ka}=\")\n" - ] - } - ], - "source": [ - "from brevitas.quant.scaled_int import Int8ActPerTensorFloat, Int32Bias\n", - "from brevitas.export import export_onnx_qop\n", - "\n", - "torch.manual_seed(1)\n", - "\n", - "IN_CH = 3\n", - "IMG_SIZE = 128\n", - "OUT_CH = 128\n", - "BATCH_SIZE = 1\n", - "\n", - "class Model(torch.nn.Module):\n", - " def __init__(self) -> None:\n", - " super().__init__()\n", - " self.input_quant = qnn.QuantIdentity(return_quant_tensor=True)\n", - " self.linear = qnn.QuantConv2d(\n", - " IN_CH, OUT_CH, kernel_size=3, bias=True,\n", - " weight_bit_width=4, bias_quant=Int32Bias,\n", - " output_quant=Int8ActPerTensorFloat)\n", - " \n", - " def forward(self, inp):\n", - " inp = self.input_quant(inp)\n", - " inp = self.linear(inp)\n", - " return inp\n", - "\n", - "inp = torch.randn(BATCH_SIZE, IN_CH, IMG_SIZE, IMG_SIZE)\n", - "model = Model() \n", - "model.eval()\n", - "\n", - "\n", - "exported_model = export_onnx_qop(model, args=inp, export_path=\"quant_model_qop.onnx\")" - ] - }, - { - "cell_type": "code", - "execution_count": 10, - "metadata": { - "collapsed": false, - "jupyter": { - "outputs_hidden": false - }, - "pycharm": { - "name": "#%%\n" - }, - "tags": [ - "skip-execution" - ] - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Serving 'quant_model_qop.onnx' at http://localhost:8085\n" - ] - }, - { - "data": { - "text/html": [ - "\n", - " \n", - " " - ], - "text/plain": [ - "" - ] - }, - "execution_count": 10, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "show_netron(\"quant_model_qop.onnx\", 8085)" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "collapsed": false, - "jupyter": { - "outputs_hidden": false - }, - "pycharm": { - "name": "#%% md\n" - } - }, - "source": [ - "In this case, we need to make sure that our input to `QuantLinear` is quantized. Using the approach shown above, with a standalone `QuantIdentity`, Brevitas will add a `QuantizeLinear` node. If `return_quant_tensor=True` is specified in `QuantIdentity`, a `DeQuantizeLinear` node won't be added. Setting `input_quant` in `QuantConv2d` is also an option.\n", - "\n", - "Note that the way `return_quant_tensor=True` is interpreted differs between QCDQ export and QOps export. With QCDQ, it doesn't affect the export, as a dequantize node is always generated. With QOps export, it prevents a quantization node from being inserted, so that an integer tensor is passed to the next layer.\n", - "\n", - "Moreover, our `QuantLinear` layer has to specify how to re-quantize the output - in this case, with the `Int8ActPerTensorFloat` activation quantizer, otherwise an error will be raised during export-time.\n", - "\n", - "Similarly, if the bias is present, it has to be quantized or an error will be raised.\n" - ] - }, { "attachments": {}, "cell_type": "markdown", @@ -811,108 +662,6 @@ "Unfortunately ONNX Runtime does not provide a built-in way to log whether execution goes through unoptimized floating-point GEMM, or int8 QGEMM." ] }, - { - "cell_type": "markdown", - "metadata": { - "collapsed": false, - "jupyter": { - "outputs_hidden": false - }, - "pycharm": { - "name": "#%% md\n" - } - }, - "source": [ - "### QOps\n", - "\n", - "As for the QCDQ case, also in this case we are using only standard ONNX operations, thus we can use ONNX Runtime for executing our exported models." - ] - }, - { - "cell_type": "code", - "execution_count": 13, - "metadata": { - "collapsed": false, - "jupyter": { - "outputs_hidden": false - }, - "pycharm": { - "name": "#%%\n" - } - }, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "/scratch/fabian/brevitas/src/brevitas/export/onnx/standard/manager.py:26: UserWarning: ONNX opset version set to 13, override with opset_version=\n", - " warnings.warn(f\"ONNX opset version set to {DEFAULT_OPSET}, override with {ka}=\")\n" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "True\n" - ] - } - ], - "source": [ - "import onnxruntime as ort\n", - "\n", - "class Model(torch.nn.Module):\n", - " def __init__(self) -> None:\n", - " super().__init__()\n", - " self.input_quant = qnn.QuantIdentity(return_quant_tensor=True)\n", - " self.conv = qnn.QuantConv2d(\n", - " IN_CH, OUT_CH, kernel_size=3, bias=False,\n", - " weight_bit_width=4,\n", - " output_quant=Int8ActPerTensorFloat, \n", - " output_bit_width=4,\n", - " return_quant_tensor=True)\n", - " \n", - " def forward(self, inp):\n", - " inp = self.input_quant(inp)\n", - " inp = self.conv(inp)\n", - " return inp\n", - "\n", - "model = Model()\n", - "model.eval()\n", - "inp = torch.randn(BATCH_SIZE, IN_CH, IMG_SIZE, IMG_SIZE)\n", - "path = 'quant_model_qops_4b_4b.onnx'\n", - "\n", - "exported_model = export_onnx_qop(model, args=inp, export_path=path)\n", - "\n", - "sess_opt = ort.SessionOptions()\n", - "sess = ort.InferenceSession(path, sess_opt)\n", - "input_name = sess.get_inputs()[0].name\n", - "pred_onx = sess.run(None, {input_name: inp.numpy()})[0]\n", - "\n", - "\n", - "out_brevitas = model(inp).int()\n", - "out_ort = torch.tensor(pred_onx, dtype=torch.int8)\n", - "\n", - "assert_with_message(torch.allclose(out_brevitas, out_ort, atol=1))" - ] - }, - { - "attachments": {}, - "cell_type": "markdown", - "metadata": { - "collapsed": false, - "jupyter": { - "outputs_hidden": false - }, - "pycharm": { - "name": "#%% md\n" - } - }, - "source": [ - "Note a few things. `QuantConv2d` defines `return_quant_tensor=True` so that the exported ONNX model doesn't have a dequantize node at the end. Vecause we set `output_bit_width=4` (overriding the 8 bit bit-width in the `output_quant` quantizer), we have a `Clip` node at the end. At the same time, within Brevitas, `return_quant_tensor=True` means the PyTorch model returns a Brevitas `QuantTensor`, from which we are taking the `int` representation.\n", - "\n", - "Due to differences in how the computation is performed between Brevitas and ONNX Runtime, it might happen the two results are slightly different (since Brevitas uses a style closer to QCDQ, rather than operating between integers), thus we added a tolerance for off-by-1 errors." - ] - }, { "cell_type": "markdown", "metadata": {}, diff --git a/src/brevitas/export/__init__.py b/src/brevitas/export/__init__.py index fbea3cd50..ca5efdbd6 100644 --- a/src/brevitas/export/__init__.py +++ b/src/brevitas/export/__init__.py @@ -6,9 +6,7 @@ from .onnx.debug import enable_debug from .onnx.qonnx.manager import QONNXManager from .onnx.standard.qcdq.manager import StdQCDQONNXManager -from .onnx.standard.qoperator.manager import StdQOpONNXManager from .torch.qcdq.manager import TorchQCDQManager -from .torch.qoperator.manager import TorchQOpManager @wraps(QONNXManager.export) @@ -21,21 +19,11 @@ def export_qonnx(*args, **kwargs): return QONNXManager.export(*args, **kwargs) -@wraps(StdQOpONNXManager.export) -def export_onnx_qop(*args, **kwargs): - return StdQOpONNXManager.export(*args, **kwargs) - - @wraps(StdQCDQONNXManager.export) def export_onnx_qcdq(*args, **kwargs): return StdQCDQONNXManager.export(*args, **kwargs) -@wraps(TorchQOpManager.export) -def export_torch_qop(*args, **kwargs): - return TorchQOpManager.export(*args, **kwargs) - - @wraps(TorchQCDQManager.export) def export_torch_qcdq(*args, **kwargs): return TorchQCDQManager.export(*args, **kwargs) diff --git a/src/brevitas/export/manager.py b/src/brevitas/export/manager.py index bc62920f3..8ad069642 100644 --- a/src/brevitas/export/manager.py +++ b/src/brevitas/export/manager.py @@ -170,61 +170,12 @@ class BaseManager(ABC): def export(cls, *args, **kwargs): return - @classmethod - def _gen_patches(cls, fn_dispatcher): - patches = [] - for fn in cls._fn_to_cache: - dispatcher = partial(fn_dispatcher, fn) - p = patch(torch.nn.functional, fn.__name__, dispatcher) - patches.append(p) - return patches - - @classmethod - def _trace_patches(cls): - patches = cls._gen_patches(cls._trace_fn_dispatcher) - return patches - - @classmethod - def _cache_patches(cls): - return cls._gen_patches(cls._cache_fn_dispatcher) - - @classmethod - def _restore_fn_patches(cls): - return [patch(torch.nn.functional, fn.__name__, fn) for fn in cls._fn_to_cache] - @classmethod def _trace_fn_dispatcher(cls, fn, input, *args, **kwargs): # baseline impl cls._fn_to_cache.pop(0) return fn(input, *args, **kwargs) - @classmethod - def _cache_fn_dispatcher(cls, fn, input, *args, **kwargs): - with ExitStack() as stack: - # disable recursing into this patch - for mgr in cls._restore_fn_patches(): - stack.enter_context(mgr) - if isinstance(input, QuantTensor): - inp_cache = None - out_cache = None - inp_cache = _CachedIO(input, metadata_only=True) - output = fn(input, *args, **kwargs) - if isinstance(output, QuantTensor): - out_cache = _CachedIO(output, metadata_only=True) - cached_io = (inp_cache, out_cache) - if fn in cls._cached_io_handler_map: - cached_io = cls._cached_io_handler_map[fn](cached_io) - cls._fn_cache.append(cached_io) - else: - # could be a fn invoked within a quant module on a dequant tensor - # or a function invoked on a float tensor. The former won't show - # up during jit tracing as they are replaced by symbolic functions, - # but the latter will, so we have to account for them in the _fn_cache - output = fn(input, *args, **kwargs) - if not config._IS_INSIDE_QUANT_LAYER: - cls._fn_cache.append(None) - return output - @classmethod def handler_from_module(cls, module: Module, no_inheritance=False): for handler in cls.handlers: @@ -253,10 +204,7 @@ def _cache_inp_out(cls, module, *args, **kwargs): module.apply(lambda m: _override_bias_caching_mode(m, enabled=True)) module.apply(lambda m: _override_inp_caching_mode(m, enabled=True)) module.apply(lambda m: _override_out_caching_mode(m, enabled=True)) - with ExitStack() as stack: - for mgr in cls._cache_patches(): - stack.enter_context(mgr) - _ = module.forward(*args, **kwargs) + _ = module.forward(*args, **kwargs) # Restore previous caching properties module.apply(lambda m: _restore_quant_metadata_caching_mode(m)) module.apply(lambda m: _restore_bias_caching_mode(m)) @@ -279,12 +227,9 @@ def jit_inference_trace( cls.set_export_mode(module, enabled=True) # force requires_grad to False to let the wrapped model lambda go through tracing requires_grad_backup_dict = _force_requires_grad_false(module) - with ExitStack() as stack: - for mgr in cls._trace_patches(): - stack.enter_context(mgr) - # wrapping with a lambda forces inlining during tracing, - # converts everything to const and removes unused params/buffers - traced_model = torch.jit.trace(_JitTraceExportWrapper(module), args) + # wrapping with a lambda forces inlining during tracing, + # converts everything to const and removes unused params/buffers + traced_model = torch.jit.trace(_JitTraceExportWrapper(module), args) # Hack to clone the function, otherwise restoring requires_grad # on module will break traced_model with BytesIO() as tmp: diff --git a/src/brevitas/export/onnx/manager.py b/src/brevitas/export/onnx/manager.py index 68be5e8da..56d5bf753 100644 --- a/src/brevitas/export/onnx/manager.py +++ b/src/brevitas/export/onnx/manager.py @@ -122,15 +122,12 @@ def export_onnx( # temporarily disable input caching to avoid collectives empty debug values module.apply(lambda m: _override_inp_caching_mode(m, enabled=False)) # perform export pass - with ExitStack() as stack: - for mgr in cls._trace_patches(): - stack.enter_context(mgr) - if export_path is not None: - export_target = export_path - else: - model_bytes = BytesIO() - export_target = model_bytes - torch.onnx.export(module, args, export_target, **onnx_export_kwargs) + if export_path is not None: + export_target = export_path + else: + model_bytes = BytesIO() + export_target = model_bytes + torch.onnx.export(module, args, export_target, **onnx_export_kwargs) # restore the model to previous properties module.apply(lambda m: _restore_inp_caching_mode(m)) diff --git a/src/brevitas/export/onnx/standard/qoperator/__init__.py b/src/brevitas/export/onnx/standard/qoperator/__init__.py deleted file mode 100644 index b10a7efee..000000000 --- a/src/brevitas/export/onnx/standard/qoperator/__init__.py +++ /dev/null @@ -1,2 +0,0 @@ -# Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. -# SPDX-License-Identifier: BSD-3-Clause diff --git a/src/brevitas/export/onnx/standard/qoperator/function.py b/src/brevitas/export/onnx/standard/qoperator/function.py deleted file mode 100644 index aa0395b71..000000000 --- a/src/brevitas/export/onnx/standard/qoperator/function.py +++ /dev/null @@ -1,126 +0,0 @@ -# Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. -# SPDX-License-Identifier: BSD-3-Clause - -import torch -from torch.autograd import Function - - -class QLinearConvFn(Function): - - @staticmethod - def symbolic( - g, - int_x, - input_scale, - input_zero_point, - int_weight, - weight_scale, - weight_zero_point, - output_scale, - ouput_zero_point, - output_dtype, - int_bias, - out_shape, - kernel_size, - padding, - stride, - groups, - dilation): - if int_bias is not None: - ret = g.op( - 'QLinearConv', - int_x, - input_scale, - input_zero_point, - int_weight, - weight_scale, - weight_zero_point, - output_scale, - ouput_zero_point, - int_bias, - kernel_shape_i=kernel_size, - pads_i=padding, - strides_i=stride, - group_i=groups, - dilations_i=dilation) - else: - ret = g.op( - 'QLinearConv', - int_x, - input_scale, - input_zero_point, - int_weight, - weight_scale, - weight_zero_point, - output_scale, - ouput_zero_point, - kernel_shape_i=kernel_size, - pads_i=padding, - strides_i=stride, - group_i=groups, - dilations_i=dilation) - return ret - - @staticmethod - def forward( - ctx, - int_x, - input_scale, - input_zero_point, - int_weight, - weight_scale, - weight_zero_point, - output_scale, - output_zero_point, - output_dtype, - int_bias, - out_shape, - kernel_size, - padding, - stride, - groups, - dilation): - return torch.empty(out_shape, dtype=output_dtype, device=int_x.device) - - -class QLinearMatMulFn(Function): - - @staticmethod - def symbolic( - g, - int_x, - input_scale, - input_zero_point, - int_weight, - weight_scale, - weight_zero_point, - output_scale, - ouput_zero_point, - output_dtype, - out_shape): - ret = g.op( - 'QLinearMatMul', - int_x, - input_scale, - input_zero_point, - int_weight, - weight_scale, - weight_zero_point, - output_scale, - ouput_zero_point) - return ret - - @staticmethod - def forward( - ctx, - int_x, - input_scale, - input_zero_point, - int_weight, - weight_scale, - weight_zero_point, - output_scale, - output_zero_point, - output_dtype, - out_shape): - return torch.empty(out_shape, dtype=output_dtype, device=int_x.device) diff --git a/src/brevitas/export/onnx/standard/qoperator/handler/__init__.py b/src/brevitas/export/onnx/standard/qoperator/handler/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/src/brevitas/export/onnx/standard/qoperator/handler/act.py b/src/brevitas/export/onnx/standard/qoperator/handler/act.py deleted file mode 100644 index b961e05cf..000000000 --- a/src/brevitas/export/onnx/standard/qoperator/handler/act.py +++ /dev/null @@ -1,124 +0,0 @@ -from abc import ABC - -import torch -from torch import Tensor - -from brevitas.export.onnx.standard.function import DequantizeLinearFn -from brevitas.export.onnx.standard.function import IntClipFn -from brevitas.export.onnx.standard.function import QuantizeLinearFn -from brevitas.nn import QuantHardTanh -from brevitas.nn import QuantIdentity -from brevitas.nn import QuantReLU -from brevitas.nn import QuantSigmoid -from brevitas.nn import QuantTanh -from brevitas.nn.quant_layer import QuantNonLinearActLayer as QuantNLAL - -from .base import StdQOpONNXQuantLayerHandler - - -class StdQOpONNXQuantNLALHandler(StdQOpONNXQuantLayerHandler, ABC): - - @classmethod - def validate(cls, module: QuantNLAL): - if cls.input_quant_supported and module.is_input_quant_enabled: - assert not module.is_quant_input_narrow_range, "Narrow range quant not supported." - elif not cls.input_quant_supported and module.is_input_quant_enabled: - raise RuntimeError("Input quant not supported.") - if module.is_act_quant_enabled: - assert not module.is_quant_act_narrow_range, "Narrow range quant not supported." - input_bit_width = module.quant_input_bit_width() - act_bit_width = module.quant_act_bit_width() - if input_bit_width is not None: - cls.validate_8b_bit_width(input_bit_width, le_then=True) - if act_bit_width is not None: - cls.validate_8b_bit_width(act_bit_width, le_then=True) - - def prepare_for_export(self, module: QuantNLAL): - self.validate(module) - if not module.return_quant_tensor: - output_dequant_symbolic_kwargs = self.output_dequant_symbolic_kwargs(module) - else: - output_dequant_symbolic_kwargs = None - output_quant_symbolic_kwargs = self.output_quant_symbolic_kwargs(module) - output_clip_symbolic_kwargs = self.output_clip_symbolic_kwargs(module) - input_dequant_symbolic_kwargs = self.input_dequant_symbolic_kwargs(module) - input_quant_symbolic_kwargs = self.input_quant_symbolic_kwargs(module) - input_clip_symbolic_kwargs = self.input_clip_symbolic_kwargs(module) - if input_quant_symbolic_kwargs is not None: - input_redequant_symbolic_kwargs = { - 'input_scale': input_quant_symbolic_kwargs['output_scale'], - 'input_zero_point': input_quant_symbolic_kwargs['output_zero_point'], - 'input_axis': input_quant_symbolic_kwargs['input_axis']} - else: - input_redequant_symbolic_kwargs = None - - self.symbolic_kwargs = { - 'input_quant_symbolic_kwargs': input_quant_symbolic_kwargs, - 'input_clip_symbolic_kwargs': input_clip_symbolic_kwargs, - 'input_dequant_symbolic_kwargs': input_dequant_symbolic_kwargs, - 'input_redequant_symbolic_kwargs': input_redequant_symbolic_kwargs, - 'output_quant_symbolic_kwargs': output_quant_symbolic_kwargs, - 'output_clip_symbolic_kwargs': output_clip_symbolic_kwargs, - 'output_dequant_symbolic_kwargs': output_dequant_symbolic_kwargs} - - def input_symbolic_execution(self, inp: Tensor): - input_quant_symbolic_kwargs = self.symbolic_kwargs['input_quant_symbolic_kwargs'] - input_clip_symbolic_kwargs = self.symbolic_kwargs['input_clip_symbolic_kwargs'] - input_dequant_symbolic_kwargs = self.symbolic_kwargs['input_dequant_symbolic_kwargs'] - input_redequant_symbolic_kwargs = self.symbolic_kwargs['input_redequant_symbolic_kwargs'] - if input_dequant_symbolic_kwargs is not None: - inp = DequantizeLinearFn.apply(inp, *input_dequant_symbolic_kwargs.values()) - if input_quant_symbolic_kwargs is not None: - inp = QuantizeLinearFn.apply(inp, *input_quant_symbolic_kwargs.values()) - inp = DequantizeLinearFn.apply(inp, *input_redequant_symbolic_kwargs.values()) - if input_clip_symbolic_kwargs is not None: - inp = IntClipFn.apply(inp, *input_clip_symbolic_kwargs.values()) - return inp - - def output_symbolic_execution(self, out: Tensor): - output_quant_symbolic_kwargs = self.symbolic_kwargs['output_quant_symbolic_kwargs'] - output_dequant_symbolic_kwargs = self.symbolic_kwargs['output_dequant_symbolic_kwargs'] - output_clip_symbolic_kwargs = self.symbolic_kwargs['output_clip_symbolic_kwargs'] - if output_quant_symbolic_kwargs is not None: - out = QuantizeLinearFn.apply(out, *output_quant_symbolic_kwargs.values()) - if output_clip_symbolic_kwargs is not None: - out = IntClipFn.apply(out, *output_clip_symbolic_kwargs.values()) - if output_dequant_symbolic_kwargs is not None: - out = DequantizeLinearFn.apply(out, *output_dequant_symbolic_kwargs.values()) - return out - - -class StdQOpONNXQuantReLUHandler(StdQOpONNXQuantNLALHandler): - handled_layer = QuantReLU - input_quant_supported = True - - def op_symbolic_execution(self, inp: Tensor): - return torch.relu(inp) - - -class StdQOpONNXQuantTanhHandler(StdQOpONNXQuantNLALHandler): - handled_layer = QuantTanh - input_quant_supported = True - - def op_symbolic_execution(self, inp: Tensor): - return torch.tanh(inp) - - -class StdQOpONNXQuantSigmoidHandler(StdQOpONNXQuantNLALHandler): - handled_layer = QuantSigmoid - input_quant_supported = True - - def op_symbolic_execution(self, inp: Tensor): - return torch.sigmoid(inp) - - -class StdQOpONNXQuantIdentityHandler(StdQOpONNXQuantNLALHandler): - handled_layer = QuantIdentity - input_quant_supported = False - - def op_symbolic_execution(self, inp: Tensor): - return inp - - -class StdQOpONNXQuantHardTanhHandler(StdQOpONNXQuantIdentityHandler): - handled_layer = QuantHardTanh diff --git a/src/brevitas/export/onnx/standard/qoperator/handler/base.py b/src/brevitas/export/onnx/standard/qoperator/handler/base.py deleted file mode 100644 index e614d2ed5..000000000 --- a/src/brevitas/export/onnx/standard/qoperator/handler/base.py +++ /dev/null @@ -1,185 +0,0 @@ -from abc import ABC -from abc import abstractmethod - -import torch -from torch import Tensor - -from brevitas.export.common.handler.base import BitWidthHandlerMixin -from brevitas.export.common.handler.base import ClipMixin -from brevitas.export.common.handler.base import QuantAxisMixin -from brevitas.export.common.handler.base import ScaleHandlerMixin -from brevitas.export.common.handler.base import ZeroPointHandlerMixin -from brevitas.export.onnx.handler import ONNXBaseHandler -from brevitas.export.onnx.standard.function import DequantizeLinearFn -from brevitas.export.onnx.standard.function import IntClipFn -from brevitas.export.onnx.standard.function import QuantizeLinearFn - - -class StdQOpONNXQuantLayerHandler(ONNXBaseHandler, - QuantAxisMixin, - ScaleHandlerMixin, - ClipMixin, - BitWidthHandlerMixin, - ZeroPointHandlerMixin, - ABC): - - @abstractmethod - def op_symbolic_execution(self, inp: Tensor): - pass - - @abstractmethod - def input_symbolic_execution(self, inp: Tensor): - pass - - @abstractmethod - def output_symbolic_execution(self, out: Tensor): - pass - - @classmethod - def op_symbolic_kwargs(cls, module): - raise NotImplementedError # optional method - - @classmethod - def torch_8b_dtype(cls, is_signed): - if is_signed: - return torch.int8 - else: - return torch.uint8 - - @classmethod - def quant_output_shape(cls, module): - cached_out = module._cached_out - if cached_out is None: - raise RuntimeError("Caching of outputs is required to export") - return cached_out.shape - - @classmethod - def output_quant_symbolic_kwargs(cls, module): - if module.is_output_quant_enabled: - return { - 'output_scale': module.quant_output_scale(), - 'output_zero_point': cls.quant_output_zero_point(module), - 'output_dtype': cls.torch_8b_dtype(module.is_quant_output_signed), - 'output_axis': cls.quant_axis(module.quant_output_scale())} - else: - return None - - @classmethod - def output_clip_symbolic_kwargs(cls, module): - if module.is_output_quant_enabled: - narrow = module.is_quant_output_narrow_range - signed = module.is_quant_output_signed - bit_width = module.quant_output_bit_width() - return cls.int_clip_symbolic_kwargs(narrow, signed, bit_width) - else: - return None - - @classmethod - def input_clip_symbolic_kwargs(cls, module): - if module.is_input_quant_enabled: - narrow = module.is_quant_input_narrow_range - signed = module.is_quant_input_signed - bit_width = module.quant_input_bit_width() - return cls.int_clip_symbolic_kwargs(narrow, signed, bit_width) - else: - return None - - @classmethod - def output_dequant_symbolic_kwargs(cls, module): - return { - 'input_scale': module.quant_output_scale(), - 'input_zero_point': cls.quant_output_zero_point(module), - 'input_axis': cls.quant_axis(module.quant_output_scale())} - - @classmethod - def input_quant_symbolic_kwargs(cls, module): - if module.is_input_quant_enabled: - return { - 'output_scale': module.quant_input_scale(), - 'output_zero_point': cls.quant_input_zero_point(module), - 'output_dtype': cls.torch_8b_dtype(module.is_quant_input_signed), - 'output_axis': cls.quant_axis(module.quant_input_scale())} - else: - return None - - @classmethod - def input_dequant_symbolic_kwargs(cls, module): - if module._cached_inp is not None: - return cls.dequant_symbolic_kwargs_from_cached_io(module._cached_inp) - else: - return None - - @classmethod - def dequant_symbolic_kwargs_from_cached_io(cls, cached_io): - cls.validate_8b_bit_width(cached_io.bit_width, le_then=True) - return { - 'input_scale': - cached_io.scale, - 'input_zero_point': - cls.zero_point_with_dtype( - cached_io.signed, cached_io.bit_width, cached_io.zero_point), - 'input_axis': - cls.quant_axis(cached_io.scale)} - - @classmethod - def quant_symbolic_kwargs_from_cached_io(cls, cached_io): - cls.validate_8b_bit_width(cached_io.bit_width, le_then=True) - q_kwargs = { - 'output_scale': - cached_io.scale, - 'output_zero_point': - cls.zero_point_with_dtype( - cached_io.signed, cached_io.bit_width, cached_io.zero_point), - 'output_dtype': - cls.torch_8b_dtype(cached_io.signed), - 'output_axis': - cls.quant_axis(cached_io.scale)} - # TODO support narrow caching - # Assume narrow is False since we are preventing it everywhere else - int_clip_kwargs = cls.clip_symbolic_kwargs(False, cached_io.signed, cached_io.bit_width) - return q_kwargs, int_clip_kwargs - - def symbolic_execution(self, inp: Tensor): - inp = self.input_symbolic_execution(inp) - out = self.op_symbolic_execution(inp) - ret = self.output_symbolic_execution(out) - return ret - - -class StdQOpONNXQuantWrapperHandler(StdQOpONNXQuantLayerHandler, ABC): - - @classmethod - def validate(cls, module): - cls.validate_8b_bit_width(module.quant_input_bit_width(), le_then=True) - cls.validate_8b_bit_width(module.quant_output_bit_width(), le_then=True) - - def prepare_for_export(self, module): - self.validate(module) - op_symbolic_kwargs = self.op_symbolic_kwargs(module) - input_dequant_symbolic_kwargs = self.input_dequant_symbolic_kwargs(module) - if module.return_quant_tensor: - output_quant_symbolic_kwargs = self.output_quant_symbolic_kwargs(module) - output_clip_symbolic_kwargs = self.output_clip_symbolic_kwargs(module) - else: - output_quant_symbolic_kwargs = None - output_clip_symbolic_kwargs = None - - self.symbolic_kwargs = { - 'op_symbolic_kwargs': op_symbolic_kwargs, - 'input_dequant_symbolic_kwargs': input_dequant_symbolic_kwargs, - 'output_quant_symbolic_kwargs': output_quant_symbolic_kwargs, - 'output_clip_symbolic_kwargs': output_clip_symbolic_kwargs} - - def input_symbolic_execution(self, inp: Tensor): - input_dequant_symbolic_kwargs = self.symbolic_kwargs['input_dequant_symbolic_kwargs'] - inp = DequantizeLinearFn.apply(inp, *input_dequant_symbolic_kwargs.values()) - return inp - - def output_symbolic_execution(self, out: Tensor): - output_quant_symbolic_kwargs = self.symbolic_kwargs['output_quant_symbolic_kwargs'] - output_clip_symbolic_kwargs = self.symbolic_kwargs['output_clip_symbolic_kwargs'] - if output_quant_symbolic_kwargs is not None: - out = QuantizeLinearFn.apply(out, *output_quant_symbolic_kwargs.values()) - if output_clip_symbolic_kwargs is not None: - out = IntClipFn.apply(out, *output_clip_symbolic_kwargs.values()) - return out diff --git a/src/brevitas/export/onnx/standard/qoperator/handler/parameter.py b/src/brevitas/export/onnx/standard/qoperator/handler/parameter.py deleted file mode 100644 index 6540b6f2d..000000000 --- a/src/brevitas/export/onnx/standard/qoperator/handler/parameter.py +++ /dev/null @@ -1,176 +0,0 @@ -from abc import ABC -from abc import abstractmethod -from typing import Union - -import torch -from torch import Tensor - -from brevitas.export.common import to_0dim_if_scalar -from brevitas.export.onnx.handler import Kernel1dApplHandlerMixin -from brevitas.export.onnx.handler import Kernel2dApplHandlerMixin -from brevitas.export.onnx.handler import Kernel3dApplHandlerMixin -from brevitas.export.onnx.standard.function import DequantizeLinearFn -from brevitas.export.onnx.standard.function import IntClipFn -from brevitas.export.onnx.standard.function import QuantizeLinearFn -from brevitas.nn import QuantConv1d -from brevitas.nn import QuantConv2d -from brevitas.nn import QuantConv3d -from brevitas.nn import QuantLinear -from brevitas.nn.quant_layer import QuantWeightBiasInputOutputLayer as QuantWBIOL - -from ..function import QLinearConvFn -from ..function import QLinearMatMulFn -from .base import StdQOpONNXQuantLayerHandler - - -class StdQOpONNXQuantWBIOLHandler(StdQOpONNXQuantLayerHandler, ABC): - - @staticmethod - def int_weight(module: QuantWBIOL): - int_weight = module.int_weight(float_datatype=False).detach() - if module.is_quant_weight_signed: - return int_weight.type(torch.int8) - else: - return int_weight.type(torch.uint8) - - @staticmethod - def int_bias(module: QuantWBIOL): - if module.bias is not None: - int_bias = module.int_bias(float_datatype=False).detach() - return int_bias.type(torch.int32) - else: - return None - - @classmethod - def validate(cls, module: QuantWBIOL, requires_quant_bias=True): - assert module.is_weight_quant_enabled, 'Weight quant required' - assert module.is_output_quant_enabled, 'Output quant required' - # Handling narrow_range is across the network is difficult do to the fact that - # it's not part of QuantTensor, and so it can't be cached - assert not module.is_quant_output_narrow_range, 'Narrow output quant not supported' - if module.is_input_quant_enabled: - assert not module.is_quant_input_narrow_range, 'Narrow output quant not supported' - cls.validate_8b_bit_width(module.quant_weight_bit_width(), le_then=True) - cls.validate_8b_bit_width(module.quant_input_bit_width(), le_then=True) - cls.validate_8b_bit_width(module.quant_output_bit_width(), le_then=True) - if module.bias is not None and requires_quant_bias: - assert module.is_bias_quant_enabled - assert module.is_quant_bias_signed - cls.validate_32b_bit_width(module.quant_bias_bit_width(), le_then=True) - - def prepare_for_export(self, module: Union[QuantConv1d, QuantConv2d, QuantConv3d]): - self.validate(module) - - op_symbolic_kwargs = self.op_symbolic_kwargs(module) - if not module.return_quant_tensor: - output_dequant_symbolic_kwargs = self.output_dequant_symbolic_kwargs(module) - else: - output_dequant_symbolic_kwargs = None - input_quant_symbolic_kwargs = self.input_quant_symbolic_kwargs(module) - if input_quant_symbolic_kwargs is not None: - input_dequant_symbolic_kwargs = self.input_dequant_symbolic_kwargs(module) - input_clip_symbolic_kwargs = self.input_clip_symbolic_kwargs(module) - else: - input_dequant_symbolic_kwargs = None - input_clip_symbolic_kwargs = None - output_clip_symbolic_kwargs = self.output_clip_symbolic_kwargs(module) - - self.symbolic_kwargs = { - 'op_symbolic_kwargs': op_symbolic_kwargs, - 'input_quant_symbolic_kwargs': input_quant_symbolic_kwargs, - 'input_clip_symbolic_kwargs': input_clip_symbolic_kwargs, - 'input_dequant_symbolic_kwargs': input_dequant_symbolic_kwargs, - 'output_dequant_symbolic_kwargs': output_dequant_symbolic_kwargs, - 'output_clip_symbolic_kwargs': output_clip_symbolic_kwargs} - - def input_symbolic_execution(self, inp: Tensor): - input_dequant_symbolic_kwargs = self.symbolic_kwargs['input_dequant_symbolic_kwargs'] - input_quant_symbolic_kwargs = self.symbolic_kwargs['input_quant_symbolic_kwargs'] - input_clip_symbolic_kwargs = self.symbolic_kwargs['input_clip_symbolic_kwargs'] - if input_dequant_symbolic_kwargs is not None: - assert input_quant_symbolic_kwargs is not None - inp = DequantizeLinearFn.apply(inp, *input_dequant_symbolic_kwargs.values()) - if input_quant_symbolic_kwargs is not None: - inp = QuantizeLinearFn.apply(inp, *input_quant_symbolic_kwargs.values()) - if input_clip_symbolic_kwargs is not None: - inp = IntClipFn.apply(inp, *input_clip_symbolic_kwargs.values()) - return inp - - def output_symbolic_execution(self, out: Tensor): - output_clip_symbolic_kwargs = self.symbolic_kwargs['output_clip_symbolic_kwargs'] - if output_clip_symbolic_kwargs is not None: - out = IntClipFn.apply(out, *output_clip_symbolic_kwargs.values()) - output_dequant_symbolic_kwargs = self.symbolic_kwargs['output_dequant_symbolic_kwargs'] - if output_dequant_symbolic_kwargs is not None: - out = DequantizeLinearFn.apply(out, *output_dequant_symbolic_kwargs.values()) - return out - - -class StdQOpONNXQuantConvNdHandler(StdQOpONNXQuantWBIOLHandler, ABC): - - def op_symbolic_kwargs(self, module: Union[QuantConv1d, QuantConv2d, QuantConv3d]): - conv_symbolic_kwargs = { - 'input_scale': module.quant_input_scale(), - 'input_zero_point': self.quant_input_zero_point(module), - 'int_weight': self.int_weight(module), - 'weight_scale': to_0dim_if_scalar(module.quant_weight_scale().flatten()), - 'weight_zero_point': to_0dim_if_scalar(self.quant_weight_zero_point(module).flatten()), - 'output_scale': module.quant_output_scale(), - 'output_zero_point': self.quant_output_zero_point(module), - 'output_dtype': self.torch_8b_dtype(module.is_quant_output_signed), - 'int_bias': self.int_bias(module), - 'out_shape': self.quant_output_shape(module), - 'kernel_size': list(module.kernel_size), - 'padding': self.padding(module), - 'stride': self.stride(module), - 'groups': module.groups, - 'dilation': self.dilation(module)} - return conv_symbolic_kwargs - - def op_symbolic_execution(self, inp: Tensor): - conv_symbolic_kwargs = self.symbolic_kwargs['op_symbolic_kwargs'] - out = QLinearConvFn.apply(inp, *conv_symbolic_kwargs.values()) - return out - - -class StdQOpONNXQuantConv3dHandler(StdQOpONNXQuantConvNdHandler, Kernel3dApplHandlerMixin): - handled_layer = QuantConv3d - - -class StdQOpONNXQuantConv2dHandler(StdQOpONNXQuantConvNdHandler, Kernel2dApplHandlerMixin): - handled_layer = QuantConv2d - - -class StdQOpONNXQuantConv1dHandler(StdQOpONNXQuantConvNdHandler, Kernel1dApplHandlerMixin): - handled_layer = QuantConv1d - - -class StdQOpONNXQuantLinearHandler(StdQOpONNXQuantWBIOLHandler): - handled_layer = QuantLinear - - # Convert linear to conv1d to handle bias - def op_symbolic_kwargs(self, module: QuantLinear): - conv_symbolic_kwargs = { - 'input_scale': module.quant_input_scale(), - 'input_zero_point': self.quant_input_zero_point(module), - 'int_weight': self.int_weight(module).view(module.out_features, module.in_features, 1), - 'weight_scale': to_0dim_if_scalar(module.quant_weight_scale().flatten()), - 'weight_zero_point': to_0dim_if_scalar(self.quant_weight_zero_point(module).flatten()), - 'output_scale': module.quant_output_scale(), - 'output_zero_point': self.quant_output_zero_point(module), - 'output_dtype': self.torch_8b_dtype(module.is_quant_output_signed), - 'int_bias': self.int_bias(module), - 'out_shape': self.quant_output_shape(module) + (1,), - 'kernel_size': [1], - 'padding': [0, 0], - 'stride': [1], - 'groups': 1, - 'dilation': [1]} - return conv_symbolic_kwargs - - def op_symbolic_execution(self, inp): - linear_symbolic_kwargs = self.symbolic_kwargs['op_symbolic_kwargs'] - inp = inp.view(inp.size(0), -1, 1) - out = QLinearConvFn.apply(inp, *linear_symbolic_kwargs.values()) - out = out.view(out.size(0), -1) - return out diff --git a/src/brevitas/export/onnx/standard/qoperator/manager.py b/src/brevitas/export/onnx/standard/qoperator/manager.py deleted file mode 100644 index be894e65b..000000000 --- a/src/brevitas/export/onnx/standard/qoperator/manager.py +++ /dev/null @@ -1,82 +0,0 @@ -# Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. -# SPDX-License-Identifier: BSD-3-Clause - -from torch.nn import functional as F -from torch.nn import Module - -from brevitas.export.manager import _set_layer_export_handler -from brevitas.export.manager import _set_layer_export_mode - -from ..function import DequantizeLinearFn -from ..function import IntClipFn -from ..function import QuantizeLinearFn -from ..manager import StdONNXBaseManager -from .handler.act import StdQOpONNXQuantHardTanhHandler -from .handler.act import StdQOpONNXQuantIdentityHandler -from .handler.act import StdQOpONNXQuantReLUHandler -from .handler.act import StdQOpONNXQuantSigmoidHandler -from .handler.act import StdQOpONNXQuantTanhHandler -from .handler.base import StdQOpONNXQuantLayerHandler -from .handler.parameter import StdQOpONNXQuantConv1dHandler -from .handler.parameter import StdQOpONNXQuantConv2dHandler -from .handler.parameter import StdQOpONNXQuantConv3dHandler -from .handler.parameter import StdQOpONNXQuantLinearHandler - - -class StdQOpONNXManager(StdONNXBaseManager): - target_name = 'StdQOpONNX' - - _fn_to_cache = [ - F.relu, - F.relu6, - F.hardtanh, - F.max_pool1d, - F.max_pool2d, - F.max_pool3d, - F.adaptive_max_pool1d, - F.adaptive_max_pool2d, - F.adaptive_max_pool3d] - - handlers = [ - StdQOpONNXQuantConv1dHandler, - StdQOpONNXQuantConv2dHandler, - StdQOpONNXQuantConv3dHandler, - StdQOpONNXQuantLinearHandler, - StdQOpONNXQuantReLUHandler, - StdQOpONNXQuantHardTanhHandler, - StdQOpONNXQuantIdentityHandler, - StdQOpONNXQuantTanhHandler, - StdQOpONNXQuantSigmoidHandler] - - onnx_passes = [ - # remove unused graph inputs & initializers - "eliminate_unused_initializer"] - - @classmethod - def _trace_fn_dispatcher(cls, fn, input, *args, **kwargs): - cached_io = cls._fn_cache.pop(0) - if cached_io is not None: - cached_inp, cached_out = cached_io - if cached_inp is not None: - deq_kwargs = StdQOpONNXQuantLayerHandler.dequant_symbolic_kwargs_from_cached_io( - cached_inp) - input = DequantizeLinearFn.apply(input, *deq_kwargs.values()) - output = fn(input, *args, **kwargs) - if cached_out is not None: - out_kwargs = StdQOpONNXQuantLayerHandler.quant_symbolic_kwargs_from_cached_io( - cached_out) - q_kwargs, int_clip_kwargs = out_kwargs - output = QuantizeLinearFn.apply(output, *q_kwargs.values()) - if int_clip_kwargs is not None: - output = IntClipFn.apply(output, *int_clip_kwargs.values()) - else: - output = fn(input, *args, **kwargs) - return output - - @classmethod - def set_export_mode(cls, module: Module, enabled: bool): - _set_layer_export_mode(module, enabled) - - @classmethod - def set_export_handler(cls, module: Module): - _set_layer_export_handler(cls, module) diff --git a/src/brevitas/export/torch/qoperator/__init__.py b/src/brevitas/export/torch/qoperator/__init__.py deleted file mode 100644 index b10a7efee..000000000 --- a/src/brevitas/export/torch/qoperator/__init__.py +++ /dev/null @@ -1,2 +0,0 @@ -# Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. -# SPDX-License-Identifier: BSD-3-Clause diff --git a/src/brevitas/export/torch/qoperator/handler/__init__.py b/src/brevitas/export/torch/qoperator/handler/__init__.py deleted file mode 100644 index b135a9c43..000000000 --- a/src/brevitas/export/torch/qoperator/handler/__init__.py +++ /dev/null @@ -1,9 +0,0 @@ -# Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. -# SPDX-License-Identifier: BSD-3-Clause - -try: - import torch.nn.quantized.functional as qF - -# Skip for pytorch 1.1.0 -except: - qF = None diff --git a/src/brevitas/export/torch/qoperator/handler/act.py b/src/brevitas/export/torch/qoperator/handler/act.py deleted file mode 100644 index 4a0d9cf2e..000000000 --- a/src/brevitas/export/torch/qoperator/handler/act.py +++ /dev/null @@ -1,66 +0,0 @@ -# Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. -# SPDX-License-Identifier: BSD-3-Clause - -from abc import ABC - -import torch -from torch import Tensor - -from brevitas.nn import QuantHardTanh -from brevitas.nn import QuantIdentity -from brevitas.nn import QuantReLU -from brevitas.nn.quant_layer import QuantNonLinearActLayer as QuantNLAL - -from . import qF -from .base import PytorchQuantLayerHandler - - -class PytorchQuantNLALHandler(PytorchQuantLayerHandler, ABC): - - @classmethod - def explicit_output_dtype(cls) -> bool: - return True - - @classmethod - def validate(cls, module: QuantNLAL): - assert not module.is_input_quant_enabled, 'Input quantization not supported' - cls.validate_8b_bit_width(module.quant_act_bit_width(), le_then=False) - - def prepare_for_export(self, module: QuantNLAL): - self.validate(module) - self.qf_impl, self.qf_kwargs = self.prepare_qf(module) - if module.is_act_quant_enabled: - self.output_quant_impl, self.output_quant_kwargs = self.prepare_output_quant(module) - self.return_quant_tensor = module.return_quant_tensor - - def forward(self, inp: Tensor, **kwargs): - q_out = inp - if self.qf_impl is not None: - q_out = self.qf_impl(q_out, **self.qf_kwargs) - if self.output_quant_impl is not None: - if q_out.is_quantized: - q_out = q_out.dequantize() - q_out = self.output_quant_impl(q_out, **self.output_quant_kwargs) - if not self.return_quant_tensor: - q_out = q_out.dequantize() - return q_out - - -class PytorchQuantReLUHandler(PytorchQuantNLALHandler): - handled_layer = QuantReLU - - @classmethod - def prepare_qf(cls, module): - return torch.nn.functional.relu, {} - - -class PytorchQuantIdentityHandler(PytorchQuantNLALHandler): - handled_layer = QuantIdentity - - @classmethod - def prepare_qf(cls, module): - return None, None - - -class PytorchQuantHardTanhHandler(PytorchQuantIdentityHandler): - handled_layer = QuantHardTanh diff --git a/src/brevitas/export/torch/qoperator/handler/base.py b/src/brevitas/export/torch/qoperator/handler/base.py deleted file mode 100644 index c5b569ff6..000000000 --- a/src/brevitas/export/torch/qoperator/handler/base.py +++ /dev/null @@ -1,71 +0,0 @@ -# Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. -# SPDX-License-Identifier: BSD-3-Clause - -from abc import ABC -from abc import abstractmethod - -import torch -from torch import Tensor - -from brevitas.export.common.handler.base import BaseHandler -from brevitas.export.common.handler.base import BitWidthHandlerMixin -from brevitas.export.common.handler.base import ZeroPointHandlerMixin - -SCALAR_SHAPE = () - - -def _is_scalar(x: Tensor): - return x.shape == SCALAR_SHAPE - - -class PytorchQuantLayerHandler(BaseHandler, BitWidthHandlerMixin, ZeroPointHandlerMixin, ABC): - - @classmethod - @abstractmethod - def explicit_output_dtype(cls) -> bool: - pass - - @classmethod - @abstractmethod - def prepare_qf(cls, module): - pass - - @classmethod - @abstractmethod - def validate(cls, module): - pass - - @classmethod - def gen_quant_impl_kwargs( - cls, scale: Tensor, zero_point: Tensor, signed: bool, include_dtype=True): - if _is_scalar(scale): - assert _is_scalar(zero_point), 'Scalar zero point required' - scale, zero_point = scale.item(), zero_point.item() - quant_impl = torch.quantize_per_tensor - else: - if _is_scalar(zero_point): - zero_point = zero_point.expand_as(scale) - quant_impl = torch.quantize_per_channel - quant_kwargs = {'scale': scale, 'zero_point': zero_point} - if include_dtype and signed: - quant_kwargs['dtype'] = torch.qint8 - elif include_dtype and not signed: - quant_kwargs['dtype'] = torch.quint8 - return quant_impl, quant_kwargs - - @classmethod - def prepare_input_quant(cls, module): - scale = module.quant_input_scale() - zero_point = cls.quant_input_zero_point(module) - signed = module.is_quant_input_signed - quant_impl, quant_kwargs = cls.gen_quant_impl_kwargs(scale, zero_point, signed) - return quant_impl, quant_kwargs - - @classmethod - def prepare_output_quant(cls, module): - scale = module.quant_output_scale() - zero_point = cls.quant_output_zero_point(module) - signed = module.is_quant_output_signed - incl_dtype = cls.explicit_output_dtype() - quant_impl, quant_kwargs = cls.gen_quant_impl_kwargs(scale, zero_point, signed, incl_dtype) - return quant_impl, quant_kwargs diff --git a/src/brevitas/export/torch/qoperator/handler/parameter.py b/src/brevitas/export/torch/qoperator/handler/parameter.py deleted file mode 100644 index 802a5a053..000000000 --- a/src/brevitas/export/torch/qoperator/handler/parameter.py +++ /dev/null @@ -1,140 +0,0 @@ -# Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. -# SPDX-License-Identifier: BSD-3-Clause - -from abc import ABC -from typing import Union -import warnings - -import torch -from torch import Tensor - -from brevitas.nn import QuantConv1d -from brevitas.nn import QuantConv2d -from brevitas.nn import QuantConv3d -from brevitas.nn import QuantLinear -from brevitas.nn.quant_layer import QuantWeightBiasInputOutputLayer as QuantWBIOL - -from .base import PytorchQuantLayerHandler - - -class PytorchQuantWBIOLHandler(PytorchQuantLayerHandler): - - def __init__(self): - super().__init__() - self.qf_impl = None - self.input_quant_impl = None - self.weight_quant_impl = None - self.weight_quant_args = None - self.input_quant_kwargs = None - self.weight_quant_kwargs = None - self.output_quant_kwargs = None - self.qf_kwargs = None - self.return_quant_tensor = None - - @classmethod - def validate(cls, module: QuantWBIOL): - assert module.is_weight_quant_enabled, 'Weight quantization required' - assert module.is_output_quant_enabled, 'Output quantization required' - - @classmethod - def prepare_bias(cls, module: QuantWBIOL): - if module.bias is not None and not module.is_bias_quant_enabled: - bias = module.bias.detach() - elif module.bias is not None and module.is_bias_quant_enabled: - # export the dequantized value - bias = module.quant_bias().value - else: - bias = module.bias - return bias - - @classmethod - def prepare_weight_quant(cls, module: QuantWBIOL): - cls.validate_bit_width(module.quant_weight_bit_width(), 7, le_then=True) - cls.validate_8b_bit_width(module.quant_input_bit_width(), le_then=False) - cls.validate_8b_bit_width(module.quant_output_bit_width(), le_then=False) - scale = module.quant_weight_scale() - zero_point = cls.quant_weight_zero_point(module) - signed = module.is_quant_weight_signed - weight = module.weight.detach() - quant_impl, quant_kwargs = cls.gen_quant_impl_kwargs(scale, zero_point, signed) - return quant_impl, (weight,), quant_kwargs - - def prepare_for_export(self, module: QuantWBIOL): - self.validate(module) - if module.is_input_quant_enabled: - self.input_quant_impl, self.input_quant_kwargs = self.prepare_input_quant(module) - weight_quant_pack = self.prepare_weight_quant(module) - self.weight_quant_impl, self.weight_quant_args, self.weight_quant_kwargs = weight_quant_pack - self.qf_impl, self.qf_kwargs = self.prepare_qf(module) - _, self.output_quant_kwargs = self.prepare_output_quant(module) - self.return_quant_tensor = module.return_quant_tensor - - def q_weight(self): - q_weight = self.weight_quant_impl(*self.weight_quant_args, **self.weight_quant_kwargs) - return q_weight - - def forward(self, q_inp: Tensor): - if self.input_quant_impl is not None: - # If the input is quantized from a previous layer, - # we have to dequant and requant - if q_inp.is_quantized: - q_inp.dequantize() - q_inp = self.input_quant_impl(q_inp, **self.input_quant_kwargs) - assert q_inp.is_quantized, 'Input needs to be quantized' - q_out = self.qf_impl(q_inp, self.q_weight(), **self.qf_kwargs, **self.output_quant_kwargs) - if not self.return_quant_tensor: - q_out = q_out.dequantize() - return q_out - - -class PytorchQuantConvNdHandler(PytorchQuantWBIOLHandler, ABC): - - @classmethod - def explicit_output_dtype(cls): - return True - - @classmethod - def prepare_qf_kwargs(cls, module: Union[QuantConv1d, QuantConv2d, QuantConv3d]): - return { - 'bias': cls.prepare_bias(module), - 'stride': module.stride, - 'padding': module.padding, - 'dilation': module.dilation, - 'groups': module.groups, - 'padding_mode': module.padding_mode} - - -class PytorchQuantConv1dHandler(PytorchQuantConvNdHandler): - handled_layer = QuantConv1d - - @classmethod - def prepare_qf(cls, module: QuantConv1d): - return torch.nn.quantized.functional.conv1d, cls.prepare_qf_kwargs(module) - - -class PytorchQuantConv2dHandler(PytorchQuantConvNdHandler): - handled_layer = QuantConv2d - - @classmethod - def prepare_qf(cls, module: QuantConv2d): - return torch.nn.quantized.functional.conv2d, cls.prepare_qf_kwargs(module) - - -class PytorchQuantConv3dHandler(PytorchQuantConvNdHandler): - handled_layer = QuantConv3d - - @classmethod - def prepare_qf(cls, module: QuantConv3d): - return torch.nn.quantized.functional.conv3d, cls.prepare_qf_kwargs(module) - - -class PytorchQuantLinearHandler(PytorchQuantWBIOLHandler): - handled_layer = QuantLinear - - @classmethod - def explicit_output_dtype(cls): - return False - - @classmethod - def prepare_qf(cls, module: QuantLinear): - return torch.nn.quantized.functional.linear, {'bias': cls.prepare_bias(module)} diff --git a/src/brevitas/export/torch/qoperator/manager.py b/src/brevitas/export/torch/qoperator/manager.py deleted file mode 100644 index 5b534843a..000000000 --- a/src/brevitas/export/torch/qoperator/manager.py +++ /dev/null @@ -1,65 +0,0 @@ -# Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. -# SPDX-License-Identifier: BSD-3-Clause - -from typing import Optional, Tuple, Union - -import torch -from torch import Tensor -from torch.nn import Module - -from brevitas.export.manager import _set_layer_export_handler -from brevitas.export.manager import _set_layer_export_mode -from brevitas.export.manager import BaseManager -from brevitas.export.manager import ExportContext -from brevitas.quant_tensor import QuantTensor - -from .handler import qF -from .handler.act import PytorchQuantHardTanhHandler -from .handler.act import PytorchQuantIdentityHandler -from .handler.act import PytorchQuantReLUHandler -from .handler.parameter import PytorchQuantConv1dHandler -from .handler.parameter import PytorchQuantConv2dHandler -from .handler.parameter import PytorchQuantConv3dHandler -from .handler.parameter import PytorchQuantLinearHandler - - -class TorchQOpManager(BaseManager): - target_name = 'torch' - - handlers = [ - PytorchQuantHardTanhHandler, - PytorchQuantIdentityHandler, - PytorchQuantReLUHandler, - PytorchQuantConv1dHandler, - PytorchQuantConv2dHandler, - PytorchQuantConv3dHandler, - PytorchQuantLinearHandler] - - @classmethod - def set_export_mode(cls, module: Module, enabled: bool): - _set_layer_export_mode(module, enabled) - - @classmethod - def set_export_handler(cls, module: Module): - _set_layer_export_handler(cls, module) - - @classmethod - def export( - cls, - module: Module, - input_shape: Optional[Tuple[int, ...]] = None, - export_path: Optional[str] = None, - input_t: Optional[Union[Tensor, QuantTensor]] = None): - if qF is None: - raise RuntimeError("torch.nn.quantized.functional cannot be imported.") - if input_shape is None and input_t is None: - raise RuntimeError("Export requires to pass in either input_shape or input_t") - if input_shape is not None and input_t is not None: - raise RuntimeError("Export accepts either an input shape or an input tensor, not both") - if input_t is None and input_shape is not None: - input_t = torch.empty(*input_shape) - with ExportContext(cls): - traced_module = cls.jit_inference_trace(module, input_t) - if export_path is not None: - traced_module.save(export_path) - return traced_module diff --git a/src/brevitas/onnx/__init__.py b/src/brevitas/onnx/__init__.py index 1db5cfebe..a31712c0b 100644 --- a/src/brevitas/onnx/__init__.py +++ b/src/brevitas/onnx/__init__.py @@ -6,5 +6,4 @@ from brevitas.export import enable_debug # noqa from brevitas.export import export_brevitas_onnx # noqa from brevitas.export import export_onnx_qcdq # noqa -from brevitas.export import export_onnx_qop # noqa from brevitas.export import export_qonnx # noqa diff --git a/tests/brevitas/export/test_torch_qop.py b/tests/brevitas/export/test_torch_qop.py deleted file mode 100644 index e01bd93cb..000000000 --- a/tests/brevitas/export/test_torch_qop.py +++ /dev/null @@ -1,204 +0,0 @@ -# Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. -# SPDX-License-Identifier: BSD-3-Clause - -import torch - -from brevitas.export import export_torch_qop -from brevitas.nn import QuantConv2d -from brevitas.nn import QuantIdentity -from brevitas.nn import QuantLinear -from brevitas.nn import QuantReLU -from brevitas.quant.scaled_int import Int8WeightPerTensorFloat -from brevitas.quant.scaled_int import Int16Bias -from brevitas.quant.scaled_int import Uint8ActPerTensorFloat -from brevitas.quant.shifted_scaled_int import ShiftedUint8ActPerTensorFloat -from tests.marker import jit_disabled_for_export -from tests.marker import requires_pt_ge - -OUT_CH = 50 -IN_CH = 40 -TOLERANCE = 1.1 -RANDN_MEAN = 1 -RANDN_STD = 3 - - -@requires_pt_ge('9999', 'Darwin') -@jit_disabled_for_export() -def test_pytorch_quant_conv_export(): - IN_SIZE = (2, IN_CH, IN_CH, IN_CH) - KERNEL_SIZE = (3, 3) - - class Model(torch.nn.Module): - - def __init__(self): - super().__init__() - self.conv = QuantConv2d( - out_channels=OUT_CH, - in_channels=IN_CH, - kernel_size=KERNEL_SIZE, - bias=False, - weight_bit_width=7, - input_bit_width=8, - output_bit_width=8, - weight_quant=Int8WeightPerTensorFloat, - input_quant=ShiftedUint8ActPerTensorFloat, - output_quant=ShiftedUint8ActPerTensorFloat, - return_quant_tensor=False) - self.conv.weight.data.uniform_(-0.01, 0.01) - - def forward(self, x): - return self.conv(x) - - inp = torch.randn(IN_SIZE) - model = Model() - model(inp) # collect scale factors - model.eval() - inp = torch.randn(IN_SIZE) * RANDN_STD + RANDN_MEAN # New input with bigger range - brevitas_out = model(inp) - pytorch_qf_model = export_torch_qop(model, input_t=inp) - pytorch_out = pytorch_qf_model(inp) - atol = model.conv.quant_output_scale().item() * TOLERANCE - assert pytorch_out.isclose(brevitas_out, rtol=0.0, atol=atol).all() - - -@requires_pt_ge('9999', 'Darwin') -@jit_disabled_for_export() -def test_pytorch_quant_linear_export(): - IN_SIZE = (IN_CH, IN_CH) - - class Model(torch.nn.Module): - - def __init__(self): - super().__init__() - self.linear = QuantLinear( - in_features=IN_CH, - out_features=OUT_CH, - bias=False, - weight_quant=Int8WeightPerTensorFloat, - weight_bit_width=7, - input_bit_width=8, - output_bit_width=8, - input_quant=ShiftedUint8ActPerTensorFloat, - output_quant=ShiftedUint8ActPerTensorFloat, - return_quant_tensor=False) - self.linear.weight.data.uniform_(-0.01, 0.01) - - def forward(self, x): - return self.linear(x) - - inp = torch.randn(IN_SIZE) - model = Model() - model(inp) # collect scale factors - model.eval() - inp = torch.randn(IN_SIZE) * RANDN_STD + RANDN_MEAN # New input with bigger range - brevitas_out = model(inp) - pytorch_qf_model = export_torch_qop(model, input_t=inp) - pytorch_out = pytorch_qf_model(inp) - atol = model.linear.quant_output_scale().item() * TOLERANCE - assert pytorch_out.isclose(brevitas_out, rtol=0.0, atol=atol).all() - - -@requires_pt_ge('9999', 'Darwin') -@jit_disabled_for_export() -def test_pytorch_quant_linear_bias_quant_export(): - IN_SIZE = (IN_CH, IN_CH) - - class Model(torch.nn.Module): - - def __init__(self): - super().__init__() - self.linear = QuantLinear( - in_features=IN_CH, - out_features=OUT_CH, - bias=True, - weight_quant=Int8WeightPerTensorFloat, - weight_bit_width=7, - input_bit_width=8, - output_bit_width=8, - input_quant=ShiftedUint8ActPerTensorFloat, - output_quant=ShiftedUint8ActPerTensorFloat, - bias_quant=Int16Bias, - return_quant_tensor=False) - self.linear.weight.data.uniform_(-0.01, 0.01) - - def forward(self, x): - return self.linear(x) - - inp = torch.randn(IN_SIZE) - model = Model() - model(inp) # collect scale factors - model.eval() - inp = torch.randn(IN_SIZE) * RANDN_STD + RANDN_MEAN # New input with bigger range - brevitas_out = model(inp) - pytorch_qf_model = export_torch_qop(model, input_t=inp) - pytorch_out = pytorch_qf_model(inp) - atol = model.linear.quant_output_scale().item() * TOLERANCE - assert pytorch_out.isclose(brevitas_out, rtol=0.0, atol=atol).all() - - -@requires_pt_ge('9999', 'Darwin') -@jit_disabled_for_export() -def test_pytorch_quant_conv_bias_quant_export(): - IN_SIZE = (2, IN_CH, IN_CH, IN_CH) - KERNEL_SIZE = (3, 3) - - class Model(torch.nn.Module): - - def __init__(self): - super().__init__() - self.conv = QuantConv2d( - out_channels=OUT_CH, - in_channels=IN_CH, - kernel_size=KERNEL_SIZE, - bias=False, - weight_bit_width=7, - input_bit_width=8, - output_bit_width=8, - weight_quant=Int8WeightPerTensorFloat, - bias_quant=Int16Bias, - input_quant=ShiftedUint8ActPerTensorFloat, - output_quant=ShiftedUint8ActPerTensorFloat, - return_quant_tensor=False) - self.conv.weight.data.uniform_(-0.01, 0.01) - - def forward(self, x): - return self.conv(x) - - inp = torch.randn(IN_SIZE) - model = Model() - model(inp) # collect scale factors - model.eval() - inp = torch.randn(IN_SIZE) * RANDN_STD + RANDN_MEAN # New input with bigger range - brevitas_out = model(inp) - pytorch_qf_model = export_torch_qop(model, input_t=inp) - pytorch_out = pytorch_qf_model(inp) - atol = model.conv.quant_output_scale().item() * TOLERANCE - assert pytorch_out.isclose(brevitas_out, rtol=0.0, atol=atol).all() - - -@requires_pt_ge('9999', 'Darwin') -@jit_disabled_for_export() -def test_quant_act_export(): - IN_SIZE = (IN_CH, IN_CH) - - class Model(torch.nn.Module): - - def __init__(self): - super().__init__() - self.act1 = QuantIdentity( - bit_width=8, act_quant=ShiftedUint8ActPerTensorFloat, return_quant_tensor=True) - self.act2 = QuantReLU(act_quant=Uint8ActPerTensorFloat) - - def forward(self, x): - return self.act2(self.act1(x)) - - inp = torch.randn(IN_SIZE) - model = Model() - model(inp) # collect scale factors - model.eval() - inp = torch.randn(IN_SIZE) * RANDN_STD + RANDN_MEAN # New input with bigger range - brevitas_out = model(inp) - pytorch_qf_model = export_torch_qop(model, input_t=inp) - pytorch_out = pytorch_qf_model(inp) - atol = model.act2.quant_output_scale().item() * TOLERANCE - assert pytorch_out.isclose(brevitas_out, rtol=0.0, atol=atol).all() diff --git a/tests/brevitas_ort/common.py b/tests/brevitas_ort/common.py index a6e00a5ca..4d1e16679 100644 --- a/tests/brevitas_ort/common.py +++ b/tests/brevitas_ort/common.py @@ -10,7 +10,6 @@ import torch from brevitas.export import export_onnx_qcdq -from brevitas.export import export_onnx_qop from brevitas.export import export_qonnx from brevitas.nn import QuantConv1d from brevitas.nn import QuantConv2d @@ -143,10 +142,7 @@ def is_brevitas_ort_close( odict = oxe.execute_onnx(exported_model, idict, True) ort_output = odict[exported_model.graph.output[0].name] else: - if export_type == 'qop': - export_onnx_qop(model, input_t, export_path=export_name) - computed_out = brevitas_output.int(float_datatype=False) - elif export_type == 'qcdq': + if export_type == 'qcdq': export_onnx_qcdq(model, input_t, export_path=export_name) elif export_type == 'qcdq_opset14': export_onnx_qcdq(model, input_t, opset_version=14, export_path=export_name) diff --git a/tests/brevitas_ort/test_quant_module.py b/tests/brevitas_ort/test_quant_module.py index fce766d2d..2b7b6b1cf 100644 --- a/tests/brevitas_ort/test_quant_module.py +++ b/tests/brevitas_ort/test_quant_module.py @@ -19,7 +19,7 @@ @parametrize_with_cases('model', cases=QuantWBIOLCases) -@pytest.mark.parametrize('export_type', ['qcdq', 'qonnx', 'qop']) +@pytest.mark.parametrize('export_type', ['qcdq', 'qonnx']) @requires_pt_ge('1.8.1') def test_ort_wbiol(model, export_type, current_cases): cases_generator_func = current_cases['model'][1] @@ -30,9 +30,6 @@ def test_ort_wbiol(model, export_type, current_cases): o_bit_width = case_id.split('-')[-5] i_bit_width = case_id.split('-')[-3] - if impl in ('QuantConvTranspose1d', 'QuantConvTranspose2d', - 'QuantConvTranspose3d') and export_type == 'qop': - pytest.skip('Export of ConvTranspose is not supported for QOperation') if 'per_channel' in quantizer and 'asymmetric' in quantizer: pytest.skip('Per-channel zero-point is not well supported in ORT.') if 'QuantLinear' in impl and 'asymmetric' in quantizer: