Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support strided load / store in interpret mode #22719

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 0 additions & 8 deletions jax/_src/pallas/primitives.py
Original file line number Diff line number Diff line change
Expand Up @@ -464,10 +464,6 @@ def _load_discharge_rule(in_avals, out_avals, *args_flat, args_tree, **_):
raise NotImplementedError("Only one indexer supported in discharge rule.")
idx = indexers[0]
if all((isinstance(s, Slice) or not s.shape) for s in idx.indices):
# TODO(b/329733289): support strided load/store in interpret mode.
for s in idx.indices:
if isinstance(s, Slice) and s.stride > 1:
raise NotImplementedError("Unimplemented stride support.")
indices = idx.indices
scalar_dims = [not isinstance(s, Slice) and not s.shape for s in indices]
slice_starts = [s.start if isinstance(s, Slice) else s for s in indices]
Expand Down Expand Up @@ -576,10 +572,6 @@ def _swap_discharge_rule(in_avals, out_avals, *args_flat, args_tree, **_):
raise NotImplementedError("Only one indexer supported in discharge rule.")
idx = indexers[0]
if all((isinstance(s, Slice) or not s.shape) for s in idx.indices):
# TODO(b/329733289): support strided load/store in interpret mode.
for s in idx.indices:
if isinstance(s, Slice) and s.stride > 1:
raise NotImplementedError("Unimplemented stride support.")
indices = idx.indices
scalar_dims = [
i
Expand Down
19 changes: 15 additions & 4 deletions jax/_src/state/discharge.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,7 +177,7 @@ def _convert_to_array_indexer(indexer: indexing.NDIndexer
assert isinstance(idx, indexing.Slice)
slice_indices = lax.broadcasted_iota(
np.dtype("int32"), total_shape, next(slice_dim_iter)
) + idx.start
) * idx.stride + idx.start
slice_indexer.append(slice_indices)
integer_indexer = tuple(
lax.expand_dims(idx, (-1,)) for idx in integer_indexer
Expand All @@ -198,10 +198,9 @@ def _maybe_convert_to_dynamic_slice(
if not all(isinstance(i, indexing.Slice) or not np.shape(i)
for i in indexer.indices):
return None
# TODO(b/329733289): support strided load/store in interpret mode.
for i in indexer.indices:
if isinstance(i, indexing.Slice) and i.stride > 1:
raise NotImplementedError("Unimplemented stride support.")
return None
_convert_i32 = lambda x: lax.convert_element_type(x, np.dtype("int32"))
starts = tuple(
_convert_i32(i.start) if isinstance(i, indexing.Slice)
Expand Down Expand Up @@ -247,10 +246,22 @@ def index_array(x, indexers):
continue
if indexer is None:
continue

if all(isinstance(i, indexing.Slice) for i in indexer.indices):
t = []
for i in indexer.indices:
if isinstance(i, indexing.Slice):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we convert any scalar indexer i to a single-element slice and then do squeeze?

start = i.start
size = i.size * i.stride
stride = i.stride
t.append((start, size, stride))

result = lax_slicing.slice(result, *zip(*t))

# If everything in the indexer is a slice or ()-shaped, we can also
# use `lax.dynamic_slice` with 1-sized slices for ()-shaped indices.
# We need to squeeze out the 1-sized slices at the end.
if maybe_slice := _maybe_convert_to_dynamic_slice(indexer):
elif maybe_slice := _maybe_convert_to_dynamic_slice(indexer):
starts, sizes, squeeze_dims = maybe_slice
y = lax_slicing.dynamic_slice(result, starts, sizes)
result = lax.squeeze(y, squeeze_dims)
Expand Down
42 changes: 37 additions & 5 deletions jax/_src/state/indexing.py
Original file line number Diff line number Diff line change
Expand Up @@ -251,8 +251,40 @@ def from_indices_shape(cls, indices, shape) -> NDIndexer:
return cls(indices, shape, int_indexer_shape, validate=True)

def get_indexer_shape(self) -> tuple[int | Array, ...]:
_, slice_indexers, _ = unpack_ndindexer(self)
slice_shape = [s.size for s in slice_indexers]
# In NDIndexers, the int_indexer_shape is *always* at the front of the
# result.
return (*self.int_indexer_shape, *slice_shape)
is_int_indexing, slice_indexers, int_indexers = unpack_ndindexer(self)

has_int_indexers = any(is_int_indexing)
has_non_adjacent_int_indexers = has_non_adjacent_true(is_int_indexing)

# shift the int_indexer_shape to the front
if has_non_adjacent_int_indexers:
slice_shape = [s.size for s in slice_indexers]
c = (*self.int_indexer_shape, *slice_shape)

elif has_int_indexers:
slice_shape = [s.size for s in slice_indexers]
int_indexer_shape = self.int_indexer_shape
pos = is_int_indexing.index(True)
c = (*slice_shape[:pos], *int_indexer_shape, *slice_shape[pos:])

else:
c = tuple(i.size for i in self.indices)

return c


# TODO: make this function better
def has_non_adjacent_true(seq: list[bool]) -> bool:
seen_true = False
last_is_true = False

for i in seq:
if i and seen_true and not last_is_true:
return True

if i:
seen_true = True

last_is_true = i

return False
127 changes: 127 additions & 0 deletions tests/pallas/indexing_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,9 @@
"""Tests for Pallas indexing logic and abstractions."""

from __future__ import annotations
from collections.abc import Sequence
import sys
from typing import NoReturn
import unittest

from absl.testing import absltest
Expand Down Expand Up @@ -620,5 +622,130 @@ class IndexerOpsInterpreterTest(IndexerOpsTest):
INTERPRET = True


class SliceBuilder:
def __getitem__(self, args):
if not isinstance(args, tuple):
args = (args,)
return args

sb = SliceBuilder()
"""
A helper object to create tuples of indexer objects using a more concise
bracket notation.

It allows for the creation of complex indexer tuples that include slices
and other objects in a readable and compact form.

Examples:

a = 5
indexers = sb[::4, a, 1::2, a]
print(indexers) # Output: (slice(None, None, 4), 5, slice(1, None, 2), 5)
"""

def _slice_builder_new(cls) -> NoReturn:
raise RuntimeError('SliceBuilder cannot be instantiated directly')

def _slice_builder_init_subclass(cls, **kwargs) -> NoReturn:
raise RuntimeError('Subclassing SliceBuilder is not allowed')

SliceBuilder.__new__ = _slice_builder_new
SliceBuilder.__init_subclass__ = _slice_builder_init_subclass


class FrozenSet(Sequence):
def __init__(self, iterable):
self._seen_ids = set()
self._items = []
for item in iterable:
if id(item) not in self._seen_ids:
self._seen_ids.add(id(item))
self._items.append(item)
self._items = tuple(self._items) # Freeze the list by converting it to a tuple

def __getitem__(self, index):
return self._items[index]

def __len__(self):
return len(self._items)

def __repr__(self):
return f'FrozenList({self._items})'

def __contains__(self, item):
return id(item) in self._seen_ids


_ADVANCED_INDEXER_TEST_CASES = [
# ((16, 3, 6, 2), "::4, a, 1::2, b"),
# ((16, 3), "a, a"),
# ((16, 16), "::4, ::4"),
# ((16, 16), "1:14:2, 2:13:4"),
# ((16, 3), "a, :"),
((16, 3), ":, a"),
# ((16, 3), "a, ::4"),
# ((16, 3), "::4, a"),
# ((8, 8, 3), "::4, ::2, a"),
# ((8, 8, 3), "::4, a, ::2"),
# ((8, 8, 3, 7), "::4, b, ::2, ::2"),
# ((8, 8, 3, 7), "b, ::4, ::2, ::2"),
# ((8, 8, 3, 7), "b, ::4, a, ::2"),
# ((3, 8, 8, 7), "b, a, ::4, ::2"),
# ((8, 8, 3, 7), "::4, b, a, ::2"),
# ((8, 8, 3, 6), "b, ::4, a, c"),
# ((8, 6, 4), "a"),
# ((6, 8, 4), "c, ::3"),
# ((6, 8, 4), "c, c"),
# ((8, 6, 4), "::3, c"),
# ((6, 2), "d"),
# ((8, 6), "::4, d"),
]


class AdvancedIndexerOpsTest(PallasBaseTest):

def setUp(self):
self.a = jnp.array([1,1,1,1,1], dtype=jnp.int32)
self.b = jnp.array([1,2,2,2,2], dtype=jnp.int32)
self.c = jnp.array([1,0,2,2,-1,1], dtype=jnp.int32)
self.d = jnp.array([1,0,0,0,0,1], dtype=jnp.bool_)

super().setUp()

@parameterized.parameters(_ADVANCED_INDEXER_TEST_CASES)
def run_test(self, in_shape: tuple[int, ...], indexer_str: str):
a, b, c, d = self.a, self.b, self.c, self.d

x = jnp.arange(np.prod(in_shape), dtype=jnp.float32).reshape(in_shape)
indexer = eval(f"sb[{indexer_str}]")
y = x[indexer]

fancy_indexers = tuple(FrozenSet(i for i in indexer if isinstance(i, jnp.ndarray)))
is_fancy_indexers = [isinstance(i, jnp.ndarray) for i in indexer]
indexer_items = indexer_str.split(", ")
fancy_indexer_names = tuple({n: None for n, b in zip(indexer_items, is_fancy_indexers) if b})

func_str = f'''
def kernel(x_ref, {" ".join(f"{n}_ref," for n in fancy_indexer_names)} o_ref):
{"\n".join(f" {f'{n} = {n}_ref[...]'}" for n in fancy_indexer_names)}
w = x_ref[{indexer_str}]
o_ref[...] = w
'''
# print(func_str)

exec(func_str, globals())

y_ = self.pallas_call(
kernel,
out_shape=jax.ShapeDtypeStruct(y.shape, jnp.float32),
)(x, *fancy_indexers)

np.testing.assert_array_equal(y_, y)


class AdvancedIndexerOpsInterpreterTest(AdvancedIndexerOpsTest):
INTERPRET = True


if __name__ == "__main__":
absltest.main(testLoader=jtu.JaxTestLoader())
4 changes: 0 additions & 4 deletions tests/pallas/ops_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -1161,10 +1161,6 @@ def masked_oob_load_store_slice(x_ref, mask_ref, start_idx_ref, o_ref):
np.testing.assert_array_equal(out, o_new)

def test_strided_load(self):
if self.INTERPRET:
# TODO(b/329733289): Remove this once the bug is fixed.
self.skipTest("Strided load not yet supported in interpreter mode")

# Reproducer from https://github.com/google/jax/issues/20895.
@functools.partial(
self.pallas_call,
Expand Down
Loading