diff --git a/jax/_src/pallas/core.py b/jax/_src/pallas/core.py index e817369a50c5..27536463b1b6 100644 --- a/jax/_src/pallas/core.py +++ b/jax/_src/pallas/core.py @@ -40,6 +40,7 @@ from jax._src.interpreters import mlir from jax._src.interpreters import partial_eval as pe from jax._src.state import discharge as state_discharge +from jax._src.state.types import TransformedRef import jax.numpy as jnp @@ -523,7 +524,7 @@ def __repr__(self): class MemoryRefTransform(Protocol): """Transforms a memory reference on load or store.""" - def __call__(self, block_aval: AbstractMemoryRef) -> AbstractMemoryRef: + def undo(self, ref: TransformedRef) -> TransformedRef: raise NotImplementedError("Abstract evaluation not implemented.") @@ -533,6 +534,7 @@ class BlockMapping: See the `check_invariants` method for precise specification. """ + # TODO(apaszke): Why do we have block_shape and block_aval? Can't mappping be a transform (it's just slicing after all)? block_shape: tuple[Mapped | int, ...] block_aval: AbstractMemoryRef # The block ref aval index_map_jaxpr: jax_core.ClosedJaxpr @@ -546,8 +548,8 @@ def check_invariants(self) -> None: if not config.enable_checks.value: return unmapped_block_shape = tuple(s for s in self.block_shape if s is not mapped) - assert unmapped_block_shape == self.block_aval.shape, ( - self.block_shape, self.block_aval) + assert unmapped_block_shape == self.ref_aval.shape, ( + self.block_shape, self.ref_aval.shape) assert len(self.block_shape) == len(self.array_shape_dtype.shape), ( self.block_shape, self.array_shape_dtype ) @@ -568,12 +570,12 @@ def replace(self, **kwargs): return new_self @property - def ref_aval(self) -> AbstractMemoryRef: + def ref_aval(self) -> AbstractMemoryRef | TransformedRef: """Returns the abstract value of the Ref after transformations.""" - block_aval = self.block_aval + ref = TransformedRef(self.block_aval, ()) for transform in self.transforms: - block_aval = transform(block_aval) - return block_aval + ref = transform.undo(ref) + return ref if ref.transforms else ref.ref def compute_start_indices_interpret(self, loop_idx, *args): discharged_jaxpr, discharged_consts = state_discharge.discharge_state( diff --git a/jax/_src/pallas/mosaic_gpu/core.py b/jax/_src/pallas/mosaic_gpu/core.py index fe8daf43e995..dbd974375ace 100644 --- a/jax/_src/pallas/mosaic_gpu/core.py +++ b/jax/_src/pallas/mosaic_gpu/core.py @@ -22,6 +22,7 @@ from jax._src import core as jax_core from jax._src import dtypes from jax._src import tree_util +from jax._src.state.types import Transform from jax._src.pallas import core as pallas_core import jax.experimental.mosaic.gpu as mgpu import jax.numpy as jnp @@ -67,6 +68,11 @@ class MemoryRefTransform(pallas_core.MemoryRefTransform, Protocol): def to_gpu_transform(self) -> mgpu.MemRefTransform: ... + def __call__(self, aval: jax_core.ShapedArray) -> jax_core.ShapedArray: + return aval.update( + shape=self.to_gpu_transform().transform_shape(aval.shape) + ) + @dataclasses.dataclass(frozen=True) class TilingTransform(MemoryRefTransform): @@ -79,52 +85,84 @@ class TilingTransform(MemoryRefTransform): tiling: tuple[int, ...] - def __call__( - self, block_aval: pallas_core.AbstractMemoryRef - ) -> pallas_core.AbstractMemoryRef: - block_shape = block_aval.shape - old_tiled_dims = block_shape[-len(self.tiling) :] - num_tiles = tuple( - block_dim // tiling_dim - for block_dim, tiling_dim in zip(old_tiled_dims, self.tiling) - ) - rem = ( - block_dim % tiling_dim - for block_dim, tiling_dim in zip(old_tiled_dims, self.tiling) - ) - if any(rem): - raise ValueError( - f"Block shape {block_shape} is not divisible by tiling {self.tiling}" - ) - new_block_shape = block_shape[: -len(self.tiling)] + num_tiles + self.tiling - return block_aval.update( - inner_aval=block_aval.inner_aval.update(shape=new_block_shape) - ) + def undo(self, ref: pallas_core.TransformedRef) -> pallas_core.TransformedRef: + return dataclasses.replace(ref, transforms=ref.transforms + (UntileRef(self.tiling),)) def to_gpu_transform(self) -> mgpu.MemRefTransform: return mgpu.TileTransform(self.tiling) +@tree_util.register_pytree_node_class +@dataclasses.dataclass(frozen=True) +class UntileRef(Transform): + + tiling: tuple[int, ...] + + def transform_shape(self, shape): + if shape is None: + return None + assert shape[-len(self.tiling) :] == self.tiling + shape = shape[: -len(self.tiling)] # Drop tiling + return shape[: -len(self.tiling)] + tuple( + block_dim * tiling_dim + for block_dim, tiling_dim in zip(shape[-len(self.tiling) :], self.tiling) + ) + + def transform_dtype(self, dtype): + return dtype + + def tree_flatten(self): + return (), (self.tiling,) + + @classmethod + def tree_unflatten(cls, metadata, arrays): + assert not arrays + return cls(*metadata) + + @dataclasses.dataclass(frozen=True) class TransposeTransform(MemoryRefTransform): """Transpose a tiled memref.""" - permutation: tuple[int, ...] - def __call__( - self, block_aval: pallas_core.AbstractMemoryRef - ) -> pallas_core.AbstractMemoryRef: - shape = block_aval.shape # pytype: disable=attribute-error - return block_aval.update( - inner_aval=block_aval.inner_aval.update( - shape=self.to_gpu_transform().transform_shape(shape) - ) + def __post_init__(self): + if set(self.permutation) != set(range(len(self.permutation))): + raise ValueError(f"Permutation {self.permutation} is not a permutation.") + + def undo(self, ref: pallas_core.TransformedRef) -> pallas_core.TransformedRef: + inverse = [None] * len(self.permutation) + for i, p in enumerate(self.permutation): + inverse[p] = i + return dataclasses.replace( + ref, transforms=ref.transforms + (TransposeRef(tuple(inverse)),) ) def to_gpu_transform(self) -> mgpu.MemRefTransform: return mgpu.TransposeTransform(self.permutation) +@tree_util.register_pytree_node_class +@dataclasses.dataclass(frozen=True) +class TransposeRef(Transform): + permutation: tuple[int, ...] + + def transform_shape(self, shape): + if shape is None: + return None + return tuple(shape[i] for i in self.permutation) + + def transform_dtype(self, dtype): + return dtype + + def tree_flatten(self): + return (), (self.permutation,) + + @classmethod + def tree_unflatten(cls, metadata, arrays): + assert not arrays + return cls(*metadata) + + @dataclasses.dataclass(frozen=True) class GPUBlockMapping(pallas_core.BlockMapping): swizzle: int | None = None @@ -156,9 +194,12 @@ def to_block_mapping( transforms = self.transforms if not isinstance(transforms, tuple): transforms = (transforms,) + block_inner_aval = bm.block_aval.inner_aval + for t in reversed(transforms): + block_inner_aval = t(block_inner_aval) return GPUBlockMapping( block_shape=bm.block_shape, - block_aval=bm.block_aval, + block_aval=bm.block_aval.update(inner_aval=block_inner_aval), origin=bm.origin, index_map_jaxpr=bm.index_map_jaxpr, index_map_src_info=bm.index_map_src_info, diff --git a/jax/_src/pallas/mosaic_gpu/lowering.py b/jax/_src/pallas/mosaic_gpu/lowering.py index 0d0ac41d11e3..c81ede844846 100644 --- a/jax/_src/pallas/mosaic_gpu/lowering.py +++ b/jax/_src/pallas/mosaic_gpu/lowering.py @@ -295,9 +295,12 @@ def lower_jaxpr_to_module( in_block_mappings, out_block_mappings = util.split_list( block_mappings, [grid_mapping.num_inputs] ) + # TODO(apaszke): We can shrink allocation if max_concurrent_steps is more than the actual number of steps. + # We allocate the fully transformed shapes here. All primitives have seen the + # inverse transformation stack and will understand how to handle it. in_structs_smem = [ jax.ShapeDtypeStruct( - [max_concurrent_steps, *bm.ref_aval.shape], bm.ref_aval.dtype + [max_concurrent_steps, *bm.block_aval.shape], bm.block_aval.dtype ) if in_smem else None @@ -317,6 +320,9 @@ def lower_jaxpr_to_module( ) out_structs_gmem = [*grid_mapping.out_shapes] # TODO(justinfu): Implement output Memref transforms + for bm in block_mappings[grid_mapping.num_inputs :]: + if bm.transforms: + raise NotImplementedError("Output transforms are not supported") out_structs_smem = [ jax.ShapeDtypeStruct([max_concurrent_steps, *bm.block_shape], s.dtype) if in_smem diff --git a/jax/_src/pallas/mosaic_gpu/primitives.py b/jax/_src/pallas/mosaic_gpu/primitives.py index dcec631e389b..bd821af8271c 100644 --- a/jax/_src/pallas/mosaic_gpu/primitives.py +++ b/jax/_src/pallas/mosaic_gpu/primitives.py @@ -136,17 +136,40 @@ def wgmma(acc, a, b, *, rhs_transpose: bool = False, swizzle: int = 128): if not isinstance(acc.aval, gpu_core.WGMMAAbstractAccumulatorRef): raise TypeError(f"Expected WGMMAAbstractAccumulatorRef got {acc}") - ma, ka, tma, tka = a.shape - kb, nb, tkb, tnb = b.shape - mc, nc = acc.shape - - if rhs_transpose: - kb, nb, tkb, tnb = nb, kb, tnb, tkb - - if tma * ma != mc or nb * tnb != nc or ka != kb or tka != tkb: - raise ValueError(f"Incompatible shapes: {a.shape=}, {b.shape=}, {acc.shape=}, {rhs_transpose=}") - - return wgmma_ref_p.bind(acc, a, b, swizzle=swizzle, rhs_transpose=rhs_transpose) + # TODO(apaszke): Make swizzling another transform and read it from the refs. + if not isinstance(a, pallas_core.TransformedRef): + raise ValueError("WGMMA inputs must be tiled references.") + + m, n = acc.shape + m2, k = a.shape + k2, n2 = b.shape + + if m != m2 or n != n2 or k != k2: + raise ValueError( + f"Incompatible shapes for matrix multiplication: lhs={a.shape}," + f" rhs={b.shape=}, acc={acc.shape}" + ) + + if (dtype := a.dtype) != b.dtype: + raise ValueError(f"Mixed input dtypes for matrix multiplication unsupported: lhs={a.dtype}, rhs={b.dtype}") + if not isinstance(a, pallas_core.TransformedRef): + raise ValueError("WGMMA lhs must be a tiled reference.") + if not isinstance(b, pallas_core.TransformedRef): + raise ValueError("WGMMA rhs must be a tiled reference.") + + elems_128b = swizzle // dtype.itemsize + if a.transforms != (gpu_core.UntileRef((64, elems_128b)),): + raise ValueError( + f"WGMMA lhs must be tiled with 64x{elems_128b} tiles for element type" + f" {dtype}." + ) + if b.transforms != (gpu_core.UntileRef((elems_128b, elems_128b)),): + raise ValueError( + f"WGMMA lhs must be tiled with {elems_128b}x{elems_128b} tiles for" + f" element type {dtype}." + ) + + return wgmma_ref_p.bind(acc, a.ref, b.ref, swizzle=swizzle, rhs_transpose=rhs_transpose) @wgmma_ref_p.def_effectful_abstract_eval diff --git a/jax/_src/state/indexing.py b/jax/_src/state/indexing.py index acf1c7216240..cb653547baff 100644 --- a/jax/_src/state/indexing.py +++ b/jax/_src/state/indexing.py @@ -256,3 +256,10 @@ def get_indexer_shape(self) -> tuple[int | Array, ...]: # In NDIndexers, the int_indexer_shape is *always* at the front of the # result. return (*self.int_indexer_shape, *slice_shape) + + def transform_shape(self, shape: None | tuple[int | Array, ...]) -> None | tuple[int | Array, ...]: + del shape # Unused + return self.get_indexer_shape() + + def transform_dtype(self, dtype): + return dtype diff --git a/jax/_src/state/types.py b/jax/_src/state/types.py index e64d6258a808..fbc438a2d134 100644 --- a/jax/_src/state/types.py +++ b/jax/_src/state/types.py @@ -18,7 +18,7 @@ from collections.abc import Sequence import dataclasses import math -from typing import Any, Union +from typing import Any, Union, Protocol from jax._src import core from jax._src import dtypes @@ -105,8 +105,34 @@ def tree_unflatten(cls, metadata, arrays): assert not arrays return cls(*metadata) + def transform_shape( + self, shape: None | tuple[int | Array, ...] + ) -> None | tuple[int | Array, ...]: + del shape # Unused + return self.shape + + def transform_dtype(self, dtype): + del dtype # Unused + return self.dtype + + +class Transform(Protocol): + def transform_shape(self, shape: None | tuple[int | Array, ...]) -> None | tuple[int | Array, ...]: + """Transform the shape. + + Can return None if the input shape is not known, but must return a concrete + result when the input shape is known. + """ + return shape + + def transform_dtype(self, dtype) -> ...: + """Transform the dtype. + + Can return None if the input dtype is not known, but must return a concrete + result when the input dtype is known. + """ + return dtype -Transform = indexing.NDIndexer | RefBitcaster @dataclasses.dataclass class RefIndexer: @@ -122,30 +148,47 @@ def __getitem__(self, slc): return TransformedRef(self.ref_or_view, (indexer,)) -@dataclasses.dataclass +@tree_util.register_pytree_node_class +@dataclasses.dataclass(frozen=True) class TransformedRef: ref: Any transforms: tuple[Transform, ...] @property def is_dynamic_size(self): - return self.transforms[-1].is_dynamic_size + return any(not isinstance(i, int) for i in self.shape) @property def shape(self) -> tuple[int | Array, ...]: - assert ( - len(self.transforms) > 0 - ), "Should not be able to create a trivial TransformedRef" - if isinstance(self.transforms[-1], indexing.NDIndexer): - return self.transforms[-1].get_indexer_shape() - return self.transforms[-1].shape + unprocessed, shape = 0, None + for unprocessed, t in enumerate(reversed(self.transforms), 1): + if (shape := t.transform_shape(None)) is not None: + unprocessed -= 1 + break + if shape is None: + shape = self.ref.shape + if not unprocessed: + return shape + for t in self.transforms[-unprocessed:]: + shape = t.transform_shape(shape) + assert shape is not None + return shape @property def dtype(self): - for transform in reversed(self.transforms): - if isinstance(transform, RefBitcaster): - return transform.dtype - return self.ref.dtype + unprocessed, dtype = 0, None + for unprocessed, t in enumerate(reversed(self.transforms), 1): + if (dtype := t.transform_dtype(None)) is not None: + unprocessed -= 1 + break + if dtype is None: + dtype = self.ref.dtype + if not unprocessed: + return dtype + for t in self.transforms[-unprocessed:]: + dtype = t.transform_dtype(dtype) + assert dtype is not None + return dtype @property def at(self) -> RefIndexer: @@ -168,6 +211,15 @@ def __setitem__(self, slc, value): from jax._src.state.primitives import ref_set # pytype: disable=import-error return ref_set(self, slc, value) + def tree_flatten(self): + return (self.ref, self.transforms), () + + @classmethod + def tree_unflatten(cls, metadata, arrays): + assert not metadata + return cls(*arrays) + + # We need an aval for `Ref`s so we can represent `get` and `swap` in Jaxprs. class AbstractRef(core.AbstractValue): diff --git a/tests/pallas/mosaic_gpu_test.py b/tests/pallas/mosaic_gpu_test.py index b35658ed4845..89a3d3c0c7c6 100644 --- a/tests/pallas/mosaic_gpu_test.py +++ b/tests/pallas/mosaic_gpu_test.py @@ -467,6 +467,10 @@ def test_realistic_matmul(self): tile_k = elems_128b m, k, n = grid_m * tile_m, grid_k * tile_k, grid_n * tile_n def kernel(a_ref, b_ref, o_ref, acc_ref): + # Make sure tiling does not alter the shape of references + assert a_ref.shape == (tile_m, tile_k) + assert b_ref.shape == (tile_k, tile_n) + assert o_ref.shape == acc_ref.shape == (tile_m, tile_n) plgpu.wgmma(acc_ref, a_ref, b_ref) plgpu.wgmma_wait(0) # TODO(apaszke): Delay the pipeline to avoid memory races # TODO(apaszke): Only store in the last step. It doesn't work because we