diff --git a/optax/tree_utils/_tree_math.py b/optax/tree_utils/_tree_math.py index 96636ec4b..0d671acad 100644 --- a/optax/tree_utils/_tree_math.py +++ b/optax/tree_utils/_tree_math.py @@ -152,7 +152,7 @@ def tree_vdot(tree_x: Any, tree_y: Any) -> chex.Numeric: numerical issues. """ vdots = jtu.tree_map(_vdot_safe, tree_x, tree_y) - return jtu.tree_reduce(operator.add, vdots) + return jtu.tree_reduce(operator.add, vdots, 0) def tree_sum(tree: Any) -> chex.Numeric: @@ -165,7 +165,7 @@ def tree_sum(tree: Any) -> chex.Numeric: a scalar value. """ sums = jtu.tree_map(jnp.sum, tree) - return jtu.tree_reduce(operator.add, sums) + return jtu.tree_reduce(operator.add, sums, 0) def _square(leaf): diff --git a/optax/tree_utils/_tree_math_test.py b/optax/tree_utils/_tree_math_test.py index 792e32e33..574c02769 100644 --- a/optax/tree_utils/_tree_math_test.py +++ b/optax/tree_utils/_tree_math_test.py @@ -232,6 +232,12 @@ def test_bias_correction_bf16(self): tu.tree_bias_correction(m, decay, count), custom_message=f'failed with decay={decay}, count={count}') + def test_empty_tree_reduce(self): + # assert False + for tree in [{}, (), [], None, {'key': [None, [None]]}]: + self.assertEqual(tu.tree_sum(tree), 0) + self.assertEqual(tu.tree_vdot(tree, tree), 0) + if __name__ == '__main__': absltest.main()