Skip to content

Commit

Permalink
[Pallas TPU] Add support for passing in and returning semaphores
Browse files Browse the repository at this point in the history
This change enables writing async ops using Pallas. However, there are *extremely sharp edges* using this API. Please read the design note here: https://jax.readthedocs.io/en/latest/pallas/async_note.html.

Followup CLs will investigate safer APIs for writing async ops.

PiperOrigin-RevId: 676243335
  • Loading branch information
sharadmv authored and Google-ML-Automation committed Sep 19, 2024
1 parent ba06bd5 commit 9d2e9c6
Show file tree
Hide file tree
Showing 13 changed files with 1,016 additions and 60 deletions.
107 changes: 101 additions & 6 deletions jax/_src/pallas/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
from jax._src import config
from jax._src import core as jax_core
from jax._src import deprecations
from jax._src import dtypes
from jax._src import linear_util as lu
from jax._src import mesh as mesh_lib
from jax._src import state
Expand Down Expand Up @@ -114,21 +115,113 @@ def from_pallas_call(pallas_call_name: str | None,
" ".join(src_info_parts[1:]))


# Pytrees of jax.ShapeDtypeStruct
ShapeDtypeStructTree = tuple[jax.ShapeDtypeStruct, ...]

split_list = util.split_list

map, unsafe_map = util.safe_map, map
zip, unsafe_zip = util.safe_zip, zip


class ShapedArrayWithMemorySpace(jax_core.ShapedArray):
__slots__ = ["memory_space"]

def __init__(self, shape, dtype, weak_type=False, sharding=None,
memory_space=None):
super().__init__(shape, dtype, weak_type=weak_type, sharding=sharding)
self.memory_space = memory_space

def __eq__(self, other):
return super().__eq__(other) and self.memory_space == other.memory_space

def __hash__(self):
return hash((
self.shape,
self.dtype,
self.weak_type,
getattr(self, "sharding", None),
self.memory_space,
))

def at_least_vspace(self):
"""Vector space method needed for AD."""
raise NotImplementedError

def join(self, other):
raise NotImplementedError

def str_short(self, short_dtypes=False):
dt_str = (
jax_core._short_dtype_name(self.dtype)
if short_dtypes
else self.dtype.name
)
dt_str = dt_str.replace("void", "float0")
shapestr = ",".join(map(str, self.shape))
if hasattr(self, "sharding"):
sharding_str = f"{dt_str}[{shapestr}]({self.sharding})"
else:
sharding_str = ""
memoryspace_str = (
"" if self.memory_space is None else f"{self.memory_space}>"
)
return f"{dt_str}{memoryspace_str}[{shapestr}]{sharding_str}"

def update(
self,
shape=None,
dtype=None,
weak_type=None,
sharding=None,
memory_space=None,
):
if shape is None:
shape = self.shape
if dtype is None:
dtype = self.dtype
if weak_type is None:
weak_type = self.weak_type
if sharding is None:
sharding = getattr(self, "sharding", None)
if memory_space is None:
memory_space = self.memory_space
return ShapedArrayWithMemorySpace(
shape, dtype, weak_type, sharding=sharding, memory_space=memory_space
)
mlir.ir_type_handlers[ShapedArrayWithMemorySpace] = mlir._array_ir_types


@dataclasses.dataclass(frozen=True)
class MemoryRef:
"""Like jax.ShapeDtypeStruct but with memory spaces."""
shape: tuple[int, ...]
dtype: jnp.dtype
# TODO(b/368122763): Unify memory space types across backends
memory_space: Any

def get_array_aval(self) -> jax_core.ShapedArray:
dtype = self.dtype
if not isinstance(dtype, (jnp.dtype, dtypes.ExtendedDType)):
dtype = jnp.dtype(dtype)
return ShapedArrayWithMemorySpace(
self.shape, dtype, memory_space=self.memory_space
)

def get_ref_aval(self) -> AbstractMemoryRef:
return AbstractMemoryRef(
ShapedArrayWithMemorySpace(self.shape, self.dtype), self.memory_space)


class AbstractMemoryRef(state.AbstractRef):
__slots__ = ["inner_aval", "memory_space"]

inner_aval: jax_core.ShapedArray

def __init__(self, inner_aval: jax_core.ShapedArray, memory_space: Any):
if isinstance(inner_aval, ShapedArrayWithMemorySpace):
if inner_aval.memory_space is not None:
assert inner_aval.memory_space == memory_space, (
f"Mismatched memory spaces: {inner_aval.memory_space=},"
f" {memory_space=}"
)
self.inner_aval = inner_aval
self.memory_space = memory_space

Expand Down Expand Up @@ -158,7 +251,7 @@ def __hash__(self):


class MemorySpace(enum.Enum):
""" Logical, device-agnostic memory spaces.
"""Logical, device-agnostic memory spaces.
Each memory space will be translated to a device-specific memory
type during lowering.
Expand Down Expand Up @@ -731,7 +824,9 @@ def _convert_block_spec_to_block_mapping(


class ScratchShape(Protocol):
def get_aval(self) -> jax_core.AbstractValue:
def get_array_aval(self) -> jax_core.AbstractValue:
...
def get_ref_aval(self) -> state.AbstractRef:
...


Expand Down Expand Up @@ -833,7 +928,7 @@ def get_grid_mapping(
if grid_spec.scratch_shapes:
flat_scratch_shapes, scratch_tree = tree_util.tree_flatten(
grid_spec.scratch_shapes)
flat_scratch_avals = map(lambda s: s.get_aval(), flat_scratch_shapes)
flat_scratch_avals = map(lambda s: s.get_ref_aval(), flat_scratch_shapes)
num_flat_scratch_operands = len(flat_scratch_avals)
jaxpr_scratch_avals = tree_util.tree_unflatten(
scratch_tree, flat_scratch_avals)
Expand Down
27 changes: 11 additions & 16 deletions jax/_src/pallas/mosaic/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ def __str__(self) -> str:

def __call__(self, shape: tuple[int, ...], dtype: jnp.dtype):
# A convenience function for constructing MemoryRef types.
return MemoryRef(shape, dtype, self)
return pallas_core.MemoryRef(shape, dtype, self)

class semaphore_dtype(dtypes.extended): pass
class semaphore(semaphore_dtype): pass
Expand All @@ -102,6 +102,10 @@ class AbstractSemaphoreTyRules:
def pallas_interpret_element_aval(_) -> jax_core.ShapedArray:
return jax_core.ShapedArray((), pallas_core.SEMAPHORE_INTERPRET_DTYPE)

@staticmethod
def physical_element_aval(_) -> jax_core.ShapedArray:
return jax_core.ShapedArray((), jnp.int32)

class AbstractSemaphoreTy(dtypes.ExtendedDType):
name: str
_rules = AbstractSemaphoreTyRules
Expand Down Expand Up @@ -144,10 +148,13 @@ def __call__(self, shape: tuple[int, ...]):
dtype = SemaphoreTy()
if pallas_core.is_interpret_mode():
dtype = pallas_core.SEMAPHORE_INTERPRET_DTYPE
return MemoryRef(shape, dtype, TPUMemorySpace.SEMAPHORE)
return pallas_core.MemoryRef(shape, dtype, TPUMemorySpace.SEMAPHORE)

def get_array_aval(self) -> pallas_core.ShapedArrayWithMemorySpace:
return self(()).get_array_aval()

def get_aval(self) -> AbstractMemoryRef:
return self(()).get_aval()
def get_ref_aval(self) -> AbstractMemoryRef:
return self(()).get_ref_aval()

@dataclasses.dataclass(frozen=True)
class AbstractSemaphore(jax_core.AbstractValue):
Expand All @@ -163,18 +170,6 @@ def join(self, other):
jax_core.raise_to_shaped_mappings[AbstractSemaphore] = lambda aval, _: aval


@dataclasses.dataclass(frozen=True)
class MemoryRef:
"""Like jax.ShapeDtypeStruct but with memory spaces."""
shape: tuple[int, ...]
dtype: jnp.dtype
memory_space: TPUMemorySpace = TPUMemorySpace.ANY

def get_aval(self) -> AbstractMemoryRef:
return AbstractMemoryRef(
jax_core.ShapedArray(self.shape, self.dtype), self.memory_space)


@dataclasses.dataclass(init=False, kw_only=True, unsafe_hash=True)
class PrefetchScalarGridSpec(pallas_core.GridSpec):
num_scalar_prefetch: int
Expand Down
46 changes: 31 additions & 15 deletions jax/_src/pallas/mosaic/lowering.py
Original file line number Diff line number Diff line change
Expand Up @@ -424,24 +424,23 @@ class MeshInfo:
axis_names: list[str]
mesh_strides: tuple[int, ...]

def lower_jaxpr_to_module(

def _check_block_mappings(
block_mappings: tuple[pallas_core.BlockMapping, ...],
lowering_context: mlir.LoweringRuleContext,
ctx: ir.Context,
grid_mapping: pallas_core.GridMapping,
jaxpr: jax_core.Jaxpr,
*,
dimension_semantics: tuple[str | None, ...] | None,
name_and_src_info: pallas_core.NameAndSrcInfo,
mesh: mesh_lib.Mesh | None = None,
for_verification: bool = False,
) -> tuple[Module, tuple[Any, ...]]:
for bm in grid_mapping.block_mappings:
) -> None:
del lowering_context # originally needed for forward compat
for bm in block_mappings:
rank = len(bm.block_shape)
# TODO(necula): add tests for SMEM blocks with trivial windowing
# We support scalars too
if (bm.block_aval.memory_space == tpu_core.TPUMemorySpace.SMEM and
bm.has_trivial_window()):
continue
if bm.block_aval.memory_space == tpu_core.TPUMemorySpace.SEMAPHORE:
continue

def err_details():
return (f"Block spec for {bm.origin} in pallas_call {name_and_src_info} "
"has block shape "
Expand Down Expand Up @@ -482,11 +481,28 @@ def err_details():

if not evenly_divisible:
raise ValueError(
"The Pallas TPU lowering currently requires that the last two "
"dimensions of your block shape are divisible by 8 and 128 "
"respectively, or be equal to the respective dimensions of the "
"overall array. "
+ err_details())
"The Pallas TPU lowering currently requires that the last two "
"dimensions of your block shape are divisible by 8 and 128 "
"respectively, or be equal to the respective dimensions of the "
"overall array. "
+ err_details()
)


def lower_jaxpr_to_module(
lowering_context: mlir.LoweringRuleContext,
ctx: ir.Context,
grid_mapping: pallas_core.GridMapping,
jaxpr: jax_core.Jaxpr,
*,
dimension_semantics: tuple[str | None, ...] | None,
name_and_src_info: pallas_core.NameAndSrcInfo,
mesh: mesh_lib.Mesh | None = None,
for_verification: bool = False,
) -> tuple[Module, tuple[Any, ...]]:
# Verify that we have legal block mappings to catch errors early.
_check_block_mappings(grid_mapping.block_mappings, lowering_context,
name_and_src_info)

mosaic_grid_mapping = MosaicGridMapping(
jaxpr, grid_mapping, dimension_semantics, mesh)
Expand Down
51 changes: 45 additions & 6 deletions jax/_src/pallas/mosaic/pallas_call_registration.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,19 +30,24 @@
from jax._src.interpreters import mlir
from jax._src.lib.mlir import ir
from jax._src.pallas import core
from jax._src.pallas.mosaic import core as tpu_core
from jax._src.pallas.mosaic import lowering
from jax._src.pallas.mosaic import verification
from jax._src import tpu_custom_call
from jax.experimental import mosaic
from jax.experimental.mosaic.dialects import tpu
from jax.experimental.pallas import tpu as pltpu

def _maybe_cast_to_int(x: jax.Array | jax_core.ShapedArray):
def _maybe_cast_to_int(x: jax.Array | jax_core.AbstractValue):
"""Casts boolean values to integers.
We perform this cast because Mosaic does not directly support bool values
for Memrefs. Instead, we load bools as integers and cast them to bools
after loading from a memref inside of the kernel.
"""
assert isinstance(
x, (jax.Array, jax_core.ShapedArray, jax_core.DShapedArray)
), type(x)
if isinstance(x, jax.Array):
if dtypes.issubdtype(x.dtype, jax.numpy.bool_):
return x.astype(lowering.BOOL_MEMREF_TYPE)
Expand All @@ -63,6 +68,41 @@ def _maybe_cast_to_int(x: jax.Array | jax_core.ShapedArray):
)


def _get_memory_space_from_aval(
out_aval: jax_core.AbstractValue,
) -> tpu_custom_call.MemorySpace | None:
if not isinstance(out_aval, jax_core.ShapedArray):
raise ValueError('Memory spaces not defined for non-ShapedArrays')
if not isinstance(out_aval, core.ShapedArrayWithMemorySpace):
# If we are passed a regular old ShapedArray, we don't constrain the
# memory space
return None
# If we are passed an aval with an explicit memory space tag, we use it
# to constrain the memory space.
match out_aval.memory_space:
case None:
return None
case tpu_core.TPUMemorySpace.ANY:
return None
case tpu_core.TPUMemorySpace.VMEM:
return tpu_custom_call.MemorySpace.VMEM
case tpu_core.TPUMemorySpace.SEMAPHORE:
return tpu_custom_call.MemorySpace.SEMAPHORE_MEM
return None


def _get_memory_spaces_from_avals(
out_avals: tuple[jax_core.AbstractValue, ...],
) -> tuple[tpu_custom_call.MemorySpace | None, ...] | None:
output_memory_spaces = None
if any(
isinstance(out_aval, core.ShapedArrayWithMemorySpace)
for out_aval in out_avals
):
output_memory_spaces = tuple(map(_get_memory_space_from_aval, out_avals))
return output_memory_spaces


def pallas_call_tpu_lowering_rule(
ctx: mlir.LoweringRuleContext,
*in_nodes,
Expand All @@ -74,6 +114,7 @@ def pallas_call_tpu_lowering_rule(
interpret: bool,
compiler_params: dict[str, Any],
cost_estimate: core.CostEstimate | None,
out_avals: tuple[jax_core.AbstractValue, ...],
):
"""Lowers a pallas_call to a Mosaic TPU custom call."""
del interpret
Expand Down Expand Up @@ -129,9 +170,6 @@ def lower_module(for_verification: bool):
(a[0] + num_dyn_bounds + num_extra_args, a[1])
for a in input_output_aliases
)
out_avals = [jax_core.ShapedArray(bm.array_shape_dtype.shape,
bm.array_shape_dtype.dtype)
for bm in grid_mapping.block_mappings_output]

if promela_dump_path := _DUMP_PROMELA_TO.value:
num_devices = 1 if mesh is None else mesh.devices.size
Expand Down Expand Up @@ -174,14 +212,15 @@ def lower_module(for_verification: bool):
def _maybe_cast_inputs(*args):
args = [_maybe_cast_to_int(x) for x in args]
return args
kernel_in_avals = [_maybe_cast_to_int(x) for x in ctx.avals_in] # type: ignore
kernel_in_avals = [_maybe_cast_to_int(x) for x in ctx.avals_in]
kernel_out_avals = [_maybe_cast_to_int(x) for x in out_avals]
cast_ctx = ctx.replace(avals_out=kernel_in_avals)
in_nodes = mlir.lower_fun(_maybe_cast_inputs)(cast_ctx, *in_nodes)

# Dynamic grid bounds have to go at the front.
dynamic_grid_args, args = in_nodes[:num_dyn_bounds], in_nodes[num_dyn_bounds:]
kernel_ctx = ctx.replace(avals_in=kernel_in_avals, avals_out=kernel_out_avals)
output_memory_spaces = _get_memory_spaces_from_avals(out_avals)
if cost_estimate is not None:
mosaic_cost_estimate = pltpu.CostEstimate(
flops=cost_estimate.flops,
Expand All @@ -208,7 +247,7 @@ def _maybe_cast_inputs(*args):
device_type=mosaic_params.get("device_type"),
internal_scratch_in_bytes=mosaic_params.get("internal_scratch_in_bytes"),
collective_id=mosaic_params.get("collective_id", None),
output_memory_spaces=None, # TODO(apaszke,sharadmv): Implement this.
output_memory_spaces=output_memory_spaces,
)
_maybe_cast_to_bool = lambda x, aval: x.astype(
jax.numpy.bool_) if aval.dtype == jax.numpy.bool_ else x
Expand Down
2 changes: 1 addition & 1 deletion jax/_src/pallas/mosaic/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@
SMEM = tpu_core.TPUMemorySpace.SMEM
VMEM = tpu_core.TPUMemorySpace.VMEM
DMA = tpu_core.SemaphoreType.DMA
REF = tpu_core.MemoryRef
REF = pallas_core.MemoryRef
SemaphoreType = tpu_core.SemaphoreType
SemaphoreTuple = jax.Array
ArrayRef = Union[REF, jax.Array]
Expand Down
Loading

0 comments on commit 9d2e9c6

Please sign in to comment.