Skip to content

Commit

Permalink
[Pallas/MGPU] Undo transforms on refs before giving them back to the …
Browse files Browse the repository at this point in the history
…users

This changes makes it so that the refs users receive inside their kernels have shapes
matching their block specs. However, the refs are not actually plain refs, but transformed
references that begin with the fully transformed abstract ref and then stack the inverse
of the transformation stack on top of it. This means that all primitives that take in refs
can also see the sequence of transforms the user applied in the block spec, which lets us
verify e.g. that the inputs to WGMMA are correctly tiled, even though their user-visible
shape remains 2D. We should be able to use the same trick in the future to propagate tiling
and better infer the layouts for loads and stores.

PiperOrigin-RevId: 679570949
  • Loading branch information
apaszke authored and Google-ML-Automation committed Sep 27, 2024
1 parent 26632fd commit ff0fa78
Show file tree
Hide file tree
Showing 7 changed files with 199 additions and 64 deletions.
16 changes: 9 additions & 7 deletions jax/_src/pallas/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
from jax._src.interpreters import mlir
from jax._src.interpreters import partial_eval as pe
from jax._src.state import discharge as state_discharge
from jax._src.state.types import TransformedRef
import jax.numpy as jnp


Expand Down Expand Up @@ -523,7 +524,7 @@ def __repr__(self):
class MemoryRefTransform(Protocol):
"""Transforms a memory reference on load or store."""

def __call__(self, block_aval: AbstractMemoryRef) -> AbstractMemoryRef:
def undo(self, ref: TransformedRef) -> TransformedRef:
raise NotImplementedError("Abstract evaluation not implemented.")


Expand All @@ -533,6 +534,7 @@ class BlockMapping:
See the `check_invariants` method for precise specification.
"""
# TODO(apaszke): Why do we have block_shape and block_aval? Can't mappping be a transform (it's just slicing after all)?
block_shape: tuple[Mapped | int, ...]
block_aval: AbstractMemoryRef # The block ref aval
index_map_jaxpr: jax_core.ClosedJaxpr
Expand All @@ -546,8 +548,8 @@ def check_invariants(self) -> None:
if not config.enable_checks.value: return

unmapped_block_shape = tuple(s for s in self.block_shape if s is not mapped)
assert unmapped_block_shape == self.block_aval.shape, (
self.block_shape, self.block_aval)
assert unmapped_block_shape == self.ref_aval.shape, (
self.block_shape, self.ref_aval.shape)
assert len(self.block_shape) == len(self.array_shape_dtype.shape), (
self.block_shape, self.array_shape_dtype
)
Expand All @@ -568,12 +570,12 @@ def replace(self, **kwargs):
return new_self

@property
def ref_aval(self) -> AbstractMemoryRef:
def ref_aval(self) -> AbstractMemoryRef | TransformedRef:
"""Returns the abstract value of the Ref after transformations."""
block_aval = self.block_aval
ref = TransformedRef(self.block_aval, ())
for transform in self.transforms:
block_aval = transform(block_aval)
return block_aval
ref = transform.undo(ref)
return ref if ref.transforms else ref.ref

def compute_start_indices_interpret(self, loop_idx, *args):
discharged_jaxpr, discharged_consts = state_discharge.discharge_state(
Expand Down
103 changes: 72 additions & 31 deletions jax/_src/pallas/mosaic_gpu/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from jax._src import core as jax_core
from jax._src import dtypes
from jax._src import tree_util
from jax._src.state.types import Transform
from jax._src.pallas import core as pallas_core
import jax.experimental.mosaic.gpu as mgpu
import jax.numpy as jnp
Expand Down Expand Up @@ -67,6 +68,11 @@ class MemoryRefTransform(pallas_core.MemoryRefTransform, Protocol):
def to_gpu_transform(self) -> mgpu.MemRefTransform:
...

def __call__(self, aval: jax_core.ShapedArray) -> jax_core.ShapedArray:
return aval.update(
shape=self.to_gpu_transform().transform_shape(aval.shape)
)


@dataclasses.dataclass(frozen=True)
class TilingTransform(MemoryRefTransform):
Expand All @@ -79,52 +85,84 @@ class TilingTransform(MemoryRefTransform):

tiling: tuple[int, ...]

def __call__(
self, block_aval: pallas_core.AbstractMemoryRef
) -> pallas_core.AbstractMemoryRef:
block_shape = block_aval.shape
old_tiled_dims = block_shape[-len(self.tiling) :]
num_tiles = tuple(
block_dim // tiling_dim
for block_dim, tiling_dim in zip(old_tiled_dims, self.tiling)
)
rem = (
block_dim % tiling_dim
for block_dim, tiling_dim in zip(old_tiled_dims, self.tiling)
)
if any(rem):
raise ValueError(
f"Block shape {block_shape} is not divisible by tiling {self.tiling}"
)
new_block_shape = block_shape[: -len(self.tiling)] + num_tiles + self.tiling
return block_aval.update(
inner_aval=block_aval.inner_aval.update(shape=new_block_shape)
)
def undo(self, ref: pallas_core.TransformedRef) -> pallas_core.TransformedRef:
return dataclasses.replace(ref, transforms=ref.transforms + (UntileRef(self.tiling),))

def to_gpu_transform(self) -> mgpu.MemRefTransform:
return mgpu.TileTransform(self.tiling)


@tree_util.register_pytree_node_class
@dataclasses.dataclass(frozen=True)
class UntileRef(Transform):

tiling: tuple[int, ...]

def transform_shape(self, shape):
if shape is None:
return None
assert shape[-len(self.tiling) :] == self.tiling
shape = shape[: -len(self.tiling)] # Drop tiling
return shape[: -len(self.tiling)] + tuple(
block_dim * tiling_dim
for block_dim, tiling_dim in zip(shape[-len(self.tiling) :], self.tiling)
)

def transform_dtype(self, dtype):
return dtype

def tree_flatten(self):
return (), (self.tiling,)

@classmethod
def tree_unflatten(cls, metadata, arrays):
assert not arrays
return cls(*metadata)


@dataclasses.dataclass(frozen=True)
class TransposeTransform(MemoryRefTransform):
"""Transpose a tiled memref."""

permutation: tuple[int, ...]

def __call__(
self, block_aval: pallas_core.AbstractMemoryRef
) -> pallas_core.AbstractMemoryRef:
shape = block_aval.shape # pytype: disable=attribute-error
return block_aval.update(
inner_aval=block_aval.inner_aval.update(
shape=self.to_gpu_transform().transform_shape(shape)
)
def __post_init__(self):
if set(self.permutation) != set(range(len(self.permutation))):
raise ValueError(f"Permutation {self.permutation} is not a permutation.")

def undo(self, ref: pallas_core.TransformedRef) -> pallas_core.TransformedRef:
inverse = [None] * len(self.permutation)
for i, p in enumerate(self.permutation):
inverse[p] = i
return dataclasses.replace(
ref, transforms=ref.transforms + (TransposeRef(tuple(inverse)),)
)

def to_gpu_transform(self) -> mgpu.MemRefTransform:
return mgpu.TransposeTransform(self.permutation)


@tree_util.register_pytree_node_class
@dataclasses.dataclass(frozen=True)
class TransposeRef(Transform):
permutation: tuple[int, ...]

def transform_shape(self, shape):
if shape is None:
return None
return tuple(shape[i] for i in self.permutation)

def transform_dtype(self, dtype):
return dtype

def tree_flatten(self):
return (), (self.permutation,)

@classmethod
def tree_unflatten(cls, metadata, arrays):
assert not arrays
return cls(*metadata)


@dataclasses.dataclass(frozen=True)
class GPUBlockMapping(pallas_core.BlockMapping):
swizzle: int | None = None
Expand Down Expand Up @@ -156,9 +194,12 @@ def to_block_mapping(
transforms = self.transforms
if not isinstance(transforms, tuple):
transforms = (transforms,)
block_inner_aval = bm.block_aval.inner_aval
for t in reversed(transforms):
block_inner_aval = t(block_inner_aval)
return GPUBlockMapping(
block_shape=bm.block_shape,
block_aval=bm.block_aval,
block_aval=bm.block_aval.update(inner_aval=block_inner_aval),
origin=bm.origin,
index_map_jaxpr=bm.index_map_jaxpr,
index_map_src_info=bm.index_map_src_info,
Expand Down
8 changes: 7 additions & 1 deletion jax/_src/pallas/mosaic_gpu/lowering.py
Original file line number Diff line number Diff line change
Expand Up @@ -295,9 +295,12 @@ def lower_jaxpr_to_module(
in_block_mappings, out_block_mappings = util.split_list(
block_mappings, [grid_mapping.num_inputs]
)
# TODO(apaszke): We can shrink allocation if max_concurrent_steps is more than the actual number of steps.
# We allocate the fully transformed shapes here. All primitives have seen the
# inverse transformation stack and will understand how to handle it.
in_structs_smem = [
jax.ShapeDtypeStruct(
[max_concurrent_steps, *bm.ref_aval.shape], bm.ref_aval.dtype
[max_concurrent_steps, *bm.block_aval.shape], bm.block_aval.dtype
)
if in_smem
else None
Expand All @@ -317,6 +320,9 @@ def lower_jaxpr_to_module(
)
out_structs_gmem = [*grid_mapping.out_shapes]
# TODO(justinfu): Implement output Memref transforms
for bm in block_mappings[grid_mapping.num_inputs :]:
if bm.transforms:
raise NotImplementedError("Output transforms are not supported")
out_structs_smem = [
jax.ShapeDtypeStruct([max_concurrent_steps, *bm.block_shape], s.dtype)
if in_smem
Expand Down
45 changes: 34 additions & 11 deletions jax/_src/pallas/mosaic_gpu/primitives.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,17 +136,40 @@ def wgmma(acc, a, b, *, rhs_transpose: bool = False, swizzle: int = 128):
if not isinstance(acc.aval, gpu_core.WGMMAAbstractAccumulatorRef):
raise TypeError(f"Expected WGMMAAbstractAccumulatorRef got {acc}")

ma, ka, tma, tka = a.shape
kb, nb, tkb, tnb = b.shape
mc, nc = acc.shape

if rhs_transpose:
kb, nb, tkb, tnb = nb, kb, tnb, tkb

if tma * ma != mc or nb * tnb != nc or ka != kb or tka != tkb:
raise ValueError(f"Incompatible shapes: {a.shape=}, {b.shape=}, {acc.shape=}, {rhs_transpose=}")

return wgmma_ref_p.bind(acc, a, b, swizzle=swizzle, rhs_transpose=rhs_transpose)
# TODO(apaszke): Make swizzling another transform and read it from the refs.
if not isinstance(a, pallas_core.TransformedRef):
raise ValueError("WGMMA inputs must be tiled references.")

m, n = acc.shape
m2, k = a.shape
k2, n2 = b.shape

if m != m2 or n != n2 or k != k2:
raise ValueError(
f"Incompatible shapes for matrix multiplication: lhs={a.shape},"
f" rhs={b.shape=}, acc={acc.shape}"
)

if (dtype := a.dtype) != b.dtype:
raise ValueError(f"Mixed input dtypes for matrix multiplication unsupported: lhs={a.dtype}, rhs={b.dtype}")
if not isinstance(a, pallas_core.TransformedRef):
raise ValueError("WGMMA lhs must be a tiled reference.")
if not isinstance(b, pallas_core.TransformedRef):
raise ValueError("WGMMA rhs must be a tiled reference.")

elems_128b = swizzle // dtype.itemsize
if a.transforms != (gpu_core.UntileRef((64, elems_128b)),):
raise ValueError(
f"WGMMA lhs must be tiled with 64x{elems_128b} tiles for element type"
f" {dtype}."
)
if b.transforms != (gpu_core.UntileRef((elems_128b, elems_128b)),):
raise ValueError(
f"WGMMA lhs must be tiled with {elems_128b}x{elems_128b} tiles for"
f" element type {dtype}."
)

return wgmma_ref_p.bind(acc, a.ref, b.ref, swizzle=swizzle, rhs_transpose=rhs_transpose)


@wgmma_ref_p.def_effectful_abstract_eval
Expand Down
7 changes: 7 additions & 0 deletions jax/_src/state/indexing.py
Original file line number Diff line number Diff line change
Expand Up @@ -256,3 +256,10 @@ def get_indexer_shape(self) -> tuple[int | Array, ...]:
# In NDIndexers, the int_indexer_shape is *always* at the front of the
# result.
return (*self.int_indexer_shape, *slice_shape)

def transform_shape(self, shape: None | tuple[int | Array, ...]) -> None | tuple[int | Array, ...]:
del shape # Unused
return self.get_indexer_shape()

def transform_dtype(self, dtype):
return dtype
Loading

0 comments on commit ff0fa78

Please sign in to comment.