diff --git a/src/brevitas/core/scaling/standalone.py b/src/brevitas/core/scaling/standalone.py index 2bd65f592..eb8f7778f 100644 --- a/src/brevitas/core/scaling/standalone.py +++ b/src/brevitas/core/scaling/standalone.py @@ -282,7 +282,6 @@ class ParameterFromRuntimeStatsScaling(brevitas.jit.ScriptModule): Maps to scaling_impl_type == ScalingImplType.PARAMETER_FROM_STATS == 'PARAMETER_FROM_STATS' == 'parameter_from_stats' when applied to runtime values (inputs/outputs/activations) in higher-level APIs. """ - __constants__ = ['momentum'] def __init__( self, @@ -301,7 +300,8 @@ def __init__( self.counter: int = brevitas.jit.Attribute(0, int) self.stats_input_view_shape_impl = scaling_stats_input_view_shape_impl self.stats = _Stats(scaling_stats_impl, scaling_shape) - self.momentum = scaling_stats_momentum + self.momentum: Optional[float] = brevitas.jit.Attribute( + scaling_stats_momentum, Optional[float]) self.register_buffer('buffer', torch.full(scaling_shape, 1.0, dtype=dtype, device=device)) self.value = Parameter(torch.full(scaling_shape, 1.0, dtype=dtype, device=device)) self.restrict_scaling = _RestrictValue(restrict_scaling_impl) diff --git a/src/brevitas/core/zero_point.py b/src/brevitas/core/zero_point.py index 2a699a77c..7f8ad106d 100644 --- a/src/brevitas/core/zero_point.py +++ b/src/brevitas/core/zero_point.py @@ -86,7 +86,7 @@ def forward(self, x: Tensor, scale: Tensor, bit_width: Tensor) -> torch.Tensor: class ParameterFromRuntimeZeroPoint(brevitas.jit.ScriptModule): - __constants__ = ['stats_permute_dims', 'zero_point_shape', 'momentum'] + __constants__ = ['stats_permute_dims', 'zero_point_shape'] def __init__( self, @@ -105,7 +105,8 @@ def __init__( self.counter: int = brevitas.jit.Attribute(0, int) self.zero_point_shape = zero_point_shape self.stats_input_view_shape_impl = zero_point_stats_input_view_shape_impl - self.momentum = zero_point_stats_momentum + self.momentum: Optional[float] = brevitas.jit.Attribute( + zero_point_stats_momentum, Optional[float]) self.value = Parameter(torch.full(zero_point_shape, 0.0, dtype=dtype, device=device)) self.register_buffer( 'buffer', torch.full(zero_point_shape, 0.0, dtype=dtype, device=device))