diff --git a/notebooks/minifloat_mx_tutorial.ipynb b/notebooks/minifloat_mx_tutorial.ipynb new file mode 100644 index 000000000..0fc038a66 --- /dev/null +++ b/notebooks/minifloat_mx_tutorial.ipynb @@ -0,0 +1,131 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Minifloat and MX Example" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Work in progress examples to show how to use minifloat and MX with Brevitas" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/home/giuseppe/miniconda3/envs/brevitas_dev/lib/python3.11/site-packages/torch/nn/modules/conv.py:456: UserWarning: Defining your `__torch_function__` as a plain method is deprecated and will be an error in future, please define it as a classmethod. (Triggered internally at /opt/conda/conda-bld/pytorch_1712608853099/work/torch/csrc/utils/python_arg_parser.cpp:294.)\n", + " return F.conv2d(input, weight, bias, self.stride,\n" + ] + } + ], + "source": [ + "from brevitas.quant.experimental.float_base import Fp8e4m3Mixin\n", + "from brevitas.quant.experimental.mx_quant_ocp import MXFloatWeight\n", + "from brevitas.quant.experimental.float_quant_ocp import FpOCPWeightPerTensorFloat, FpOCPActPerTensorFloat\n", + "from brevitas.quant.experimental.mx_quant_ocp import MXFloatAct\n", + "import brevitas.nn as qnn\n", + "import torch.nn as nn\n", + "import torch\n", + "from brevitas.quant_tensor import FloatQuantTensor\n", + "\n", + "class OCPFP8Weight(FpOCPWeightPerTensorFloat, Fp8e4m3Mixin):\n", + " pass\n", + "\n", + "\n", + "class OCPFP8Act(FpOCPActPerTensorFloat, Fp8e4m3Mixin):\n", + " pass\n", + "\n", + "\n", + "class FP8Model(nn.Module):\n", + " def __init__(self):\n", + " super().__init__()\n", + " self.conv = qnn.QuantConv2d(32, 64, 3, weight_quant=OCPFP8Weight, input_quant=OCPFP8Act)\n", + " \n", + " def forward(self, x):\n", + " return self.conv(x)\n", + "\n", + "ocp_fp8_model = FP8Model()\n", + "x = torch.randn(1, 32, 8, 8)\n", + "ocp_fp8_model.eval()\n", + "o = ocp_fp8_model(x)\n", + "\n", + "intermediate_input = ocp_fp8_model.conv.input_quant(x)\n", + "assert isinstance(intermediate_input, FloatQuantTensor)" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [], + "source": [ + "from brevitas.quant_tensor import GroupwiseFloatQuantTensor\n", + "\n", + "\n", + "class MXFloat8Weight(MXFloatWeight, Fp8e4m3Mixin):\n", + " # The group dimension for the weights it is automatically identified based on the layer type\n", + " # If a new layer type is used, it can be manually specified\n", + " pass\n", + "\n", + "class MXFloat8Act(MXFloatAct, Fp8e4m3Mixin):\n", + " # It is necessary to specify the group dimension for the activation quantization\n", + " group_dim = 1\n", + "\n", + "\n", + "class MXModel(nn.Module):\n", + " def __init__(self):\n", + " super().__init__()\n", + " self.conv = qnn.QuantConv2d(32, 64, 3, weight_quant=MXFloat8Weight, input_quant=MXFloat8Act)\n", + " \n", + " def forward(self, x):\n", + " return self.conv(x)\n", + "\n", + "mx_model = MXModel()\n", + "x = torch.randn(1, 32, 8, 8)\n", + "mx_model.eval()\n", + "o = mx_model(x)\n", + "\n", + "intermediate_input = mx_model.conv.input_quant(x)\n", + "assert isinstance(intermediate_input, GroupwiseFloatQuantTensor)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "brevitas_dev", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.11.9" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +}