Skip to content

Commit

Permalink
This CL updates the Snapshotter class to allow None as a valid va…
Browse files Browse the repository at this point in the history
…lue for `snapshot_ttl_seconds`. This TTL value is passed to `paths.process_path` which already supports a `None` TTL.

PiperOrigin-RevId: 675680440
Change-Id: If44d8bcb01f53306660bdcd19a23b570b8f4a20e
  • Loading branch information
Acme Contributor authored and copybara-github committed Sep 17, 2024
1 parent 8c2a8c8 commit 766a98b
Show file tree
Hide file tree
Showing 2 changed files with 35 additions and 2 deletions.
5 changes: 3 additions & 2 deletions acme/tf/savers.py
Original file line number Diff line number Diff line change
Expand Up @@ -286,15 +286,16 @@ def __init__(
*,
directory: str = '~/acme/',
time_delta_minutes: float = 30.0,
snapshot_ttl_seconds: int = _DEFAULT_SNAPSHOT_TTL,
snapshot_ttl_seconds: int | None = _DEFAULT_SNAPSHOT_TTL,
):
"""Builds the saver object.
Args:
objects_to_save: Mapping specifying what to snapshot.
directory: Which directory to put the snapshot in.
time_delta_minutes: How often to save the snapshot, in minutes.
snapshot_ttl_seconds: TTL (time to leave) in seconds for snapshots.
snapshot_ttl_seconds: TTL (time to live) in seconds for snapshots. If
`None`, then snapshots will be created in `directory` without a TTL.
"""
objects_to_save = objects_to_save or {}

Expand Down
32 changes: 32 additions & 0 deletions acme/tf/savers_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,6 +201,38 @@ def test_snapshot(self):
assert np.allclose(outputs1, outputs2)
assert all(tree.map_structure(np.allclose, list(grads1), list(grads2)))

def test_snapshot_no_ttl(self):
"""Test that snapshotter correctly calls saves/restores snapshots w/o a TTL."""
# Create a test network.
net1 = networks.LayerNormMLP([10, 10])
spec = specs.Array([10], dtype=np.float32)
tf2_utils.create_variables(net1, [spec])

# Save the test network.
directory = self.get_tempdir()
objects_to_save = {'net': net1}
snapshotter = tf2_savers.Snapshotter(
objects_to_save, directory=directory, snapshot_ttl_seconds=None
)
snapshotter.save()

# Reload the test network.
net2 = tf.saved_model.load(os.path.join(snapshotter.directory, 'net'))
inputs = tf2_utils.add_batch_dim(tf2_utils.zeros_like(spec))

with tf.GradientTape() as tape:
outputs1 = net1(inputs)
loss1 = tf.math.reduce_sum(outputs1)
grads1 = tape.gradient(loss1, net1.trainable_variables)

with tf.GradientTape() as tape:
outputs2 = net2(inputs)
loss2 = tf.math.reduce_sum(outputs2)
grads2 = tape.gradient(loss2, net2.trainable_variables)

assert np.allclose(outputs1, outputs2)
assert all(tree.map_structure(np.allclose, list(grads1), list(grads2)))

def test_snapshot_distribution(self):
"""Test that snapshotter correctly calls saves/restores snapshots."""
# Create a test network.
Expand Down

0 comments on commit 766a98b

Please sign in to comment.