From 46f7109594b5a005809b3851df148848450bc418 Mon Sep 17 00:00:00 2001 From: jax authors Date: Tue, 17 Sep 2024 22:10:12 -0700 Subject: [PATCH] Rewrite mosaic concat to support operand shapes that do not align with native shapes, Expand tests to cover multi operand, batch dim concat, etc. PiperOrigin-RevId: 675835891 --- jaxlib/mosaic/dialect/tpu/tpu.td | 1 + jaxlib/mosaic/dialect/tpu/tpu_ops.cc | 19 ++ .../tpu/transforms/apply_vector_layout.cc | 170 ++++++++++++++---- .../tpu/transforms/infer_vector_layout.cc | 66 +++++-- tests/pallas/tpu_pallas_test.py | 2 - 5 files changed, 204 insertions(+), 54 deletions(-) diff --git a/jaxlib/mosaic/dialect/tpu/tpu.td b/jaxlib/mosaic/dialect/tpu/tpu.td index ffcc8d52cd05..c3f8598de253 100644 --- a/jaxlib/mosaic/dialect/tpu/tpu.td +++ b/jaxlib/mosaic/dialect/tpu/tpu.td @@ -401,6 +401,7 @@ def TPU_ConcatenateOp : TPU_Op<"concatenate", [Pure]> { let assemblyFormat = [{ $sources `in` $dimension attr-dict `:` type($sources) `->` type($output) }]; + let hasVerifier = 1; } def TPU_BitcastOp : TPU_Op<"bitcast", [Pure]> { diff --git a/jaxlib/mosaic/dialect/tpu/tpu_ops.cc b/jaxlib/mosaic/dialect/tpu/tpu_ops.cc index d80db4e1394e..233f8e99c1ce 100644 --- a/jaxlib/mosaic/dialect/tpu/tpu_ops.cc +++ b/jaxlib/mosaic/dialect/tpu/tpu_ops.cc @@ -659,6 +659,25 @@ LogicalResult ShuffledStoreOp::canonicalize(ShuffledStoreOp op, } return success(); } + +LogicalResult ConcatenateOp::verify() { + auto dimension = getDimension(); + auto first_shape = getOperand(0).getType().cast().getShape(); + for (int i = 0; i < getNumOperands(); ++i) { + auto operand = getOperand(i); + auto vty = cast(operand.getType()); + auto shape = vty.getShape(); + for (int dim = 0; dim < shape.size(); ++dim) { + if (dim != dimension && shape[dim] != first_shape[dim]) { + return emitOpError( + "Not implemented: Expected all operands to have " + "the same shape outside of the concat dim"); + } + } + } + return success(); +} + } // namespace tpu } // namespace mlir diff --git a/jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout.cc b/jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout.cc index a1714fc8090b..b905c7c3a025 100644 --- a/jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout.cc +++ b/jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout.cc @@ -52,6 +52,7 @@ #include "absl/log/log.h" #include "absl/status/status.h" #include "absl/types/span.h" +#include "third_party/llvm/llvm-project/llvm/include/llvm/ADT/ArrayRef.h" #include "mlir/include/mlir/Dialect/Arith/IR/Arith.h" #include "mlir/include/mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/include/mlir/Dialect/Vector/IR/VectorOps.h" @@ -2509,54 +2510,149 @@ LogicalResult tpu_concatenate_rule(RewriteContext &ctx, Operation &op, TPU_ASSERT_OP( llvm::all_of(layouts_in, [](const Layout &l) { return l.has_value(); })); TPU_ASSERT_OP(layouts_out.front().has_value()); - const VectorLayout &layout = *layouts_out.front(); - for (const Layout &l : layouts_in) { - if (l != layout) { - return op.emitOpError("Not implemented: Inconsistent layouts"); - } - } OpBuilder builder(&op); auto concatenate_op = cast(op); const VectorType res_ty = concatenate_op.getResult().getType(); const uint32_t dimension = concatenate_op.getDimension(); - if (dimension - res_ty.getRank() >= -2) { - if (!layout.hasNativeTiling(ctx.target_shape) || - layout.offsets() != LayoutOffsets{0, 0}) { - return op.emitOpError( - "Not implemented: Only native tiling with offset (0, 0) is supported " - "when concatenation along tiling dims."); - } - // Check if the concat dim size of src and res is aligned to native tiling. - auto check_aligned = [&](const VectorType &vty) { - auto i = dimension - res_ty.getRank(); - return vty.getRank() >= 2 && - *(vty.getShape().end() + i) % *(layout.tiling().end() + i) == 0; - }; - bool is_aligned = check_aligned(res_ty); - int op_idx = 0; - while (is_aligned && op_idx < op.getNumOperands()) { - auto vty = dyn_cast(op.getOperand(op_idx++).getType()); - is_aligned = check_aligned(vty); + SmallVector> vregs; + vregs.reserve(op.getNumOperands()); + + std::optional tiling_dim; + auto res_layout = layouts_out.front(); + + if (!res_layout.has_value()) { + return op.emitOpError("Not implemented: expected result layout"); + } + + auto has_implicit_dim = + (res_layout->implicit_dim() != VectorLayout::ImplicitDim::kNone); + auto minor_most_dims = + has_implicit_dim ? res_ty.getRank() - 1 : res_ty.getRank() - 2; + + if (dimension >= minor_most_dims) { + tiling_dim = dimension - minor_most_dims; + } + + // Op level invariants on layouts, other op level invariants are checked in + // the verifier. + if (tiling_dim.has_value() && + !res_layout->hasNativeTiling(ctx.target_shape)) { + return op.emitOpError("Not implemented: result non-native-tiling."); + } + if (tiling_dim.has_value() && res_layout->offsets() != LayoutOffsets{0, 0}) { + return op.emitOpError("Not implemented: result non-zero offset."); + } + if (tiling_dim.has_value() && + !res_layout->hasNativeTiling(ctx.target_shape)) { + return op.emitOpError("Not implemented: Non native tiling in concat."); + } + if (tiling_dim.has_value() && + res_layout->implicit_dim() != VectorLayout::ImplicitDim::kNone) { + return op.emitOpError("Not implemented: implicit dim"); + } + + int64_t offset_at_dim = 0; + for (int i = 0; i < op.getNumOperands(); ++i) { + auto operand = op.getOperand(i); + if (!layouts_in[i].has_value()) { + return op.emitOpError("Not implemented: Expected input layout"); } - if (!is_aligned) { - return op.emitOpError( - "Not implemented: Only aligned shapes are supported when " - "concatenation along tiling dims"); + auto const &layout = *layouts_in[i]; + auto vty = cast(operand.getType()); + auto shape = vty.getShape(); + + if (tiling_dim.has_value()) { + auto starting_point = offset_at_dim; + auto offset_amount = starting_point % layout.tiling()[tiling_dim.value()]; + if (offset_amount != layout.offsets()[tiling_dim.value()]) { + return op.emitOpError( + "Not implemented: Relayout not called, unaligned dims concatenated " + "without proper offsets. Ensure that infer_vector_layout pass was " + "called."); + } } - } + offset_at_dim += shape[dimension]; - SmallVector> tiles; - tiles.reserve(concatenate_op->getNumOperands()); - for (Value operand : concatenate_op.getOperands()) { FAILUREOR_ASSIGN_OR_RETURN( - xla::Array t, + xla::Array vreg_array, disassemble(builder, layout, cast>(operand), ctx.target_shape)); - tiles.emplace_back(std::move(t)); + vregs.push_back(std::move(vreg_array)); + } + + CHECK_EQ(vregs.size(), op.getNumOperands()); + const SmallVector vreg_array_shape = + res_layout->tileArrayShape(res_ty.getShape(), ctx.target_shape); + // Fill out out_vregs with 0s, to avoid a problem with where we have to + // blend with a vreg that has not been written to yet. + xla::Array out_vregs(vreg_array_shape, nullptr); + + auto boundIdxConst = + std::bind(IdxConst, std::placeholders::_1, builder, op.getLoc()); + + // Handle the untiled concatenation case. + if (!tiling_dim.has_value()) { + int64_t offset = 0; + for (const xla::Array &arr : vregs) { + arr.Each([&](const absl::Span idx, const Value v) { + SmallVector res_idx(toArrayRef(idx)); + res_idx[dimension] += offset; + out_vregs(res_idx) = v; + }); + offset += arr.dim(dimension); + } + } else { + // Tiled concatenation logic. + int64_t offset = 0; + for (size_t i = 0; i < vregs.size(); ++i) { + auto &vreg = vregs[i]; + const auto &layout = layouts_in[i]; + const int64_t operand_offset = *layout->offsets()[tiling_dim.value()]; + if (operand_offset != 0) { + // We are offset, so we must blend with the previous vreg. + // Or, to frame it in an another way, the prior vreg + // stored its entire dim size in the offset, but only wrote the + // last dime partially. + offset -= 1; + } + + const auto bitwidth = res_ty.getElementTypeBitWidth(); + const int packing = res_layout->packing(); + + SmallVector out_idx; + vreg.Each([&](absl::Span idx, Value *v) { + out_idx.assign(idx.begin(), idx.end()); + out_idx[dimension] += offset; + if (idx[dimension] == 0 && operand_offset != 0) { + Value mask; + if (tiling_dim.value() == 0) { // sublane + const VectorType vmask_ty = getNativeVregOrVmaskType( + builder.getI1Type(), bitwidth, ctx.target_shape); + mask = builder.create( + op.getLoc(), vmask_ty, + ArrayRef{boundIdxConst(0), boundIdxConst(0)}, + ArrayRef{boundIdxConst(operand_offset * packing), + boundIdxConst(layout->tiling()[1])}); + } else { // lane + mask = builder.create( + op.getLoc(), + VectorType::get(ctx.target_shape, builder.getI1Type()), + ArrayRef{boundIdxConst(0), boundIdxConst(0)}, + ArrayRef{boundIdxConst(layout->tiling()[0]), + boundIdxConst(operand_offset * packing)}); + } + // Blend the current value with the existing value in the output. + *v = builder.create(op.getLoc(), mask, + out_vregs(out_idx), *v); + } + out_vregs(out_idx) = *v; + }); + offset += vreg.dim(dimension); + } } - const xla::Array res_tiles = concatenate(tiles, dimension); - op.replaceAllUsesWith( - assemble(builder, res_ty, layout, res_tiles, ctx.target_shape)); + auto assembled = + assemble(builder, res_ty, *res_layout, out_vregs, ctx.target_shape); + op.replaceAllUsesWith(assembled); op.erase(); return success(); } diff --git a/jaxlib/mosaic/dialect/tpu/transforms/infer_vector_layout.cc b/jaxlib/mosaic/dialect/tpu/transforms/infer_vector_layout.cc index 2894b0797e7b..e667c6ee6cf6 100644 --- a/jaxlib/mosaic/dialect/tpu/transforms/infer_vector_layout.cc +++ b/jaxlib/mosaic/dialect/tpu/transforms/infer_vector_layout.cc @@ -774,21 +774,57 @@ class VectorLayoutInferer { } auto res_ty = op.getResult().getType(); int8_t bitwidth = res_ty.getElementTypeBitWidth(); - auto layout = getLayout(op.getSources().front()); - // When concatenating vectors with replicated offsets, we want to reset the - // replicated offset to zero. Because we are not sure if the replicated - // value from each vector are same. - layout = VectorLayout( - layout->bitwidth(), - {layout->offsets()[0].value_or(0), layout->offsets()[1].value_or(0)}, - layout->tiling(), layout->implicit_dim()); - if (dimension >= res_rank - 2) { - layout = VectorLayout(bitwidth, {0, 0}, nativeTiling(bitwidth), - ImplicitDim::kNone); - } - SmallVector in_layouts(op->getNumOperands(), layout); - setLayout(op, in_layouts, layout); - return success(); + + std::optional tiling_dim; + if (dimension == res_ty.getRank() - 1) { + tiling_dim = 1; + } else if (dimension == res_ty.getRank() - 2) { + tiling_dim = 0; + } + + if (tiling_dim.has_value()) { + int64_t dim_offset = 0; + + auto op_layouts = getLayoutFromOperands(op); + SmallVector in_layouts; + in_layouts.reserve(op.getSources().size()); + for (int i = 0; i < op.getSources().size(); ++i) { + // Compute the offset per source. + // Ex: for a cat of (10, 128), (10, 128) on dim 0, where the + // vreg_sice for that dim is 8, the first source starts at + // offset 0, and overflows the vreg + // by 2, so the offset for the second input is 2. + auto op_shape = + cast(op.getSources()[i].getType()).getShape(); + auto starting_point = dim_offset; + auto offset_amount = + starting_point % nativeTiling(bitwidth)[tiling_dim.value()]; + auto op_layout = op_layouts[i]; + SmallVector in_idx{op_layout->offsets()[0].value_or(0), + op_layout->offsets()[1].value_or(0)}; + in_idx[tiling_dim.value()] = offset_amount; + dim_offset += op_shape[dimension]; + in_layouts.push_back(VectorLayout(bitwidth, {in_idx[0], in_idx[1]}, + nativeTiling(bitwidth), + ImplicitDim::kNone)); + } + auto res_layout = VectorLayout(bitwidth, {0, 0}, nativeTiling(bitwidth), + ImplicitDim::kNone); + setLayout(op, in_layouts, res_layout); + return success(); + } else { + auto layout = getLayout(op.getSources().front()); + // When concatenating vectors with replicated offsets, we want to reset + // the replicated offset to zero. Because we are not sure if the + // replicated value from each vector are same. + layout = VectorLayout( + layout->bitwidth(), + {layout->offsets()[0].value_or(0), layout->offsets()[1].value_or(0)}, + layout->tiling(), layout->implicit_dim()); + SmallVector in_layouts(op->getNumOperands(), layout); + setLayout(op, in_layouts, layout); + return success(); + } } LogicalResult infer(tpu::LoadOp op) { diff --git a/tests/pallas/tpu_pallas_test.py b/tests/pallas/tpu_pallas_test.py index 9a81f3196ba2..9bb6e16bffe9 100644 --- a/tests/pallas/tpu_pallas_test.py +++ b/tests/pallas/tpu_pallas_test.py @@ -2279,9 +2279,7 @@ def wrapper(self): class MiscellaneousTest(PallasBaseTest): """Tests for reported bugs. Only pass in interpret mode unless fixed.""" - @only_passes_in_interpret() def test_float32_stack(self): - """b/347761105""" x = np.arange(128, dtype=jnp.float32).reshape(1, 128) y = x + 128