diff --git a/jaxlib/mosaic/dialect/tpu/tpu.td b/jaxlib/mosaic/dialect/tpu/tpu.td index b1a9ac910998..c1fba60f4cc5 100644 --- a/jaxlib/mosaic/dialect/tpu/tpu.td +++ b/jaxlib/mosaic/dialect/tpu/tpu.td @@ -526,6 +526,16 @@ def TPU_MemRefReshapeOp : TPU_Op<"memref_reshape", [Pure]> { let hasCanonicalizeMethod = 1; } +def TPU_MemRefBitcastOp : TPU_Op<"memref_bitcast", [Pure]> { + let arguments = (ins AnyMemRef:$input); + let results = (outs AnyMemRef:$result); + let assemblyFormat = [{ + $input attr-dict `:` type($input) `->` type($result) + }]; + let hasVerifier = 1; + let hasCanonicalizeMethod = 1; +} + def TPU_ReinterpretCastOp : TPU_Op<"reinterpret_cast", [Pure]> { let arguments = (ins AnyMemRef:$input); let results = (outs AnyMemRef:$result); diff --git a/jaxlib/mosaic/dialect/tpu/tpu_ops.cc b/jaxlib/mosaic/dialect/tpu/tpu_ops.cc index 5baec61ad138..ff349160dc50 100644 --- a/jaxlib/mosaic/dialect/tpu/tpu_ops.cc +++ b/jaxlib/mosaic/dialect/tpu/tpu_ops.cc @@ -303,6 +303,93 @@ LogicalResult MemRefReshapeOp::canonicalize(MemRefReshapeOp op, return success(); } +LogicalResult MemRefBitcastOp::verify() { + auto src_ty = getMemRefType(getInput()); + auto tgt_ty = getType(); + if (tgt_ty.getMemorySpace() != nullptr && + tgt_ty.getMemorySpace() != src_ty.getMemorySpace()) { + return emitOpError("Memory spaces do not match."); + } + if (src_ty.getRank() != tgt_ty.getRank()) { + return emitOpError("Ranks do not match."); + } + if (src_ty.getRank() <= 1) { + return emitOpError("Not implemented: 1d memref bitcast."); + } + auto src_bitwidth = src_ty.getElementTypeBitWidth(); + auto tgt_bitwidth = tgt_ty.getElementTypeBitWidth(); + for (int i = 0; i < src_ty.getRank(); ++i) { + auto src_dim_size = src_ty.getDimSize(i); + auto tgt_dim_size = tgt_ty.getDimSize(i); + if (i == src_ty.getRank() - 2) { + src_dim_size *= src_bitwidth; + tgt_dim_size *= tgt_bitwidth; + } + if (src_dim_size != tgt_dim_size) { + return emitOpError( + "Expected the same dim size on the 2nd minormost dim: ") + << src_dim_size << " vs " << tgt_dim_size; + } + } + // Source and target attributes may be different before propagation is done by + // the canonicalizer, so we allow this when attributes are "unset" in the + // target type. + auto tgt_layout = dyn_cast(tgt_ty.getLayout()); + if (!tgt_layout) { + return success(); + } + auto src_layout = dyn_cast(src_ty.getLayout()); + if (!src_layout) { + return emitOpError("Expected a tiled layout for the input memref."); + } + // TODO(jevinjiang): verify memref tiling is valid. Here we just assume the + // source and target tilings are valid. + auto src_tile = src_layout.getTiles().front().dimensions(); + auto tgt_tile = tgt_layout.getTiles().front().dimensions(); + if (src_tile[0] * src_bitwidth != tgt_tile[0] * tgt_bitwidth) { + return emitOpError("Invalid memref bitcast."); + } + return success(); +} + +LogicalResult MemRefBitcastOp::canonicalize(MemRefBitcastOp op, + PatternRewriter &rewriter) { + auto src_ty = op.getInput().getType(); + auto dst_ty = op.getType(); + if (src_ty == dst_ty) { + rewriter.replaceOp(op, op.getInput()); + return success(); + } + auto erase_layout_op = op.getInput().getDefiningOp(); + if (!erase_layout_op) { + return failure(); + } + auto src_bitwidth = src_ty.getElementTypeBitWidth(); + auto tgt_bitwidth = dst_ty.getElementTypeBitWidth(); + auto layout_ref = erase_layout_op.getOperand(); + auto layout_ty = layout_ref.getType(); + auto layout = cast(layout_ty.getLayout()); + CHECK(!layout.getTiles().empty()); + auto tile = layout.getTiles().front().dimensions(); + if (tile[0] * src_bitwidth % tgt_bitwidth != 0) { + return failure(); + } + SmallVector new_tiles = + {xla::Tile({tile[0] * src_bitwidth / tgt_bitwidth, 128})}; + if (tgt_bitwidth < 32) { + new_tiles.push_back(xla::Tile({32 / tgt_bitwidth, 1})); + } + auto new_layout = tpu::TiledLayoutAttr::get(src_ty.getContext(), new_tiles, + layout.getTileStrides()); + auto new_result_ty = + MemRefType::get(dst_ty.getShape(), dst_ty.getElementType(), new_layout, + layout_ty.getMemorySpace()); + auto bitcast = + rewriter.create(op.getLoc(), new_result_ty, layout_ref); + rewriter.replaceOpWithNewOp(op, op.getType(), bitcast); + return success(); +} + template LogicalResult verifyStridedOp(Op op, MemRefType memref_ty, VectorType vector_ty) {