Skip to content

Commit

Permalink
Remove the deprecation of jax.tree_map for the release of 0.4.25
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 610014256
  • Loading branch information
yashk2810 authored and jax authors committed Feb 24, 2024
1 parent 22996bb commit f4045dc
Show file tree
Hide file tree
Showing 2 changed files with 1 addition and 10 deletions.
2 changes: 0 additions & 2 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,6 @@ Remember to align the itemized text with the first line of an item within a list
* JAX arrays now support NumPy-style scalar boolean indexing, e.g. `x[True]` or `x[False]`.
* Added {mod}`jax.tree` module, with a more convenient interface for referencing functions
in {mod}`jax.tree_util`.
* {func}`jax.tree_map` is deprecated; use `jax.tree.map` instead, or for backward
compatibility with older JAX versions, use {func}`jax.tree_util.tree_map`.
* {func}`jax.tree.transpose` (i.e. {func}`jax.tree_util.tree_transpose`) now accepts
`inner_treedef=None`, in which case the inner treedef will be automatically inferred.

Expand Down
9 changes: 1 addition & 8 deletions jax/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,7 @@
)

from jax._src.tree_util import (
tree_map as _deprecated_tree_map,
tree_map as tree_map,
treedef_is_leaf as _deprecated_treedef_is_leaf,
tree_flatten as _deprecated_tree_flatten,
tree_leaves as _deprecated_tree_leaves,
Expand Down Expand Up @@ -212,20 +212,13 @@
"or jax.tree_util.tree_unflatten (any JAX version).",
_deprecated_tree_unflatten
),
# Added Feb 22, 2024
"tree_map": (
"jax.tree_map is deprecated: use jax.tree.map (jax v0.4.25 or newer) "
"or jax.tree_util.tree_map (any JAX version).",
_deprecated_tree_map
),
}

import typing as _typing
if _typing.TYPE_CHECKING:
from jax._src.tree_util import treedef_is_leaf as treedef_is_leaf
from jax._src.tree_util import tree_flatten as tree_flatten
from jax._src.tree_util import tree_leaves as tree_leaves
from jax._src.tree_util import tree_map as tree_map
from jax._src.tree_util import tree_structure as tree_structure
from jax._src.tree_util import tree_transpose as tree_transpose
from jax._src.tree_util import tree_unflatten as tree_unflatten
Expand Down

0 comments on commit f4045dc

Please sign in to comment.