Skip to content

Commit

Permalink
Fix (ptq): change momentum to attribute
Browse files Browse the repository at this point in the history
  • Loading branch information
Giuseppe5 committed Oct 3, 2023
1 parent bff4c82 commit 449764f
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 4 deletions.
4 changes: 2 additions & 2 deletions src/brevitas/core/scaling/standalone.py
Original file line number Diff line number Diff line change
Expand Up @@ -282,7 +282,6 @@ class ParameterFromRuntimeStatsScaling(brevitas.jit.ScriptModule):
Maps to scaling_impl_type == ScalingImplType.PARAMETER_FROM_STATS == 'PARAMETER_FROM_STATS'
== 'parameter_from_stats' when applied to runtime values (inputs/outputs/activations) in higher-level APIs.
"""
__constants__ = ['momentum']

def __init__(
self,
Expand All @@ -301,7 +300,8 @@ def __init__(
self.counter: int = brevitas.jit.Attribute(0, int)
self.stats_input_view_shape_impl = scaling_stats_input_view_shape_impl
self.stats = _Stats(scaling_stats_impl, scaling_shape)
self.momentum = scaling_stats_momentum
self.momentum: Optional[float] = brevitas.jit.Attribute(
scaling_stats_momentum, Optional[float])
self.register_buffer('buffer', torch.full(scaling_shape, 1.0, dtype=dtype, device=device))
self.value = Parameter(torch.full(scaling_shape, 1.0, dtype=dtype, device=device))
self.restrict_scaling = _RestrictValue(restrict_scaling_impl)
Expand Down
5 changes: 3 additions & 2 deletions src/brevitas/core/zero_point.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ def forward(self, x: Tensor, scale: Tensor, bit_width: Tensor) -> torch.Tensor:


class ParameterFromRuntimeZeroPoint(brevitas.jit.ScriptModule):
__constants__ = ['stats_permute_dims', 'zero_point_shape', 'momentum']
__constants__ = ['stats_permute_dims', 'zero_point_shape']

def __init__(
self,
Expand All @@ -105,7 +105,8 @@ def __init__(
self.counter: int = brevitas.jit.Attribute(0, int)
self.zero_point_shape = zero_point_shape
self.stats_input_view_shape_impl = zero_point_stats_input_view_shape_impl
self.momentum = zero_point_stats_momentum
self.momentum: Optional[float] = brevitas.jit.Attribute(
zero_point_stats_momentum, Optional[float])
self.value = Parameter(torch.full(zero_point_shape, 0.0, dtype=dtype, device=device))
self.register_buffer(
'buffer', torch.full(zero_point_shape, 0.0, dtype=dtype, device=device))
Expand Down

0 comments on commit 449764f

Please sign in to comment.