From 658538f8d0fff65c2d04d9f171d9ef0a7ef6555e Mon Sep 17 00:00:00 2001 From: jax authors Date: Wed, 13 Nov 2024 12:14:38 -0800 Subject: [PATCH] [Mosaic][TPU] Omit short circuiting of relayout (we should always relayout!) and implement product mismatch case for where we relayout from replicated to offset, and the number of vregs changes. PiperOrigin-RevId: 696226548 --- .../tpu/transforms/apply_vector_layout.cc | 59 +++++++++++++++---- 1 file changed, 49 insertions(+), 10 deletions(-) diff --git a/jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout.cc b/jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout.cc index c9c4a81e668d..8792503f4636 100644 --- a/jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout.cc +++ b/jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout.cc @@ -4723,6 +4723,11 @@ FailureOr> disassemble( TPU_ASSERT_LOC(val.getLoc(), def_layout.has_value()); TPU_ASSERT_LOC(val.getLoc(), def_layout->generalizes(layout, vty.getShape(), target_shape)); + auto layout_product = + xla::Product(layout.tileArrayShape(vty.getShape(), target_shape)); + auto def_layout_product = + xla::Product(def_layout->tileArrayShape(vty.getShape(), target_shape)); + TPU_ASSERT_LOC(val.getLoc(), layout_product == def_layout_product); // TODO(tlongeri): Maybe just add a parameter to tileArrayShape instead of // having `tileArrayShape` and `tileArrayImplicitShape`. SmallVector layout_shape = @@ -6324,11 +6329,50 @@ FailureOr> relayout(RewriteContext &ctx, if (src.generalizes(dst, vty.getShape(), target_shape)) { // A value with a replicated offset might use fewer vregs than a value with // a non-zero offset. - if (xla::Product(src.tileArrayShape(vty.getShape(), target_shape)) != - xla::Product(dst.tileArrayShape(vty.getShape(), target_shape))) { - return emitError(v.getLoc(), - "Not implemented: source layout is more general, but " - "vreg count changes"); + auto src_product = + xla::Product(src.tileArrayShape(vty.getShape(), target_shape)); + auto dst_product = + xla::Product(dst.tileArrayShape(vty.getShape(), target_shape)); + if (src_product != dst_product) { + TPU_ASSERT_LOC(v.getLoc(), dst_product > src_product); + auto src_offsets = src.offsets(); + + TPU_ASSERT_LOC(v.getLoc(), src_offsets != dst.offsets()); + TPU_ASSERT_LOC(v.getLoc(), src.bitwidth() == dst.bitwidth()); + + if (src.implicit_dim() != dst.implicit_dim()) { + return emitError(v.getLoc(), + "Not implemented: Source layout is more general, but " + "vreg count changes and implicit dims are mismatched"); + } + + if (src.tiling() != dst.tiling()) { + return emitError(v.getLoc(), + "Not implemented: Source layout is more general, but " + "vreg count changes and tiling are mismatched"); + } + + // This case is moving from a replicated to a non replicated layout. + // As such, we need to make a new destination shape that is the + // materialization of the src shape with replication. + FAILUREOR_ASSIGN_OR_RETURN(auto src_vregs, + disassemble(builder, src, v, target_shape, + /*use_implicit_shape=*/true)); + auto dst_vregs_shape = dst.tileArrayShape(vty.getShape(), target_shape); + xla::Array dst_vregs(dst_vregs_shape); + dst_vregs.Each([&](const absl::Span idx, Value *vreg) { + SmallVector local_idx(idx.begin(), idx.end()); + if (!src_offsets[0].has_value()) { + local_idx[local_idx.size() - 2] = 0; + } + if (!src_offsets[1].has_value()) { + local_idx[local_idx.size() - 1] = 0; + } + *vreg = src_vregs(local_idx); + }); + return assemble(builder, vty, dst, std::move(dst_vregs), target_shape, + /*use_implicit_shape=*/true) + .getResult(); } src_tiles.Reshape(dst.tileArrayImplicitShape(vty.getShape(), target_shape)); return assemble(builder, vty, dst, std::move(src_tiles), target_shape, @@ -6411,8 +6455,6 @@ LogicalResult applyLayoutOp(RewriteContext &ctx, Operation &op) { if (vector_operand == nullptr) { continue; } - auto vty = vector_operand.getType(); - // The operand should always be an Operation (and not a BlockArgument) // since we expect the FuncOp to have only memrefs and semaphores as // arguments. @@ -6427,9 +6469,6 @@ LogicalResult applyLayoutOp(RewriteContext &ctx, Operation &op) { getOutLayouts(*def_op, ctx.target_shape)); const Layout lo = def_layouts[res_idx]; TPU_ASSERT_OP(lo.has_value()); - if (lo->generalizes(*li, vty.getShape(), ctx.target_shape)) { - continue; - } OpBuilder builder(&op); FAILUREOR_ASSIGN_OR_RETURN( Value new_v, relayout(ctx, builder, vector_operand, /*src=*/*lo,