Skip to content

Commit

Permalink
Separate out and add tests for Pallas indexing logic.
Browse files Browse the repository at this point in the history
Breaking changes:
* This changes `pl.load/store`'s broadcasting semantics. These functions will *no longer* automatically insert dummy axes for broadcasting on behalf of the user. They will try to broadcast values against each other though. This makes it more along the lines of regular NumPy indexing.

PiperOrigin-RevId: 548767185
  • Loading branch information
sharadmv authored and The jax_triton Authors committed Jul 17, 2023
1 parent 1966b57 commit 163e1d4
Show file tree
Hide file tree
Showing 8 changed files with 402 additions and 147 deletions.
5 changes: 3 additions & 2 deletions jax_triton/pallas/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,9 @@

"""Module for pallas, a jaxpr "dialect" for Triton."""
from jax_triton.pallas.core import BlockSpec
from jax_triton.pallas.indexing import ds
from jax_triton.pallas.indexing import dslice
from jax_triton.pallas.indexing import broadcast_to
from jax_triton.pallas.pallas_call import pallas_call
from jax_triton.pallas.pallas_call import pallas_call_p
from jax_triton.pallas.primitives import atomic_add
Expand All @@ -25,8 +28,6 @@
from jax_triton.pallas.primitives import atomic_xchg
from jax_triton.pallas.primitives import atomic_xor
from jax_triton.pallas.primitives import dot
from jax_triton.pallas.primitives import ds
from jax_triton.pallas.primitives import dslice
from jax_triton.pallas.primitives import load
from jax_triton.pallas.primitives import max_contiguous
from jax_triton.pallas.primitives import multiple_of
Expand Down
158 changes: 158 additions & 0 deletions jax_triton/pallas/indexing.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,158 @@
# Copyright 2023 The jax_triton Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Contains shared logic and abstractions for Pallas indexing ops."""

from __future__ import annotations

import dataclasses
from typing import Any, Optional, Tuple, Union

import jax
from jax import core as jax_core
from jax import tree_util
from jax._src.interpreters import mlir
from jax._src.util import merge_lists
from jax._src.util import partition_list
import jax.numpy as jnp
import numpy as np


# Currently, JAX doesn't have a primitive that does an equal-rank broadcast.
# We could use `jnp.broadcast_to` but that lowers to squeezing,
# then broadcast_in_dim. Triton has an equal-rank broadcast (`tl.broadcast_to`)
# so in the lowering, we have to expand out those squeezed dimensions again.
# Having a simple `broadcast_to` primitive allows us to lower directly
# to `tl.broadcast_to`.
broadcast_to_p = jax_core.Primitive('broadcast_to')

def broadcast_to(a: jax.Array, shape: Tuple[int, ...]) -> jax.Array:
if a.shape == shape:
return a
return broadcast_to_p.bind(a, shape=shape)

@broadcast_to_p.def_impl
def _broadcast_to_impl(a, *, shape):
return jnp.broadcast_to(a, shape)

@broadcast_to_p.def_abstract_eval
def _broadcast_to_abstract_eval(aval, *, shape):
return jax_core.ShapedArray(shape, aval.dtype)

mlir.register_lowering(
broadcast_to_p, mlir.lower_fun(_broadcast_to_impl, False)
)


@tree_util.register_pytree_node_class
@dataclasses.dataclass
class Slice:
"""Represents a slice with a dynamic start index and a fixed size."""
start: Any
size: int

def tree_flatten(self):
# If `start` is statically known, we treat it as static information
if isinstance(self.start, int):
return (), (True, self.start, self.size)
return (self.start,), (False, self.size)

@classmethod
def tree_unflatten(cls, data, xs):
is_static = data[0]
if is_static:
del xs
start, size = data[1:]
return Slice(start, size)
start, = xs
size = data[1]
return Slice(start, size)

@classmethod
def from_slice(cls, slc: slice, size: int) -> Slice:
start, stop = slc.start, slc.stop
start = 0 if start is None else start
stop = size if stop is None else stop
return Slice(start, stop - start)


def dslice(start: Optional[Union[int, jax.Array]], size: Optional[int] = None
) -> Union[slice, Slice]:
"""Constructs a `Slice` from a start and a size."""
if start is None:
return slice(None)
if size is None:
if not isinstance(start, int):
raise ValueError("Non-static `dslice`")
return Slice(0, start)
return Slice(start, size)
ds = dslice # Handy alias

@tree_util.register_pytree_node_class
@dataclasses.dataclass
class NDIndexer:
indices: Tuple[Union[int, Slice, jax.Array], ...]
shape: Tuple[int, ...]
int_indexer_shape: Tuple[int, ...]

def __post_init__(self):
if len(self.indices) != len(self.shape):
raise ValueError("`indices` must be the same length as `Ref` shape.")

def tree_flatten(self):
indexed_dims = [not isinstance(idx, slice) for idx in self.indices]
slice_idx, non_slice_idx = partition_list(indexed_dims, self.indices)
flat_idx, idx_tree = tree_util.tree_flatten(non_slice_idx)
return flat_idx, (slice_idx, idx_tree, indexed_dims, self.shape,
self.int_indexer_shape)

@classmethod
def tree_unflatten(cls, data, flat_idx):
slice_idx, idx_tree, indexed_dims, shape, int_indexer_shape = data
non_slice_idx = tree_util.tree_unflatten(idx_tree, flat_idx)
indices = merge_lists(indexed_dims, slice_idx, non_slice_idx)
return NDIndexer(tuple(indices), shape, int_indexer_shape)

@classmethod
def from_indices_shape(cls, indices, shape) -> NDIndexer:
if len(indices) > len(shape):
raise ValueError("`indices` must be the no longer than `shape`.")
# Pad out indices with slice(None)
indices = [*indices, *[slice(None)] * (len(shape) - len(indices))]
# Convert all `slice`s to `Slice`s
indices = tuple(Slice.from_slice(i, s) if isinstance(i, slice)
else i for i, s in zip(indices, shape))
is_int_indexing = [not isinstance(i, Slice) for i in indices]
other_indexers, int_indexers = partition_list(is_int_indexing, indices)
int_indexers = [np.array(i, np.int32) if isinstance(i, int) else i for i in
int_indexers]
indexer_shapes = [i.shape for i in int_indexers]
if indexer_shapes:
try:
bcast_shape = np.broadcast_shapes(*indexer_shapes)
except ValueError as e:
# Raise a nicer error than the NumPy one.
raise ValueError("Cannot broadcast shapes for indexing: "
f"{tuple(a for a in indexer_shapes)}") from e
else:
bcast_shape = ()
int_indexers = [broadcast_to(i, bcast_shape) for i in int_indexers]
indices = merge_lists(is_int_indexing, other_indexers, int_indexers)
return NDIndexer(tuple(indices), shape, bcast_shape)

def get_indexer_shape(self) -> Tuple[int, ...]:
is_int_indexing = [not isinstance(i, Slice) for i in self.indices]
other_indexers, _ = partition_list(is_int_indexing, self.indices)
other_shape = [s.size for s in other_indexers]
return tuple((*self.int_indexer_shape, *other_shape))
10 changes: 7 additions & 3 deletions jax_triton/pallas/ops/layer_norm.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,8 +162,12 @@ def body(i, acc_ref):
row_idx = i * block_m + jnp.arange(block_m)
row_mask = row_idx < m
mask = row_mask[:, None] & col_mask[None, :]
a = pl.load(x_ref, (row_idx, col_idx), mask=mask, other=0.).astype(jnp.float32)
dout = pl.load(do_ref, (row_idx, col_idx), mask=mask, other=0.).astype(jnp.float32)
a = pl.load(
x_ref, (row_idx[:, None], col_idx[None]), mask=mask, other=0.0
).astype(jnp.float32)
dout = pl.load(
do_ref, (row_idx[:, None], col_idx[None]), mask=mask, other=0.0
).astype(jnp.float32)
mean = pl.load(mean_ref, (row_idx,), mask=row_mask, other=0.).astype(jnp.float32)
rstd = pl.load(rstd_ref, (row_idx,), mask=row_mask, other=0.).astype(jnp.float32)
a_hat = (a - mean[:, None]) * rstd[:, None]
Expand Down Expand Up @@ -275,4 +279,4 @@ def layer_norm_reference(x, weight, bias, *, eps: float = 1e-5):
var = jnp.maximum(0., mean2 - jnp.square(mean))
y = x - mean[:, None]
mul = lax.rsqrt(var + eps)
return y * mul[:, None] * weight[None] + bias[None]
return y * mul[:, None] * weight[None] + bias[None]
8 changes: 6 additions & 2 deletions jax_triton/pallas/ops/rms_norm.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,8 +145,12 @@ def body(i, acc_ref):
row_idx = i * block_m + jnp.arange(block_m)
row_mask = row_idx < m
mask = row_mask[:, None] & col_mask[None, :]
a = pl.load(x_ref, (row_idx, col_idx), mask=mask, other=0.).astype(jnp.float32)
dout = pl.load(do_ref, (row_idx, col_idx), mask=mask, other=0.).astype(jnp.float32)
a = pl.load(
x_ref, (row_idx[:, None], col_idx[None]), mask=mask, other=0.0
).astype(jnp.float32)
dout = pl.load(
do_ref, (row_idx[:, None], col_idx[None]), mask=mask, other=0.0
).astype(jnp.float32)
rstd = pl.load(rstd_ref, (row_idx,), mask=row_mask, other=0.).astype(jnp.float32)
a_hat = a * rstd[:, None]
dw_acc_ref, db_acc_ref = acc_ref
Expand Down
123 changes: 9 additions & 114 deletions jax_triton/pallas/primitives.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,10 @@

"""Module for pallas-specific JAX primitives and functions."""
from __future__ import annotations
import dataclasses
import enum
import functools

from typing import Any, List, Optional, Tuple, Union
from typing import Any, Optional, Tuple

import jax
from jax import lax
Expand All @@ -27,40 +26,24 @@
from jax._src import core as jax_core
from jax._src import pretty_printer as pp
from jax._src import state
from jax._src.util import (safe_map, safe_zip, split_list, merge_lists,
partition_list)
from jax._src.util import (safe_map, safe_zip)
from jax._src.state import primitives as state_primitives
from jax._src.state import discharge as state_discharge
from jax.interpreters import ad
from jax.interpreters import mlir
from jax.interpreters import xla
import jax.numpy as jnp
import numpy as np

from jax_triton.pallas import core as pallas_core
from jax_triton.pallas import indexing

partial = functools.partial
Slice = indexing.Slice
NDIndexer = indexing.NDIndexer

map, unsafe_map = safe_map, map
zip, unsafe_zip = safe_zip, zip

def _process_idx(idx, ref_shape):
if any(isinstance(i, slice) and i != slice(None) for i in idx):
raise NotImplementedError("Non-`slice(None)` slices not supported yet.")
if len(idx) != len(ref_shape):
raise ValueError("Must provide indexer for each dimension of `Ref`.")
is_int_indexing = [isinstance(i, (jnp.ndarray, int)) for i in idx]
other_indexers, int_indexers = partition_list(is_int_indexing, idx)
int_indexers = [np.array(i, np.int32) if isinstance(i, int) else i for i in
int_indexers]
indexer_shapes = [jnp.shape(i) for i in int_indexers]
bcast_shape = tuple(s for i in indexer_shapes for s in i)
idx_iter = iter(range(len(bcast_shape)))
int_indexers = [
lax.broadcast_in_dim(i, bcast_shape, tuple(next(idx_iter) for _ in
range(len(i.shape))))
for i in int_indexers
]
return merge_lists(is_int_indexing, other_indexers, int_indexers)

program_id_p = jax_core.Primitive("program_id")

def program_id(axis):
Expand Down Expand Up @@ -213,94 +196,6 @@ def _multiple_of_abstract_eval(aval, **_):
return aval
multiple_of_p.def_abstract_eval(_multiple_of_abstract_eval)

@tree_util.register_pytree_node_class
@dataclasses.dataclass
class Slice:
start: Any
size: int

def tree_flatten(self):
if isinstance(self.start, int):
return (), (True, self.start, self.size)
return (self.start,), (False, self.size)

@classmethod
def tree_unflatten(cls, data, xs):
if data[0]:
return Slice(data[1], data[2])
return Slice(xs[0], data[1])

@classmethod
def from_slice(cls, slc: slice, size: int) -> Slice:
start, stop = slc.start, slc.stop
start = 0 if start is None else start
stop = size if stop is None else stop
return Slice(start, stop - start)

def dslice(start: Optional[Union[int, jax.Array]], stop: Optional[int] = None):
if start is None:
return slice(None)
if stop is None:
if not isinstance(start, int):
raise ValueError("Non-static `dslice`")
return Slice(0, start)
return Slice(start, stop)
ds = dslice # Handy alias

@tree_util.register_pytree_node_class
@dataclasses.dataclass
class NDIndexer:
indices: Tuple[Union[int, Slice, jax.Array]]
shape: Tuple[int, ...]
int_indexer_shape: Tuple[int, ...]

def __post_init__(self):
if len(self.indices) != len(self.shape):
raise ValueError("`indices` must be the same length as `Ref` shape.")

def tree_flatten(self):
indexed_dims = [not isinstance(idx, slice) for idx in self.indices]
slice_idx, non_slice_idx = partition_list(indexed_dims, self.indices)
flat_idx, idx_tree = tree_util.tree_flatten(non_slice_idx)
return flat_idx, (slice_idx, idx_tree, indexed_dims, self.shape,
self.int_indexer_shape)

@classmethod
def tree_unflatten(cls, data, flat_idx):
slice_idx, idx_tree, indexed_dims, shape, int_indexer_shape = data
non_slice_idx = tree_util.tree_unflatten(idx_tree, flat_idx)
indices = merge_lists(indexed_dims, slice_idx, non_slice_idx)
return NDIndexer(tuple(indices), shape, int_indexer_shape)

@classmethod
def from_indices_shape(cls, indices, shape) -> NDIndexer:
indices = tuple(Slice.from_slice(i, s) if isinstance(i, slice)
else i for i, s in zip(indices, shape))
if any(isinstance(i, slice) and i != slice(None) for i in indices):
raise NotImplementedError("Non-`slice(None)` slices not supported yet.")
if len(indices) != len(shape):
raise ValueError("Must provide indexer for each dimension of `Ref`.")
is_int_indexing = [isinstance(i, (jax.Array, int)) for i in indices]
other_indexers, int_indexers = partition_list(is_int_indexing, indices)
int_indexers = [np.array(i, np.int32) if isinstance(i, int) else i for i in
int_indexers]
indexer_shapes = [i.shape for i in int_indexers]
bcast_shape = tuple(s for i in indexer_shapes for s in i)
idx_iter = iter(range(len(bcast_shape)))
int_indexers = [
lax.broadcast_in_dim(i, bcast_shape, tuple(next(idx_iter) for _ in
range(len(i.shape))))
for i in int_indexers
]
indices = merge_lists(is_int_indexing, other_indexers, int_indexers)
return NDIndexer(tuple(indices), shape, bcast_shape)

def get_indexer_shape(self) -> Tuple[int, ...]:
is_int_indexing = [not isinstance(i, Slice) for i in self.indices]
other_indexers, _ = partition_list(is_int_indexing, self.indices)
other_shape = [s.size for s in other_indexers]
return tuple((*self.int_indexer_shape, *other_shape))

load_p = jax_core.Primitive('masked_load')

def _load_abstract_eval(ref_aval, *all_avals, args_tree,
Expand Down Expand Up @@ -373,7 +268,7 @@ def _load_jvp(primals, tangents, *, args_tree, masked, **params: Any):
def _load_discharge_rule(in_avals, out_avals, ref, *args, args_tree,
masked, eviction_policy, cache_modifier, is_volatile):
idx, *masked_other = tree_util.tree_unflatten(args_tree, args)
if all(isinstance(s, Slice) or s.shape == () for s in idx.indices):
if all(isinstance(s, Slice) or not s.shape for s in idx.indices):
indices = idx.indices
scalar_dims = [not isinstance(s, Slice) and s.shape == () for s in indices]
slice_starts = [s.start if isinstance(s, Slice) else s for s in indices]
Expand Down Expand Up @@ -509,4 +404,4 @@ def dot(a, b, trans_a: bool = False, trans_b: bool = False,
return jax.lax.dot_general(
a, b, dimension_numbers=(((lhs_contract_dim,), (rhs_contract_dim,)), ((), ())),
precision=precision,
preferred_element_type=None).astype(jnp.float32)
preferred_element_type=None).astype(jnp.float32)
Loading

0 comments on commit 163e1d4

Please sign in to comment.