Skip to content

Commit

Permalink
simplify
Browse files Browse the repository at this point in the history
  • Loading branch information
Giuseppe5 committed Aug 26, 2024
1 parent b4c8cf5 commit 7ec08ed
Showing 1 changed file with 6 additions and 10 deletions.
16 changes: 6 additions & 10 deletions src/brevitas/quant/solver/parameter.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,23 +128,19 @@ def reshaped_scaling_shape(module):
return module.weight.shape

@value
def expanded_scaling_shape(module, input_channel_dim, group_size=None):
def expanded_scaling_shape(module, group_dim, group_size=None):
assert group_size is not None, "Per Group scaling requires group size"
size = list(module.weight.shape)
size[input_channel_dim] = (size[input_channel_dim] + group_size - 1) // group_size
size.insert(input_channel_dim + 1, group_size)
size[group_dim] = (size[group_dim] + group_size - 1) // group_size
size.insert(group_dim + 1, group_size)
return size

@value
def input_channel_dim(module):
return 1 if not hasattr(module, 'transposed') or not module.transposed else 0

@value
def padding(module, input_channel_dim, group_size):
def padding(module, group_dim, group_size):
padding = [0, 0] * len(module.weight.shape)
size = list(module.weight.shape)
if size[input_channel_dim] % group_size != 0:
padding[2 * input_channel_dim] = group_size - size[input_channel_dim] % group_size
if size[group_dim] % group_size != 0:
padding[2 * group_dim] = group_size - size[group_dim] % group_size
padding = list(reversed(padding))
return padding

Expand Down

0 comments on commit 7ec08ed

Please sign in to comment.