Skip to content

Commit

Permalink
Avoid converting to TOSA if the invariants are not respected (#101)
Browse files Browse the repository at this point in the history
* Avoid converting to TOSA if the invariants are not respected

* Address comments

Address review comments
  • Loading branch information
josel-amd authored Jun 12, 2024
1 parent 55d5d4b commit 4818661
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 4 deletions.
18 changes: 14 additions & 4 deletions src/Conversion/ONNXToTOSA/Math/Elementwise.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ struct IsBool {
}
};

template <typename OpAdaptorT, typename TypeChecker>
template <typename OpAdaptorT, typename TypeChecker, typename TosaOpT>
LogicalResult checkBasicTosaRequirementsForBinaryOps(
ConversionPatternRewriter &rewriter, Operation *op, OpAdaptorT adaptor,
Type resultType) {
Expand All @@ -92,6 +92,15 @@ LogicalResult checkBasicTosaRequirementsForBinaryOps(

Type resultElementType = resultTensorType.getElementType();

if (TosaOpT::template hasTrait<
::mlir::OpTrait::SameOperandsAndResultElementType>()) {
if (lhsType.getElementType() != rhsType.getElementType() ||
lhsType.getElementType() != resultElementType) {
return rewriter.notifyMatchFailure(
op, "lhs, rhs and result must have the same type");
}
}

if (failed(TypeChecker::checkType(rewriter, resultElementType, op))) {
return failure();
}
Expand Down Expand Up @@ -144,8 +153,8 @@ class ONNXBinaryElementwiseOpLoweringToTOSA
LogicalResult matchAndRewrite(ONNXOpT op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {

if (failed(checkBasicTosaRequirementsForBinaryOps<OpAdaptor, TypeChecker>(
rewriter, op, adaptor, op.getResult().getType())))
if (failed(checkBasicTosaRequirementsForBinaryOps<OpAdaptor, TypeChecker,
TosaOpT>(rewriter, op, adaptor, op.getResult().getType())))
return failure();

auto loc = op.getLoc();
Expand Down Expand Up @@ -179,7 +188,8 @@ class ONNXMulOpLoweringToTosa : public OpConversionPattern<ONNXMulOp> {
using OpConversionPattern::OpConversionPattern;
LogicalResult matchAndRewrite(ONNXMulOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
if (failed(checkBasicTosaRequirementsForBinaryOps<OpAdaptor, IsIntOrFloat>(
if (failed(checkBasicTosaRequirementsForBinaryOps<OpAdaptor, IsIntOrFloat,
mlir::tosa::MulOp>(
rewriter, op, adaptor, op.getResult().getType())))
return failure();

Expand Down
10 changes: 10 additions & 0 deletions test/mlir/conversion/onnx_to_tosa/Math/Elementwise.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -513,6 +513,16 @@ func.func @test_pow_f64(%arg0: tensor<13x21x1xf64>, %arg1: tensor<13x21x1xf64>)

// -----

func.func @test_pow_mixed_types(%arg0: tensor<3xf32>, %arg1: tensor<3xi32>) -> (tensor<3xf32>) {
// CHECK-LABEL: func @test_pow_mixed_types
// CHECK-SAME: ([[PARAM_0:%.*]]: tensor<3xf32>, [[PARAM_1:%.*]]: tensor<3xi32>) -> tensor<3xf32>
// CHECK: "onnx.Pow"([[PARAM_0]], [[PARAM_1]]) {onnx_node_name = "onnx.Pow_0"} : (tensor<3xf32>, tensor<3xi32>) -> tensor<3xf32>
%0 = "onnx.Pow"(%arg0, %arg1) {onnx_node_name = "onnx.Pow_0"} : (tensor<3xf32>, tensor<3xi32>) -> tensor<3xf32>
return %0 : tensor<3xf32>
}

// -----

func.func @test_sqrt(%arg0: tensor<3xf32>) -> tensor<3xf32> {
%0 = "onnx.Sqrt"(%arg0) : (tensor<3xf32>) -> tensor<3xf32>
return %0 : tensor<3xf32>
Expand Down

0 comments on commit 4818661

Please sign in to comment.