From 3ae48621dd89bf01d4e165cff99e8160133fef60 Mon Sep 17 00:00:00 2001 From: Sergei Lebedev Date: Fri, 27 Sep 2024 02:27:44 -0700 Subject: [PATCH] Fixed Pallas Mosaic GPU test following recent changes PiperOrigin-RevId: 679504036 --- jax/_src/pallas/mosaic_gpu/lowering.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/jax/_src/pallas/mosaic_gpu/lowering.py b/jax/_src/pallas/mosaic_gpu/lowering.py index a64d4cc9f9b1..bdb22442c6b7 100644 --- a/jax/_src/pallas/mosaic_gpu/lowering.py +++ b/jax/_src/pallas/mosaic_gpu/lowering.py @@ -357,6 +357,10 @@ def body(launch_ctx: mgpu.LaunchContext, *buffers: ir.Value): *aval.shape, dtype=mgpu_utils.dtype_to_ir_type(aval.dtype) ) ) + case gpu_core.AbstractMemoryRef() if isinstance( + aval.dtype, gpu_core.BarrierType + ): + pass case gpu_core.AbstractMemoryRef() if aval.memory_space == SMEM: scratch_buffers_template.append(next(smem_scratch_it)) should_discharge.append(False) @@ -397,7 +401,7 @@ def gmem_slice( step: ir.Value, block_mapping: pallas_core.BlockMapping, ) -> Sequence[mgpu.DynamicSlice]: - assert len(sequential_axes) == 1 + assert len(sequential_axes) <= 1 program_ids = [step if i is None else i for i in program_ids_template] idxs = _eval_index_map(module_ctx, launch_ctx, program_ids, block_mapping) return tuple( @@ -428,7 +432,7 @@ def store( idx: int, step: ir.Value, slot: ir.Value, prev_base_offset: ir.Value ) -> ir.Value: if not out_in_smem[idx]: - return + return _as_index(-1) # We have to do some work to make sure that consecutive stores are not # going to be writing to the same location, or else we'll end up with