Skip to content

Commit

Permalink
Support strided load / store in interpret mode
Browse files Browse the repository at this point in the history
  • Loading branch information
ayaka14732 committed Jul 29, 2024
1 parent 6a7822a commit 73493de
Showing 1 changed file with 12 additions and 1 deletion.
13 changes: 12 additions & 1 deletion jax/_src/state/discharge.py
Original file line number Diff line number Diff line change
Expand Up @@ -247,10 +247,21 @@ def index_array(x, indexers):
continue
if indexer is None:
continue

if all(isinstance(i, indexing.Slice) for i in indexer.indices):
t = []
for i in indexer.indices:
if isinstance(i, indexing.Slice):
start = i.start
size = i.size * i.stride
stride = i.stride
t.append((start, size, stride))

result = lax_slicing.slice(result, *zip(*t))
# If everything in the indexer is a slice or ()-shaped, we can also
# use `lax.dynamic_slice` with 1-sized slices for ()-shaped indices.
# We need to squeeze out the 1-sized slices at the end.
if maybe_slice := _maybe_convert_to_dynamic_slice(indexer):
elif maybe_slice := _maybe_convert_to_dynamic_slice(indexer):
starts, sizes, squeeze_dims = maybe_slice
y = lax_slicing.dynamic_slice(result, starts, sizes)
result = lax.squeeze(y, squeeze_dims)
Expand Down

0 comments on commit 73493de

Please sign in to comment.