From ff0fa78845b8b5f10d949030a570a0db7d615c8f Mon Sep 17 00:00:00 2001 From: Adam Paszke Date: Fri, 27 Sep 2024 06:46:13 -0700 Subject: [PATCH] [Pallas/MGPU] Undo transforms on refs before giving them back to the users This changes makes it so that the refs users receive inside their kernels have shapes matching their block specs. However, the refs are not actually plain refs, but transformed references that begin with the fully transformed abstract ref and then stack the inverse of the transformation stack on top of it. This means that all primitives that take in refs can also see the sequence of transforms the user applied in the block spec, which lets us verify e.g. that the inputs to WGMMA are correctly tiled, even though their user-visible shape remains 2D. We should be able to use the same trick in the future to propagate tiling and better infer the layouts for loads and stores. PiperOrigin-RevId: 679570949 --- jax/_src/pallas/core.py | 16 ++-- jax/_src/pallas/mosaic_gpu/core.py | 103 ++++++++++++++++------- jax/_src/pallas/mosaic_gpu/lowering.py | 8 +- jax/_src/pallas/mosaic_gpu/primitives.py | 45 +++++++--- jax/_src/state/indexing.py | 7 ++ jax/_src/state/types.py | 80 +++++++++++++++--- tests/pallas/mosaic_gpu_test.py | 4 + 7 files changed, 199 insertions(+), 64 deletions(-) diff --git a/jax/_src/pallas/core.py b/jax/_src/pallas/core.py index e817369a50c5..27536463b1b6 100644 --- a/jax/_src/pallas/core.py +++ b/jax/_src/pallas/core.py @@ -40,6 +40,7 @@ from jax._src.interpreters import mlir from jax._src.interpreters import partial_eval as pe from jax._src.state import discharge as state_discharge +from jax._src.state.types import TransformedRef import jax.numpy as jnp @@ -523,7 +524,7 @@ def __repr__(self): class MemoryRefTransform(Protocol): """Transforms a memory reference on load or store.""" - def __call__(self, block_aval: AbstractMemoryRef) -> AbstractMemoryRef: + def undo(self, ref: TransformedRef) -> TransformedRef: raise NotImplementedError("Abstract evaluation not implemented.") @@ -533,6 +534,7 @@ class BlockMapping: See the `check_invariants` method for precise specification. """ + # TODO(apaszke): Why do we have block_shape and block_aval? Can't mappping be a transform (it's just slicing after all)? block_shape: tuple[Mapped | int, ...] block_aval: AbstractMemoryRef # The block ref aval index_map_jaxpr: jax_core.ClosedJaxpr @@ -546,8 +548,8 @@ def check_invariants(self) -> None: if not config.enable_checks.value: return unmapped_block_shape = tuple(s for s in self.block_shape if s is not mapped) - assert unmapped_block_shape == self.block_aval.shape, ( - self.block_shape, self.block_aval) + assert unmapped_block_shape == self.ref_aval.shape, ( + self.block_shape, self.ref_aval.shape) assert len(self.block_shape) == len(self.array_shape_dtype.shape), ( self.block_shape, self.array_shape_dtype ) @@ -568,12 +570,12 @@ def replace(self, **kwargs): return new_self @property - def ref_aval(self) -> AbstractMemoryRef: + def ref_aval(self) -> AbstractMemoryRef | TransformedRef: """Returns the abstract value of the Ref after transformations.""" - block_aval = self.block_aval + ref = TransformedRef(self.block_aval, ()) for transform in self.transforms: - block_aval = transform(block_aval) - return block_aval + ref = transform.undo(ref) + return ref if ref.transforms else ref.ref def compute_start_indices_interpret(self, loop_idx, *args): discharged_jaxpr, discharged_consts = state_discharge.discharge_state( diff --git a/jax/_src/pallas/mosaic_gpu/core.py b/jax/_src/pallas/mosaic_gpu/core.py index fe8daf43e995..dbd974375ace 100644 --- a/jax/_src/pallas/mosaic_gpu/core.py +++ b/jax/_src/pallas/mosaic_gpu/core.py @@ -22,6 +22,7 @@ from jax._src import core as jax_core from jax._src import dtypes from jax._src import tree_util +from jax._src.state.types import Transform from jax._src.pallas import core as pallas_core import jax.experimental.mosaic.gpu as mgpu import jax.numpy as jnp @@ -67,6 +68,11 @@ class MemoryRefTransform(pallas_core.MemoryRefTransform, Protocol): def to_gpu_transform(self) -> mgpu.MemRefTransform: ... + def __call__(self, aval: jax_core.ShapedArray) -> jax_core.ShapedArray: + return aval.update( + shape=self.to_gpu_transform().transform_shape(aval.shape) + ) + @dataclasses.dataclass(frozen=True) class TilingTransform(MemoryRefTransform): @@ -79,52 +85,84 @@ class TilingTransform(MemoryRefTransform): tiling: tuple[int, ...] - def __call__( - self, block_aval: pallas_core.AbstractMemoryRef - ) -> pallas_core.AbstractMemoryRef: - block_shape = block_aval.shape - old_tiled_dims = block_shape[-len(self.tiling) :] - num_tiles = tuple( - block_dim // tiling_dim - for block_dim, tiling_dim in zip(old_tiled_dims, self.tiling) - ) - rem = ( - block_dim % tiling_dim - for block_dim, tiling_dim in zip(old_tiled_dims, self.tiling) - ) - if any(rem): - raise ValueError( - f"Block shape {block_shape} is not divisible by tiling {self.tiling}" - ) - new_block_shape = block_shape[: -len(self.tiling)] + num_tiles + self.tiling - return block_aval.update( - inner_aval=block_aval.inner_aval.update(shape=new_block_shape) - ) + def undo(self, ref: pallas_core.TransformedRef) -> pallas_core.TransformedRef: + return dataclasses.replace(ref, transforms=ref.transforms + (UntileRef(self.tiling),)) def to_gpu_transform(self) -> mgpu.MemRefTransform: return mgpu.TileTransform(self.tiling) +@tree_util.register_pytree_node_class +@dataclasses.dataclass(frozen=True) +class UntileRef(Transform): + + tiling: tuple[int, ...] + + def transform_shape(self, shape): + if shape is None: + return None + assert shape[-len(self.tiling) :] == self.tiling + shape = shape[: -len(self.tiling)] # Drop tiling + return shape[: -len(self.tiling)] + tuple( + block_dim * tiling_dim + for block_dim, tiling_dim in zip(shape[-len(self.tiling) :], self.tiling) + ) + + def transform_dtype(self, dtype): + return dtype + + def tree_flatten(self): + return (), (self.tiling,) + + @classmethod + def tree_unflatten(cls, metadata, arrays): + assert not arrays + return cls(*metadata) + + @dataclasses.dataclass(frozen=True) class TransposeTransform(MemoryRefTransform): """Transpose a tiled memref.""" - permutation: tuple[int, ...] - def __call__( - self, block_aval: pallas_core.AbstractMemoryRef - ) -> pallas_core.AbstractMemoryRef: - shape = block_aval.shape # pytype: disable=attribute-error - return block_aval.update( - inner_aval=block_aval.inner_aval.update( - shape=self.to_gpu_transform().transform_shape(shape) - ) + def __post_init__(self): + if set(self.permutation) != set(range(len(self.permutation))): + raise ValueError(f"Permutation {self.permutation} is not a permutation.") + + def undo(self, ref: pallas_core.TransformedRef) -> pallas_core.TransformedRef: + inverse = [None] * len(self.permutation) + for i, p in enumerate(self.permutation): + inverse[p] = i + return dataclasses.replace( + ref, transforms=ref.transforms + (TransposeRef(tuple(inverse)),) ) def to_gpu_transform(self) -> mgpu.MemRefTransform: return mgpu.TransposeTransform(self.permutation) +@tree_util.register_pytree_node_class +@dataclasses.dataclass(frozen=True) +class TransposeRef(Transform): + permutation: tuple[int, ...] + + def transform_shape(self, shape): + if shape is None: + return None + return tuple(shape[i] for i in self.permutation) + + def transform_dtype(self, dtype): + return dtype + + def tree_flatten(self): + return (), (self.permutation,) + + @classmethod + def tree_unflatten(cls, metadata, arrays): + assert not arrays + return cls(*metadata) + + @dataclasses.dataclass(frozen=True) class GPUBlockMapping(pallas_core.BlockMapping): swizzle: int | None = None @@ -156,9 +194,12 @@ def to_block_mapping( transforms = self.transforms if not isinstance(transforms, tuple): transforms = (transforms,) + block_inner_aval = bm.block_aval.inner_aval + for t in reversed(transforms): + block_inner_aval = t(block_inner_aval) return GPUBlockMapping( block_shape=bm.block_shape, - block_aval=bm.block_aval, + block_aval=bm.block_aval.update(inner_aval=block_inner_aval), origin=bm.origin, index_map_jaxpr=bm.index_map_jaxpr, index_map_src_info=bm.index_map_src_info, diff --git a/jax/_src/pallas/mosaic_gpu/lowering.py b/jax/_src/pallas/mosaic_gpu/lowering.py index 0d0ac41d11e3..c81ede844846 100644 --- a/jax/_src/pallas/mosaic_gpu/lowering.py +++ b/jax/_src/pallas/mosaic_gpu/lowering.py @@ -295,9 +295,12 @@ def lower_jaxpr_to_module( in_block_mappings, out_block_mappings = util.split_list( block_mappings, [grid_mapping.num_inputs] ) + # TODO(apaszke): We can shrink allocation if max_concurrent_steps is more than the actual number of steps. + # We allocate the fully transformed shapes here. All primitives have seen the + # inverse transformation stack and will understand how to handle it. in_structs_smem = [ jax.ShapeDtypeStruct( - [max_concurrent_steps, *bm.ref_aval.shape], bm.ref_aval.dtype + [max_concurrent_steps, *bm.block_aval.shape], bm.block_aval.dtype ) if in_smem else None @@ -317,6 +320,9 @@ def lower_jaxpr_to_module( ) out_structs_gmem = [*grid_mapping.out_shapes] # TODO(justinfu): Implement output Memref transforms + for bm in block_mappings[grid_mapping.num_inputs :]: + if bm.transforms: + raise NotImplementedError("Output transforms are not supported") out_structs_smem = [ jax.ShapeDtypeStruct([max_concurrent_steps, *bm.block_shape], s.dtype) if in_smem diff --git a/jax/_src/pallas/mosaic_gpu/primitives.py b/jax/_src/pallas/mosaic_gpu/primitives.py index dcec631e389b..bd821af8271c 100644 --- a/jax/_src/pallas/mosaic_gpu/primitives.py +++ b/jax/_src/pallas/mosaic_gpu/primitives.py @@ -136,17 +136,40 @@ def wgmma(acc, a, b, *, rhs_transpose: bool = False, swizzle: int = 128): if not isinstance(acc.aval, gpu_core.WGMMAAbstractAccumulatorRef): raise TypeError(f"Expected WGMMAAbstractAccumulatorRef got {acc}") - ma, ka, tma, tka = a.shape - kb, nb, tkb, tnb = b.shape - mc, nc = acc.shape - - if rhs_transpose: - kb, nb, tkb, tnb = nb, kb, tnb, tkb - - if tma * ma != mc or nb * tnb != nc or ka != kb or tka != tkb: - raise ValueError(f"Incompatible shapes: {a.shape=}, {b.shape=}, {acc.shape=}, {rhs_transpose=}") - - return wgmma_ref_p.bind(acc, a, b, swizzle=swizzle, rhs_transpose=rhs_transpose) + # TODO(apaszke): Make swizzling another transform and read it from the refs. + if not isinstance(a, pallas_core.TransformedRef): + raise ValueError("WGMMA inputs must be tiled references.") + + m, n = acc.shape + m2, k = a.shape + k2, n2 = b.shape + + if m != m2 or n != n2 or k != k2: + raise ValueError( + f"Incompatible shapes for matrix multiplication: lhs={a.shape}," + f" rhs={b.shape=}, acc={acc.shape}" + ) + + if (dtype := a.dtype) != b.dtype: + raise ValueError(f"Mixed input dtypes for matrix multiplication unsupported: lhs={a.dtype}, rhs={b.dtype}") + if not isinstance(a, pallas_core.TransformedRef): + raise ValueError("WGMMA lhs must be a tiled reference.") + if not isinstance(b, pallas_core.TransformedRef): + raise ValueError("WGMMA rhs must be a tiled reference.") + + elems_128b = swizzle // dtype.itemsize + if a.transforms != (gpu_core.UntileRef((64, elems_128b)),): + raise ValueError( + f"WGMMA lhs must be tiled with 64x{elems_128b} tiles for element type" + f" {dtype}." + ) + if b.transforms != (gpu_core.UntileRef((elems_128b, elems_128b)),): + raise ValueError( + f"WGMMA lhs must be tiled with {elems_128b}x{elems_128b} tiles for" + f" element type {dtype}." + ) + + return wgmma_ref_p.bind(acc, a.ref, b.ref, swizzle=swizzle, rhs_transpose=rhs_transpose) @wgmma_ref_p.def_effectful_abstract_eval diff --git a/jax/_src/state/indexing.py b/jax/_src/state/indexing.py index acf1c7216240..cb653547baff 100644 --- a/jax/_src/state/indexing.py +++ b/jax/_src/state/indexing.py @@ -256,3 +256,10 @@ def get_indexer_shape(self) -> tuple[int | Array, ...]: # In NDIndexers, the int_indexer_shape is *always* at the front of the # result. return (*self.int_indexer_shape, *slice_shape) + + def transform_shape(self, shape: None | tuple[int | Array, ...]) -> None | tuple[int | Array, ...]: + del shape # Unused + return self.get_indexer_shape() + + def transform_dtype(self, dtype): + return dtype diff --git a/jax/_src/state/types.py b/jax/_src/state/types.py index e64d6258a808..fbc438a2d134 100644 --- a/jax/_src/state/types.py +++ b/jax/_src/state/types.py @@ -18,7 +18,7 @@ from collections.abc import Sequence import dataclasses import math -from typing import Any, Union +from typing import Any, Union, Protocol from jax._src import core from jax._src import dtypes @@ -105,8 +105,34 @@ def tree_unflatten(cls, metadata, arrays): assert not arrays return cls(*metadata) + def transform_shape( + self, shape: None | tuple[int | Array, ...] + ) -> None | tuple[int | Array, ...]: + del shape # Unused + return self.shape + + def transform_dtype(self, dtype): + del dtype # Unused + return self.dtype + + +class Transform(Protocol): + def transform_shape(self, shape: None | tuple[int | Array, ...]) -> None | tuple[int | Array, ...]: + """Transform the shape. + + Can return None if the input shape is not known, but must return a concrete + result when the input shape is known. + """ + return shape + + def transform_dtype(self, dtype) -> ...: + """Transform the dtype. + + Can return None if the input dtype is not known, but must return a concrete + result when the input dtype is known. + """ + return dtype -Transform = indexing.NDIndexer | RefBitcaster @dataclasses.dataclass class RefIndexer: @@ -122,30 +148,47 @@ def __getitem__(self, slc): return TransformedRef(self.ref_or_view, (indexer,)) -@dataclasses.dataclass +@tree_util.register_pytree_node_class +@dataclasses.dataclass(frozen=True) class TransformedRef: ref: Any transforms: tuple[Transform, ...] @property def is_dynamic_size(self): - return self.transforms[-1].is_dynamic_size + return any(not isinstance(i, int) for i in self.shape) @property def shape(self) -> tuple[int | Array, ...]: - assert ( - len(self.transforms) > 0 - ), "Should not be able to create a trivial TransformedRef" - if isinstance(self.transforms[-1], indexing.NDIndexer): - return self.transforms[-1].get_indexer_shape() - return self.transforms[-1].shape + unprocessed, shape = 0, None + for unprocessed, t in enumerate(reversed(self.transforms), 1): + if (shape := t.transform_shape(None)) is not None: + unprocessed -= 1 + break + if shape is None: + shape = self.ref.shape + if not unprocessed: + return shape + for t in self.transforms[-unprocessed:]: + shape = t.transform_shape(shape) + assert shape is not None + return shape @property def dtype(self): - for transform in reversed(self.transforms): - if isinstance(transform, RefBitcaster): - return transform.dtype - return self.ref.dtype + unprocessed, dtype = 0, None + for unprocessed, t in enumerate(reversed(self.transforms), 1): + if (dtype := t.transform_dtype(None)) is not None: + unprocessed -= 1 + break + if dtype is None: + dtype = self.ref.dtype + if not unprocessed: + return dtype + for t in self.transforms[-unprocessed:]: + dtype = t.transform_dtype(dtype) + assert dtype is not None + return dtype @property def at(self) -> RefIndexer: @@ -168,6 +211,15 @@ def __setitem__(self, slc, value): from jax._src.state.primitives import ref_set # pytype: disable=import-error return ref_set(self, slc, value) + def tree_flatten(self): + return (self.ref, self.transforms), () + + @classmethod + def tree_unflatten(cls, metadata, arrays): + assert not metadata + return cls(*arrays) + + # We need an aval for `Ref`s so we can represent `get` and `swap` in Jaxprs. class AbstractRef(core.AbstractValue): diff --git a/tests/pallas/mosaic_gpu_test.py b/tests/pallas/mosaic_gpu_test.py index b35658ed4845..89a3d3c0c7c6 100644 --- a/tests/pallas/mosaic_gpu_test.py +++ b/tests/pallas/mosaic_gpu_test.py @@ -467,6 +467,10 @@ def test_realistic_matmul(self): tile_k = elems_128b m, k, n = grid_m * tile_m, grid_k * tile_k, grid_n * tile_n def kernel(a_ref, b_ref, o_ref, acc_ref): + # Make sure tiling does not alter the shape of references + assert a_ref.shape == (tile_m, tile_k) + assert b_ref.shape == (tile_k, tile_n) + assert o_ref.shape == acc_ref.shape == (tile_m, tile_n) plgpu.wgmma(acc_ref, a_ref, b_ref) plgpu.wgmma_wait(0) # TODO(apaszke): Delay the pipeline to avoid memory races # TODO(apaszke): Only store in the last step. It doesn't work because we