Skip to content

Commit

Permalink
Updated notebooks
Browse files Browse the repository at this point in the history
  • Loading branch information
Giuseppe5 authored and nickfraser committed Sep 4, 2024
1 parent 8b21ba6 commit 42a34bf
Showing 1 changed file with 8 additions and 10 deletions.
18 changes: 8 additions & 10 deletions notebooks/minifloat_mx_tutorial.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -163,8 +163,8 @@
" pass\n",
"\n",
"class MXFloat8Act(MXFloat8e4m3Act, Fp8e4m3Mixin):\n",
" # It is necessary to specify the group dimension for the activation quantization\n",
" group_dim = 1\n",
" # In layerwise quantization, groupdim is automatically determined\n",
" pass\n",
"\n",
"\n",
"class MXModel(nn.Module):\n",
Expand Down Expand Up @@ -221,9 +221,8 @@
" group_size = 8\n",
"\n",
"class MXFloat8ActNoPadding(MXFloat8e4m3Act, Fp8e4m3Mixin):\n",
" # It is necessary to specify the group dimension for the activation quantization\n",
" # In layerwise quantization, groupdim is automatically determined\n",
" group_size = 8\n",
" group_dim = 1\n",
"\n",
"\n",
"class MXModelNoPadding(nn.Module):\n",
Expand Down Expand Up @@ -277,8 +276,8 @@
" pass\n",
"\n",
"class MXFloat8Act(MXFloat8e4m3Act, Fp8e4m3Mixin):\n",
" # It is necessary to specify the group dimension for the activation quantization\n",
" group_dim = 1\n",
" # In layerwise quantization, groupdim is automatically determined\n",
" pass\n",
"\n",
"\n",
"class MXModel(nn.Module):\n",
Expand Down Expand Up @@ -314,12 +313,11 @@
"class MXInt8Weight(MXInt8Weight):\n",
" # The group dimension for the weights it is automatically identified based on the layer type\n",
" # If a new layer type is used, it can be manually specified\n",
" bit_width = 8\n",
" pass\n",
"\n",
"class MXInt8Act(MXInt8Act):\n",
" # It is necessary to specify the group dimension for the activation quantization\n",
" group_dim = 1\n",
" bit_width = 8\n",
" # In layerwise quantization, groupdim is automatically determined\n",
" pass\n",
"\n",
"class MXModel(nn.Module):\n",
" def __init__(self):\n",
Expand Down

0 comments on commit 42a34bf

Please sign in to comment.