diff --git a/jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout.cc b/jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout.cc index a1714fc8090b..f6e1c7918646 100644 --- a/jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout.cc +++ b/jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout.cc @@ -3555,24 +3555,53 @@ LogicalResult vector_extract_rule(RewriteContext &ctx, Operation &op, op.erase(); return success(); } else { - for (int64_t i : extract_op.getStaticPosition()) { - if (i != 0) { - return op.emitOpError( - "Not implemented: Only 0 indices supported for scalar results"); - } - } + // TODO(b/367459476): Support non-zero offsets. if (layout_in.offsets() != LayoutOffsets{0, 0}) { return op.emitOpError("Not implemented: Unsupported layout"); } + auto [sub_tile, lane_tile] = layout_in.tiling(); FAILUREOR_ASSIGN_OR_RETURN( const xla::Array vregs, disassemble(builder, layout_in, extract_op.getVector(), ctx.target_shape)); TPU_ASSERT_GT_OP(vregs.num_elements(), 0); + + SmallVector indices(extract_op.getStaticPosition()); + auto vreg_slice = layout_in.vregSlice(ctx.target_shape); + std::array position = {0, 0}; + SmallVector vreg_index(indices); + // TODO(b/367459476): Support non-VREG-aligned tiling. + CHECK_EQ(lane_tile, ctx.target_shape[1]); + layout_in.insertImplicit(indices, static_cast(0)); + layout_in.insertImplicit(vreg_index, static_cast(0)); + int i = *(indices.end()-2); + int j = *(indices.end()-1); + *(vreg_index.end() -2) = i / vreg_slice[0]; + *(vreg_index.end() -1) = j / vreg_slice[1]; + layout_in.eraseImplicit(vreg_index); + position[0] = ((j % vreg_slice[1]) / lane_tile * sub_tile + ) + i % sub_tile; + position[1] = j % lane_tile; + + TPU_ASSERT_LT_OP(vreg_index, vregs.dimensions()); + Value extracted_vreg = vregs(vreg_index); + + // Invert the offsets to get the rotation amount. + position[0] = (ctx.target_shape[0] - position[0]) % ctx.target_shape[0]; + position[1] = (ctx.target_shape[1] - position[1]) % ctx.target_shape[1]; + auto res_vreg_ty = extracted_vreg.getType(); + Value shift = builder.create( + builder.getIntegerAttr(builder.getI32Type(), position[0])); + Value rotated_vreg = builder.create( + res_vreg_ty, extracted_vreg, shift, 0, /*stride*/nullptr, nullptr); + shift = builder.create( + builder.getIntegerAttr(builder.getI32Type(), position[1])); + rotated_vreg = builder.create( + res_vreg_ty, rotated_vreg, shift, 1, /*stride*/nullptr, nullptr); extract_op.replaceAllUsesWith( - builder - .create(op.getLoc(), *vregs.data(), - ArrayRef{0, 0}) + builder.create( + op.getLoc(), rotated_vreg, + ArrayRef{0, 0}) .getResult()); } extract_op.erase(); diff --git a/tests/pallas/pallas_error_handling_test.py b/tests/pallas/pallas_error_handling_test.py index 06b4bd3e3a4f..34b0ff1492a4 100644 --- a/tests/pallas/pallas_error_handling_test.py +++ b/tests/pallas/pallas_error_handling_test.py @@ -44,22 +44,22 @@ def setUp(self): if not jtu.test_device_matches(["tpu"]): self.skipTest("Test only works on TPU.") - def test_vector_extract_nonzero(self): - input_arr = jax.random.uniform(jax.random.key(0), (2, 2), dtype=jnp.float32) - out_shape = jax.ShapeDtypeStruct((1, 1), jnp.float32) + def test_non_singular_stride(self): + input_arr = jax.random.uniform( + jax.random.key(0), (8, 128), dtype=jnp.float32) + out_shape = jax.ShapeDtypeStruct((8, 16), jnp.float32) grid_spec = pltpu.PrefetchScalarGridSpec( num_scalar_prefetch=0, in_specs=[ pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.VMEM), ], - out_specs=pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.SMEM), + out_specs=pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.VMEM), ) @functools.partial(pl.pallas_call, out_shape=out_shape, grid_spec=grid_spec) def test_kernel(input_ref, output_ref): - val = input_ref[...] - x = val[0, 0] + val[0, 1] - output_ref[0, 0] = x + x = input_ref[:, ::8] + output_ref[...] = x # Test that a Mosaic error is raised. This assert is a guard against # underlying changes in Mosaic. @@ -67,7 +67,7 @@ def test_kernel(input_ref, output_ref): # the test example to force a different error. with self.assertRaisesRegex( error_handling.MosaicError, - "Not implemented: Only 0 indices supported for scalar results", + "Not Implemented: Stride on last dim is not 1", ): test_kernel(input_arr) @@ -78,7 +78,7 @@ def test_kernel(input_ref, output_ref): except error_handling.MosaicError as e: tb_string = traceback.format_tb(e.__traceback__) tb_string = "".join(tb_string) - self.assertEndsWith(tb_string, "x = val[0, 0] + val[0, 1]\n") + self.assertEndsWith(tb_string, "x = input_ref[:, ::8]\n") @jax.jit def kernel_in_jitted_fn(x): @@ -91,7 +91,7 @@ def kernel_in_jitted_fn(x): except error_handling.MosaicError as e: tb_string = traceback.format_tb(e.__traceback__) tb_string = "".join(tb_string) - self.assertEndsWith(tb_string, "x = val[0, 0] + val[0, 1]\n") + self.assertEndsWith(tb_string, "x = input_ref[:, ::8]\n") def test_invalid_smem_vmem_verification_error(self): input_arr = jax.random.uniform(jax.random.key(0), (2, 2), dtype=jnp.float32)