Skip to content

Commit

Permalink
Merge pull request #1065 from carlosgmartin:add_missing_tree_reduce_i…
Browse files Browse the repository at this point in the history
…nitializer

PiperOrigin-RevId: 680875068
  • Loading branch information
OptaxDev committed Oct 1, 2024
2 parents 93784c4 + d2d30b8 commit 04d79e5
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 2 deletions.
4 changes: 2 additions & 2 deletions optax/tree_utils/_tree_math.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,7 @@ def tree_vdot(tree_x: Any, tree_y: Any) -> chex.Numeric:
numerical issues.
"""
vdots = jax.tree.map(_vdot_safe, tree_x, tree_y)
return jax.tree.reduce(operator.add, vdots)
return jax.tree.reduce(operator.add, vdots, initializer=0)


def tree_sum(tree: Any) -> chex.Numeric:
Expand All @@ -164,7 +164,7 @@ def tree_sum(tree: Any) -> chex.Numeric:
a scalar value.
"""
sums = jax.tree.map(jnp.sum, tree)
return jax.tree.reduce(operator.add, sums)
return jax.tree.reduce(operator.add, sums, initializer=0)


def _square(leaf):
Expand Down
5 changes: 5 additions & 0 deletions optax/tree_utils/_tree_math_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -231,6 +231,11 @@ def test_bias_correction_bf16(self):
tu.tree_bias_correction(m, decay, count),
custom_message=f'failed with decay={decay}, count={count}')

def test_empty_tree_reduce(self):
for tree in [{}, (), [], None, {'key': [None, [None]]}]:
self.assertEqual(tu.tree_sum(tree), 0)
self.assertEqual(tu.tree_vdot(tree, tree), 0)


if __name__ == '__main__':
absltest.main()

0 comments on commit 04d79e5

Please sign in to comment.