diff --git a/src/brevitas/quant/solver/parameter.py b/src/brevitas/quant/solver/parameter.py index cc36cc72d..d4befc1c9 100644 --- a/src/brevitas/quant/solver/parameter.py +++ b/src/brevitas/quant/solver/parameter.py @@ -111,16 +111,17 @@ def scaling_impl(scaling_impl_type): class SolveParameterScalingShape(ExtendedInjector): @value - def scaling_shape(module, input_channel_dim, group_size=None, scaling_per_output=None): + def scaling_shape(module, group_dim, group_size=None, scaling_per_output=None): if scaling_per_output == ScalingPerOutputType.TENSOR: return SCALAR_SHAPE elif scaling_per_output == ScalingPerOutputType.CHANNEL: return this.scaling_per_output_channel_shape elif scaling_per_output == ScalingPerOutputType.GROUP: assert group_size is not None, "Per Group scaling requires group size" + assert group_dim is not None, "Per Group scaling requires group dim" size = list(module.weight.shape) - size[input_channel_dim] = (size[input_channel_dim] + group_size - 1) // group_size - size.insert(input_channel_dim + 1, 1) + size[group_dim] = (size[group_dim] + group_size - 1) // group_size + size.insert(group_dim + 1, 1) return size @value