From f8d783bcce5c3724f0dece1b388457e9b9a7ae43 Mon Sep 17 00:00:00 2001 From: Christos Perivolaropoulos Date: Fri, 20 Sep 2024 05:00:19 -0700 Subject: [PATCH] [pallas::mosaic_gpu] Turn the accumulator into a reference * Changes the accumulator into a reference * Creates a discharged flavor of the wgmma op * run_scoped lowering discharges the input jaxpr * dereferencing the accumulator ref is done by a new primitive that behaves as expected when discharged * the deref primitive implies flushing the wgmma pipeline. * run_scoped does not allow references to be leaked. PiperOrigin-RevId: 676801466 --- jax/_src/pallas/mosaic_gpu/__init__.py | 3 +- jax/_src/pallas/mosaic_gpu/core.py | 50 ++++++++ jax/_src/pallas/mosaic_gpu/lowering.py | 67 +++++++++-- jax/_src/pallas/mosaic_gpu/primitives.py | 145 +++++++++++------------ tests/pallas/mosaic_gpu_test.py | 10 +- 5 files changed, 182 insertions(+), 93 deletions(-) diff --git a/jax/_src/pallas/mosaic_gpu/__init__.py b/jax/_src/pallas/mosaic_gpu/__init__.py index ddf27361493a..bbada82ace82 100644 --- a/jax/_src/pallas/mosaic_gpu/__init__.py +++ b/jax/_src/pallas/mosaic_gpu/__init__.py @@ -18,14 +18,13 @@ from jax._src.pallas.mosaic_gpu.core import GPUBlockSpec from jax._src.pallas.mosaic_gpu.core import GPUCompilerParams from jax._src.pallas.mosaic_gpu.core import GPUMemorySpace +from jax._src.pallas.mosaic_gpu.core import WGMMAAccumulatorRef as ACC from jax._src.pallas.mosaic_gpu.primitives import async_copy_gmem_to_smem from jax._src.pallas.mosaic_gpu.primitives import async_copy_smem_to_gmem from jax._src.pallas.mosaic_gpu.primitives import wait_barrier from jax._src.pallas.mosaic_gpu.primitives import wait_smem_to_gmem -from jax._src.pallas.mosaic_gpu.primitives import zero_accumulator from jax._src.pallas.mosaic_gpu.primitives import wgmma from jax._src.pallas.mosaic_gpu.primitives import wgmma_wait GMEM = GPUMemorySpace.GMEM SMEM = GPUMemorySpace.SMEM -REGS = GPUMemorySpace.REGS diff --git a/jax/_src/pallas/mosaic_gpu/core.py b/jax/_src/pallas/mosaic_gpu/core.py index c4489226c860..df633619e6ee 100644 --- a/jax/_src/pallas/mosaic_gpu/core.py +++ b/jax/_src/pallas/mosaic_gpu/core.py @@ -203,3 +203,53 @@ def get_ref_aval(self) -> AbstractMemoryRef: [self.num_barriers], BarrierType(self.num_arrivals) ) return AbstractMemoryRef(aval, SMEM) + + +@dataclasses.dataclass(frozen=True) +class WGMMAAccumulatorRef: + shape: tuple[int, int] + dtype: jnp.dtype = jnp.float32 + + def get_ref_aval(self) -> AbstractMemoryRef: + return WGMMAAbstractAccumulatorRef( + jax_core.ShapedArray(shape=self.shape, dtype=self.dtype), GPUMemorySpace.REGS + ) + + +def _is_trivial_index(idx): + _is_deref1 = lambda i: i is Ellipsis or i == slice(None) + if isinstance(idx, tuple): + return all(_is_deref1(i) for i in idx) + + return _is_deref1(idx) + +class WGMMAAbstractAccumulatorRef(AbstractMemoryRef): + __slots__ = ["inner_aval", "memory_space"] + + def __repr__(self) -> str: + return f'Accumulator{{{self.inner_aval.str_short()}}}' + + def join(self, other): + return _as_accum(super().join(other)) + + def update(self, inner_aval=None, memory_space=None): + return _as_accum(super().update(inner_aval=None, memory_space=None)) + + def at_least_vspace(self): + return _as_accum(super().at_least_vspace()) + + def _getitem(self, tracer, idx): + if not _is_trivial_index(idx): + raise NotImplementedError(f"Can only dereference accumulators, not slice ({idx=}).") + from jax._src.pallas.mosaic_gpu.primitives import wgmma_accumulator_deref # pytype: disable=import-error + return wgmma_accumulator_deref(tracer) + +def _as_accum(ref) -> WGMMAAbstractAccumulatorRef: + return WGMMAAbstractAccumulatorRef( + inner_aval=ref.inner_aval, + memory_space=ref.memory_space, # pytype: disable=attribute-error + ) + +def _ref_raise_to_shaped(ref_aval, weak_type): + return _as_accum(jax_core.raise_to_shaped_mappings[AbstractMemoryRef](ref_aval, weak_type)) +jax_core.raise_to_shaped_mappings[WGMMAAbstractAccumulatorRef] = _ref_raise_to_shaped diff --git a/jax/_src/pallas/mosaic_gpu/lowering.py b/jax/_src/pallas/mosaic_gpu/lowering.py index 702f5ca17e61..f7cc0a46dde0 100644 --- a/jax/_src/pallas/mosaic_gpu/lowering.py +++ b/jax/_src/pallas/mosaic_gpu/lowering.py @@ -28,6 +28,7 @@ from jax._src import core as jax_core from jax._src import pjit from jax._src import util +from jax._src.interpreters import mlir from jax._src.interpreters import partial_eval as pe from jax._src.lib.mlir import ir from jax._src.lib.mlir.dialects import arith as arith_dialect @@ -37,7 +38,9 @@ from jax._src.pallas import primitives from jax._src.pallas import utils as pallas_utils from jax._src.pallas.mosaic_gpu import core as gpu_core +from jax._src.state import discharge from jax._src.state import primitives as sp +from jax._src.state import utils as state_utils import jax.experimental.mosaic.gpu as mgpu from jax.experimental.mosaic.gpu import core as mgpu_core from jax.experimental.mosaic.gpu import utils as mgpu_utils @@ -615,6 +618,10 @@ def _swap_lowering_rule( del tree # Unused. if indexers: raise NotImplementedError("No support for indexers yet") + if not isinstance(value, mgpu.FragmentedArray): + raise TypeError(f"Can only store arrays (got {value}).") + if not isinstance(x_smem, ir.Value) and ir.MemRefType.isinstance(x_smem): + raise TypeError(f"Can only store to references (got {value}).") x_aval, _ = ctx.avals_in old_value = mgpu.FragmentedArray.load_strided( x_smem, is_signed=mgpu_utils.is_signed(x_aval.dtype) @@ -735,15 +742,57 @@ def _(val, idx): def _run_scoped_lowering_rule( ctx: LoweringRuleContext, *consts, jaxpr: jax_core.Jaxpr ): - in_avals = [v.aval.inner_aval for v in jaxpr.invars] - bytes_allocated, input_refs = ctx.module_ctx.scratch_view([ - jax.ShapeDtypeStruct(shape=aval.shape, dtype=aval.dtype) - for aval in in_avals - ]) - outs = lower_jaxpr_to_mosaic_gpu( - ctx.module_ctx, ctx.launch_ctx, jaxpr, input_refs, consts - ) - ctx.module_ctx.stack_free_smem(bytes_allocated) + input_refs = [] + bytes_allocated = 0 + should_discharge = [] + for a in jaxpr.invars: + a = a.aval + if isinstance(a, gpu_core.WGMMAAbstractAccumulatorRef): + mlir_dtype = mlir.dtype_to_ir_type(a.dtype) + input_refs.append(mgpu.WGMMAAccumulator.zero(*a.shape, mlir_dtype)) + should_discharge.append(True) + elif a.memory_space == gpu_core.SMEM: + ref_bytes, [input_ref] = ctx.module_ctx.scratch_view( + [jax.ShapeDtypeStruct(shape=a.shape, dtype=a.dtype)] + ) + bytes_allocated += ref_bytes + input_refs.append(input_ref) + should_discharge.append(False) + else: + raise ValueError(f"Can't convert to ref: {a}") + + if any(should_discharge): + # We convert consts to args, because we only have ir.Values and + # not JAX values during lowering. discharge_state() produces JAX + # valiues for the aguments but expects them to be provided for the + # consts. We also don't want to wrap the values in refs. + no_const_jaxpr = pe.convert_constvars_jaxpr(jaxpr) + should_discharge = [False] * len(consts) + should_discharge + discharged_jaxpr, _ = discharge.discharge_state(no_const_jaxpr, (), should_discharge=should_discharge) + new_input_vals = consts + tuple(input_refs) + outs = lower_jaxpr_to_mosaic_gpu( + ctx.module_ctx, ctx.launch_ctx, discharged_jaxpr, new_input_vals, () + ) + # Discharge appends to the output the refs that got discharged. + outs = outs[:-sum(should_discharge)] + else: + outs = lower_jaxpr_to_mosaic_gpu( + ctx.module_ctx, ctx.launch_ctx, jaxpr, input_refs, consts + ) + + for o in outs: + # This is definitely one of the accumulators we produced. Each + # run_scoped call is responsible for dereferencing its own + # accumulators. + if isinstance(o, mgpu.WGMMAAccumulator) or ( + isinstance(o, ir.Value) and ir.MemRefType.isinstance(o.type) + ): + raise ValueError(f"No references are allowed to escape a scope. (got {o})") + + assert len(outs) == len(jaxpr.outvars), (jaxpr, outs) + if bytes_allocated: + ctx.module_ctx.stack_free_smem(bytes_allocated) + return outs diff --git a/jax/_src/pallas/mosaic_gpu/primitives.py b/jax/_src/pallas/mosaic_gpu/primitives.py index ef30dd0956ec..0a99e77c5ec8 100644 --- a/jax/_src/pallas/mosaic_gpu/primitives.py +++ b/jax/_src/pallas/mosaic_gpu/primitives.py @@ -21,7 +21,7 @@ from jax._src import core as jax_core from jax._src import effects from jax._src import state -from jax._src.interpreters import mlir +from jax._src.state import discharge from jax._src.lib.mlir.dialects import nvvm as nvvm_dialect from jax._src.pallas import core as pallas_core from jax._src.pallas.mosaic_gpu import core as gpu_core @@ -118,62 +118,9 @@ class _WGMMAPipelineEffect(effects.Effect): _wgmma_pipeline_effect = _WGMMAPipelineEffect() effects.control_flow_allowed_effects.add_type(_WGMMAPipelineEffect) - -# Not a shaped array to avoid unexpected operations. -class WGMMAAbstractAccumulator(jax_core.AbstractValue): - __slots__ = ['shape', 'dtype'] - - def __init__(self, shape, dtype): - self.shape = shape - self.dtype = dtype - - def __eq__(self, other): - return (type(self) is type(other) - and self.dtype == other.dtype and self.shape == other.shape) - - def __hash__(self): - return hash((self.shape, self.dtype)) - - def update(self, shape=None, dtype=None): - if shape is None: - shape = self.shape - if dtype is None: - dtype = self.dtype - return WGMMAAbstractAccumulator(shape, dtype) - - def str_short(self, short_dtypes=False) -> str: - del short_dtypes - shapestr = ",".join(map(str, self.shape)) - return f"Accumulator{{{self.dtype.name}}}[{shapestr}]" - -@dataclasses.dataclass(frozen=True) -class WGMMAAccumulator: - inner_aval: WGMMAAbstractAccumulator - - shape = property(lambda self: self.inner_aval.shape) - dtype = property(lambda self: self.inner_aval.dtype) - - def as_array(self) -> jax_core.ShapedArray: - return acc_to_shaped_array_p.bind(self.inner_aval) - - -jax_core.raise_to_shaped_mappings[WGMMAAbstractAccumulator] = lambda aval, _: aval - -acc_to_shaped_array_p = jax_core.Primitive("acc_to_shaped_array") - -@acc_to_shaped_array_p.def_abstract_eval -def _acc_to_shaped_array_abstract_eval(acc) -> jax_core.ShapedArray: - return jax_core.ShapedArray(shape=acc.shape, dtype=acc.dtype) - - -@lowering.register_lowering_rule(acc_to_shaped_array_p) -def _acc_to_shaped_array_lowering_rule( - ctx: lowering.LoweringRuleContext, acc -): - del ctx - return acc.value - -wgmma_p = jax_core.Primitive("wgmma") +# WGMMA on an accumulator reference +wgmma_ref_p = jax_core.Primitive("wgmma_ref") +wgmma_ref_p.multiple_results = True def wgmma(acc, a, b, *, rhs_transpose: bool | None = None, swizzle: int = 128): """Asynchronous warp group matmul. @@ -189,8 +136,8 @@ def wgmma(acc, a, b, *, rhs_transpose: bool | None = None, swizzle: int = 128): n_tile: The number of tiles to use. swizzle: The swizzle pattern. """ - if not isinstance(acc, WGMMAAccumulator): - raise TypeError(acc) + if not isinstance(acc.aval, gpu_core.WGMMAAbstractAccumulatorRef): + raise TypeError(f"Expected WGMMAAbstractAccumulatorRef got {acc}") rhs_transpose = ( (jnp.dtype(b.dtype).itemsize == 2) @@ -208,18 +155,43 @@ def wgmma(acc, a, b, *, rhs_transpose: bool | None = None, swizzle: int = 128): if tma * ma != mc or nb * tnb != nc or ka != kb or tka != tkb: raise ValueError(f"Incompatible shapes: {a.shape=}, {b.shape=}, {acc.shape=}, {rhs_transpose=}") - outval = wgmma_p.bind(acc.inner_aval, a, b, swizzle=swizzle, rhs_transpose=rhs_transpose) - return WGMMAAccumulator(outval) + return wgmma_ref_p.bind(acc, a, b, swizzle=swizzle, rhs_transpose=rhs_transpose) -@wgmma_p.def_effectful_abstract_eval -def _wgmma_effectful_abstract_eval(acc, *args, **kwargs): - del args, kwargs - return acc, { + +@wgmma_ref_p.def_effectful_abstract_eval +def _wgmma_ref_effectful_abstract_eval(acc, *args, **kwargs): + del acc, args, kwargs + return [], { _wgmma_pipeline_effect, + state.WriteEffect(0), + state.ReadEffect(0), state.ReadEffect(1), state.ReadEffect(2), } + +@discharge.register_discharge_rule(wgmma_ref_p) +def _wgmma_ref_discharge_rule( + in_avals, out_avals, + acc, + a, + b, + swizzle, + rhs_transpose, +): + del in_avals, out_avals + return ( + wgmma_p.bind( + acc, a, b, swizzle=swizzle, rhs_transpose=rhs_transpose + ), + None, + None, + ), [] + + +# Functional WGMMA, returns a shaped array. Internal. +wgmma_p = jax_core.Primitive("wgmma") + @lowering.register_lowering_rule(wgmma_p) def _wgmma_lowering_rule( ctx: lowering.LoweringRuleContext, @@ -242,6 +214,15 @@ def _wgmma_lowering_rule( nvvm_dialect.wgmma_commit_group_sync_aligned() return new_acc +@wgmma_p.def_effectful_abstract_eval +def _wgmma_effectful_abstract_eval(acc, *args, **kwargs): + del args, kwargs + return acc, { + _wgmma_pipeline_effect, + state.ReadEffect(1), + state.ReadEffect(2), + } + wgmma_wait_p = jax_core.Primitive("wgmma_wait") wgmma_wait_p.multiple_results = True @@ -260,19 +241,29 @@ def _wgmma_wait_lowering_rule(ctx: lowering.LoweringRuleContext, allow_groups): nvvm_dialect.wgmma_wait_group_sync_aligned(allow_groups) return () -zero_accumulator_p = jax_core.Primitive("zero_accumulator") -def zero_accumulator(shape, dtype): - return WGMMAAccumulator(zero_accumulator_p.bind(shape=shape, dtype=dtype)) +wgmma_accumulator_deref_p = jax_core.Primitive("wgmma_accumulator_deref_p") +def wgmma_accumulator_deref(acc): + """Dereferences an accumulator register.""" -@zero_accumulator_p.def_abstract_eval -def _zero_accumulator_abstract_eval(shape, dtype): - return WGMMAAbstractAccumulator(shape=shape, dtype=dtype) + if not isinstance(acc.aval, gpu_core.WGMMAAbstractAccumulatorRef): + raise TypeError(f"acc must be a WGMMAAccumulatorAbstractRef, got {acc.aval=}") + return wgmma_accumulator_deref_p.bind(acc) -@lowering.register_lowering_rule(zero_accumulator_p) -def _zero_accumulator_lowering_rule( - ctx: lowering.LoweringRuleContext, shape, dtype -): +@wgmma_accumulator_deref_p.def_effectful_abstract_eval +def _wgmma_accumulator_deref_abstract_eval(acc): + # Dereferencing implies flushing so we have a wgmma pipeline effect. + ret = acc.inner_aval if isinstance(acc, gpu_core.WGMMAAbstractAccumulatorRef) else acc + assert isinstance(ret, jax_core.ShapedArray), acc + return ret, {_wgmma_pipeline_effect} + +@discharge.register_discharge_rule(wgmma_accumulator_deref_p) +def _wgmma_accumulator_deref_discharge_rule(in_avals, out_avals, acc): + del in_avals, out_avals + return (None,), wgmma_accumulator_deref_p.bind(acc) + +@lowering.register_lowering_rule(wgmma_accumulator_deref_p) +def _wgmma_accumulator_deref_lowering_rule(ctx: lowering.LoweringRuleContext, acc): del ctx - m, n = shape - return mgpu.WGMMAAccumulator.zero(m=m, n=n, dtype=mlir.dtype_to_ir_type(jnp.dtype(dtype))) + nvvm_dialect.wgmma_wait_group_sync_aligned(0) + return acc.value diff --git a/tests/pallas/mosaic_gpu_test.py b/tests/pallas/mosaic_gpu_test.py index 0eb7d91960ff..431dba133ce3 100644 --- a/tests/pallas/mosaic_gpu_test.py +++ b/tests/pallas/mosaic_gpu_test.py @@ -380,11 +380,11 @@ def test_wgmma(self, dtype): swizzle = 128 elems_128b = swizzle // jnp.dtype(dtype).itemsize def kernel(a_ref, b_ref, o_ref): - acc = plgpu.zero_accumulator((64, 128), jnp.float32) - acc = plgpu.wgmma(acc, a_ref, b_ref, rhs_transpose=rhs_transpose) - plgpu.wgmma_wait(0) - # TODO(cperivol): turn acc into a reference so we can reason about effects. - o_ref[...] = acc.as_array() + def scope(acc_ref): + plgpu.wgmma(acc_ref, a_ref, b_ref, rhs_transpose=rhs_transpose) + return acc_ref[...] + + o_ref[...] = pl.run_scoped(scope, plgpu.ACC((64, 128), jnp.float32)) key1, key2 = jax.random.split(jax.random.key(42), 2) a = jax.random.uniform(key1, shape=(64, 128), dtype=dtype)