-
Notifications
You must be signed in to change notification settings - Fork 197
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
1 changed file
with
131 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,131 @@ | ||
{ | ||
"cells": [ | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"# Minifloat and MX Example" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"Work in progress examples to show how to use minifloat and MX with Brevitas" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 3, | ||
"metadata": {}, | ||
"outputs": [ | ||
{ | ||
"name": "stderr", | ||
"output_type": "stream", | ||
"text": [ | ||
"/home/giuseppe/miniconda3/envs/brevitas_dev/lib/python3.11/site-packages/torch/nn/modules/conv.py:456: UserWarning: Defining your `__torch_function__` as a plain method is deprecated and will be an error in future, please define it as a classmethod. (Triggered internally at /opt/conda/conda-bld/pytorch_1712608853099/work/torch/csrc/utils/python_arg_parser.cpp:294.)\n", | ||
" return F.conv2d(input, weight, bias, self.stride,\n" | ||
] | ||
} | ||
], | ||
"source": [ | ||
"from brevitas.quant.experimental.float_base import Fp8e4m3Mixin\n", | ||
"from brevitas.quant.experimental.mx_quant_ocp import MXFloatWeight\n", | ||
"from brevitas.quant.experimental.float_quant_ocp import FpOCPWeightPerTensorFloat, FpOCPActPerTensorFloat\n", | ||
"from brevitas.quant.experimental.mx_quant_ocp import MXFloatAct\n", | ||
"import brevitas.nn as qnn\n", | ||
"import torch.nn as nn\n", | ||
"import torch\n", | ||
"from brevitas.quant_tensor import FloatQuantTensor\n", | ||
"\n", | ||
"class OCPFP8Weight(FpOCPWeightPerTensorFloat, Fp8e4m3Mixin):\n", | ||
" pass\n", | ||
"\n", | ||
"\n", | ||
"class OCPFP8Act(FpOCPActPerTensorFloat, Fp8e4m3Mixin):\n", | ||
" pass\n", | ||
"\n", | ||
"\n", | ||
"class FP8Model(nn.Module):\n", | ||
" def __init__(self):\n", | ||
" super().__init__()\n", | ||
" self.conv = qnn.QuantConv2d(32, 64, 3, weight_quant=OCPFP8Weight, input_quant=OCPFP8Act)\n", | ||
" \n", | ||
" def forward(self, x):\n", | ||
" return self.conv(x)\n", | ||
"\n", | ||
"ocp_fp8_model = FP8Model()\n", | ||
"x = torch.randn(1, 32, 8, 8)\n", | ||
"ocp_fp8_model.eval()\n", | ||
"o = ocp_fp8_model(x)\n", | ||
"\n", | ||
"intermediate_input = ocp_fp8_model.conv.input_quant(x)\n", | ||
"assert isinstance(intermediate_input, FloatQuantTensor)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 5, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"from brevitas.quant_tensor import GroupwiseFloatQuantTensor\n", | ||
"\n", | ||
"\n", | ||
"class MXFloat8Weight(MXFloatWeight, Fp8e4m3Mixin):\n", | ||
" # The group dimension for the weights it is automatically identified based on the layer type\n", | ||
" # If a new layer type is used, it can be manually specified\n", | ||
" pass\n", | ||
"\n", | ||
"class MXFloat8Act(MXFloatAct, Fp8e4m3Mixin):\n", | ||
" # It is necessary to specify the group dimension for the activation quantization\n", | ||
" group_dim = 1\n", | ||
"\n", | ||
"\n", | ||
"class MXModel(nn.Module):\n", | ||
" def __init__(self):\n", | ||
" super().__init__()\n", | ||
" self.conv = qnn.QuantConv2d(32, 64, 3, weight_quant=MXFloat8Weight, input_quant=MXFloat8Act)\n", | ||
" \n", | ||
" def forward(self, x):\n", | ||
" return self.conv(x)\n", | ||
"\n", | ||
"mx_model = MXModel()\n", | ||
"x = torch.randn(1, 32, 8, 8)\n", | ||
"mx_model.eval()\n", | ||
"o = mx_model(x)\n", | ||
"\n", | ||
"intermediate_input = mx_model.conv.input_quant(x)\n", | ||
"assert isinstance(intermediate_input, GroupwiseFloatQuantTensor)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [] | ||
} | ||
], | ||
"metadata": { | ||
"kernelspec": { | ||
"display_name": "brevitas_dev", | ||
"language": "python", | ||
"name": "python3" | ||
}, | ||
"language_info": { | ||
"codemirror_mode": { | ||
"name": "ipython", | ||
"version": 3 | ||
}, | ||
"file_extension": ".py", | ||
"mimetype": "text/x-python", | ||
"name": "python", | ||
"nbconvert_exporter": "python", | ||
"pygments_lexer": "ipython3", | ||
"version": "3.11.9" | ||
} | ||
}, | ||
"nbformat": 4, | ||
"nbformat_minor": 2 | ||
} |