Skip to content

Commit

Permalink
notebook
Browse files Browse the repository at this point in the history
  • Loading branch information
Giuseppe5 committed Jun 19, 2024
1 parent 2b82908 commit ee3fbdd
Showing 1 changed file with 131 additions and 0 deletions.
131 changes: 131 additions & 0 deletions notebooks/minifloat_mx_tutorial.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,131 @@
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Minifloat and MX Example"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Work in progress examples to show how to use minifloat and MX with Brevitas"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/home/giuseppe/miniconda3/envs/brevitas_dev/lib/python3.11/site-packages/torch/nn/modules/conv.py:456: UserWarning: Defining your `__torch_function__` as a plain method is deprecated and will be an error in future, please define it as a classmethod. (Triggered internally at /opt/conda/conda-bld/pytorch_1712608853099/work/torch/csrc/utils/python_arg_parser.cpp:294.)\n",
" return F.conv2d(input, weight, bias, self.stride,\n"
]
}
],
"source": [
"from brevitas.quant.experimental.float_base import Fp8e4m3Mixin\n",
"from brevitas.quant.experimental.mx_quant_ocp import MXFloatWeight\n",
"from brevitas.quant.experimental.float_quant_ocp import FpOCPWeightPerTensorFloat, FpOCPActPerTensorFloat\n",
"from brevitas.quant.experimental.mx_quant_ocp import MXFloatAct\n",
"import brevitas.nn as qnn\n",
"import torch.nn as nn\n",
"import torch\n",
"from brevitas.quant_tensor import FloatQuantTensor\n",
"\n",
"class OCPFP8Weight(FpOCPWeightPerTensorFloat, Fp8e4m3Mixin):\n",
" pass\n",
"\n",
"\n",
"class OCPFP8Act(FpOCPActPerTensorFloat, Fp8e4m3Mixin):\n",
" pass\n",
"\n",
"\n",
"class FP8Model(nn.Module):\n",
" def __init__(self):\n",
" super().__init__()\n",
" self.conv = qnn.QuantConv2d(32, 64, 3, weight_quant=OCPFP8Weight, input_quant=OCPFP8Act)\n",
" \n",
" def forward(self, x):\n",
" return self.conv(x)\n",
"\n",
"ocp_fp8_model = FP8Model()\n",
"x = torch.randn(1, 32, 8, 8)\n",
"ocp_fp8_model.eval()\n",
"o = ocp_fp8_model(x)\n",
"\n",
"intermediate_input = ocp_fp8_model.conv.input_quant(x)\n",
"assert isinstance(intermediate_input, FloatQuantTensor)"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [],
"source": [
"from brevitas.quant_tensor import GroupwiseFloatQuantTensor\n",
"\n",
"\n",
"class MXFloat8Weight(MXFloatWeight, Fp8e4m3Mixin):\n",
" # The group dimension for the weights it is automatically identified based on the layer type\n",
" # If a new layer type is used, it can be manually specified\n",
" pass\n",
"\n",
"class MXFloat8Act(MXFloatAct, Fp8e4m3Mixin):\n",
" # It is necessary to specify the group dimension for the activation quantization\n",
" group_dim = 1\n",
"\n",
"\n",
"class MXModel(nn.Module):\n",
" def __init__(self):\n",
" super().__init__()\n",
" self.conv = qnn.QuantConv2d(32, 64, 3, weight_quant=MXFloat8Weight, input_quant=MXFloat8Act)\n",
" \n",
" def forward(self, x):\n",
" return self.conv(x)\n",
"\n",
"mx_model = MXModel()\n",
"x = torch.randn(1, 32, 8, 8)\n",
"mx_model.eval()\n",
"o = mx_model(x)\n",
"\n",
"intermediate_input = mx_model.conv.input_quant(x)\n",
"assert isinstance(intermediate_input, GroupwiseFloatQuantTensor)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "brevitas_dev",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.9"
}
},
"nbformat": 4,
"nbformat_minor": 2
}

0 comments on commit ee3fbdd

Please sign in to comment.