diff --git a/jax/custom_batching.py b/jax/custom_batching.py index a4850f04c2ec..9b8dc8f8709a 100644 --- a/jax/custom_batching.py +++ b/jax/custom_batching.py @@ -13,6 +13,6 @@ # limitations under the License. from jax._src.custom_batching import ( - custom_vmap, - sequential_vmap, + custom_vmap as custom_vmap, + sequential_vmap as sequential_vmap, ) diff --git a/jax/custom_transpose.py b/jax/custom_transpose.py index 311139da2567..314163c4684a 100644 --- a/jax/custom_transpose.py +++ b/jax/custom_transpose.py @@ -13,5 +13,5 @@ # limitations under the License. from jax._src.custom_transpose import ( - custom_transpose, + custom_transpose as custom_transpose, ) diff --git a/jax/distributed.py b/jax/distributed.py index 284ae6f95f48..cf39b81f423a 100644 --- a/jax/distributed.py +++ b/jax/distributed.py @@ -12,4 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from jax._src.distributed import (initialize, shutdown) +from jax._src.distributed import ( + initialize as initialize, + shutdown as shutdown, +) diff --git a/jax/dlpack.py b/jax/dlpack.py index 707e966ee243..a65496ec0cbf 100644 --- a/jax/dlpack.py +++ b/jax/dlpack.py @@ -12,4 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. -from jax._src.dlpack import (to_dlpack, from_dlpack, SUPPORTED_DTYPES) +from jax._src.dlpack import ( + to_dlpack as to_dlpack, + from_dlpack as from_dlpack, + SUPPORTED_DTYPES as SUPPORTED_DTYPES, +)