Skip to content

Commit

Permalink
[Pallas/Mosaic GPU] Replace tiling/transpose fields of GPUBlockSpec w…
Browse files Browse the repository at this point in the history
…ith a transform list

PiperOrigin-RevId: 679045307
  • Loading branch information
apaszke authored and Google-ML-Automation committed Sep 26, 2024
1 parent f6fdfb4 commit 8a4c0b1
Show file tree
Hide file tree
Showing 3 changed files with 17 additions and 13 deletions.
2 changes: 2 additions & 0 deletions jax/_src/pallas/mosaic_gpu/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@
from jax._src.pallas.mosaic_gpu.core import GPUBlockSpec
from jax._src.pallas.mosaic_gpu.core import GPUCompilerParams
from jax._src.pallas.mosaic_gpu.core import GPUMemorySpace
from jax._src.pallas.mosaic_gpu.core import TilingTransform
from jax._src.pallas.mosaic_gpu.core import TransposeTransform
from jax._src.pallas.mosaic_gpu.primitives import async_copy_gmem_to_smem
from jax._src.pallas.mosaic_gpu.primitives import async_copy_smem_to_gmem
from jax._src.pallas.mosaic_gpu.primitives import wait_barrier
Expand Down
14 changes: 5 additions & 9 deletions jax/_src/pallas/mosaic_gpu/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,10 +132,8 @@ class GPUBlockMapping(pallas_core.BlockMapping):

@dataclasses.dataclass
class GPUBlockSpec(pallas_core.BlockSpec):
# TODO(justinfu): Replace tiling a list of transforms.
tiling: tuple[int, ...] | None = None
transpose_permutation: tuple[int, ...] | None = None
swizzle: int | None = None
transforms: MemoryRefTransform | tuple[MemoryRefTransform, ...] = ()
swizzle: int | None = None # TODO: apaszke - Swizzle is also a transform.

def to_block_mapping(
self,
Expand All @@ -155,11 +153,9 @@ def to_block_mapping(
grid=grid,
mapped_dims=mapped_dims,
)
transforms: tuple[pallas_core.MemoryRefTransform, ...] = ()
if self.tiling is not None:
transforms += (TilingTransform(self.tiling),)
if self.transpose_permutation is not None:
transforms += (TransposeTransform(self.transpose_permutation),)
transforms = self.transforms
if isinstance(transforms, MemoryRefTransform):
transforms = (transforms,)
return GPUBlockMapping(
block_shape=bm.block_shape,
block_aval=bm.block_aval,
Expand Down
14 changes: 10 additions & 4 deletions tests/pallas/mosaic_gpu_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -326,11 +326,15 @@ def kernel(o_ref):
)

def test_swizzled_blockspec_shapes(self):

@functools.partial(
pl.pallas_call,
in_specs=[
plgpu.GPUBlockSpec(
(128, 64), lambda *i: i, tiling=(64, 64), swizzle=128
(128, 64),
lambda *i: i,
transforms=plgpu.TilingTransform((64, 64)),
swizzle=128,
),
],
out_specs=pl.BlockSpec((2, 1, 64, 64), lambda i, j: (i, j, 64, 64)),
Expand Down Expand Up @@ -390,20 +394,22 @@ def kernel(a_ref, b_ref, o_ref):
a = jax.random.uniform(key1, shape=(64, 128), dtype=dtype)
b = jax.random.uniform(key2, shape=(128, 128), dtype=dtype)

rhs_transforms = (plgpu.TilingTransform((elems_128b, elems_128b)),)
if rhs_transpose:
rhs_transforms += (plgpu.TransposeTransform((1, 0, 2, 3)),)
res = pl.pallas_call(
kernel,
in_specs=[
plgpu.GPUBlockSpec(
(64, 128),
lambda i, j: (i, j),
tiling=(64, elems_128b),
transforms=plgpu.TilingTransform((64, elems_128b)),
swizzle=128,
),
plgpu.GPUBlockSpec(
(128, 128),
lambda *i: i,
transpose_permutation=(1, 0, 2, 3) if rhs_transpose else None,
tiling=(elems_128b, elems_128b),
transforms=rhs_transforms,
swizzle=128,
),
],
Expand Down

0 comments on commit 8a4c0b1

Please sign in to comment.