Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[TOSA] Add ONNX to TOSA ArgMax conversion pass #45

Open
wants to merge 14 commits into
base: feature/onnx_to_torch
Choose a base branch
from
4 changes: 4 additions & 0 deletions src/Conversion/ONNXToTOSA/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,10 @@

add_onnx_mlir_library(OMONNXToTOSA
ConvertONNXToTOSA.cpp

Math/Elementwise.cpp
Tensor/ArgMax.cpp

LINK_LIBS PUBLIC
OMONNXOps
MLIRTosa
Expand Down
89 changes: 39 additions & 50 deletions src/Conversion/ONNXToTOSA/ConvertONNXToTOSA.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,54 +12,31 @@
//
//===----------------------------------------------------------------------===//

#include "mlir/Dialect/Tosa/IR/TosaOps.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"

#include "src/Dialect/ONNX/DialectBuilder.hpp"
#include "src/Dialect/ONNX/ONNXOps.hpp"
#include "src/Dialect/ONNX/ONNXOpsHelper.hpp"
#include "src/Pass/Passes.hpp"
#include "src/Conversion/ONNXToTOSA/ONNXToTOSACommon.hpp"

using namespace mlir;

namespace onnx_mlir {

// This defines a template to construct ops whose legalizations are
// specialized.
template <typename OnnxOpT>
class ConvertOnnxOp : public OpConversionPattern<OnnxOpT> {
public:
using OpConversionPattern<OnnxOpT>::OpConversionPattern;
using OpAdaptor = typename OnnxOpT::Adaptor;
LogicalResult matchAndRewrite(OnnxOpT op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override;
};

template <>
LogicalResult ConvertOnnxOp<ONNXReluOp>::matchAndRewrite(ONNXReluOp op,
OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const {
Value input = adaptor.X();
auto inputTy = input.getType().dyn_cast<TensorType>();

if (!inputTy)
return op.emitError("Only Tensor types supported in TOSA");

if (!inputTy.getElementType().isa<FloatType>()) {
return op.emitError(
"Only floating-point datatype legalization currently supported");
}
static bool isSignedInt(Type type) {
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Wondering if this should be named isTosaSignedInt - I know it is in a convert-to-tosa file, but I still had to read the contents carefully to understand it was limiting to TOSA supported types. SImilar comment for isFloat.

Also wondering if there is any way to reuse Tosa_SignedInt from mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td directly? (I'm guessing not, but wondered if you knew for certain.)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, there isn't. They turn into anonymous constraint functions in the .cpp files where they are used. That's the buf thing that my concepts proposal was to solve.

IntegerType intType = type.dyn_cast<IntegerType>();
std::set<unsigned> intWidth{8, 16, 32, 48, 64};
return intType && intType.isSigned() &&
(intWidth.find(intType.getWidth()) != intWidth.end());
}

// Rescale the clampIn for quantized types. TBD
// Maps to tosa.clamp which has both int and fp limits.
Value clampIn = input;
static bool isFloat(Type type) {
return type.isa<BFloat16Type, Float16Type, Float32Type>();
}

rewriter.replaceOpWithNewOp<tosa::ClampOp>(op, op.getType(), clampIn,
rewriter.getI64IntegerAttr(0),
rewriter.getI64IntegerAttr(std::numeric_limits<int32_t>::max()),
rewriter.getF32FloatAttr(0.0f),
rewriter.getF32FloatAttr(std::numeric_limits<float>::max()));
return success();
void populateONNXToTOSAConversionPattern(ConversionTarget &target,
RewritePatternSet &patterns, TypeConverter &typeConverter,
MLIRContext *ctx) {
// Math
populateLoweringONNXElementwiseOpToTOSAPattern(
target, patterns, typeConverter, ctx);
// Tensor
populateLoweringONNXArgMaxOpToTOSAPattern(patterns, typeConverter, ctx);
}

// Performs lowering to TOSA dialect
Expand All @@ -79,24 +56,36 @@ struct FrontendToTosaLoweringPass
};

void FrontendToTosaLoweringPass::runOnOperation() {
ModuleOp module = getOperation();
// Define final conversion target
MLIRContext *context = &getContext();
RewritePatternSet patterns(context);
ConversionTarget target(*context);

// We use the type converter to legalize types before any conversion patterns
// are executed. This ensures that we do not need to trigger separate
// conversion failures. Quantized types are not supported right now.
TypeConverter typeConverter;
typeConverter.addConversion([](Type type) { return type; });

typeConverter.addConversion([](Type type) -> Optional<Type> {
if (isSignedInt(type) || isFloat(type))
return type;
return llvm::None;
});
typeConverter.addConversion([&](TensorType type) -> Optional<Type> {
if (typeConverter.isLegal(type.getElementType()))
return type;
return llvm::None;
});

// Define legal dialects and operations
target.addLegalDialect<tosa::TosaDialect, func::FuncDialect>();

#define INSERT_ONNXOP_PATTERN(OnnxOp) \
target.addIllegalOp<OnnxOp>(); \
patterns.add<ConvertOnnxOp<OnnxOp>>(typeConverter, context);
INSERT_ONNXOP_PATTERN(ONNXReluOp);
#undef INSERT_ONNXOP_PATTERN
// Define patterns
populateONNXToTOSAConversionPattern(target, patterns, typeConverter, context);

if (failed(
applyPartialConversion(getOperation(), target, std::move(patterns))))
if (failed(applyPartialConversion(module, target, std::move(patterns)))) {
signalPassFailure();
}
}

std::unique_ptr<Pass> createConvertONNXToTOSAPass() {
Expand Down
71 changes: 71 additions & 0 deletions src/Conversion/ONNXToTOSA/Math/Elementwise.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
/*
* SPDX-License-Identifier: Apache-2.0
*/

//===---------------- Elementwise.cpp - Elementwise Op --------------------===//
//
// Copyright (c) 2022 Advanced Micro Devices, Inc.
//
// =============================================================================
//
// This file lowers ONNX element-wise operators to TOSA dialect.
//
//===----------------------------------------------------------------------===//

#include "src/Conversion/ONNXToTOSA/ONNXToTOSACommon.hpp"

using namespace mlir;

namespace onnx_mlir {

namespace {

template <typename ONNXOpT, typename TOSAOpT>
class ONNXUnaryOpLoweringToTOSA : public OpConversionPattern<ONNXOpT> {
public:
using OpConversionPattern<ONNXOpT>::OpConversionPattern;
using OpAdaptor = typename ONNXOpT::Adaptor;
LogicalResult matchAndRewrite(ONNXOpT op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
rewriter.replaceOpWithNewOp<TOSAOpT>(op, op.getType(), adaptor.X());
return success();
}
};

class ONNXReluOpLoweringToTOSA : public OpConversionPattern<ONNXReluOp> {
public:
using OpConversionPattern::OpConversionPattern;
LogicalResult matchAndRewrite(ONNXReluOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {

Value input = adaptor.X();

// Quantized types are not supported right now (in type conversion).
// Once they are, the input should be rescaled for quantized types. (TBD)
// Maps to `tosa.clamp` which has both int and fp limits.
rewriter.replaceOpWithNewOp<tosa::ClampOp>(op, op.getType(), input,
rewriter.getI64IntegerAttr(0),
rewriter.getI64IntegerAttr(std::numeric_limits<int32_t>::max()),
rewriter.getF32FloatAttr(0.0f),
rewriter.getF32FloatAttr(std::numeric_limits<float>::max()));
return success();
}
};

} // namespace

void populateLoweringONNXElementwiseOpToTOSAPattern(ConversionTarget &target,
RewritePatternSet &patterns, TypeConverter &typeConverter,
MLIRContext *ctx) {
patterns.insert<ONNXReluOpLoweringToTOSA>(typeConverter, ctx);

#define INSERT_UNARY_PATTERN(ONNXOp, TOSAOp) \
target.addIllegalOp<ONNXOp>(); \
patterns.insert<ONNXUnaryOpLoweringToTOSA<ONNXOp, TOSAOp>>( \
typeConverter, ctx);
INSERT_UNARY_PATTERN(ONNXNegOp, tosa::NegateOp)
INSERT_UNARY_PATTERN(ONNXFloorOp, tosa::FloorOp)
#undef INSERT_UNARY_PATTERN
}

} // namespace onnx_mlir
53 changes: 53 additions & 0 deletions src/Conversion/ONNXToTOSA/ONNXToTOSACommon.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
/*
* SPDX-License-Identifier: Apache-2.0
*/

//====------ ONNXToTOSACommon.hpp - ONNX dialects to TOSA lowering --------===//
//
// Copyright (c) 2022 Advanced Micro Devices, Inc.
//
// =============================================================================
//
// This file contains common code shared by the functions performing the
// lowering to the TOSA dialect.
//
//===----------------------------------------------------------------------===//

#include "mlir/Dialect/Quant/QuantTypes.h"
#include "mlir/Dialect/Tosa/IR/TosaOps.h"

#include "mlir/IR/MLIRContext.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"

#include "src/Dialect/ONNX/DialectBuilder.hpp"
#include "src/Dialect/ONNX/ONNXOps.hpp"
#include "src/Dialect/ONNX/ONNXOpsHelper.hpp"
#include "src/Pass/Passes.hpp"
#include "src/Transform/ONNX/ConstPropHelper.hpp"

//===----------------------------------------------------------------------===//
// Functions to add lowering patterns for frontend operations.
//===----------------------------------------------------------------------===//

namespace onnx_mlir {

//===----------------------------------------------------------------------===//
// This is to get a TOSA operation of a given type for a specific operation.
//===----------------------------------------------------------------------===//
template <typename ONNXOp>
struct TOSADialectOp {
using Op = void;
};

template <typename Op>
using TOSAOp = typename TOSADialectOp<Op>::Op;

// `Math` directory methods:
void populateLoweringONNXElementwiseOpToTOSAPattern(
ConversionTarget &, RewritePatternSet &, TypeConverter &, MLIRContext *);
// `Tensor` directory methods:
void populateLoweringONNXArgMaxOpToTOSAPattern(
RewritePatternSet &, TypeConverter &, MLIRContext *);

} // namespace onnx_mlir
50 changes: 50 additions & 0 deletions src/Conversion/ONNXToTOSA/Tensor/ArgMax.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
/*
* SPDX-License-Identifier: Apache-2.0
*/

//===------------------- ArgMax.cpp - ArgMax Op ---------------------------===//
//
// Copyright (c) 2022 Advanced Micro Devices, Inc.
//
// =============================================================================
//
// This file lowers ONNX ArgMax operator to TOSA dialect.
//
//===----------------------------------------------------------------------===//

#include "src/Conversion/ONNXToTOSA/ONNXToTOSACommon.hpp"

using namespace mlir;

namespace onnx_mlir {

namespace {

class ONNXArgMaxOpLoweringToTOSA : public OpConversionPattern<ONNXArgMaxOp> {
public:
using OpConversionPattern::OpConversionPattern;
LogicalResult matchAndRewrite(ONNXArgMaxOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {

if (adaptor.keepdims() != 1)
return rewriter.notifyMatchFailure(op, "keepdims != 1 is not supported");

if (adaptor.select_last_index() != 0)
return rewriter.notifyMatchFailure(
op, "select_last_index != 0 is not supported");

IntegerAttr axis = rewriter.getI64IntegerAttr(adaptor.axis());
rewriter.replaceOpWithNewOp<tosa::ArgMaxOp>(
op, op.getType(), adaptor.data(), axis);
return success();
}
};

} // namespace

void populateLoweringONNXArgMaxOpToTOSAPattern(RewritePatternSet &patterns,
TypeConverter &typeConverter, MLIRContext *ctx) {
patterns.insert<ONNXArgMaxOpLoweringToTOSA>(typeConverter, ctx);
}

} // namespace onnx_mlir
37 changes: 37 additions & 0 deletions test/mlir/conversion/onnx_to_tosa/Math/Elementwise.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
// RUN: onnx-mlir-opt --shape-inference --convert-onnx-to-tosa %s -split-input-file | FileCheck %s

func.func @test_relu(%arg0 : tensor<10x10xf32>) -> tensor<10x10xf32> {
%0 = "onnx.Relu"(%arg0) : (tensor<10x10xf32>) -> tensor<10x10xf32>
"func.return"(%0) : (tensor<10x10xf32>) -> ()
// CHECK-LABEL: func @test_relu
// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<10x10xf32>) -> tensor<10x10xf32> {
// CHECK-NEXT: [[VAR_0_:%.+]] = "tosa.clamp"([[PARAM_0_]]) {max_fp = 3.40282347E+38 : f32, max_int = 2147483647 : i64, min_fp = 0.000000e+00 : f32, min_int = 0 : i64} : (tensor<10x10xf32>) -> tensor<10x10xf32>
// CHECK-NEXT: return [[VAR_0_]] : tensor<10x10xf32>
// CHECK-NEXT: }
}

func.func @test_relu_dynamic(%arg0 : tensor<?x10xf32>) -> tensor<*xf32> {
%0 = "onnx.Relu"(%arg0) : (tensor<?x10xf32>) -> tensor<*xf32>
"func.return"(%0) : (tensor<*xf32>) -> ()
// CHECK-LABEL: func @test_relu_dynamic
// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<?x10xf32>) -> tensor<?x10xf32> {
// CHECK-NEXT: [[VAR_0_:%.+]] = "tosa.clamp"([[PARAM_0_]]) {max_fp = 3.40282347E+38 : f32, max_int = 2147483647 : i64, min_fp = 0.000000e+00 : f32, min_int = 0 : i64} : (tensor<?x10xf32>) -> tensor<?x10xf32>
// CHECK-NEXT: return [[VAR_0_]] : tensor<?x10xf32>
// CHECK-NEXT: }
}

func.func @test_neg(%arg0: tensor<10x10xf32>) -> tensor<10x10xf32> {
%0 = "onnx.Neg"(%arg0) : (tensor<10x10xf32>) -> tensor<10x10xf32>
"func.return"(%0) : (tensor<10x10xf32>) -> ()
// CHECK-LABEL: func @test_neg
// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<10x10xf32>) -> tensor<10x10xf32> {
// CHECK-NEXT: [[VAR_0_:%.+]] = "tosa.negate"([[PARAM_0_]]) : (tensor<10x10xf32>) -> tensor<10x10xf32>
}

func.func @test_floor(%arg0: tensor<10x10xf32>) -> tensor<10x10xf32> {
%0 = "onnx.Floor"(%arg0) : (tensor<10x10xf32>) -> tensor<10x10xf32>
"func.return"(%0) : (tensor<10x10xf32>) -> ()
// CHECK-LABEL: func @test_floor
// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<10x10xf32>) -> tensor<10x10xf32> {
// CHECK-NEXT: [[VAR_0_:%.+]] = "tosa.floor"([[PARAM_0_]]) : (tensor<10x10xf32>) -> tensor<10x10xf32>
}
11 changes: 11 additions & 0 deletions test/mlir/conversion/onnx_to_tosa/Tensor/ArgMax.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
// RUN: onnx-mlir-opt --convert-onnx-to-tosa %s -split-input-file | FileCheck %s

func.func @test_argmax(%arg0: tensor<8x16x32xf32>) -> tensor<8x16x32xi64> {
%0 = "onnx.ArgMax"(%arg0) {axis = 2 : si64, keepdims = 1 : si64, onnx_node_name = "ArgMax_0"} : (tensor<8x16x32xf32>) -> tensor<8x16x32xi64>
return %0 : tensor<8x16x32xi64>
// CHECK-LABEL: func @test_argmax
// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<8x16x32xf32>) -> tensor<8x16x32xi64> {
// CHECK-NEXT: [[VAR_0_:%.+]] = "tosa.argmax"([[PARAM_0_]]) {axis = 2 : i64} : (tensor<8x16x32xf32>) -> tensor<8x16x32xi64>
// CHECK-NEXT: return [[VAR_0_]] : tensor<8x16x32xi64>
// CHECK-NEXT: }
}
11 changes: 0 additions & 11 deletions test/mlir/tosa/onnx_lowering.mlir

This file was deleted.