From dfb10c89480adcb4aa8119e8f55a0f7c6fe7dd72 Mon Sep 17 00:00:00 2001 From: carlosgmartin Date: Thu, 19 Sep 2024 21:58:42 -0400 Subject: [PATCH] Add missing initializer argument of 0 to tree_reduce in tree_vdot and tree_sum. --- optax/tree_utils/_tree_math.py | 4 ++-- optax/tree_utils/_tree_math_test.py | 6 ++++++ 2 files changed, 8 insertions(+), 2 deletions(-) diff --git a/optax/tree_utils/_tree_math.py b/optax/tree_utils/_tree_math.py index 96636ec4b..0d671acad 100644 --- a/optax/tree_utils/_tree_math.py +++ b/optax/tree_utils/_tree_math.py @@ -152,7 +152,7 @@ def tree_vdot(tree_x: Any, tree_y: Any) -> chex.Numeric: numerical issues. """ vdots = jtu.tree_map(_vdot_safe, tree_x, tree_y) - return jtu.tree_reduce(operator.add, vdots) + return jtu.tree_reduce(operator.add, vdots, 0) def tree_sum(tree: Any) -> chex.Numeric: @@ -165,7 +165,7 @@ def tree_sum(tree: Any) -> chex.Numeric: a scalar value. """ sums = jtu.tree_map(jnp.sum, tree) - return jtu.tree_reduce(operator.add, sums) + return jtu.tree_reduce(operator.add, sums, 0) def _square(leaf): diff --git a/optax/tree_utils/_tree_math_test.py b/optax/tree_utils/_tree_math_test.py index 792e32e33..467611990 100644 --- a/optax/tree_utils/_tree_math_test.py +++ b/optax/tree_utils/_tree_math_test.py @@ -232,6 +232,12 @@ def test_bias_correction_bf16(self): tu.tree_bias_correction(m, decay, count), custom_message=f'failed with decay={decay}, count={count}') + def test_empty_tree_reduce(self): + # assert False + for tree in [{}, (), [], None, {"key": [None, [None]]}]: + tu.tree_sum(tree) + tu.tree_vdot(tree, tree) + if __name__ == '__main__': absltest.main()