Skip to content

Commit

Permalink
[XLA:Mosaic][Pallas] Enable vector.ExtractOp for non-zero indices.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 665053223
  • Loading branch information
justinjfu authored and Google-ML-Automation committed Sep 26, 2024
1 parent 6f7ad64 commit b19980a
Show file tree
Hide file tree
Showing 2 changed files with 48 additions and 19 deletions.
47 changes: 38 additions & 9 deletions jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout.cc
Original file line number Diff line number Diff line change
Expand Up @@ -3555,24 +3555,53 @@ LogicalResult vector_extract_rule(RewriteContext &ctx, Operation &op,
op.erase();
return success();
} else {
for (int64_t i : extract_op.getStaticPosition()) {
if (i != 0) {
return op.emitOpError(
"Not implemented: Only 0 indices supported for scalar results");
}
}
// TODO(b/367459476): Support non-zero offsets.
if (layout_in.offsets() != LayoutOffsets{0, 0}) {
return op.emitOpError("Not implemented: Unsupported layout");
}
auto [sub_tile, lane_tile] = layout_in.tiling();
FAILUREOR_ASSIGN_OR_RETURN(
const xla::Array<Value> vregs,
disassemble(builder, layout_in, extract_op.getVector(),
ctx.target_shape));
TPU_ASSERT_GT_OP(vregs.num_elements(), 0);

SmallVector<int64_t> indices(extract_op.getStaticPosition());
auto vreg_slice = layout_in.vregSlice(ctx.target_shape);
std::array<int64_t, 2> position = {0, 0};
SmallVector<int64_t> vreg_index(indices);
// TODO(b/367459476): Support non-VREG-aligned tiling.
CHECK_EQ(lane_tile, ctx.target_shape[1]);
layout_in.insertImplicit(indices, static_cast<int64_t>(0));
layout_in.insertImplicit(vreg_index, static_cast<int64_t>(0));
int i = *(indices.end()-2);
int j = *(indices.end()-1);
*(vreg_index.end() -2) = i / vreg_slice[0];
*(vreg_index.end() -1) = j / vreg_slice[1];
layout_in.eraseImplicit(vreg_index);
position[0] = ((j % vreg_slice[1]) / lane_tile * sub_tile
) + i % sub_tile;
position[1] = j % lane_tile;

TPU_ASSERT_LT_OP(vreg_index, vregs.dimensions());
Value extracted_vreg = vregs(vreg_index);

// Invert the offsets to get the rotation amount.
position[0] = (ctx.target_shape[0] - position[0]) % ctx.target_shape[0];
position[1] = (ctx.target_shape[1] - position[1]) % ctx.target_shape[1];
auto res_vreg_ty = extracted_vreg.getType();
Value shift = builder.create<arith::ConstantOp>(
builder.getIntegerAttr(builder.getI32Type(), position[0]));
Value rotated_vreg = builder.create<tpu::DynamicRotateOp>(
res_vreg_ty, extracted_vreg, shift, 0, /*stride*/nullptr, nullptr);
shift = builder.create<arith::ConstantOp>(
builder.getIntegerAttr(builder.getI32Type(), position[1]));
rotated_vreg = builder.create<tpu::DynamicRotateOp>(
res_vreg_ty, rotated_vreg, shift, 1, /*stride*/nullptr, nullptr);
extract_op.replaceAllUsesWith(
builder
.create<vector::ExtractOp>(op.getLoc(), *vregs.data(),
ArrayRef<int64_t>{0, 0})
builder.create<vector::ExtractOp>(
op.getLoc(), rotated_vreg,
ArrayRef<int64_t>{0, 0})
.getResult());
}
extract_op.erase();
Expand Down
20 changes: 10 additions & 10 deletions tests/pallas/pallas_error_handling_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,30 +44,30 @@ def setUp(self):
if not jtu.test_device_matches(["tpu"]):
self.skipTest("Test only works on TPU.")

def test_vector_extract_nonzero(self):
input_arr = jax.random.uniform(jax.random.key(0), (2, 2), dtype=jnp.float32)
out_shape = jax.ShapeDtypeStruct((1, 1), jnp.float32)
def test_non_singular_stride(self):
input_arr = jax.random.uniform(
jax.random.key(0), (8, 128), dtype=jnp.float32)
out_shape = jax.ShapeDtypeStruct((8, 16), jnp.float32)
grid_spec = pltpu.PrefetchScalarGridSpec(
num_scalar_prefetch=0,
in_specs=[
pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.VMEM),
],
out_specs=pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.SMEM),
out_specs=pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.VMEM),
)

@functools.partial(pl.pallas_call, out_shape=out_shape, grid_spec=grid_spec)
def test_kernel(input_ref, output_ref):
val = input_ref[...]
x = val[0, 0] + val[0, 1]
output_ref[0, 0] = x
x = input_ref[:, ::8]
output_ref[...] = x

# Test that a Mosaic error is raised. This assert is a guard against
# underlying changes in Mosaic.
# If this is fixed in future Mosaic releases we will need to change
# the test example to force a different error.
with self.assertRaisesRegex(
error_handling.MosaicError,
"Not implemented: Only 0 indices supported for scalar results",
"Not Implemented: Stride on last dim is not 1",
):
test_kernel(input_arr)

Expand All @@ -78,7 +78,7 @@ def test_kernel(input_ref, output_ref):
except error_handling.MosaicError as e:
tb_string = traceback.format_tb(e.__traceback__)
tb_string = "".join(tb_string)
self.assertEndsWith(tb_string, "x = val[0, 0] + val[0, 1]\n")
self.assertEndsWith(tb_string, "x = input_ref[:, ::8]\n")

@jax.jit
def kernel_in_jitted_fn(x):
Expand All @@ -91,7 +91,7 @@ def kernel_in_jitted_fn(x):
except error_handling.MosaicError as e:
tb_string = traceback.format_tb(e.__traceback__)
tb_string = "".join(tb_string)
self.assertEndsWith(tb_string, "x = val[0, 0] + val[0, 1]\n")
self.assertEndsWith(tb_string, "x = input_ref[:, ::8]\n")

def test_invalid_smem_vmem_verification_error(self):
input_arr = jax.random.uniform(jax.random.key(0), (2, 2), dtype=jnp.float32)
Expand Down

0 comments on commit b19980a

Please sign in to comment.