Skip to content

Commit

Permalink
Update
Browse files Browse the repository at this point in the history
  • Loading branch information
ayaka14732 committed Aug 19, 2024
1 parent fcb04f9 commit 5d691cd
Show file tree
Hide file tree
Showing 3 changed files with 177 additions and 6 deletions.
14 changes: 13 additions & 1 deletion jax/_src/state/discharge.py
Original file line number Diff line number Diff line change
Expand Up @@ -246,10 +246,22 @@ def index_array(x, indexers):
continue
if indexer is None:
continue

if all(isinstance(i, indexing.Slice) for i in indexer.indices):
t = []
for i in indexer.indices:
if isinstance(i, indexing.Slice):
start = i.start
size = i.size * i.stride
stride = i.stride
t.append((start, size, stride))

result = lax_slicing.slice(result, *zip(*t))

# If everything in the indexer is a slice or ()-shaped, we can also
# use `lax.dynamic_slice` with 1-sized slices for ()-shaped indices.
# We need to squeeze out the 1-sized slices at the end.
if maybe_slice := _maybe_convert_to_dynamic_slice(indexer):
elif maybe_slice := _maybe_convert_to_dynamic_slice(indexer):
starts, sizes, squeeze_dims = maybe_slice
y = lax_slicing.dynamic_slice(result, starts, sizes)
result = lax.squeeze(y, squeeze_dims)
Expand Down
42 changes: 37 additions & 5 deletions jax/_src/state/indexing.py
Original file line number Diff line number Diff line change
Expand Up @@ -251,8 +251,40 @@ def from_indices_shape(cls, indices, shape) -> NDIndexer:
return cls(indices, shape, int_indexer_shape, validate=True)

def get_indexer_shape(self) -> tuple[int | Array, ...]:
_, slice_indexers, _ = unpack_ndindexer(self)
slice_shape = [s.size for s in slice_indexers]
# In NDIndexers, the int_indexer_shape is *always* at the front of the
# result.
return (*self.int_indexer_shape, *slice_shape)
is_int_indexing, slice_indexers, int_indexers = unpack_ndindexer(self)

has_int_indexers = any(is_int_indexing)
has_non_adjacent_int_indexers = has_non_adjacent_true(is_int_indexing)

# shift the int_indexer_shape to the front
if has_non_adjacent_int_indexers:
slice_shape = [s.size for s in slice_indexers]
c = (*self.int_indexer_shape, *slice_shape)

elif has_int_indexers:
slice_shape = [s.size for s in slice_indexers]
int_indexer_shape = self.int_indexer_shape
pos = is_int_indexing.index(True)
c = (*slice_shape[:pos], *int_indexer_shape, *slice_shape[pos:])

else:
c = tuple(i.size for i in self.indices)

return c


# TODO: make this function better
def has_non_adjacent_true(seq: list[bool]) -> bool:
seen_true = False
last_is_true = False

for i in seq:
if i and seen_true and not last_is_true:
return True

if i:
seen_true = True

last_is_true = i

return False
127 changes: 127 additions & 0 deletions tests/pallas/indexing_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,9 @@
"""Tests for Pallas indexing logic and abstractions."""

from __future__ import annotations
from collections.abc import Sequence
import sys
from typing import NoReturn
import unittest

from absl.testing import absltest
Expand Down Expand Up @@ -620,5 +622,130 @@ class IndexerOpsInterpreterTest(IndexerOpsTest):
INTERPRET = True


class SliceBuilder:
def __getitem__(self, args):
if not isinstance(args, tuple):
args = (args,)
return args

sb = SliceBuilder()
"""
A helper object to create tuples of indexer objects using a more concise
bracket notation.
It allows for the creation of complex indexer tuples that include slices
and other objects in a readable and compact form.
Examples:
a = 5
indexers = sb[::4, a, 1::2, a]
print(indexers) # Output: (slice(None, None, 4), 5, slice(1, None, 2), 5)
"""

def _slice_builder_new(cls) -> NoReturn:
raise RuntimeError('SliceBuilder cannot be instantiated directly')

def _slice_builder_init_subclass(cls, **kwargs) -> NoReturn:
raise RuntimeError('Subclassing SliceBuilder is not allowed')

SliceBuilder.__new__ = _slice_builder_new
SliceBuilder.__init_subclass__ = _slice_builder_init_subclass


class FrozenSet(Sequence):
def __init__(self, iterable):
self._seen_ids = set()
self._items = []
for item in iterable:
if id(item) not in self._seen_ids:
self._seen_ids.add(id(item))
self._items.append(item)
self._items = tuple(self._items) # Freeze the list by converting it to a tuple

def __getitem__(self, index):
return self._items[index]

def __len__(self):
return len(self._items)

def __repr__(self):
return f'FrozenList({self._items})'

def __contains__(self, item):
return id(item) in self._seen_ids


_ADVANCED_INDEXER_TEST_CASES = [
# ((16, 3, 6, 2), "::4, a, 1::2, b"),
# ((16, 3), "a, a"),
# ((16, 16), "::4, ::4"),
# ((16, 16), "1:14:2, 2:13:4"),
# ((16, 3), "a, :"),
((16, 3), ":, a"),
# ((16, 3), "a, ::4"),
# ((16, 3), "::4, a"),
# ((8, 8, 3), "::4, ::2, a"),
# ((8, 8, 3), "::4, a, ::2"),
# ((8, 8, 3, 7), "::4, b, ::2, ::2"),
# ((8, 8, 3, 7), "b, ::4, ::2, ::2"),
# ((8, 8, 3, 7), "b, ::4, a, ::2"),
# ((3, 8, 8, 7), "b, a, ::4, ::2"),
# ((8, 8, 3, 7), "::4, b, a, ::2"),
# ((8, 8, 3, 6), "b, ::4, a, c"),
# ((8, 6, 4), "a"),
# ((6, 8, 4), "c, ::3"),
# ((6, 8, 4), "c, c"),
# ((8, 6, 4), "::3, c"),
# ((6, 2), "d"),
# ((8, 6), "::4, d"),
]


class AdvancedIndexerOpsTest(PallasBaseTest):

def setUp(self):
self.a = jnp.array([1,1,1,1,1], dtype=jnp.int32)
self.b = jnp.array([1,2,2,2,2], dtype=jnp.int32)
self.c = jnp.array([1,0,2,2,-1,1], dtype=jnp.int32)
self.d = jnp.array([1,0,0,0,0,1], dtype=jnp.bool_)

super().setUp()

@parameterized.parameters(_ADVANCED_INDEXER_TEST_CASES)
def run_test(self, in_shape: tuple[int, ...], indexer_str: str):
a, b, c, d = self.a, self.b, self.c, self.d

x = jnp.arange(np.prod(in_shape), dtype=jnp.float32).reshape(in_shape)
indexer = eval(f"sb[{indexer_str}]")
y = x[indexer]

fancy_indexers = tuple(FrozenSet(i for i in indexer if isinstance(i, jnp.ndarray)))
is_fancy_indexers = [isinstance(i, jnp.ndarray) for i in indexer]
indexer_items = indexer_str.split(", ")
fancy_indexer_names = tuple({n: None for n, b in zip(indexer_items, is_fancy_indexers) if b})

func_str = f'''
def kernel(x_ref, {" ".join(f"{n}_ref," for n in fancy_indexer_names)} o_ref):
{"\n".join(f" {f'{n} = {n}_ref[...]'}" for n in fancy_indexer_names)}
w = x_ref[{indexer_str}]
o_ref[...] = w
'''
# print(func_str)

exec(func_str, globals())

y_ = self.pallas_call(
kernel,
out_shape=jax.ShapeDtypeStruct(y.shape, jnp.float32),
)(x, *fancy_indexers)

np.testing.assert_array_equal(y_, y)


class AdvancedIndexerOpsInterpreterTest(AdvancedIndexerOpsTest):
INTERPRET = True


if __name__ == "__main__":
absltest.main(testLoader=jtu.JaxTestLoader())

0 comments on commit 5d691cd

Please sign in to comment.