diff --git a/CHANGELOG.md b/CHANGELOG.md index d4e775d59db7..45517190a0b0 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -15,8 +15,6 @@ Remember to align the itemized text with the first line of an item within a list * JAX arrays now support NumPy-style scalar boolean indexing, e.g. `x[True]` or `x[False]`. * Added {mod}`jax.tree` module, with a more convenient interface for referencing functions in {mod}`jax.tree_util`. - * {func}`jax.tree_map` is deprecated; use `jax.tree.map` instead, or for backward - compatibility with older JAX versions, use {func}`jax.tree_util.tree_map`. * {func}`jax.tree.transpose` (i.e. {func}`jax.tree_util.tree_transpose`) now accepts `inner_treedef=None`, in which case the inner treedef will be automatically inferred. diff --git a/jax/__init__.py b/jax/__init__.py index 42aeefec4740..ba9a69ed7ec7 100644 --- a/jax/__init__.py +++ b/jax/__init__.py @@ -136,7 +136,7 @@ ) from jax._src.tree_util import ( - tree_map as _deprecated_tree_map, + tree_map as tree_map, treedef_is_leaf as _deprecated_treedef_is_leaf, tree_flatten as _deprecated_tree_flatten, tree_leaves as _deprecated_tree_leaves, @@ -212,12 +212,6 @@ "or jax.tree_util.tree_unflatten (any JAX version).", _deprecated_tree_unflatten ), - # Added Feb 22, 2024 - "tree_map": ( - "jax.tree_map is deprecated: use jax.tree.map (jax v0.4.25 or newer) " - "or jax.tree_util.tree_map (any JAX version).", - _deprecated_tree_map - ), } import typing as _typing @@ -225,7 +219,6 @@ from jax._src.tree_util import treedef_is_leaf as treedef_is_leaf from jax._src.tree_util import tree_flatten as tree_flatten from jax._src.tree_util import tree_leaves as tree_leaves - from jax._src.tree_util import tree_map as tree_map from jax._src.tree_util import tree_structure as tree_structure from jax._src.tree_util import tree_transpose as tree_transpose from jax._src.tree_util import tree_unflatten as tree_unflatten