Skip to content

Commit

Permalink
Convert ONNX.Where to TOSA.select
Browse files Browse the repository at this point in the history
  • Loading branch information
p-lanza committed Dec 16, 2024
1 parent f021e4c commit e636262
Show file tree
Hide file tree
Showing 5 changed files with 123 additions and 0 deletions.
1 change: 1 addition & 0 deletions src/Conversion/ONNXToTOSA/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ add_onnx_mlir_library(OMONNXToTOSA
Tensor/Squeeze.cpp
Tensor/Tile.cpp
Tensor/Transpose.cpp
Tensor/Where.cpp
Flow/EntryPoint.cpp


Expand Down
2 changes: 2 additions & 0 deletions src/Conversion/ONNXToTOSA/ConvertONNXToTOSA.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,8 @@ void populateONNXToTOSAConversionPattern(ConversionTarget &target,
target, patterns, typeConverter, ctx);
populateLoweringONNXTransposeOpToTOSAPattern(
target, patterns, typeConverter, ctx);
populateLoweringONNXWhereOpToTOSAPattern(
target, patterns, typeConverter, ctx);
// NN
populateLoweringONNXMaxPoolSingleOutOpToTOSAPattern(
target, patterns, typeConverter, ctx);
Expand Down
2 changes: 2 additions & 0 deletions src/Conversion/ONNXToTOSA/ONNXToTOSACommon.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -173,6 +173,8 @@ void populateLoweringONNXExpandOpToTOSAPattern(mlir::ConversionTarget &,
mlir::RewritePatternSet &, mlir::TypeConverter &, mlir::MLIRContext *);
void populateLoweringONNXTransposeOpToTOSAPattern(mlir::ConversionTarget &,
mlir::RewritePatternSet &, mlir::TypeConverter &, mlir::MLIRContext *);
void populateLoweringONNXWhereOpToTOSAPattern(mlir::ConversionTarget &,
mlir::RewritePatternSet &, mlir::TypeConverter &, mlir::MLIRContext *);
// 'Flow' directory methods:
void populateLoweringONNXEntryPointOpToTOSAPattern(mlir::ConversionTarget &,
mlir::RewritePatternSet &, mlir::TypeConverter &, mlir::MLIRContext *);
Expand Down
74 changes: 74 additions & 0 deletions src/Conversion/ONNXToTOSA/Tensor/Where.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
// (c) Copyright 2022 - 2024 Advanced Micro Devices, Inc. All Rights Reserved.

#include "DialectBuilder.hpp"
#include "mlir/Dialect/Tosa/IR/TosaOps.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Transforms/DialectConversion.h"
#include "src/Conversion/ONNXToTOSA/ONNXToTOSACommon.hpp"
#include "src/Dialect/ONNX/ONNXOps.hpp"

using namespace mlir;

namespace onnx_mlir {
namespace {

class ONNXWhereLoweringToTOSA : public OpConversionPattern<ONNXWhereOp> {
public:
using OpConversionPattern::OpConversionPattern;
using OpAdaptor = typename ONNXWhereOp::Adaptor;

LogicalResult matchAndRewrite(ONNXWhereOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {

auto loc = op.getLoc();
Value pred = adaptor.getOperands()[0];
Value lhs = adaptor.getOperands()[1];
Value rhs = adaptor.getOperands()[2];

// Check types are compatible
auto predType = dyn_cast<TensorType>(pred.getType());
auto lhsType = dyn_cast<TensorType>(lhs.getType());
auto rhsType = dyn_cast<TensorType>(rhs.getType());
auto resultType = dyn_cast<TensorType>(op->getResultTypes()[0]);

if (!predType || !lhsType || !rhsType || !resultType) {
return rewriter.notifyMatchFailure(op, "Tosa only supports TensorTypes");
}
if (!isTOSABool(predType.getElementType())) {
return rewriter.notifyMatchFailure(
op, "Expected bool type for condition to onnx.Where");
}
if (lhsType.getElementType() != rhsType.getElementType() ||
lhsType.getElementType() != resultType.getElementType()) {
return rewriter.notifyMatchFailure(op,
"Expected element type for X, Y and output to be the same in "
"onnx.Where");
}

// Broadcast dimensions
IndexExprBuilderForTosa createTosaIE(rewriter, op->getLoc());
ONNXBroadcastOpShapeHelper shapeHelper(op, {}, &createTosaIE);
if (shapeHelper.computeShape().succeeded() &&
shapeHelper.hasRankBroadcast()) {
TosaBuilder tosaBuilder(rewriter, loc);
llvm::SmallVector<Value, 4> newValues =
tosaBuilder.equalizeRanks({pred, lhs, rhs});
pred = newValues[0];
lhs = newValues[1];
rhs = newValues[2];
}

rewriter.replaceOpWithNewOp<mlir::tosa::SelectOp>(
op, op.getType(), pred, lhs, rhs);
return success();
}
};

} // namespace

void populateLoweringONNXWhereOpToTOSAPattern(ConversionTarget &target,
RewritePatternSet &patterns, TypeConverter &typeConverter,
MLIRContext *ctx) {
patterns.insert<ONNXWhereLoweringToTOSA>(ctx);
}
} // namespace onnx_mlir
44 changes: 44 additions & 0 deletions test/mlir/conversion/onnx_to_tosa/Tensor/Where.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
// RUN: onnx-mlir-opt --shape-inference --convert-onnx-to-tosa -cse %s -split-input-file | FileCheck %s


func.func @test_where(%arg0: tensor<13x21x1xi1>, %arg1: tensor<13x21x1xf32>, %arg2: tensor<13x21x1xf32>) -> tensor<13x21x1xf32> {
%0 = "onnx.Where"(%arg0, %arg1, %arg2) : (tensor<13x21x1xi1>, tensor<13x21x1xf32>, tensor<13x21x1xf32>) -> tensor<13x21x1xf32>
"func.return"(%0) : (tensor<13x21x1xf32>) -> ()
// CHECK-LABEL: func @test_where
// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<13x21x1xi1>, [[PARAM_1_:%.+]]: tensor<13x21x1xf32>, [[PARAM_2_:%.+]]: tensor<13x21x1xf32>) -> tensor<13x21x1xf32> {
// CHECK-NEXT: [[VAR_0_:%.+]] = tosa.select [[PARAM_0_]], [[PARAM_1_]], [[PARAM_2_]] : (tensor<13x21x1xi1>, tensor<13x21x1xf32>, tensor<13x21x1xf32>) -> tensor<13x21x1xf32>
}

// -----

func.func @test_where_broadcast(%arg0: tensor<21x1xi1>, %arg1: tensor<13x21x1xf32>, %arg2: tensor<1xf32>) -> tensor<13x21x1xf32> {
%0 = "onnx.Where"(%arg0, %arg1, %arg2) : (tensor<21x1xi1>, tensor<13x21x1xf32>, tensor<1xf32>) -> tensor<13x21x1xf32>
"func.return"(%0) : (tensor<13x21x1xf32>) -> ()
// CHECK-LABEL: func.func @test_where_broadcast
// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<21x1xi1>, [[PARAM_1_:%.+]]: tensor<13x21x1xf32>, [[PARAM_2_:%.+]]: tensor<1xf32>) -> tensor<13x21x1xf32> {
// CHECK: [[VAR_0_:%.+]] = tosa.reshape [[PARAM_0_]] {new_shape = array<i64: 1, 21, 1>} : (tensor<21x1xi1>) -> tensor<1x21x1xi1>
// CHECK: [[VAR_1_:%.+]] = tosa.reshape [[PARAM_2_]] {new_shape = array<i64: 1, 1, 1>} : (tensor<1xf32>) -> tensor<1x1x1xf32>
// CHECK: [[VAR_2_:%.+]] = tosa.select [[VAR_0_]], [[PARAM_1_]], [[VAR_1_]] : (tensor<1x21x1xi1>, tensor<13x21x1xf32>, tensor<1x1x1xf32>) -> tensor<13x21x1xf32>
// CHECK: return [[VAR_2_]] : tensor<13x21x1xf32>
}

// -----

func.func @test_where_ui32(%arg0: tensor<13x21x1xi1>, %arg1: tensor<13x21x1xui32>, %arg2: tensor<13x21x1xui32>) -> tensor<13x21x1xui32> {
%0 = "onnx.Where"(%arg0, %arg1, %arg2) : (tensor<13x21x1xi1>, tensor<13x21x1xui32>, tensor<13x21x1xui32>) -> tensor<13x21x1xui32>
"func.return"(%0) : (tensor<13x21x1xui32>) -> ()
// CHECK-LABEL: func.func @test_where_ui32
// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<13x21x1xi1>, [[PARAM_1_:%.+]]: tensor<13x21x1xui32>, [[PARAM_2_:%.+]]: tensor<13x21x1xui32>) -> tensor<13x21x1xui32> {
// CHECK: [[VAR_0_:%.+]] = tosa.select [[PARAM_0_]], [[PARAM_1_]], [[PARAM_2_]] : (tensor<13x21x1xi1>, tensor<13x21x1xui32>, tensor<13x21x1xui32>) -> tensor<13x21x1xui32>
// CHECK: return [[VAR_0_]] : tensor<13x21x1xui32>
}

// -----

func.func @test_where_f64(%arg0: tensor<13x21x1xi1>, %arg1: tensor<13x21x1xf64>, %arg2: tensor<13x21x1xf64>) -> tensor<13x21x1xf64> {
%0 = "onnx.Where"(%arg0, %arg1, %arg2) : (tensor<13x21x1xi1>, tensor<13x21x1xf64>, tensor<13x21x1xf64>) -> tensor<13x21x1xf64>
"func.return"(%0) : (tensor<13x21x1xf64>) -> ()
// CHECK-LABEL: func.func @test_where_f64
// CHECK-NOT: onnx.Where
// CHECK: return {{.*}}: tensor<13x21x1xf64>
}

0 comments on commit e636262

Please sign in to comment.