diff --git a/jax/_src/state/discharge.py b/jax/_src/state/discharge.py index eaa96c940d18..7f130dd4cdcf 100644 --- a/jax/_src/state/discharge.py +++ b/jax/_src/state/discharge.py @@ -247,10 +247,21 @@ def index_array(x, indexers): continue if indexer is None: continue + + if all(isinstance(i, indexing.Slice) for i in indexer.indices): + t = [] + for i in indexer.indices: + if isinstance(i, indexing.Slice): + start = i.start + size = i.size * i.stride + stride = i.stride + t.append((start, size, stride)) + + result = lax_slicing.slice(result, *zip(*t)) # If everything in the indexer is a slice or ()-shaped, we can also # use `lax.dynamic_slice` with 1-sized slices for ()-shaped indices. # We need to squeeze out the 1-sized slices at the end. - if maybe_slice := _maybe_convert_to_dynamic_slice(indexer): + elif maybe_slice := _maybe_convert_to_dynamic_slice(indexer): starts, sizes, squeeze_dims = maybe_slice y = lax_slicing.dynamic_slice(result, starts, sizes) result = lax.squeeze(y, squeeze_dims)