Skip to content

Commit

Permalink
Fixed Pallas Mosaic GPU test following recent changes
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 679261897
  • Loading branch information
superbobry authored and Google-ML-Automation committed Sep 27, 2024
1 parent ea6ee4d commit 290aae0
Showing 1 changed file with 6 additions and 2 deletions.
8 changes: 6 additions & 2 deletions jax/_src/pallas/mosaic_gpu/lowering.py
Original file line number Diff line number Diff line change
Expand Up @@ -357,6 +357,10 @@ def body(launch_ctx: mgpu.LaunchContext, *buffers: ir.Value):
*aval.shape, dtype=mgpu_utils.dtype_to_ir_type(aval.dtype)
)
)
case gpu_core.AbstractMemoryRef() if isinstance(
aval.dtype, gpu_core.BarrierType
):
pass
case gpu_core.AbstractMemoryRef() if aval.memory_space == SMEM:
scratch_buffers_template.append(next(smem_scratch_it))
should_discharge.append(False)
Expand Down Expand Up @@ -397,7 +401,7 @@ def gmem_slice(
step: ir.Value,
block_mapping: pallas_core.BlockMapping,
) -> Sequence[mgpu.DynamicSlice]:
assert len(sequential_axes) == 1
assert len(sequential_axes) <= 1
program_ids = [step if i is None else i for i in program_ids_template]
idxs = _eval_index_map(module_ctx, launch_ctx, program_ids, block_mapping)
return tuple(
Expand Down Expand Up @@ -428,7 +432,7 @@ def store(
idx: int, step: ir.Value, slot: ir.Value, prev_base_offset: ir.Value
) -> ir.Value:
if not out_in_smem[idx]:
return
return _as_index(-1)

# We have to do some work to make sure that consecutive stores are not
# going to be writing to the same location, or else we'll end up with
Expand Down

0 comments on commit 290aae0

Please sign in to comment.