diff --git a/src/brevitas/quant/base.py b/src/brevitas/quant/base.py index f90caeb2e..269779f68 100644 --- a/src/brevitas/quant/base.py +++ b/src/brevitas/quant/base.py @@ -348,6 +348,7 @@ class PerChannelPreNorm(ExtendedInjector): normalize_stats_impl = (this << 1).normalize_stats_impl tracked_parameter_list = (this << 1).tracked_parameter_list pre_scaling_shape = (this << 1).pre_scaling_shape + permute_dims = (this << 1).permute_dims class SolvePostScaleGranularity(ExtendedInjector):