Skip to content

Commit

Permalink
Feat (mx): adding padding and transposed support (#1007)
Browse files Browse the repository at this point in the history
  • Loading branch information
Giuseppe5 authored Aug 26, 2024
1 parent 89fca2f commit 21537ef
Show file tree
Hide file tree
Showing 4 changed files with 115 additions and 19 deletions.
82 changes: 80 additions & 2 deletions notebooks/minifloat_mx_tutorial.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,8 @@
"o = ocp_fp8_model(x)\n",
"\n",
"intermediate_input = ocp_fp8_model.conv.input_quant(x)\n",
"assert isinstance(intermediate_input, FloatQuantTensor)"
"assert isinstance(intermediate_input, FloatQuantTensor)\n",
"assert isinstance(ocp_fp8_model.conv.quant_weight(), FloatQuantTensor)"
]
},
{
Expand Down Expand Up @@ -180,7 +181,84 @@
"o = mx_model(x)\n",
"\n",
"intermediate_input = mx_model.conv.input_quant(x)\n",
"assert isinstance(intermediate_input, GroupwiseFloatQuantTensor)"
"assert isinstance(intermediate_input, GroupwiseFloatQuantTensor)\n",
"assert isinstance(mx_model.conv.quant_weight(), GroupwiseFloatQuantTensor)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"If the input channel dimension is not divisible by group size, padding will be applied."
]
},
{
"cell_type": "code",
"execution_count": 16,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Non padding weights shape torch.Size([64, 8, 3, 3])\n",
"Padded weights shape torch.Size([64, 32, 3, 3])\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"/home/giuseppe/miniconda3/envs/brevitas_dev/lib/python3.11/site-packages/torch/nn/modules/conv.py:456: UserWarning: Defining your `__torch_function__` as a plain method is deprecated and will be an error in future, please define it as a classmethod. (Triggered internally at /opt/conda/conda-bld/pytorch_1712608853099/work/torch/csrc/utils/python_arg_parser.cpp:294.)\n",
" return F.conv2d(input, weight, bias, self.stride,\n"
]
}
],
"source": [
"class MXFloat8WeightNoPadding(MXFloat8e4m3Weight, Fp8e4m3Mixin):\n",
" # The group dimension for the weights it is automatically identified based on the layer type\n",
" # If a new layer type is used, it can be manually specified\n",
" group_size = 8\n",
"\n",
"class MXFloat8ActNoPadding(MXFloat8e4m3Act, Fp8e4m3Mixin):\n",
" # It is necessary to specify the group dimension for the activation quantization\n",
" group_size = 8\n",
" group_dim = 1\n",
"\n",
"\n",
"class MXModelNoPadding(nn.Module):\n",
" def __init__(self):\n",
" super().__init__()\n",
" self.conv = qnn.QuantConv2d(8, 64, 3, weight_quant=MXFloat8WeightNoPadding, input_quant=MXFloat8ActNoPadding)\n",
" \n",
" def forward(self, x):\n",
" return self.conv(x)\n",
"\n",
"class MXModel(nn.Module):\n",
" def __init__(self):\n",
" super().__init__()\n",
" self.conv = qnn.QuantConv2d(8, 64, 3, weight_quant=MXFloat8Weight, input_quant=MXFloat8Act)\n",
" \n",
" def forward(self, x):\n",
" return self.conv(x)\n",
"\n",
"mx_model_no_padding = MXModelNoPadding()\n",
"mx_model = MXModel()\n",
"# Make sure that the modules are the same\n",
"mx_model_no_padding.load_state_dict(mx_model.state_dict())\n",
"\n",
"x = torch.randn(1, 8, 8, 8)\n",
"mx_model.eval()\n",
"mx_model_no_padding.eval()\n",
"o_no_padding = mx_model_no_padding(x)\n",
"o = mx_model(x)\n",
"\n",
"# The quant weight of the padded model is different from the non padding one\n",
"print(f\"Non padding weights shape {mx_model_no_padding.conv.quant_weight().value.shape}\")\n",
"print(f\"Padded weights shape {mx_model.conv.quant_weight().value.shape}\")\n",
"\n",
"# However, results are still the same \n",
"assert torch.allclose(o, o_no_padding)"
]
},
{
Expand Down
19 changes: 13 additions & 6 deletions src/brevitas/core/function_wrapper/shape.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,17 +156,14 @@ def forward(self, x: torch.Tensor):
class OverSubChannelBlockView(brevitas.jit.ScriptModule):
__constants__ = ['expanded_scaling_shape']

def __init__(self, expanded_scaling_shape, permute_dims: Optional[Tuple[int, ...]]) -> None:
def __init__(self, expanded_scaling_shape, padding) -> None:
super(OverSubChannelBlockView, self).__init__()
self.expanded_scaling_shape = expanded_scaling_shape
if permute_dims is not None:
self.permute_impl = PermuteDims(permute_dims)
else:
self.permute_impl = torch.nn.Identity()
self.padding = padding

@brevitas.jit.script_method
def forward(self, x: torch.Tensor):
y = self.permute_impl(x)
y = torch.nn.functional.pad(x, self.padding, mode='constant', value=0)
y = y.view(self.expanded_scaling_shape)
return y

Expand All @@ -181,6 +178,16 @@ def __init__(self, group_size, group_dim) -> None:

@brevitas.jit.script_method
def forward(self, x):

tensor_shape = x.shape
tensor_shape_list = list(tensor_shape)
padding = [0, 0] * len(tensor_shape_list)
if tensor_shape_list[self.group_dim] % self.group_size != 0:
padding[2 * self.group_dim] = self.group_size - tensor_shape_list[
self.group_dim] % self.group_size
padding = list(reversed(padding))
x = torch.nn.functional.pad(x, padding, mode='constant', value=0)

tensor_shape = x.shape
tensor_shape_list = list(tensor_shape)
tensor_shape_list[self.group_dim] = int(tensor_shape_list[self.group_dim] / self.group_size)
Expand Down
4 changes: 2 additions & 2 deletions src/brevitas/quant/solver/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,13 +172,13 @@ def int_scaling_impl(restrict_scaling_type):
class SolveStatsReduceDimFromEnum(ExtendedInjector):

@value
def stats_reduce_dim(scaling_stats_op, scaling_per_output):
def stats_reduce_dim(scaling_stats_op, scaling_per_output, group_dim=None):
if scaling_per_output == ScalingPerOutputType.CHANNEL or scaling_stats_op == StatsOp.MAX_AVE:
return SCALING_STATS_REDUCE_DIM
elif scaling_per_output == ScalingPerOutputType.TENSOR:
return None
elif scaling_per_output == ScalingPerOutputType.GROUP:
return SCALING_STATS_REDUCE_DIM + 1
return group_dim + 1

@value
def keepdim(scaling_per_output):
Expand Down
29 changes: 20 additions & 9 deletions src/brevitas/quant/solver/parameter.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,36 +111,47 @@ def scaling_impl(scaling_impl_type):
class SolveParameterScalingShape(ExtendedInjector):

@value
def scaling_shape(module, group_size=None, scaling_per_output=None):
def scaling_shape(module, group_dim, group_size=None, scaling_per_output=None):
if scaling_per_output == ScalingPerOutputType.TENSOR:
return SCALAR_SHAPE
elif scaling_per_output == ScalingPerOutputType.CHANNEL:
return this.scaling_per_output_channel_shape
elif scaling_per_output == ScalingPerOutputType.GROUP:
assert group_size is not None, "Per Group scaling requires group size"
assert group_dim is not None, "Per Group scaling requires group dim"
size = list(module.weight.shape)
assert size[1] % group_size == 0, 'Input channel is not divisible by group size'
size[1] = size[1] // group_size
size.insert(2, 1)
size[group_dim] = (size[group_dim] + group_size - 1) // group_size
size.insert(group_dim + 1, 1)
return size

@value
def reshaped_scaling_shape(module):
return module.weight.shape

@value
def expanded_scaling_shape(module, group_size=None):
def expanded_scaling_shape(module, group_dim, group_size=None):
assert group_size is not None, "Per Group scaling requires group size"
size = list(module.weight.shape)
assert size[1] % group_size == 0, 'Input channel is not divisible by group size'
size[1] = size[1] // group_size
size.insert(2, group_size)
size[group_dim] = (size[group_dim] + group_size - 1) // group_size
size.insert(group_dim + 1, group_size)
return size

@value
def padding(module, group_dim, group_size):
padding = [0, 0] * len(module.weight.shape)
size = list(module.weight.shape)
if size[group_dim] % group_size != 0:
# Padding is done on the left side
padding[2 * group_dim] = group_size - size[group_dim] % group_size
# Padding takes a list of 2 values per dim in reverse order (N_DIM, N_DIM-1,...,0)
# so we need to reverse the order
padding = list(reversed(padding))
return padding

@value
def group_dim(module, group_size=None):
if group_size is not None:
return 1
return 1 if not hasattr(module, 'transposed') or not module.transposed else 0


class SolveInputViewImpl(ExtendedInjector):
Expand Down

0 comments on commit 21537ef

Please sign in to comment.