From 3c25da2c599aa9f4dc47ea95d0739a535d4e2374 Mon Sep 17 00:00:00 2001 From: Adam Paszke Date: Thu, 26 Sep 2024 03:40:24 -0700 Subject: [PATCH] [Pallas/Mosaic GPU] Replace tiling/transpose fields of GPUBlockSpec with a transform list PiperOrigin-RevId: 679079269 --- jax/_src/pallas/mosaic_gpu/__init__.py | 2 ++ jax/_src/pallas/mosaic_gpu/core.py | 14 +++++--------- tests/pallas/mosaic_gpu_test.py | 14 ++++++++++---- 3 files changed, 17 insertions(+), 13 deletions(-) diff --git a/jax/_src/pallas/mosaic_gpu/__init__.py b/jax/_src/pallas/mosaic_gpu/__init__.py index bbada82ace82..187a84478c65 100644 --- a/jax/_src/pallas/mosaic_gpu/__init__.py +++ b/jax/_src/pallas/mosaic_gpu/__init__.py @@ -18,6 +18,8 @@ from jax._src.pallas.mosaic_gpu.core import GPUBlockSpec from jax._src.pallas.mosaic_gpu.core import GPUCompilerParams from jax._src.pallas.mosaic_gpu.core import GPUMemorySpace +from jax._src.pallas.mosaic_gpu.core import TilingTransform +from jax._src.pallas.mosaic_gpu.core import TransposeTransform from jax._src.pallas.mosaic_gpu.core import WGMMAAccumulatorRef as ACC from jax._src.pallas.mosaic_gpu.primitives import async_copy_gmem_to_smem from jax._src.pallas.mosaic_gpu.primitives import async_copy_smem_to_gmem diff --git a/jax/_src/pallas/mosaic_gpu/core.py b/jax/_src/pallas/mosaic_gpu/core.py index df633619e6ee..0b79121c34c1 100644 --- a/jax/_src/pallas/mosaic_gpu/core.py +++ b/jax/_src/pallas/mosaic_gpu/core.py @@ -132,10 +132,8 @@ class GPUBlockMapping(pallas_core.BlockMapping): @dataclasses.dataclass class GPUBlockSpec(pallas_core.BlockSpec): - # TODO(justinfu): Replace tiling a list of transforms. - tiling: tuple[int, ...] | None = None - transpose_permutation: tuple[int, ...] | None = None - swizzle: int | None = None + transforms: MemoryRefTransform | tuple[MemoryRefTransform, ...] = () + swizzle: int | None = None # TODO: apaszke - Swizzle is also a transform. def to_block_mapping( self, @@ -155,11 +153,9 @@ def to_block_mapping( grid=grid, mapped_dims=mapped_dims, ) - transforms: tuple[pallas_core.MemoryRefTransform, ...] = () - if self.tiling is not None: - transforms += (TilingTransform(self.tiling),) - if self.transpose_permutation is not None: - transforms += (TransposeTransform(self.transpose_permutation),) + transforms = self.transforms + if not isinstance(transforms, tuple): + transforms = (transforms,) return GPUBlockMapping( block_shape=bm.block_shape, block_aval=bm.block_aval, diff --git a/tests/pallas/mosaic_gpu_test.py b/tests/pallas/mosaic_gpu_test.py index 431dba133ce3..5abf4e2e363e 100644 --- a/tests/pallas/mosaic_gpu_test.py +++ b/tests/pallas/mosaic_gpu_test.py @@ -326,11 +326,15 @@ def kernel(o_ref): ) def test_swizzled_blockspec_shapes(self): + @functools.partial( pl.pallas_call, in_specs=[ plgpu.GPUBlockSpec( - (128, 64), lambda *i: i, tiling=(64, 64), swizzle=128 + (128, 64), + lambda *i: i, + transforms=plgpu.TilingTransform((64, 64)), + swizzle=128, ), ], out_specs=pl.BlockSpec((2, 1, 64, 64), lambda i, j: (i, j, 64, 64)), @@ -390,20 +394,22 @@ def scope(acc_ref): a = jax.random.uniform(key1, shape=(64, 128), dtype=dtype) b = jax.random.uniform(key2, shape=(128, 128), dtype=dtype) + rhs_transforms = (plgpu.TilingTransform((elems_128b, elems_128b)),) + if rhs_transpose: + rhs_transforms += (plgpu.TransposeTransform((1, 0, 2, 3)),) res = pl.pallas_call( kernel, in_specs=[ plgpu.GPUBlockSpec( (64, 128), lambda i, j: (i, j), - tiling=(64, elems_128b), + transforms=plgpu.TilingTransform((64, elems_128b)), swizzle=128, ), plgpu.GPUBlockSpec( (128, 128), lambda *i: i, - transpose_permutation=(1, 0, 2, 3) if rhs_transpose else None, - tiling=(elems_128b, elems_128b), + transforms=rhs_transforms, swizzle=128, ), ],