Skip to content

Commit

Permalink
Add missing initializer argument of 0 to tree_reduce in tree_vdot and…
Browse files Browse the repository at this point in the history
… tree_sum.
  • Loading branch information
carlosgmartin committed Sep 20, 2024
1 parent ee63e45 commit dfb10c8
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 2 deletions.
4 changes: 2 additions & 2 deletions optax/tree_utils/_tree_math.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,7 +152,7 @@ def tree_vdot(tree_x: Any, tree_y: Any) -> chex.Numeric:
numerical issues.
"""
vdots = jtu.tree_map(_vdot_safe, tree_x, tree_y)
return jtu.tree_reduce(operator.add, vdots)
return jtu.tree_reduce(operator.add, vdots, 0)


def tree_sum(tree: Any) -> chex.Numeric:
Expand All @@ -165,7 +165,7 @@ def tree_sum(tree: Any) -> chex.Numeric:
a scalar value.
"""
sums = jtu.tree_map(jnp.sum, tree)
return jtu.tree_reduce(operator.add, sums)
return jtu.tree_reduce(operator.add, sums, 0)


def _square(leaf):
Expand Down
6 changes: 6 additions & 0 deletions optax/tree_utils/_tree_math_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -232,6 +232,12 @@ def test_bias_correction_bf16(self):
tu.tree_bias_correction(m, decay, count),
custom_message=f'failed with decay={decay}, count={count}')

def test_empty_tree_reduce(self):
# assert False
for tree in [{}, (), [], None, {"key": [None, [None]]}]:
tu.tree_sum(tree)
tu.tree_vdot(tree, tree)


if __name__ == '__main__':
absltest.main()

0 comments on commit dfb10c8

Please sign in to comment.