Skip to content

Commit

Permalink
Rewrite mosaic concat to support operand shapes that do not align wit…
Browse files Browse the repository at this point in the history
…h native shapes, Expand tests to cover multi operand, batch dim concat, etc.

PiperOrigin-RevId: 675835891
  • Loading branch information
Google-ML-Automation committed Sep 26, 2024
1 parent 6f7ad64 commit f67281e
Show file tree
Hide file tree
Showing 5 changed files with 197 additions and 54 deletions.
1 change: 1 addition & 0 deletions jaxlib/mosaic/dialect/tpu/tpu.td
Original file line number Diff line number Diff line change
Expand Up @@ -401,6 +401,7 @@ def TPU_ConcatenateOp : TPU_Op<"concatenate", [Pure]> {
let assemblyFormat = [{
$sources `in` $dimension attr-dict `:` type($sources) `->` type($output)
}];
let hasVerifier = 1;
}

def TPU_BitcastOp : TPU_Op<"bitcast", [Pure]> {
Expand Down
19 changes: 19 additions & 0 deletions jaxlib/mosaic/dialect/tpu/tpu_ops.cc
Original file line number Diff line number Diff line change
Expand Up @@ -659,6 +659,25 @@ LogicalResult ShuffledStoreOp::canonicalize(ShuffledStoreOp op,
}
return success();
}

LogicalResult ConcatenateOp::verify() {
auto dimension = getDimension();
auto first_shape = getOperand(0).getType().cast<VectorType>().getShape();
for (int i = 0; i < getNumOperands(); ++i) {
auto operand = getOperand(i);
auto vty = cast<VectorType>(operand.getType());
auto shape = vty.getShape();
for (int dim = 0; dim < shape.size(); ++dim) {
if (dim != dimension && shape[dim] != first_shape[dim]) {
return emitOpError(
"Not implemented: Expected all operands to have "
"the same shape outside of the concat dim");
}
}
}
return success();
}

} // namespace tpu
} // namespace mlir

Expand Down
163 changes: 126 additions & 37 deletions jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout.cc
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@
#include "absl/log/log.h"
#include "absl/status/status.h"
#include "absl/types/span.h"
#include "third_party/llvm/llvm-project/llvm/include/llvm/ADT/ArrayRef.h"
#include "mlir/include/mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/include/mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/include/mlir/Dialect/Vector/IR/VectorOps.h"
Expand Down Expand Up @@ -2509,54 +2510,142 @@ LogicalResult tpu_concatenate_rule(RewriteContext &ctx, Operation &op,
TPU_ASSERT_OP(
llvm::all_of(layouts_in, [](const Layout &l) { return l.has_value(); }));
TPU_ASSERT_OP(layouts_out.front().has_value());
const VectorLayout &layout = *layouts_out.front();
for (const Layout &l : layouts_in) {
if (l != layout) {
return op.emitOpError("Not implemented: Inconsistent layouts");
}
}
OpBuilder builder(&op);
auto concatenate_op = cast<tpu::ConcatenateOp>(op);
const VectorType res_ty = concatenate_op.getResult().getType();
const uint32_t dimension = concatenate_op.getDimension();
if (dimension - res_ty.getRank() >= -2) {
if (!layout.hasNativeTiling(ctx.target_shape) ||
layout.offsets() != LayoutOffsets{0, 0}) {
return op.emitOpError(
"Not implemented: Only native tiling with offset (0, 0) is supported "
"when concatenation along tiling dims.");
}
// Check if the concat dim size of src and res is aligned to native tiling.
auto check_aligned = [&](const VectorType &vty) {
auto i = dimension - res_ty.getRank();
return vty.getRank() >= 2 &&
*(vty.getShape().end() + i) % *(layout.tiling().end() + i) == 0;
};
bool is_aligned = check_aligned(res_ty);
int op_idx = 0;
while (is_aligned && op_idx < op.getNumOperands()) {
auto vty = dyn_cast<VectorType>(op.getOperand(op_idx++).getType());
is_aligned = check_aligned(vty);
SmallVector<xla::Array<Value>> vregs;
vregs.reserve(op.getNumOperands());

std::optional<int64_t> tiling_dim;
if (dimension >= res_ty.getRank() - 2) {
tiling_dim = dimension - (res_ty.getRank() - 2);
}

// Op level invariants on layouts, other op level invariants are checked in
// the verifier.
auto res_layout = layouts_out.front();
if (!res_layout.has_value()) {
return op.emitOpError("Not implemented: expected result layout");
}
if (tiling_dim.has_value() &&
!res_layout->hasNativeTiling(ctx.target_shape)) {
return op.emitOpError("Not implemented: result non-native-tiling.");
}
if (tiling_dim.has_value() && res_layout->offsets() != LayoutOffsets{0, 0}) {
return op.emitOpError("Not implemented: result non-zero offset.");
}
if (tiling_dim.has_value() &&
!res_layout->hasNativeTiling(ctx.target_shape)) {
return op.emitOpError("Not implemented: Non native tiling in concat.");
}
if (tiling_dim.has_value() &&
res_layout->implicit_dim() != VectorLayout::ImplicitDim::kNone) {
return op.emitOpError("Not implemented: implicit dim");
}

int64_t offset_at_dim = 0;
for (int i = 0; i < op.getNumOperands(); ++i) {
auto operand = op.getOperand(i);
if (!layouts_in[i].has_value()) {
return op.emitOpError("Not implemented: Expected input layout");
}
if (!is_aligned) {
return op.emitOpError(
"Not implemented: Only aligned shapes are supported when "
"concatenation along tiling dims");
auto const &layout = *layouts_in[i];
auto vty = cast<VectorType>(operand.getType());
auto shape = vty.getShape();

if (tiling_dim.has_value()) {
auto starting_point = offset_at_dim;
auto offset_amount = starting_point % layout.tiling()[tiling_dim.value()];
if (offset_amount != layout.offsets()[tiling_dim.value()]) {
return op.emitOpError(
"Not implemented: Relayout not called, unaligned dims concatenated "
"without proper offsets. Ensure that infer_vector_layout pass was "
"called.");
}
}
}
offset_at_dim += shape[dimension];

SmallVector<xla::Array<Value>> tiles;
tiles.reserve(concatenate_op->getNumOperands());
for (Value operand : concatenate_op.getOperands()) {
FAILUREOR_ASSIGN_OR_RETURN(
xla::Array<Value> t,
xla::Array<Value> vreg_array,
disassemble(builder, layout, cast<TypedValue<VectorType>>(operand),
ctx.target_shape));
tiles.emplace_back(std::move(t));
vregs.push_back(std::move(vreg_array));
}

CHECK_EQ(vregs.size(), op.getNumOperands());
const SmallVector<int64_t> vreg_array_shape =
res_layout->tileArrayShape(res_ty.getShape(), ctx.target_shape);
// Fill out out_vregs with 0s, to avoid a problem with where we have to
// blend with a vreg that has not been written to yet.
xla::Array<Value> out_vregs(vreg_array_shape, nullptr);

auto boundIdxConst =
std::bind(IdxConst, std::placeholders::_1, builder, op.getLoc());

// Handle the untiled concatenation case.
if (!tiling_dim.has_value()) {
int64_t offset = 0;
for (const xla::Array<Value> &arr : vregs) {
arr.Each([&](const absl::Span<const int64_t> idx, const Value v) {
SmallVector<int64_t> res_idx(toArrayRef(idx));
res_idx[dimension] += offset;
out_vregs(res_idx) = v;
});
offset += arr.dim(dimension);
}
} else {
// Tiled concatenation logic.
int64_t offset = 0;
for (size_t i = 0; i < vregs.size(); ++i) {
auto &vreg = vregs[i];
const auto &layout = layouts_in[i];
const int64_t operand_offset = *layout->offsets()[tiling_dim.value()];
if (operand_offset != 0) {
// We are offset, so we must blend with the previous vreg.
// Or, to frame it in an another way, the prior vreg
// stored its entire dim size in the offset, but only wrote the
// last dime partially.
offset -= 1;
}

const auto bitwidth = res_ty.getElementTypeBitWidth();
const int packing = res_layout->packing();

SmallVector<int64_t> out_idx;
vreg.Each([&](absl::Span<const int64_t> idx, Value *v) {
out_idx.assign(idx.begin(), idx.end());
out_idx[dimension] += offset;
if (idx[dimension] == 0 && operand_offset != 0) {
Value mask;
if (tiling_dim.value() == 0) { // sublane
const VectorType vmask_ty = getNativeVregOrVmaskType(
builder.getI1Type(), bitwidth, ctx.target_shape);
mask = builder.create<tpu::CreateMaskOp>(
op.getLoc(), vmask_ty,
ArrayRef<Value>{boundIdxConst(0), boundIdxConst(0)},
ArrayRef<Value>{boundIdxConst(operand_offset * packing),
boundIdxConst(layout->tiling()[1])});
} else { // lane
mask = builder.create<tpu::CreateMaskOp>(
op.getLoc(),
VectorType::get(ctx.target_shape, builder.getI1Type()),
ArrayRef<Value>{boundIdxConst(0), boundIdxConst(0)},
ArrayRef<Value>{boundIdxConst(layout->tiling()[0]),
boundIdxConst(operand_offset * packing)});
}
// Blend the current value with the existing value in the output.
*v = builder.create<arith::SelectOp>(op.getLoc(), mask,
out_vregs(out_idx), *v);
}
out_vregs(out_idx) = *v;
});
offset += vreg.dim(dimension);
}
}
const xla::Array<Value> res_tiles = concatenate(tiles, dimension);
op.replaceAllUsesWith(
assemble(builder, res_ty, layout, res_tiles, ctx.target_shape));
auto assembled =
assemble(builder, res_ty, *res_layout, out_vregs, ctx.target_shape);
op.replaceAllUsesWith(assembled);
op.erase();
return success();
}
Expand Down
66 changes: 51 additions & 15 deletions jaxlib/mosaic/dialect/tpu/transforms/infer_vector_layout.cc
Original file line number Diff line number Diff line change
Expand Up @@ -774,21 +774,57 @@ class VectorLayoutInferer {
}
auto res_ty = op.getResult().getType();
int8_t bitwidth = res_ty.getElementTypeBitWidth();
auto layout = getLayout(op.getSources().front());
// When concatenating vectors with replicated offsets, we want to reset the
// replicated offset to zero. Because we are not sure if the replicated
// value from each vector are same.
layout = VectorLayout(
layout->bitwidth(),
{layout->offsets()[0].value_or(0), layout->offsets()[1].value_or(0)},
layout->tiling(), layout->implicit_dim());
if (dimension >= res_rank - 2) {
layout = VectorLayout(bitwidth, {0, 0}, nativeTiling(bitwidth),
ImplicitDim::kNone);
}
SmallVector<Layout> in_layouts(op->getNumOperands(), layout);
setLayout(op, in_layouts, layout);
return success();

std::optional<int64_t> tiling_dim;
if (dimension == res_ty.getRank() - 1) {
tiling_dim = 1;
} else if (dimension == res_ty.getRank() - 2) {
tiling_dim = 0;
}

if (tiling_dim.has_value()) {
int64_t dim_offset = 0;

auto op_layouts = getLayoutFromOperands(op);
SmallVector<Layout> in_layouts;
in_layouts.reserve(op.getSources().size());
for (int i = 0; i < op.getSources().size(); ++i) {
// Compute the offset per source.
// Ex: for a cat of (10, 128), (10, 128) on dim 0, where the
// vreg_sice for that dim is 8, the first source starts at
// offset 0, and overflows the vreg
// by 2, so the offset for the second input is 2.
auto op_shape =
cast<VectorType>(op.getSources()[i].getType()).getShape();
auto starting_point = dim_offset;
auto offset_amount =
starting_point % nativeTiling(bitwidth)[tiling_dim.value()];
auto op_layout = op_layouts[i];
SmallVector<int64_t> in_idx{op_layout->offsets()[0].value_or(0),
op_layout->offsets()[1].value_or(0)};
in_idx[tiling_dim.value()] = offset_amount;
dim_offset += op_shape[dimension];
in_layouts.push_back(VectorLayout(bitwidth, {in_idx[0], in_idx[1]},
nativeTiling(bitwidth),
ImplicitDim::kNone));
}
auto res_layout = VectorLayout(bitwidth, {0, 0}, nativeTiling(bitwidth),
ImplicitDim::kNone);
setLayout(op, in_layouts, res_layout);
return success();
} else {
auto layout = getLayout(op.getSources().front());
// When concatenating vectors with replicated offsets, we want to reset
// the replicated offset to zero. Because we are not sure if the
// replicated value from each vector are same.
layout = VectorLayout(
layout->bitwidth(),
{layout->offsets()[0].value_or(0), layout->offsets()[1].value_or(0)},
layout->tiling(), layout->implicit_dim());
SmallVector<Layout> in_layouts(op->getNumOperands(), layout);
setLayout(op, in_layouts, layout);
return success();
}
}

LogicalResult infer(tpu::LoadOp op) {
Expand Down
2 changes: 0 additions & 2 deletions tests/pallas/tpu_pallas_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -2279,9 +2279,7 @@ def wrapper(self):
class MiscellaneousTest(PallasBaseTest):
"""Tests for reported bugs. Only pass in interpret mode unless fixed."""

@only_passes_in_interpret()
def test_float32_stack(self):
"""b/347761105"""
x = np.arange(128, dtype=jnp.float32).reshape(1, 128)
y = x + 128

Expand Down

0 comments on commit f67281e

Please sign in to comment.