From 3acbd44952b86f54de6c937d9ca0874e47b382f9 Mon Sep 17 00:00:00 2001 From: Yash Katariya Date: Tue, 1 Feb 2022 16:38:12 -0800 Subject: [PATCH] Remove isinstance checks PiperOrigin-RevId: 425745786 --- jax/experimental/global_device_array.py | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/jax/experimental/global_device_array.py b/jax/experimental/global_device_array.py index 26daae5fdfb2..4a1076475d5c 100644 --- a/jax/experimental/global_device_array.py +++ b/jax/experimental/global_device_array.py @@ -54,7 +54,7 @@ def _canonicalize_mesh_axes(mesh_axes): return pspec def _get_indices(global_shape: Shape, global_mesh: pxla.Mesh, - mesh_axes: MeshAxes) -> Tuple[pxla.Index, ...]: + mesh_axes: MeshAxes) -> Tuple[Index, ...]: # Import here to avoid cyclic import error when importing gda in pjit.py. from jax.experimental.pjit import get_array_mapping, _prepare_axis_resources @@ -66,11 +66,7 @@ def _get_indices(global_shape: Shape, global_mesh: pxla.Mesh, sharding_spec = pxla.mesh_sharding_specs( global_mesh.shape, global_mesh.axis_names)(aval, array_mapping) indices = pxla.spec_to_indices(global_shape, sharding_spec) - for index in indices: - assert isinstance(index, tuple) - for idx in index: - assert isinstance(idx, slice) - return indices + return indices # type: ignore @_convert_list_args_to_tuple