diff --git a/notebooks/minifloat_mx_tutorial.ipynb b/notebooks/minifloat_mx_tutorial.ipynb index 60f00fcd4..284a0d4f5 100644 --- a/notebooks/minifloat_mx_tutorial.ipynb +++ b/notebooks/minifloat_mx_tutorial.ipynb @@ -104,7 +104,8 @@ "o = ocp_fp8_model(x)\n", "\n", "intermediate_input = ocp_fp8_model.conv.input_quant(x)\n", - "assert isinstance(intermediate_input, FloatQuantTensor)" + "assert isinstance(intermediate_input, FloatQuantTensor)\n", + "assert isinstance(ocp_fp8_model.conv.quant_weight(), FloatQuantTensor)" ] }, { @@ -180,7 +181,84 @@ "o = mx_model(x)\n", "\n", "intermediate_input = mx_model.conv.input_quant(x)\n", - "assert isinstance(intermediate_input, GroupwiseFloatQuantTensor)" + "assert isinstance(intermediate_input, GroupwiseFloatQuantTensor)\n", + "assert isinstance(mx_model.conv.quant_weight(), GroupwiseFloatQuantTensor)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "If the input channel dimension is not divisible by group size, padding will be applied." + ] + }, + { + "cell_type": "code", + "execution_count": 16, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Non padding weights shape torch.Size([64, 8, 3, 3])\n", + "Padded weights shape torch.Size([64, 32, 3, 3])\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/home/giuseppe/miniconda3/envs/brevitas_dev/lib/python3.11/site-packages/torch/nn/modules/conv.py:456: UserWarning: Defining your `__torch_function__` as a plain method is deprecated and will be an error in future, please define it as a classmethod. (Triggered internally at /opt/conda/conda-bld/pytorch_1712608853099/work/torch/csrc/utils/python_arg_parser.cpp:294.)\n", + " return F.conv2d(input, weight, bias, self.stride,\n" + ] + } + ], + "source": [ + "class MXFloat8WeightNoPadding(MXFloat8e4m3Weight, Fp8e4m3Mixin):\n", + " # The group dimension for the weights it is automatically identified based on the layer type\n", + " # If a new layer type is used, it can be manually specified\n", + " group_size = 8\n", + "\n", + "class MXFloat8ActNoPadding(MXFloat8e4m3Act, Fp8e4m3Mixin):\n", + " # It is necessary to specify the group dimension for the activation quantization\n", + " group_size = 8\n", + " group_dim = 1\n", + "\n", + "\n", + "class MXModelNoPadding(nn.Module):\n", + " def __init__(self):\n", + " super().__init__()\n", + " self.conv = qnn.QuantConv2d(8, 64, 3, weight_quant=MXFloat8WeightNoPadding, input_quant=MXFloat8ActNoPadding)\n", + " \n", + " def forward(self, x):\n", + " return self.conv(x)\n", + "\n", + "class MXModel(nn.Module):\n", + " def __init__(self):\n", + " super().__init__()\n", + " self.conv = qnn.QuantConv2d(8, 64, 3, weight_quant=MXFloat8Weight, input_quant=MXFloat8Act)\n", + " \n", + " def forward(self, x):\n", + " return self.conv(x)\n", + "\n", + "mx_model_no_padding = MXModelNoPadding()\n", + "mx_model = MXModel()\n", + "# Make sure that the modules are the same\n", + "mx_model_no_padding.load_state_dict(mx_model.state_dict())\n", + "\n", + "x = torch.randn(1, 8, 8, 8)\n", + "mx_model.eval()\n", + "mx_model_no_padding.eval()\n", + "o_no_padding = mx_model_no_padding(x)\n", + "o = mx_model(x)\n", + "\n", + "# The quant weight of the padded model is different from the non padding one\n", + "print(f\"Non padding weights shape {mx_model_no_padding.conv.quant_weight().value.shape}\")\n", + "print(f\"Padded weights shape {mx_model.conv.quant_weight().value.shape}\")\n", + "\n", + "# However, results are still the same \n", + "assert torch.allclose(o, o_no_padding)" ] }, { diff --git a/src/brevitas/core/function_wrapper/shape.py b/src/brevitas/core/function_wrapper/shape.py index 1bb77476e..84ee9f355 100644 --- a/src/brevitas/core/function_wrapper/shape.py +++ b/src/brevitas/core/function_wrapper/shape.py @@ -156,17 +156,14 @@ def forward(self, x: torch.Tensor): class OverSubChannelBlockView(brevitas.jit.ScriptModule): __constants__ = ['expanded_scaling_shape'] - def __init__(self, expanded_scaling_shape, permute_dims: Optional[Tuple[int, ...]]) -> None: + def __init__(self, expanded_scaling_shape, padding) -> None: super(OverSubChannelBlockView, self).__init__() self.expanded_scaling_shape = expanded_scaling_shape - if permute_dims is not None: - self.permute_impl = PermuteDims(permute_dims) - else: - self.permute_impl = torch.nn.Identity() + self.padding = padding @brevitas.jit.script_method def forward(self, x: torch.Tensor): - y = self.permute_impl(x) + y = torch.nn.functional.pad(x, self.padding, mode='constant', value=0) y = y.view(self.expanded_scaling_shape) return y @@ -181,6 +178,16 @@ def __init__(self, group_size, group_dim) -> None: @brevitas.jit.script_method def forward(self, x): + + tensor_shape = x.shape + tensor_shape_list = list(tensor_shape) + padding = [0, 0] * len(tensor_shape_list) + if tensor_shape_list[self.group_dim] % self.group_size != 0: + padding[2 * self.group_dim] = self.group_size - tensor_shape_list[ + self.group_dim] % self.group_size + padding = list(reversed(padding)) + x = torch.nn.functional.pad(x, padding, mode='constant', value=0) + tensor_shape = x.shape tensor_shape_list = list(tensor_shape) tensor_shape_list[self.group_dim] = int(tensor_shape_list[self.group_dim] / self.group_size) diff --git a/src/brevitas/quant/solver/common.py b/src/brevitas/quant/solver/common.py index 2847275e8..4d46cc704 100644 --- a/src/brevitas/quant/solver/common.py +++ b/src/brevitas/quant/solver/common.py @@ -172,13 +172,13 @@ def int_scaling_impl(restrict_scaling_type): class SolveStatsReduceDimFromEnum(ExtendedInjector): @value - def stats_reduce_dim(scaling_stats_op, scaling_per_output): + def stats_reduce_dim(scaling_stats_op, scaling_per_output, group_dim=None): if scaling_per_output == ScalingPerOutputType.CHANNEL or scaling_stats_op == StatsOp.MAX_AVE: return SCALING_STATS_REDUCE_DIM elif scaling_per_output == ScalingPerOutputType.TENSOR: return None elif scaling_per_output == ScalingPerOutputType.GROUP: - return SCALING_STATS_REDUCE_DIM + 1 + return group_dim + 1 @value def keepdim(scaling_per_output): diff --git a/src/brevitas/quant/solver/parameter.py b/src/brevitas/quant/solver/parameter.py index 97137c567..76d3f6f3e 100644 --- a/src/brevitas/quant/solver/parameter.py +++ b/src/brevitas/quant/solver/parameter.py @@ -111,17 +111,17 @@ def scaling_impl(scaling_impl_type): class SolveParameterScalingShape(ExtendedInjector): @value - def scaling_shape(module, group_size=None, scaling_per_output=None): + def scaling_shape(module, group_dim, group_size=None, scaling_per_output=None): if scaling_per_output == ScalingPerOutputType.TENSOR: return SCALAR_SHAPE elif scaling_per_output == ScalingPerOutputType.CHANNEL: return this.scaling_per_output_channel_shape elif scaling_per_output == ScalingPerOutputType.GROUP: assert group_size is not None, "Per Group scaling requires group size" + assert group_dim is not None, "Per Group scaling requires group dim" size = list(module.weight.shape) - assert size[1] % group_size == 0, 'Input channel is not divisible by group size' - size[1] = size[1] // group_size - size.insert(2, 1) + size[group_dim] = (size[group_dim] + group_size - 1) // group_size + size.insert(group_dim + 1, 1) return size @value @@ -129,18 +129,29 @@ def reshaped_scaling_shape(module): return module.weight.shape @value - def expanded_scaling_shape(module, group_size=None): + def expanded_scaling_shape(module, group_dim, group_size=None): assert group_size is not None, "Per Group scaling requires group size" size = list(module.weight.shape) - assert size[1] % group_size == 0, 'Input channel is not divisible by group size' - size[1] = size[1] // group_size - size.insert(2, group_size) + size[group_dim] = (size[group_dim] + group_size - 1) // group_size + size.insert(group_dim + 1, group_size) return size + @value + def padding(module, group_dim, group_size): + padding = [0, 0] * len(module.weight.shape) + size = list(module.weight.shape) + if size[group_dim] % group_size != 0: + # Padding is done on the left side + padding[2 * group_dim] = group_size - size[group_dim] % group_size + # Padding takes a list of 2 values per dim in reverse order (N_DIM, N_DIM-1,...,0) + # so we need to reverse the order + padding = list(reversed(padding)) + return padding + @value def group_dim(module, group_size=None): if group_size is not None: - return 1 + return 1 if not hasattr(module, 'transposed') or not module.transposed else 0 class SolveInputViewImpl(ExtendedInjector):