Skip to content

Commit

Permalink
Separate out and add tests for Pallas indexing logic.
Browse files Browse the repository at this point in the history
Breaking changes:
* This changes `pl.load/store`'s broadcasting semantics. These functions will *no longer* automatically insert dummy axes for broadcasting on behalf of the user. They will try to broadcast values against each other though. This makes it more along the lines of regular NumPy indexing.

PiperOrigin-RevId: 548826315
  • Loading branch information
sharadmv authored and The jax_triton Authors committed Jul 17, 2023
1 parent 163e1d4 commit 695afcf
Show file tree
Hide file tree
Showing 8 changed files with 147 additions and 402 deletions.
5 changes: 2 additions & 3 deletions jax_triton/pallas/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,6 @@

"""Module for pallas, a jaxpr "dialect" for Triton."""
from jax_triton.pallas.core import BlockSpec
from jax_triton.pallas.indexing import ds
from jax_triton.pallas.indexing import dslice
from jax_triton.pallas.indexing import broadcast_to
from jax_triton.pallas.pallas_call import pallas_call
from jax_triton.pallas.pallas_call import pallas_call_p
from jax_triton.pallas.primitives import atomic_add
Expand All @@ -28,6 +25,8 @@
from jax_triton.pallas.primitives import atomic_xchg
from jax_triton.pallas.primitives import atomic_xor
from jax_triton.pallas.primitives import dot
from jax_triton.pallas.primitives import ds
from jax_triton.pallas.primitives import dslice
from jax_triton.pallas.primitives import load
from jax_triton.pallas.primitives import max_contiguous
from jax_triton.pallas.primitives import multiple_of
Expand Down
158 changes: 0 additions & 158 deletions jax_triton/pallas/indexing.py

This file was deleted.

10 changes: 3 additions & 7 deletions jax_triton/pallas/ops/layer_norm.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,12 +162,8 @@ def body(i, acc_ref):
row_idx = i * block_m + jnp.arange(block_m)
row_mask = row_idx < m
mask = row_mask[:, None] & col_mask[None, :]
a = pl.load(
x_ref, (row_idx[:, None], col_idx[None]), mask=mask, other=0.0
).astype(jnp.float32)
dout = pl.load(
do_ref, (row_idx[:, None], col_idx[None]), mask=mask, other=0.0
).astype(jnp.float32)
a = pl.load(x_ref, (row_idx, col_idx), mask=mask, other=0.).astype(jnp.float32)
dout = pl.load(do_ref, (row_idx, col_idx), mask=mask, other=0.).astype(jnp.float32)
mean = pl.load(mean_ref, (row_idx,), mask=row_mask, other=0.).astype(jnp.float32)
rstd = pl.load(rstd_ref, (row_idx,), mask=row_mask, other=0.).astype(jnp.float32)
a_hat = (a - mean[:, None]) * rstd[:, None]
Expand Down Expand Up @@ -279,4 +275,4 @@ def layer_norm_reference(x, weight, bias, *, eps: float = 1e-5):
var = jnp.maximum(0., mean2 - jnp.square(mean))
y = x - mean[:, None]
mul = lax.rsqrt(var + eps)
return y * mul[:, None] * weight[None] + bias[None]
return y * mul[:, None] * weight[None] + bias[None]
8 changes: 2 additions & 6 deletions jax_triton/pallas/ops/rms_norm.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,12 +145,8 @@ def body(i, acc_ref):
row_idx = i * block_m + jnp.arange(block_m)
row_mask = row_idx < m
mask = row_mask[:, None] & col_mask[None, :]
a = pl.load(
x_ref, (row_idx[:, None], col_idx[None]), mask=mask, other=0.0
).astype(jnp.float32)
dout = pl.load(
do_ref, (row_idx[:, None], col_idx[None]), mask=mask, other=0.0
).astype(jnp.float32)
a = pl.load(x_ref, (row_idx, col_idx), mask=mask, other=0.).astype(jnp.float32)
dout = pl.load(do_ref, (row_idx, col_idx), mask=mask, other=0.).astype(jnp.float32)
rstd = pl.load(rstd_ref, (row_idx,), mask=row_mask, other=0.).astype(jnp.float32)
a_hat = a * rstd[:, None]
dw_acc_ref, db_acc_ref = acc_ref
Expand Down
123 changes: 114 additions & 9 deletions jax_triton/pallas/primitives.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,11 @@

"""Module for pallas-specific JAX primitives and functions."""
from __future__ import annotations
import dataclasses
import enum
import functools

from typing import Any, Optional, Tuple
from typing import Any, List, Optional, Tuple, Union

import jax
from jax import lax
Expand All @@ -26,24 +27,40 @@
from jax._src import core as jax_core
from jax._src import pretty_printer as pp
from jax._src import state
from jax._src.util import (safe_map, safe_zip)
from jax._src.util import (safe_map, safe_zip, split_list, merge_lists,
partition_list)
from jax._src.state import primitives as state_primitives
from jax._src.state import discharge as state_discharge
from jax.interpreters import ad
from jax.interpreters import mlir
from jax.interpreters import xla
import jax.numpy as jnp
import numpy as np

from jax_triton.pallas import core as pallas_core
from jax_triton.pallas import indexing

partial = functools.partial
Slice = indexing.Slice
NDIndexer = indexing.NDIndexer

map, unsafe_map = safe_map, map
zip, unsafe_zip = safe_zip, zip

def _process_idx(idx, ref_shape):
if any(isinstance(i, slice) and i != slice(None) for i in idx):
raise NotImplementedError("Non-`slice(None)` slices not supported yet.")
if len(idx) != len(ref_shape):
raise ValueError("Must provide indexer for each dimension of `Ref`.")
is_int_indexing = [isinstance(i, (jnp.ndarray, int)) for i in idx]
other_indexers, int_indexers = partition_list(is_int_indexing, idx)
int_indexers = [np.array(i, np.int32) if isinstance(i, int) else i for i in
int_indexers]
indexer_shapes = [jnp.shape(i) for i in int_indexers]
bcast_shape = tuple(s for i in indexer_shapes for s in i)
idx_iter = iter(range(len(bcast_shape)))
int_indexers = [
lax.broadcast_in_dim(i, bcast_shape, tuple(next(idx_iter) for _ in
range(len(i.shape))))
for i in int_indexers
]
return merge_lists(is_int_indexing, other_indexers, int_indexers)

program_id_p = jax_core.Primitive("program_id")

def program_id(axis):
Expand Down Expand Up @@ -196,6 +213,94 @@ def _multiple_of_abstract_eval(aval, **_):
return aval
multiple_of_p.def_abstract_eval(_multiple_of_abstract_eval)

@tree_util.register_pytree_node_class
@dataclasses.dataclass
class Slice:
start: Any
size: int

def tree_flatten(self):
if isinstance(self.start, int):
return (), (True, self.start, self.size)
return (self.start,), (False, self.size)

@classmethod
def tree_unflatten(cls, data, xs):
if data[0]:
return Slice(data[1], data[2])
return Slice(xs[0], data[1])

@classmethod
def from_slice(cls, slc: slice, size: int) -> Slice:
start, stop = slc.start, slc.stop
start = 0 if start is None else start
stop = size if stop is None else stop
return Slice(start, stop - start)

def dslice(start: Optional[Union[int, jax.Array]], stop: Optional[int] = None):
if start is None:
return slice(None)
if stop is None:
if not isinstance(start, int):
raise ValueError("Non-static `dslice`")
return Slice(0, start)
return Slice(start, stop)
ds = dslice # Handy alias

@tree_util.register_pytree_node_class
@dataclasses.dataclass
class NDIndexer:
indices: Tuple[Union[int, Slice, jax.Array]]
shape: Tuple[int, ...]
int_indexer_shape: Tuple[int, ...]

def __post_init__(self):
if len(self.indices) != len(self.shape):
raise ValueError("`indices` must be the same length as `Ref` shape.")

def tree_flatten(self):
indexed_dims = [not isinstance(idx, slice) for idx in self.indices]
slice_idx, non_slice_idx = partition_list(indexed_dims, self.indices)
flat_idx, idx_tree = tree_util.tree_flatten(non_slice_idx)
return flat_idx, (slice_idx, idx_tree, indexed_dims, self.shape,
self.int_indexer_shape)

@classmethod
def tree_unflatten(cls, data, flat_idx):
slice_idx, idx_tree, indexed_dims, shape, int_indexer_shape = data
non_slice_idx = tree_util.tree_unflatten(idx_tree, flat_idx)
indices = merge_lists(indexed_dims, slice_idx, non_slice_idx)
return NDIndexer(tuple(indices), shape, int_indexer_shape)

@classmethod
def from_indices_shape(cls, indices, shape) -> NDIndexer:
indices = tuple(Slice.from_slice(i, s) if isinstance(i, slice)
else i for i, s in zip(indices, shape))
if any(isinstance(i, slice) and i != slice(None) for i in indices):
raise NotImplementedError("Non-`slice(None)` slices not supported yet.")
if len(indices) != len(shape):
raise ValueError("Must provide indexer for each dimension of `Ref`.")
is_int_indexing = [isinstance(i, (jax.Array, int)) for i in indices]
other_indexers, int_indexers = partition_list(is_int_indexing, indices)
int_indexers = [np.array(i, np.int32) if isinstance(i, int) else i for i in
int_indexers]
indexer_shapes = [i.shape for i in int_indexers]
bcast_shape = tuple(s for i in indexer_shapes for s in i)
idx_iter = iter(range(len(bcast_shape)))
int_indexers = [
lax.broadcast_in_dim(i, bcast_shape, tuple(next(idx_iter) for _ in
range(len(i.shape))))
for i in int_indexers
]
indices = merge_lists(is_int_indexing, other_indexers, int_indexers)
return NDIndexer(tuple(indices), shape, bcast_shape)

def get_indexer_shape(self) -> Tuple[int, ...]:
is_int_indexing = [not isinstance(i, Slice) for i in self.indices]
other_indexers, _ = partition_list(is_int_indexing, self.indices)
other_shape = [s.size for s in other_indexers]
return tuple((*self.int_indexer_shape, *other_shape))

load_p = jax_core.Primitive('masked_load')

def _load_abstract_eval(ref_aval, *all_avals, args_tree,
Expand Down Expand Up @@ -268,7 +373,7 @@ def _load_jvp(primals, tangents, *, args_tree, masked, **params: Any):
def _load_discharge_rule(in_avals, out_avals, ref, *args, args_tree,
masked, eviction_policy, cache_modifier, is_volatile):
idx, *masked_other = tree_util.tree_unflatten(args_tree, args)
if all(isinstance(s, Slice) or not s.shape for s in idx.indices):
if all(isinstance(s, Slice) or s.shape == () for s in idx.indices):
indices = idx.indices
scalar_dims = [not isinstance(s, Slice) and s.shape == () for s in indices]
slice_starts = [s.start if isinstance(s, Slice) else s for s in indices]
Expand Down Expand Up @@ -404,4 +509,4 @@ def dot(a, b, trans_a: bool = False, trans_b: bool = False,
return jax.lax.dot_general(
a, b, dimension_numbers=(((lhs_contract_dim,), (rhs_contract_dim,)), ((), ())),
precision=precision,
preferred_element_type=None).astype(jnp.float32)
preferred_element_type=None).astype(jnp.float32)
Loading

0 comments on commit 695afcf

Please sign in to comment.