Skip to content

Commit

Permalink
[Mosaic][TPU] Omit short circuiting of relayout (we should always rel…
Browse files Browse the repository at this point in the history
…ayout!) and implement product mismatch case for where we relayout from replicated to offset, and the number of vregs changes.

PiperOrigin-RevId: 696226548
  • Loading branch information
Google-ML-Automation committed Nov 14, 2024
1 parent 8370082 commit 658538f
Showing 1 changed file with 49 additions and 10 deletions.
59 changes: 49 additions & 10 deletions jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout.cc
Original file line number Diff line number Diff line change
Expand Up @@ -4723,6 +4723,11 @@ FailureOr<xla::Array<Value>> disassemble(
TPU_ASSERT_LOC(val.getLoc(), def_layout.has_value());
TPU_ASSERT_LOC(val.getLoc(),
def_layout->generalizes(layout, vty.getShape(), target_shape));
auto layout_product =
xla::Product(layout.tileArrayShape(vty.getShape(), target_shape));
auto def_layout_product =
xla::Product(def_layout->tileArrayShape(vty.getShape(), target_shape));
TPU_ASSERT_LOC(val.getLoc(), layout_product == def_layout_product);
// TODO(tlongeri): Maybe just add a parameter to tileArrayShape instead of
// having `tileArrayShape` and `tileArrayImplicitShape`.
SmallVector<int64_t> layout_shape =
Expand Down Expand Up @@ -6324,11 +6329,50 @@ FailureOr<TypedValue<VectorType>> relayout(RewriteContext &ctx,
if (src.generalizes(dst, vty.getShape(), target_shape)) {
// A value with a replicated offset might use fewer vregs than a value with
// a non-zero offset.
if (xla::Product(src.tileArrayShape(vty.getShape(), target_shape)) !=
xla::Product(dst.tileArrayShape(vty.getShape(), target_shape))) {
return emitError(v.getLoc(),
"Not implemented: source layout is more general, but "
"vreg count changes");
auto src_product =
xla::Product(src.tileArrayShape(vty.getShape(), target_shape));
auto dst_product =
xla::Product(dst.tileArrayShape(vty.getShape(), target_shape));
if (src_product != dst_product) {
TPU_ASSERT_LOC(v.getLoc(), dst_product > src_product);
auto src_offsets = src.offsets();

TPU_ASSERT_LOC(v.getLoc(), src_offsets != dst.offsets());
TPU_ASSERT_LOC(v.getLoc(), src.bitwidth() == dst.bitwidth());

if (src.implicit_dim() != dst.implicit_dim()) {
return emitError(v.getLoc(),
"Not implemented: Source layout is more general, but "
"vreg count changes and implicit dims are mismatched");
}

if (src.tiling() != dst.tiling()) {
return emitError(v.getLoc(),
"Not implemented: Source layout is more general, but "
"vreg count changes and tiling are mismatched");
}

// This case is moving from a replicated to a non replicated layout.
// As such, we need to make a new destination shape that is the
// materialization of the src shape with replication.
FAILUREOR_ASSIGN_OR_RETURN(auto src_vregs,
disassemble(builder, src, v, target_shape,
/*use_implicit_shape=*/true));
auto dst_vregs_shape = dst.tileArrayShape(vty.getShape(), target_shape);
xla::Array<Value> dst_vregs(dst_vregs_shape);
dst_vregs.Each([&](const absl::Span<const int64_t> idx, Value *vreg) {
SmallVector<int64_t> local_idx(idx.begin(), idx.end());
if (!src_offsets[0].has_value()) {
local_idx[local_idx.size() - 2] = 0;
}
if (!src_offsets[1].has_value()) {
local_idx[local_idx.size() - 1] = 0;
}
*vreg = src_vregs(local_idx);
});
return assemble(builder, vty, dst, std::move(dst_vregs), target_shape,
/*use_implicit_shape=*/true)
.getResult();
}
src_tiles.Reshape(dst.tileArrayImplicitShape(vty.getShape(), target_shape));
return assemble(builder, vty, dst, std::move(src_tiles), target_shape,
Expand Down Expand Up @@ -6411,8 +6455,6 @@ LogicalResult applyLayoutOp(RewriteContext &ctx, Operation &op) {
if (vector_operand == nullptr) {
continue;
}
auto vty = vector_operand.getType();

// The operand should always be an Operation (and not a BlockArgument)
// since we expect the FuncOp to have only memrefs and semaphores as
// arguments.
Expand All @@ -6427,9 +6469,6 @@ LogicalResult applyLayoutOp(RewriteContext &ctx, Operation &op) {
getOutLayouts(*def_op, ctx.target_shape));
const Layout lo = def_layouts[res_idx];
TPU_ASSERT_OP(lo.has_value());
if (lo->generalizes(*li, vty.getShape(), ctx.target_shape)) {
continue;
}
OpBuilder builder(&op);
FAILUREOR_ASSIGN_OR_RETURN(
Value new_v, relayout(ctx, builder, vector_operand, /*src=*/*lo,
Expand Down

0 comments on commit 658538f

Please sign in to comment.