Skip to content

Commit

Permalink
Added pytype None checks to savers.py
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 595412439
Change-Id: I1367163d0c2cd08c22446bc9e439b946628d4e4c
  • Loading branch information
Acme Contributor authored and copybara-github committed Jan 3, 2024
1 parent 4c6351e commit 148331f
Showing 1 changed file with 21 additions and 4 deletions.
25 changes: 21 additions & 4 deletions acme/tf/savers.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,23 +150,38 @@ def save(self, force: bool = False) -> bool:
time.time() - self._last_saved < 60 * self._time_delta_minutes):
return False

checkpoint_manager: tf.train.CheckpointManager = self.checkpoint_manager
# Save any checkpoints.
logging.info('Saving checkpoint: %s', self._checkpoint_manager.directory)
self._checkpoint_manager.save()
logging.info('Saving checkpoint: %s', checkpoint_manager.directory)
checkpoint_manager.save()
self._last_saved = time.time()

return True

def restore(self):
"""Restore from most recent checkpoint."""

# Restore from the most recent checkpoint (if it exists).
checkpoint_to_restore = self._checkpoint_manager.latest_checkpoint
checkpoint_to_restore = self.checkpoint_manager.latest_checkpoint
logging.info('Attempting to restore checkpoint: %s',
checkpoint_to_restore)
self._checkpoint.restore(checkpoint_to_restore)

@property
def directory(self):
return self._checkpoint_manager.directory
return self.checkpoint_manager.directory

@property
def checkpoint_manager(self) -> tf.train.CheckpointManager:
if not self._enable_checkpointing:
raise ValueError(
'Check-point not enabled. No checkpoint manager available.'
)

# At this point, _enable_checkpointing is true, so _checkpoint_manager
# should not be None.
assert self._checkpoint_manager is not None
return self._checkpoint_manager


class CheckpointingRunner(core.Worker):
Expand Down Expand Up @@ -332,6 +347,8 @@ def __init__(self):

@tf.function
def __call__(self, *args, **kwargs):
if self._module is None:
raise ValueError('_module not set')
return self._module(*args, **kwargs)

@property
Expand Down

0 comments on commit 148331f

Please sign in to comment.