Skip to content

Commit

Permalink
Feat (scaling/standalone): flag to retrieve full state dict (#874)
Browse files Browse the repository at this point in the history
  • Loading branch information
Giuseppe5 authored Feb 23, 2024
1 parent 0186ccd commit 506954c
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions src/brevitas/core/scaling/standalone.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,7 +221,7 @@ def state_dict(self, destination=None, prefix='', keep_vars=False):
output_dict = super(ParameterFromStatsFromParameterScaling, self).state_dict(
destination=destination, prefix=prefix, keep_vars=keep_vars)
# Avoid saving the init value
if not self.init_done:
if not self.init_done and not config._FULL_STATE_DICT:
del output_dict[prefix + 'value']
return output_dict

Expand Down Expand Up @@ -362,7 +362,7 @@ def state_dict(self, destination=None, prefix='', keep_vars=False):
# Avoid saving the buffer
del output_dict[prefix + 'buffer']
# Avoid saving the init value
if self.counter == 0:
if self.counter == 0 and not config._FULL_STATE_DICT:
del output_dict[prefix + 'value']
# Save buffer into value for any non-zero number of collection steps
elif self.counter <= self.collect_stats_steps:
Expand Down

0 comments on commit 506954c

Please sign in to comment.