Skip to content

Commit

Permalink
support shape_propogate for gather,reduce,scatter,transpose
Browse files Browse the repository at this point in the history
  • Loading branch information
eedalong committed Jun 14, 2024
1 parent 037e6ba commit 9cd90c4
Showing 1 changed file with 159 additions and 1 deletion.
160 changes: 159 additions & 1 deletion tao_compiler/mlir/disc/transforms/disc_shape_propagate.cc
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,9 @@ bool isBinaryOp(Operation* op) {
isa<mhlo::SelectOp>(*op) || isa<mhlo::ConvertOp>(*op);
}

bool isUnaryOp(Operation* op) { return isa<mhlo::ConvertOp>(op); }
bool isUnaryOp(Operation* op) {
return isa<mhlo::ConvertOp, mhlo::ScatterOp>(op);
}
bool isConcreteShape(ShapeContext& ctx) {
for (auto dim : ctx.shape) {
if (dim == ShapedType::kDynamic) return false;
Expand Down Expand Up @@ -152,6 +154,157 @@ std::optional<ShapeContext> propagateHelper<mhlo::DotOp>(
return ShapeContext(op->getResult(0), new_shape);
}

template <>
std::optional<ShapeContext> propagateHelper<mhlo::ConcatenateOp>(
OpBuilder& b, Operation* op, ShapeContext& inputCtx) {
auto concat_op = cast<mhlo::ConcatenateOp>(op);
auto operands = op->getOperands();
SmallVector<int64_t> new_shape(
op->getResult(0).getType().cast<RankedTensorType>().getShape());
for (auto dim_size : inputCtx.shape) {
if (dim_size == ShapedType::kDynamic) {
new_shape.push_back(ShapedType::kDynamic);
}
}
return ShapeContext(op->getResult(0), new_shape);
}

template <>
std::optional<ShapeContext> propagateHelper<mhlo::TransposeOp>(
OpBuilder& b, Operation* op, ShapeContext& inputCtx) {
auto transpose_op = cast<mhlo::TransposeOp>(op);
SmallVector<int64_t> new_shape;

for (auto it = transpose_op.getPermutation().begin();
it != transpose_op.getPermutation().end(); it++) {
int64_t src_dim = (*it).getSExtValue();
new_shape.push_back(inputCtx.shape[src_dim]);
}

return ShapeContext(op->getResult(0), new_shape);
}

template <>
std::optional<ShapeContext> propagateHelper<mhlo::ReduceOp>(
OpBuilder& b, Operation* op, ShapeContext& inputCtx) {
auto reduce_op = cast<mhlo::ReduceOp>(op);
SmallVector<int64_t> new_shape;

for (int dim = 0; dim < inputCtx.shape.size(); dim++) {
bool add_dim = true;
for (auto it = reduce_op.getDimensions().begin();
it != reduce_op.getDimensions().end(); it++) {
int64_t src_dim = (*it).getSExtValue();
add_dim = add_dim && !(dim == src_dim);
}
if (add_dim) {
new_shape.push_back(inputCtx.shape[dim]);
}
}

return ShapeContext(op->getResult(0), new_shape);
}

template <>
std::optional<ShapeContext> propagateHelper<mhlo::DynamicGatherOp>(
OpBuilder& b, Operation* op, ShapeContext& inputCtx) {
auto dynamic_gather_op = cast<mhlo::DynamicGatherOp>(op);
SmallVector<int64_t> new_shape;

auto attr = dynamic_gather_op.getDimensionNumbers();
auto slice_sizes =
op->getOperand(2).getType().cast<RankedTensorType>().getShape();

auto offset_dims = attr.getOffsetDims();
auto index_vector_dim = attr.getIndexVectorDim();
auto collapsed_slice_dims = attr.getCollapsedSliceDims();

if (inputCtx.value == op->getOperand(1)) {
// start_indices
int dim_idx = 0;
for (int dim_idx = 0; dim_idx < inputCtx.shape.size(); dim_idx++) {
if (dim_idx != index_vector_dim) {
new_shape.push_back(inputCtx.shape[dim_idx]);
}
dim_idx += 1;
}
} else if (inputCtx.value == op->getOperand(2)) {
for (int dim_idx = 0; dim_idx < inputCtx.shape.size(); dim_idx++) {
bool include_this_dim = true;
for (auto collapsed_slice_dim : collapsed_slice_dims) {
if (dim_idx == collapsed_slice_dim) {
include_this_dim = false;
}
}

if (include_this_dim) {
// need to decide whether it is a constant value or value from operand
new_shape.push_back(inputCtx.shape[dim_idx]);
}
dim_idx += 1;
}
} else {
new_shape = SmallVector<int64_t>(
op->getResult(0).getType().cast<RankedTensorType>().getShape());
}

return ShapeContext(op->getResult(0), new_shape);
}

template <>
std::optional<ShapeContext> propagateHelper<mhlo::GatherOp>(
OpBuilder& b, Operation* op, ShapeContext& inputCtx) {
auto gather_op = cast<mhlo::GatherOp>(op);

// batch_dims = [d for d in axes(result) and d not in offset_dims].
auto attr = gather_op.getDimensionNumbers();
auto offset_dims = attr.getOffsetDims();
auto index_vector_dim = attr.getIndexVectorDim();
auto slice_sizes = gather_op.getSliceSizes();
auto collapsed_slice_dims = attr.getCollapsedSliceDims();
auto src_shape =
op->getOperand(0).getType().cast<RankedTensorType>().getShape();
SmallVector<Value> slice_sizes_vec;

// process offset_dim_sizes, offset dims
int dim_idx = 0;
for (auto dim_size : slice_sizes) {
bool include_this_dim = true;
for (auto collapsed_slice_dim : collapsed_slice_dims) {
if (dim_idx == collapsed_slice_dim) {
include_this_dim = false;
}
}
if (include_this_dim) {
// need to decide whether it is a constant value or value from operand
if (src_shape[dim_idx] == dim_size.getSExtValue()) {
slice_sizes_vec.push_back(
b.create<memref::DimOp>(op->getLoc(), op->getOperand(0), dim_idx)
.getResult());
} else {
slice_sizes_vec.push_back(b.create<arith::ConstantIndexOp>(
op->getLoc(), dim_size.getSExtValue()));
}
}
dim_idx += 1;
}

// create a dynamic gather op
auto dynamic_gather_op = b.create<mhlo::DynamicGatherOp>(
op->getLoc(), op->getResult(0).getType(), op->getOperand(0),
op->getOperand(1),
b.create<tensor::FromElementsOp>(op->getLoc(), slice_sizes_vec),
mhlo::GatherDimensionNumbersAttr::get(
attr.getContext(), attr.getOffsetDims(), attr.getCollapsedSliceDims(),
attr.getStartIndexMap(), attr.getIndexVectorDim()),
gather_op.getIndicesAreSorted());
gather_op.getResult().replaceAllUsesWith(dynamic_gather_op.getResult());
gather_op.erase();

// Update DynamicGatherOp result shape information
return propagateHelper<mhlo::DynamicGatherOp>(b, dynamic_gather_op, inputCtx);
}

LogicalResult parseInputDynamicDims(
func::FuncOp main,
std::vector<std::pair<int, std::vector<int>>>& input_dynamic_dims) {
Expand Down Expand Up @@ -204,6 +357,11 @@ std::optional<ShapeContext> propagateOpShape(OpBuilder& rewriter, Operation* op,
std::optional<ShapeContext> (*)(OpBuilder&, Operation*, ShapeContext&);
const std::vector<PropagationFunc> propagationFunctions = {
propagateHelper<mhlo::DotOp>,
propagateHelper<mhlo::ConcatenateOp>,
propagateHelper<mhlo::TransposeOp>,
propagateHelper<mhlo::ReduceOp>,
propagateHelper<mhlo::GatherOp>,
propagateHelper<mhlo::DynamicGatherOp>,
};
// Iterate over the propagation functions and apply each one
for (const auto& propagate : propagationFunctions) {
Expand Down

0 comments on commit 9cd90c4

Please sign in to comment.