Skip to content

Commit

Permalink
[pallas::mosaic_gpu] Turn the accumulator into a reference
Browse files Browse the repository at this point in the history
* Changes the accumulator into a reference
* Creates a discharged flavor of the wgmma op
* run_scoped lowering discharges the input jaxpr
* dereferencing the accumulator ref is done by a new primitive that behaves as expected when discharged
* the deref primitive implies flushing the wgmma pipeline.
* run_scoped does not allow references to be leaked.

PiperOrigin-RevId: 676801466
  • Loading branch information
cperivol authored and Google-ML-Automation committed Sep 26, 2024
1 parent f6fdfb4 commit e979541
Show file tree
Hide file tree
Showing 5 changed files with 181 additions and 95 deletions.
3 changes: 1 addition & 2 deletions jax/_src/pallas/mosaic_gpu/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,14 +18,13 @@
from jax._src.pallas.mosaic_gpu.core import GPUBlockSpec
from jax._src.pallas.mosaic_gpu.core import GPUCompilerParams
from jax._src.pallas.mosaic_gpu.core import GPUMemorySpace
from jax._src.pallas.mosaic_gpu.core import WGMMAAccumulatorRef as ACC
from jax._src.pallas.mosaic_gpu.primitives import async_copy_gmem_to_smem
from jax._src.pallas.mosaic_gpu.primitives import async_copy_smem_to_gmem
from jax._src.pallas.mosaic_gpu.primitives import wait_barrier
from jax._src.pallas.mosaic_gpu.primitives import wait_smem_to_gmem
from jax._src.pallas.mosaic_gpu.primitives import zero_accumulator
from jax._src.pallas.mosaic_gpu.primitives import wgmma
from jax._src.pallas.mosaic_gpu.primitives import wgmma_wait

GMEM = GPUMemorySpace.GMEM
SMEM = GPUMemorySpace.SMEM
REGS = GPUMemorySpace.REGS
50 changes: 50 additions & 0 deletions jax/_src/pallas/mosaic_gpu/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,3 +203,53 @@ def get_ref_aval(self) -> AbstractMemoryRef:
[self.num_barriers], BarrierType(self.num_arrivals)
)
return AbstractMemoryRef(aval, SMEM)


@dataclasses.dataclass(frozen=True)
class WGMMAAccumulatorRef:
shape: tuple[int, int]
dtype: jnp.dtype = jnp.float32

def get_ref_aval(self) -> AbstractMemoryRef:
return WGMMAAbstractAccumulatorRef(
jax_core.ShapedArray(shape=self.shape, dtype=self.dtype), GPUMemorySpace.REGS
)


def _is_trivial_index(idx):
_is_deref1 = lambda i: i is Ellipsis or i == slice(None)
if isinstance(idx, tuple):
return all(_is_deref1(i) for i in idx)

return _is_deref1(idx)

class WGMMAAbstractAccumulatorRef(AbstractMemoryRef):
__slots__ = ["inner_aval", "memory_space"]

def __repr__(self) -> str:
return f'Accumulator{{{self.inner_aval.str_short()}}}'

def join(self, other):
return _as_accum(super().join(other))

def update(self, inner_aval=None, memory_space=None):
return _as_accum(super().update(inner_aval=None, memory_space=None))

def at_least_vspace(self):
return _as_accum(super().at_least_vspace())

def _getitem(self, tracer, idx):
if not _is_trivial_index(idx):
raise NotImplementedError(f"Can only dereference accumulators, not slice ({idx=}).")
from jax._src.pallas.mosaic_gpu.primitives import wgmma_accumulator_deref # pytype: disable=import-error
return wgmma_accumulator_deref(tracer)

def _as_accum(ref) -> WGMMAAbstractAccumulatorRef:
return WGMMAAbstractAccumulatorRef(
inner_aval=ref.inner_aval,
memory_space=ref.memory_space, # pytype: disable=attribute-error
)

def _ref_raise_to_shaped(ref_aval, weak_type):
return _as_accum(jax_core.raise_to_shaped_mappings[AbstractMemoryRef](ref_aval, weak_type))
jax_core.raise_to_shaped_mappings[WGMMAAbstractAccumulatorRef] = _ref_raise_to_shaped
66 changes: 57 additions & 9 deletions jax/_src/pallas/mosaic_gpu/lowering.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
from jax._src import core as jax_core
from jax._src import pjit
from jax._src import util
from jax._src.interpreters import mlir
from jax._src.interpreters import partial_eval as pe
from jax._src.lib.mlir import ir
from jax._src.lib.mlir.dialects import arith as arith_dialect
Expand All @@ -37,6 +38,7 @@
from jax._src.pallas import primitives
from jax._src.pallas import utils as pallas_utils
from jax._src.pallas.mosaic_gpu import core as gpu_core
from jax._src.state import discharge
from jax._src.state import primitives as sp
import jax.experimental.mosaic.gpu as mgpu
from jax.experimental.mosaic.gpu import core as mgpu_core
Expand Down Expand Up @@ -615,6 +617,10 @@ def _swap_lowering_rule(
del tree # Unused.
if indexers:
raise NotImplementedError("No support for indexers yet")
if not isinstance(value, mgpu.FragmentedArray):
raise TypeError(f"Can only store arrays (got {value}).")
if not isinstance(x_smem, ir.Value) and ir.MemRefType.isinstance(x_smem):
raise TypeError(f"Can only store to references (got {value}).")
x_aval, _ = ctx.avals_in
old_value = mgpu.FragmentedArray.load_strided(
x_smem, is_signed=mgpu_utils.is_signed(x_aval.dtype)
Expand Down Expand Up @@ -735,15 +741,57 @@ def _(val, idx):
def _run_scoped_lowering_rule(
ctx: LoweringRuleContext, *consts, jaxpr: jax_core.Jaxpr
):
in_avals = [v.aval.inner_aval for v in jaxpr.invars]
bytes_allocated, input_refs = ctx.module_ctx.scratch_view([
jax.ShapeDtypeStruct(shape=aval.shape, dtype=aval.dtype)
for aval in in_avals
])
outs = lower_jaxpr_to_mosaic_gpu(
ctx.module_ctx, ctx.launch_ctx, jaxpr, input_refs, consts
)
ctx.module_ctx.stack_free_smem(bytes_allocated)
input_refs = []
bytes_allocated = 0
should_discharge = []
for a in jaxpr.invars:
a = a.aval
if isinstance(a, gpu_core.WGMMAAbstractAccumulatorRef):
mlir_dtype = mlir.dtype_to_ir_type(a.dtype)
input_refs.append(mgpu.WGMMAAccumulator.zero(*a.shape, mlir_dtype))
should_discharge.append(True)
elif a.memory_space == gpu_core.SMEM:
ref_bytes, [input_ref] = ctx.module_ctx.scratch_view(
[jax.ShapeDtypeStruct(shape=a.shape, dtype=a.dtype)]
)
bytes_allocated += ref_bytes
input_refs.append(input_ref)
should_discharge.append(False)
else:
raise ValueError(f"Can't convert to ref: {a}")

if any(should_discharge):
# We convert consts to args, because we only have ir.Values and
# not JAX values during lowering. discharge_state() produces JAX
# valiues for the aguments but expects them to be provided for the
# consts. We also don't want to wrap the values in refs.
no_const_jaxpr = pe.convert_constvars_jaxpr(jaxpr)
should_discharge = [False] * len(consts) + should_discharge
discharged_jaxpr, _ = discharge.discharge_state(no_const_jaxpr, (), should_discharge=should_discharge)
new_input_vals = consts + tuple(input_refs)
outs = lower_jaxpr_to_mosaic_gpu(
ctx.module_ctx, ctx.launch_ctx, discharged_jaxpr, new_input_vals, ()
)
# Discharge appends to the output the refs that got discharged.
outs = outs[:-sum(should_discharge)]
else:
outs = lower_jaxpr_to_mosaic_gpu(
ctx.module_ctx, ctx.launch_ctx, jaxpr, input_refs, consts
)

for o in outs:
# This is definitely one of the accumulators we produced. Each
# run_scoped call is responsible for dereferencing its own
# accumulators.
if isinstance(o, mgpu.WGMMAAccumulator) or (
isinstance(o, ir.Value) and ir.MemRefType.isinstance(o.type)
):
raise ValueError(f"No references are allowed to escape a scope. (got {o})")

assert len(outs) == len(jaxpr.outvars), (jaxpr, outs)
if bytes_allocated:
ctx.module_ctx.stack_free_smem(bytes_allocated)

return outs


Expand Down
147 changes: 68 additions & 79 deletions jax/_src/pallas/mosaic_gpu/primitives.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,10 @@

from __future__ import annotations

import dataclasses

from jax._src import core as jax_core
from jax._src import effects
from jax._src import state
from jax._src.interpreters import mlir
from jax._src.state import discharge
from jax._src.lib.mlir.dialects import nvvm as nvvm_dialect
from jax._src.pallas import core as pallas_core
from jax._src.pallas.mosaic_gpu import core as gpu_core
Expand Down Expand Up @@ -118,62 +116,9 @@ class _WGMMAPipelineEffect(effects.Effect):
_wgmma_pipeline_effect = _WGMMAPipelineEffect()
effects.control_flow_allowed_effects.add_type(_WGMMAPipelineEffect)


# Not a shaped array to avoid unexpected operations.
class WGMMAAbstractAccumulator(jax_core.AbstractValue):
__slots__ = ['shape', 'dtype']

def __init__(self, shape, dtype):
self.shape = shape
self.dtype = dtype

def __eq__(self, other):
return (type(self) is type(other)
and self.dtype == other.dtype and self.shape == other.shape)

def __hash__(self):
return hash((self.shape, self.dtype))

def update(self, shape=None, dtype=None):
if shape is None:
shape = self.shape
if dtype is None:
dtype = self.dtype
return WGMMAAbstractAccumulator(shape, dtype)

def str_short(self, short_dtypes=False) -> str:
del short_dtypes
shapestr = ",".join(map(str, self.shape))
return f"Accumulator{{{self.dtype.name}}}[{shapestr}]"

@dataclasses.dataclass(frozen=True)
class WGMMAAccumulator:
inner_aval: WGMMAAbstractAccumulator

shape = property(lambda self: self.inner_aval.shape)
dtype = property(lambda self: self.inner_aval.dtype)

def as_array(self) -> jax_core.ShapedArray:
return acc_to_shaped_array_p.bind(self.inner_aval)


jax_core.raise_to_shaped_mappings[WGMMAAbstractAccumulator] = lambda aval, _: aval

acc_to_shaped_array_p = jax_core.Primitive("acc_to_shaped_array")

@acc_to_shaped_array_p.def_abstract_eval
def _acc_to_shaped_array_abstract_eval(acc) -> jax_core.ShapedArray:
return jax_core.ShapedArray(shape=acc.shape, dtype=acc.dtype)


@lowering.register_lowering_rule(acc_to_shaped_array_p)
def _acc_to_shaped_array_lowering_rule(
ctx: lowering.LoweringRuleContext, acc
):
del ctx
return acc.value

wgmma_p = jax_core.Primitive("wgmma")
# WGMMA on an accumulator reference
wgmma_ref_p = jax_core.Primitive("wgmma_ref")
wgmma_ref_p.multiple_results = True

def wgmma(acc, a, b, *, rhs_transpose: bool | None = None, swizzle: int = 128):
"""Asynchronous warp group matmul.
Expand All @@ -189,8 +134,8 @@ def wgmma(acc, a, b, *, rhs_transpose: bool | None = None, swizzle: int = 128):
n_tile: The number of tiles to use.
swizzle: The swizzle pattern.
"""
if not isinstance(acc, WGMMAAccumulator):
raise TypeError(acc)
if not isinstance(acc.aval, gpu_core.WGMMAAbstractAccumulatorRef):
raise TypeError(f"Expected WGMMAAbstractAccumulatorRef got {acc}")

rhs_transpose = (
(jnp.dtype(b.dtype).itemsize == 2)
Expand All @@ -208,18 +153,43 @@ def wgmma(acc, a, b, *, rhs_transpose: bool | None = None, swizzle: int = 128):
if tma * ma != mc or nb * tnb != nc or ka != kb or tka != tkb:
raise ValueError(f"Incompatible shapes: {a.shape=}, {b.shape=}, {acc.shape=}, {rhs_transpose=}")

outval = wgmma_p.bind(acc.inner_aval, a, b, swizzle=swizzle, rhs_transpose=rhs_transpose)
return WGMMAAccumulator(outval)
return wgmma_ref_p.bind(acc, a, b, swizzle=swizzle, rhs_transpose=rhs_transpose)

@wgmma_p.def_effectful_abstract_eval
def _wgmma_effectful_abstract_eval(acc, *args, **kwargs):
del args, kwargs
return acc, {

@wgmma_ref_p.def_effectful_abstract_eval
def _wgmma_ref_effectful_abstract_eval(acc, *args, **kwargs):
del acc, args, kwargs
return [], {
_wgmma_pipeline_effect,
state.WriteEffect(0),
state.ReadEffect(0),
state.ReadEffect(1),
state.ReadEffect(2),
}


@discharge.register_discharge_rule(wgmma_ref_p)
def _wgmma_ref_discharge_rule(
in_avals, out_avals,
acc,
a,
b,
swizzle,
rhs_transpose,
):
del in_avals, out_avals
return (
wgmma_p.bind(
acc, a, b, swizzle=swizzle, rhs_transpose=rhs_transpose
),
None,
None,
), []


# Functional WGMMA, returns a shaped array. Internal.
wgmma_p = jax_core.Primitive("wgmma")

@lowering.register_lowering_rule(wgmma_p)
def _wgmma_lowering_rule(
ctx: lowering.LoweringRuleContext,
Expand All @@ -242,6 +212,15 @@ def _wgmma_lowering_rule(
nvvm_dialect.wgmma_commit_group_sync_aligned()
return new_acc

@wgmma_p.def_effectful_abstract_eval
def _wgmma_effectful_abstract_eval(acc, *args, **kwargs):
del args, kwargs
return acc, {
_wgmma_pipeline_effect,
state.ReadEffect(1),
state.ReadEffect(2),
}

wgmma_wait_p = jax_core.Primitive("wgmma_wait")
wgmma_wait_p.multiple_results = True

Expand All @@ -260,19 +239,29 @@ def _wgmma_wait_lowering_rule(ctx: lowering.LoweringRuleContext, allow_groups):
nvvm_dialect.wgmma_wait_group_sync_aligned(allow_groups)
return ()

zero_accumulator_p = jax_core.Primitive("zero_accumulator")
def zero_accumulator(shape, dtype):
return WGMMAAccumulator(zero_accumulator_p.bind(shape=shape, dtype=dtype))
wgmma_accumulator_deref_p = jax_core.Primitive("wgmma_accumulator_deref_p")
def wgmma_accumulator_deref(acc):
"""Dereferences an accumulator register."""

@zero_accumulator_p.def_abstract_eval
def _zero_accumulator_abstract_eval(shape, dtype):
return WGMMAAbstractAccumulator(shape=shape, dtype=dtype)
if not isinstance(acc.aval, gpu_core.WGMMAAbstractAccumulatorRef):
raise TypeError(f"acc must be a WGMMAAccumulatorAbstractRef, got {acc.aval=}")

return wgmma_accumulator_deref_p.bind(acc)

@lowering.register_lowering_rule(zero_accumulator_p)
def _zero_accumulator_lowering_rule(
ctx: lowering.LoweringRuleContext, shape, dtype
):
@wgmma_accumulator_deref_p.def_effectful_abstract_eval
def _wgmma_accumulator_deref_abstract_eval(acc):
# Dereferencing implies flushing so we have a wgmma pipeline effect.
ret = acc.inner_aval if isinstance(acc, gpu_core.WGMMAAbstractAccumulatorRef) else acc
assert isinstance(ret, jax_core.ShapedArray), acc
return ret, {_wgmma_pipeline_effect}

@discharge.register_discharge_rule(wgmma_accumulator_deref_p)
def _wgmma_accumulator_deref_discharge_rule(in_avals, out_avals, acc):
del in_avals, out_avals
return (None,), wgmma_accumulator_deref_p.bind(acc)

@lowering.register_lowering_rule(wgmma_accumulator_deref_p)
def _wgmma_accumulator_deref_lowering_rule(ctx: lowering.LoweringRuleContext, acc):
del ctx
m, n = shape
return mgpu.WGMMAAccumulator.zero(m=m, n=n, dtype=mlir.dtype_to_ir_type(jnp.dtype(dtype)))
nvvm_dialect.wgmma_wait_group_sync_aligned(0)
return acc.value
10 changes: 5 additions & 5 deletions tests/pallas/mosaic_gpu_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -380,11 +380,11 @@ def test_wgmma(self, dtype):
swizzle = 128
elems_128b = swizzle // jnp.dtype(dtype).itemsize
def kernel(a_ref, b_ref, o_ref):
acc = plgpu.zero_accumulator((64, 128), jnp.float32)
acc = plgpu.wgmma(acc, a_ref, b_ref, rhs_transpose=rhs_transpose)
plgpu.wgmma_wait(0)
# TODO(cperivol): turn acc into a reference so we can reason about effects.
o_ref[...] = acc.as_array()
def scope(acc_ref):
plgpu.wgmma(acc_ref, a_ref, b_ref, rhs_transpose=rhs_transpose)
return acc_ref[...]

o_ref[...] = pl.run_scoped(scope, plgpu.ACC((64, 128), jnp.float32))

key1, key2 = jax.random.split(jax.random.key(42), 2)
a = jax.random.uniform(key1, shape=(64, 128), dtype=dtype)
Expand Down

0 comments on commit e979541

Please sign in to comment.