Skip to content

Commit

Permalink
[Pallas TPU] Remove the skipTest for OpsExtraTest on TPU
Browse files Browse the repository at this point in the history
Unify `OpsExtraTest` and `OpsTest`. Add `skipTest` for individual functions instead of skipping the whole test class

Reverts 21fea5b

PiperOrigin-RevId: 679493740
  • Loading branch information
ayaka14732 authored and Google-ML-Automation committed Sep 30, 2024
1 parent b3fca90 commit 10fe4c8
Show file tree
Hide file tree
Showing 8 changed files with 106 additions and 260 deletions.
32 changes: 10 additions & 22 deletions jax/_src/pallas/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,6 @@
from jax._src.interpreters import mlir
from jax._src.interpreters import partial_eval as pe
from jax._src.state import discharge as state_discharge
from jax._src.state.types import TransformedRef
import jax.numpy as jnp


Expand Down Expand Up @@ -497,7 +496,7 @@ def to_block_mapping(

mapping = BlockMapping(
block_shape=mapped_block_shape,
transformed_block_aval=block_aval, # There are no transforms by default
block_aval=block_aval,
index_map_jaxpr=jax_core.ClosedJaxpr(jaxpr, consts),
index_map_src_info=index_map_src_info,
indexing_mode=self.indexing_mode,
Expand All @@ -524,7 +523,7 @@ def __repr__(self):
class MemoryRefTransform(Protocol):
"""Transforms a memory reference on load or store."""

def undo(self, ref: TransformedRef) -> TransformedRef:
def __call__(self, block_aval: AbstractMemoryRef) -> AbstractMemoryRef:
raise NotImplementedError("Abstract evaluation not implemented.")


Expand All @@ -534,10 +533,8 @@ class BlockMapping:
See the `check_invariants` method for precise specification.
"""
# TODO(apaszke,sharadmv): Replace mapped dims in block_shape with a transform.
# After all, it's just indexing out singleton dimensions.
block_shape: tuple[Mapped | int, ...]
transformed_block_aval: AbstractMemoryRef
block_aval: AbstractMemoryRef # The block ref aval
index_map_jaxpr: jax_core.ClosedJaxpr
index_map_src_info: NameAndSrcInfo
indexing_mode: IndexingMode
Expand All @@ -549,8 +546,8 @@ def check_invariants(self) -> None:
if not config.enable_checks.value: return

unmapped_block_shape = tuple(s for s in self.block_shape if s is not mapped)
assert unmapped_block_shape == self.ref_aval.shape, (
self.block_shape, self.ref_aval.shape)
assert unmapped_block_shape == self.block_aval.shape, (
self.block_shape, self.block_aval)
assert len(self.block_shape) == len(self.array_shape_dtype.shape), (
self.block_shape, self.array_shape_dtype
)
Expand All @@ -571,21 +568,12 @@ def replace(self, **kwargs):
return new_self

@property
def block_aval(self) -> AbstractMemoryRef:
# If you hit this, make sure you take transforms into account and use either
# ref_aval or transformed_block_aval.
assert not self.transforms, "Lowering failed to handle transforms"
return self.transformed_block_aval

@property
def ref_aval(self) -> AbstractMemoryRef | TransformedRef:
def ref_aval(self) -> AbstractMemoryRef:
"""Returns the abstract value of the Ref after transformations."""
if not self.transforms:
return self.transformed_block_aval
ref = TransformedRef(self.transformed_block_aval, ())
for transform in reversed(self.transforms):
ref = transform.undo(ref)
return ref
block_aval = self.block_aval
for transform in self.transforms:
block_aval = transform(block_aval)
return block_aval

def compute_start_indices_interpret(self, loop_idx, *args):
discharged_jaxpr, discharged_consts = state_discharge.discharge_state(
Expand Down
113 changes: 33 additions & 80 deletions jax/_src/pallas/mosaic_gpu/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,16 +14,14 @@

"""Contains GPU-specific Pallas abstractions."""

import abc
from collections.abc import Sequence
import dataclasses
import enum
from typing import Any, ClassVar, Literal
from typing import Any, ClassVar, Literal, Protocol

from jax._src import core as jax_core
from jax._src import dtypes
from jax._src import tree_util
from jax._src.state.types import Transform
from jax._src.pallas import core as pallas_core
import jax.experimental.mosaic.gpu as mgpu
import jax.numpy as jnp
Expand Down Expand Up @@ -65,15 +63,9 @@ def __call__(self, shape: tuple[int, ...], dtype: jnp.dtype):
return pallas_core.MemoryRef(shape, dtype, memory_space=self)


class MemoryRefTransform(pallas_core.MemoryRefTransform, abc.ABC):
@abc.abstractmethod
class MemoryRefTransform(pallas_core.MemoryRefTransform, Protocol):
def to_gpu_transform(self) -> mgpu.MemRefTransform:
pass

def __call__(self, aval: jax_core.ShapedArray) -> jax_core.ShapedArray:
return aval.update(
shape=self.to_gpu_transform().transform_shape(aval.shape)
)
...


@dataclasses.dataclass(frozen=True)
Expand All @@ -87,86 +79,52 @@ class TilingTransform(MemoryRefTransform):

tiling: tuple[int, ...]

def undo(self, ref: pallas_core.TransformedRef) -> pallas_core.TransformedRef:
return dataclasses.replace(
ref, transforms=(*ref.transforms, UntileRef(self.tiling))
def __call__(
self, block_aval: pallas_core.AbstractMemoryRef
) -> pallas_core.AbstractMemoryRef:
block_shape = block_aval.shape
old_tiled_dims = block_shape[-len(self.tiling) :]
num_tiles = tuple(
block_dim // tiling_dim
for block_dim, tiling_dim in zip(old_tiled_dims, self.tiling)
)
rem = (
block_dim % tiling_dim
for block_dim, tiling_dim in zip(old_tiled_dims, self.tiling)
)
if any(rem):
raise ValueError(
f"Block shape {block_shape} is not divisible by tiling {self.tiling}"
)
new_block_shape = block_shape[: -len(self.tiling)] + num_tiles + self.tiling
return block_aval.update(
inner_aval=block_aval.inner_aval.update(shape=new_block_shape)
)

def to_gpu_transform(self) -> mgpu.MemRefTransform:
return mgpu.TileTransform(self.tiling)


@tree_util.register_pytree_node_class
@dataclasses.dataclass(frozen=True)
class UntileRef(Transform):

tiling: tuple[int, ...]

def transform_shape(self, shape):
if shape is None:
return None
assert shape[-len(self.tiling) :] == self.tiling
shape = shape[: -len(self.tiling)] # Drop tiling
return shape[: -len(self.tiling)] + tuple(
block_dim * tiling_dim
for block_dim, tiling_dim in zip(shape[-len(self.tiling) :], self.tiling)
)

def transform_dtype(self, dtype):
return dtype

def tree_flatten(self):
return (), (self.tiling,)

@classmethod
def tree_unflatten(cls, metadata, arrays):
assert not arrays
return cls(*metadata)


@dataclasses.dataclass(frozen=True)
class TransposeTransform(MemoryRefTransform):
"""Transpose a tiled memref."""
permutation: tuple[int, ...]

def __post_init__(self):
if set(self.permutation) != set(range(len(self.permutation))):
raise ValueError(f"Permutation {self.permutation} is not a permutation.")
permutation: tuple[int, ...]

def undo(self, ref: pallas_core.TransformedRef) -> pallas_core.TransformedRef:
inverse = [-1] * len(self.permutation)
for i, p in enumerate(self.permutation):
inverse[p] = i
return dataclasses.replace(
ref, transforms=(*ref.transforms, TransposeRef(tuple(inverse)))
def __call__(
self, block_aval: pallas_core.AbstractMemoryRef
) -> pallas_core.AbstractMemoryRef:
shape = block_aval.shape # pytype: disable=attribute-error
return block_aval.update(
inner_aval=block_aval.inner_aval.update(
shape=self.to_gpu_transform().transform_shape(shape)
)
)

def to_gpu_transform(self) -> mgpu.MemRefTransform:
return mgpu.TransposeTransform(self.permutation)


@tree_util.register_pytree_node_class
@dataclasses.dataclass(frozen=True)
class TransposeRef(Transform):
permutation: tuple[int, ...]

def transform_shape(self, shape):
if shape is None:
return None
return tuple(shape[i] for i in self.permutation)

def transform_dtype(self, dtype):
return dtype

def tree_flatten(self):
return (), (self.permutation,)

@classmethod
def tree_unflatten(cls, metadata, arrays):
assert not arrays
return cls(*metadata)


@dataclasses.dataclass(frozen=True)
class GPUBlockMapping(pallas_core.BlockMapping):
swizzle: int | None = None
Expand Down Expand Up @@ -198,14 +156,9 @@ def to_block_mapping(
transforms = self.transforms
if not isinstance(transforms, tuple):
transforms = (transforms,)
block_inner_aval = bm.block_aval.inner_aval
for t in transforms:
block_inner_aval = t(block_inner_aval)
return GPUBlockMapping(
block_shape=bm.block_shape,
transformed_block_aval=bm.block_aval.update(
inner_aval=block_inner_aval
),
block_aval=bm.block_aval,
origin=bm.origin,
index_map_jaxpr=bm.index_map_jaxpr,
index_map_src_info=bm.index_map_src_info,
Expand Down
11 changes: 2 additions & 9 deletions jax/_src/pallas/mosaic_gpu/lowering.py
Original file line number Diff line number Diff line change
Expand Up @@ -280,7 +280,7 @@ def lower_jaxpr_to_module(

in_in_smem, out_in_smem = util.split_list(
[
bm.transformed_block_aval.memory_space in (None, gpu_core.SMEM)
bm.block_aval.memory_space in (None, gpu_core.SMEM)
for bm in block_mappings
],
[grid_mapping.num_inputs],
Expand All @@ -290,13 +290,9 @@ def lower_jaxpr_to_module(
in_block_mappings, out_block_mappings = util.split_list(
block_mappings, [grid_mapping.num_inputs]
)
# TODO(apaszke): We can shrink allocation if max_concurrent_steps is more than the actual number of steps.
# We allocate the fully transformed shapes here. All primitives have seen the
# inverse transformation stack and will understand how to handle it.
in_structs_smem = [
jax.ShapeDtypeStruct(
[max_concurrent_steps, *bm.transformed_block_aval.shape],
bm.transformed_block_aval.dtype,
[max_concurrent_steps, *bm.ref_aval.shape], bm.ref_aval.dtype
)
if in_smem
else None
Expand All @@ -316,9 +312,6 @@ def lower_jaxpr_to_module(
)
out_structs_gmem = [*grid_mapping.out_shapes]
# TODO(justinfu): Implement output Memref transforms
for bm in block_mappings[grid_mapping.num_inputs :]:
if bm.transforms:
raise NotImplementedError("Output transforms are not supported")
out_structs_smem = [
jax.ShapeDtypeStruct([max_concurrent_steps, *bm.block_shape], s.dtype)
if in_smem
Expand Down
53 changes: 14 additions & 39 deletions jax/_src/pallas/mosaic_gpu/primitives.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,7 @@ class _WGMMAPipelineEffect(effects.Effect):
wgmma_ref_p = jax_core.Primitive("wgmma_ref")
wgmma_ref_p.multiple_results = True

def wgmma(acc, a, b, *, swizzle: int = 128):
def wgmma(acc, a, b, *, rhs_transpose: bool = False, swizzle: int = 128):
"""Asynchronous warp group matmul.
The sm90 wgmma instruction, essentially acc[...] += a @ b. Requires
Expand All @@ -129,49 +129,24 @@ def wgmma(acc, a, b, *, swizzle: int = 128):
acc: The accumulator register.
a: The left hand side operand.
b: The right hand side operand.
transpose: Whether to transpose b.
n_tile: The number of tiles to use.
swizzle: The swizzle pattern.
"""
if not isinstance(acc.aval, gpu_core.WGMMAAbstractAccumulatorRef):
raise TypeError(f"Expected WGMMAAbstractAccumulatorRef got {acc}")

# TODO(apaszke): Make swizzling another transform and read it from the refs.
if not isinstance(a, pallas_core.TransformedRef):
raise ValueError("WGMMA inputs must be tiled references.")

m, n = acc.shape
m2, k = a.shape
k2, n2 = b.shape

if m != m2 or n != n2 or k != k2:
raise ValueError(
f"Incompatible shapes for matrix multiplication: lhs={a.shape},"
f" rhs={b.shape=}, acc={acc.shape}"
)

if (dtype := a.dtype) != b.dtype:
raise ValueError(f"Mixed input dtypes for matrix multiplication unsupported: lhs={a.dtype}, rhs={b.dtype}")
if not isinstance(a, pallas_core.TransformedRef):
raise ValueError("WGMMA lhs must be a tiled reference.")
if not isinstance(b, pallas_core.TransformedRef):
raise ValueError("WGMMA rhs must be a tiled reference.")

elems_128b = swizzle // dtype.itemsize
if a.transforms != (gpu_core.UntileRef((64, elems_128b)),):
raise ValueError(
f"WGMMA lhs must be tiled with 64x{elems_128b} tiles for element type"
f" {dtype}."
)
rhs_transpose_transform = gpu_core.TransposeRef((1, 0, 2, 3))
rhs_tiling = gpu_core.UntileRef((elems_128b, elems_128b))
if not (
rhs_transpose := (b.transforms == (rhs_transpose_transform, rhs_tiling))
) and not (b.transforms == (rhs_tiling,)):
raise ValueError(
f"WGMMA rhs must be tiled with {elems_128b}x{elems_128b} tiles for"
f" element type {dtype} (and optionally transposed)."
)

return wgmma_ref_p.bind(acc, a.ref, b.ref, swizzle=swizzle, rhs_transpose=rhs_transpose)
ma, ka, tma, tka = a.shape
kb, nb, tkb, tnb = b.shape
mc, nc = acc.shape

if rhs_transpose:
kb, nb, tkb, tnb = nb, kb, tnb, tkb

if tma * ma != mc or nb * tnb != nc or ka != kb or tka != tkb:
raise ValueError(f"Incompatible shapes: {a.shape=}, {b.shape=}, {acc.shape=}, {rhs_transpose=}")

return wgmma_ref_p.bind(acc, a, b, swizzle=swizzle, rhs_transpose=rhs_transpose)


@wgmma_ref_p.def_effectful_abstract_eval
Expand Down
7 changes: 0 additions & 7 deletions jax/_src/state/indexing.py
Original file line number Diff line number Diff line change
Expand Up @@ -256,10 +256,3 @@ def get_indexer_shape(self) -> tuple[int | Array, ...]:
# In NDIndexers, the int_indexer_shape is *always* at the front of the
# result.
return (*self.int_indexer_shape, *slice_shape)

def transform_shape(self, shape: None | tuple[int | Array, ...]) -> None | tuple[int | Array, ...]:
del shape # Unused
return self.get_indexer_shape()

def transform_dtype(self, dtype):
return dtype
Loading

0 comments on commit 10fe4c8

Please sign in to comment.