Skip to content

Commit

Permalink
Merge pull request #86 from Xilinx/tiagot.clamp_max_min_tosa_folding
Browse files Browse the repository at this point in the history
feat: add min to clamp and max to clamp canonicalizations.
  • Loading branch information
ttjost authored Jan 5, 2024
2 parents 8330ba1 + 89f16ea commit d0868e8
Show file tree
Hide file tree
Showing 3 changed files with 118 additions and 0 deletions.
4 changes: 4 additions & 0 deletions mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -733,6 +733,8 @@ def Tosa_MaximumOp : Tosa_ElemWiseBinaryOp<"maximum", [Commutative]> {
let results = (outs
Tosa_Tensor:$output
);

let hasCanonicalizer = 1;
}

//===----------------------------------------------------------------------===//
Expand All @@ -754,6 +756,8 @@ def Tosa_MinimumOp : Tosa_ElemWiseBinaryOp<"minimum", [Commutative]> {
let results = (outs
Tosa_Tensor:$output
);

let hasCanonicalizer = 1;
}

//===----------------------------------------------------------------------===//
Expand Down
86 changes: 86 additions & 0 deletions mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -433,6 +433,92 @@ void SliceOp::getCanonicalizationPatterns(RewritePatternSet &results,
results.add<ConcatSliceOptimization>(context);
}

struct MinToClampOptimization : public OpRewritePattern<tosa::MinimumOp> {
using OpRewritePattern<tosa::MinimumOp>::OpRewritePattern;

LogicalResult matchAndRewrite(tosa::MinimumOp op,
PatternRewriter &rewriter) const override {

DenseElementsAttr constant;
if (!matchPattern(op.getInput2(), m_Constant(&constant)) ||
!constant.isSplat())
return failure();

Value input = op.getInput1();
auto elementTy = llvm::cast<ShapedType>(input.getType()).getElementType();

int64_t minInt = std::numeric_limits<int32_t>::min();
float minFp = std::numeric_limits<float>::lowest();

int64_t maxInt;
float maxFp;
if (isa<FloatType>(elementTy)) {
auto constMin = constant.getSplatValue<llvm::APFloat>();
maxFp = constMin.convertToFloat();
maxInt = constMin.convertToFloat();
} else {
auto constMin = constant.getSplatValue<llvm::APInt>();
maxFp = constMin.getSExtValue();
maxInt = constMin.getSExtValue();
}

rewriter.replaceOpWithNewOp<tosa::ClampOp>(
op, op.getType(), input, rewriter.getI64IntegerAttr(minInt),
rewriter.getI64IntegerAttr(maxInt), rewriter.getF32FloatAttr(minFp),
rewriter.getF32FloatAttr(maxFp));

return success();
}
};

void MinimumOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
results.add<MinToClampOptimization>(context);
}

struct MaxToClampOptimization : public OpRewritePattern<tosa::MaximumOp> {
using OpRewritePattern<tosa::MaximumOp>::OpRewritePattern;

LogicalResult matchAndRewrite(tosa::MaximumOp op,
PatternRewriter &rewriter) const override {

DenseElementsAttr constant;
if (!matchPattern(op.getInput2(), m_Constant(&constant)) ||
!constant.isSplat())
return failure();

Value input = op.getInput1();
auto elementTy = llvm::cast<ShapedType>(input.getType()).getElementType();

int64_t maxInt = std::numeric_limits<int64_t>::max();
float maxFp = std::numeric_limits<float>::max();

int64_t minInt;
float minFp;
if (isa<FloatType>(elementTy)) {
auto constMax = constant.getSplatValue<llvm::APFloat>();
minFp = constMax.convertToFloat();
minInt = constMax.convertToFloat();
} else {
auto constMax = constant.getSplatValue<llvm::APInt>();
minFp = constMax.getSExtValue();
minInt = constMax.getSExtValue();
}

rewriter.replaceOpWithNewOp<tosa::ClampOp>(
op, op.getType(), input, rewriter.getI64IntegerAttr(minInt),
rewriter.getI64IntegerAttr(maxInt), rewriter.getF32FloatAttr(minFp),
rewriter.getF32FloatAttr(maxFp));

return success();
}
};

void MaximumOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
results.add<MaxToClampOptimization>(context);
}

//===----------------------------------------------------------------------===//
// Operator Folders.
//===----------------------------------------------------------------------===//
Expand Down
28 changes: 28 additions & 0 deletions mlir/test/Dialect/Tosa/canonicalize.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,34 @@ func.func @clamp_twice_is_single_clamp(%arg0: tensor<4xi8>) -> tensor<4xi8> {
return %1 : tensor<4xi8>
}

func.func @clamp_minimum_i32(%arg0: tensor<4xi32>) -> tensor<4xi32> {
// CHECK: "tosa.clamp"(%arg0) <{max_fp = 6.000000e+00 : f32, max_int = 6 : i64, min_fp = -3.40282347E+38 : f32, min_int = -2147483648 : i64}
%0 = "tosa.const"() <{value = dense<6> : tensor<1xi32>}> : () -> tensor<1xi32>
%1 = "tosa.minimum"(%arg0, %0) : (tensor<4xi32>, tensor<1xi32>) -> tensor<4xi32>
return %1 : tensor<4xi32>
}

func.func @clamp_minimum_f32(%arg0: tensor<4xf32>) -> tensor<4xf32> {
// CHECK: "tosa.clamp"(%arg0) <{max_fp = 6.000000e+00 : f32, max_int = 6 : i64, min_fp = -3.40282347E+38 : f32, min_int = -2147483648 : i64}
%0 = "tosa.const"() <{value = dense<6.0> : tensor<1xf32>}> : () -> tensor<1xf32>
%1 = "tosa.minimum"(%arg0, %0) : (tensor<4xf32>, tensor<1xf32>) -> tensor<4xf32>
return %1 : tensor<4xf32>
}

func.func @clamp_maximum_i32(%arg0: tensor<4xi32>) -> tensor<4xi32> {
// CHECK: "tosa.clamp"(%arg0) <{max_fp = 3.40282347E+38 : f32, max_int = 9223372036854775807 : i64, min_fp = -6.000000e+00 : f32, min_int = -6 : i64}
%0 = "tosa.const"() <{value = dense<-6> : tensor<1xi32>}> : () -> tensor<1xi32>
%1 = "tosa.maximum"(%arg0, %0) : (tensor<4xi32>, tensor<1xi32>) -> tensor<4xi32>
return %1 : tensor<4xi32>
}

func.func @clamp_maximum_f32(%arg0: tensor<4xf32>) -> tensor<4xf32> {
// CHECK: "tosa.clamp"(%arg0) <{max_fp = 3.40282347E+38 : f32, max_int = 9223372036854775807 : i64, min_fp = -6.000000e+00 : f32, min_int = -6 : i64}
%0 = "tosa.const"() <{value = dense<-6.0> : tensor<1xf32>}> : () -> tensor<1xf32>
%1 = "tosa.maximum"(%arg0, %0) : (tensor<4xf32>, tensor<1xf32>) -> tensor<4xf32>
return %1 : tensor<4xf32>
}

// CHECK-LABEL: @concat_fold_zero
func.func @concat_fold_zero(%arg0: tensor<?x0xf32>, %arg1: tensor<?x1xf32>, %arg2: tensor<?x2xf32>) -> tensor<?x3xf32> {
// CHECK: "tosa.concat"(%arg1, %arg2) <{axis = 1 : i64}>
Expand Down

0 comments on commit d0868e8

Please sign in to comment.