diff --git a/src/brevitas/quant/solver/parameter.py b/src/brevitas/quant/solver/parameter.py index 198505ec8..67c3a56c1 100644 --- a/src/brevitas/quant/solver/parameter.py +++ b/src/brevitas/quant/solver/parameter.py @@ -137,7 +137,7 @@ def expanded_scaling_shape(module, input_channel_dim, group_size=None): @value def input_channel_dim(module): - return 1 if not module.transposed else 0 + return 1 if not hasattr(module, 'transposed') or not module.transposed else 0 @value def padding(module, input_channel_dim, group_size): @@ -151,7 +151,7 @@ def padding(module, input_channel_dim, group_size): @value def group_dim(module, group_size=None): if group_size is not None: - return 1 if not module.transposed else 0 + return 1 if not hasattr(module, 'transposed') or not module.transposed else 0 class SolveInputViewImpl(ExtendedInjector):