From 766a98bfac9b595d5b80384e435076ef5e52d9ba Mon Sep 17 00:00:00 2001 From: Acme Contributor Date: Tue, 17 Sep 2024 13:08:07 -0700 Subject: [PATCH] This CL updates the `Snapshotter` class to allow `None` as a valid value for `snapshot_ttl_seconds`. This TTL value is passed to `paths.process_path` which already supports a `None` TTL. PiperOrigin-RevId: 675680440 Change-Id: If44d8bcb01f53306660bdcd19a23b570b8f4a20e --- acme/tf/savers.py | 5 +++-- acme/tf/savers_test.py | 32 ++++++++++++++++++++++++++++++++ 2 files changed, 35 insertions(+), 2 deletions(-) diff --git a/acme/tf/savers.py b/acme/tf/savers.py index 1b33478afd..53d8ed8394 100644 --- a/acme/tf/savers.py +++ b/acme/tf/savers.py @@ -286,7 +286,7 @@ def __init__( *, directory: str = '~/acme/', time_delta_minutes: float = 30.0, - snapshot_ttl_seconds: int = _DEFAULT_SNAPSHOT_TTL, + snapshot_ttl_seconds: int | None = _DEFAULT_SNAPSHOT_TTL, ): """Builds the saver object. @@ -294,7 +294,8 @@ def __init__( objects_to_save: Mapping specifying what to snapshot. directory: Which directory to put the snapshot in. time_delta_minutes: How often to save the snapshot, in minutes. - snapshot_ttl_seconds: TTL (time to leave) in seconds for snapshots. + snapshot_ttl_seconds: TTL (time to live) in seconds for snapshots. If + `None`, then snapshots will be created in `directory` without a TTL. """ objects_to_save = objects_to_save or {} diff --git a/acme/tf/savers_test.py b/acme/tf/savers_test.py index cd075a1818..e0ca33f583 100644 --- a/acme/tf/savers_test.py +++ b/acme/tf/savers_test.py @@ -201,6 +201,38 @@ def test_snapshot(self): assert np.allclose(outputs1, outputs2) assert all(tree.map_structure(np.allclose, list(grads1), list(grads2))) + def test_snapshot_no_ttl(self): + """Test that snapshotter correctly calls saves/restores snapshots w/o a TTL.""" + # Create a test network. + net1 = networks.LayerNormMLP([10, 10]) + spec = specs.Array([10], dtype=np.float32) + tf2_utils.create_variables(net1, [spec]) + + # Save the test network. + directory = self.get_tempdir() + objects_to_save = {'net': net1} + snapshotter = tf2_savers.Snapshotter( + objects_to_save, directory=directory, snapshot_ttl_seconds=None + ) + snapshotter.save() + + # Reload the test network. + net2 = tf.saved_model.load(os.path.join(snapshotter.directory, 'net')) + inputs = tf2_utils.add_batch_dim(tf2_utils.zeros_like(spec)) + + with tf.GradientTape() as tape: + outputs1 = net1(inputs) + loss1 = tf.math.reduce_sum(outputs1) + grads1 = tape.gradient(loss1, net1.trainable_variables) + + with tf.GradientTape() as tape: + outputs2 = net2(inputs) + loss2 = tf.math.reduce_sum(outputs2) + grads2 = tape.gradient(loss2, net2.trainable_variables) + + assert np.allclose(outputs1, outputs2) + assert all(tree.map_structure(np.allclose, list(grads1), list(grads2))) + def test_snapshot_distribution(self): """Test that snapshotter correctly calls saves/restores snapshots.""" # Create a test network.