Skip to content

Commit

Permalink
fix simplify
Browse files Browse the repository at this point in the history
  • Loading branch information
Giuseppe5 committed Aug 26, 2024
1 parent 7ec08ed commit a85c733
Showing 1 changed file with 4 additions and 3 deletions.
7 changes: 4 additions & 3 deletions src/brevitas/quant/solver/parameter.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,16 +111,17 @@ def scaling_impl(scaling_impl_type):
class SolveParameterScalingShape(ExtendedInjector):

@value
def scaling_shape(module, input_channel_dim, group_size=None, scaling_per_output=None):
def scaling_shape(module, group_dim, group_size=None, scaling_per_output=None):
if scaling_per_output == ScalingPerOutputType.TENSOR:
return SCALAR_SHAPE
elif scaling_per_output == ScalingPerOutputType.CHANNEL:
return this.scaling_per_output_channel_shape
elif scaling_per_output == ScalingPerOutputType.GROUP:
assert group_size is not None, "Per Group scaling requires group size"
assert group_dim is not None, "Per Group scaling requires group dim"
size = list(module.weight.shape)
size[input_channel_dim] = (size[input_channel_dim] + group_size - 1) // group_size
size.insert(input_channel_dim + 1, 1)
size[group_dim] = (size[group_dim] + group_size - 1) // group_size
size.insert(group_dim + 1, 1)
return size

@value
Expand Down

0 comments on commit a85c733

Please sign in to comment.