diff --git a/notebooks/minifloat_mx_tutorial.ipynb b/notebooks/minifloat_mx_tutorial.ipynb index 284a0d4f5..bd43880de 100644 --- a/notebooks/minifloat_mx_tutorial.ipynb +++ b/notebooks/minifloat_mx_tutorial.ipynb @@ -163,8 +163,8 @@ " pass\n", "\n", "class MXFloat8Act(MXFloat8e4m3Act, Fp8e4m3Mixin):\n", - " # It is necessary to specify the group dimension for the activation quantization\n", - " group_dim = 1\n", + " # In layerwise quantization, groupdim is automatically determined\n", + " pass\n", "\n", "\n", "class MXModel(nn.Module):\n", @@ -221,9 +221,8 @@ " group_size = 8\n", "\n", "class MXFloat8ActNoPadding(MXFloat8e4m3Act, Fp8e4m3Mixin):\n", - " # It is necessary to specify the group dimension for the activation quantization\n", + " # In layerwise quantization, groupdim is automatically determined\n", " group_size = 8\n", - " group_dim = 1\n", "\n", "\n", "class MXModelNoPadding(nn.Module):\n", @@ -277,8 +276,8 @@ " pass\n", "\n", "class MXFloat8Act(MXFloat8e4m3Act, Fp8e4m3Mixin):\n", - " # It is necessary to specify the group dimension for the activation quantization\n", - " group_dim = 1\n", + " # In layerwise quantization, groupdim is automatically determined\n", + " pass\n", "\n", "\n", "class MXModel(nn.Module):\n", @@ -314,12 +313,11 @@ "class MXInt8Weight(MXInt8Weight):\n", " # The group dimension for the weights it is automatically identified based on the layer type\n", " # If a new layer type is used, it can be manually specified\n", - " bit_width = 8\n", + " pass\n", "\n", "class MXInt8Act(MXInt8Act):\n", - " # It is necessary to specify the group dimension for the activation quantization\n", - " group_dim = 1\n", - " bit_width = 8\n", + " # In layerwise quantization, groupdim is automatically determined\n", + " pass\n", "\n", "class MXModel(nn.Module):\n", " def __init__(self):\n",