Skip to content

Commit

Permalink
[TOSA] Update type converter and unary ops
Browse files Browse the repository at this point in the history
Signed-off-by: Philipp Braun <philipp.braun@amd.com>
  • Loading branch information
philippb-amd committed Jul 27, 2022
1 parent 096d449 commit 06a87c4
Show file tree
Hide file tree
Showing 6 changed files with 197 additions and 61 deletions.
3 changes: 3 additions & 0 deletions src/Conversion/ONNXToTOSA/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,9 @@

add_onnx_mlir_library(OMONNXToTOSA
ConvertONNXToTOSA.cpp

Math/Elementwise.cpp

LINK_LIBS PUBLIC
OMONNXOps
MLIRTosaDialect
Expand Down
87 changes: 37 additions & 50 deletions src/Conversion/ONNXToTOSA/ConvertONNXToTOSA.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,54 +12,29 @@
//
//===----------------------------------------------------------------------===//

#include "mlir/Dialect/Tosa/IR/TosaOps.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"

#include "src/Dialect/ONNX/DialectBuilder.hpp"
#include "src/Dialect/ONNX/ONNXOps.hpp"
#include "src/Dialect/ONNX/ONNXOpsHelper.hpp"
#include "src/Pass/Passes.hpp"
#include "src/Conversion/ONNXToTOSA/ONNXToTOSACommon.hpp"

using namespace mlir;

namespace onnx_mlir {

// This defines a template to construct ops whose legalizations are
// specialized.
template <typename OnnxOpT>
class ConvertOnnxOp : public OpConversionPattern<OnnxOpT> {
public:
using OpConversionPattern<OnnxOpT>::OpConversionPattern;
using OpAdaptor = typename OnnxOpT::Adaptor;
LogicalResult matchAndRewrite(OnnxOpT op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override;
};

template <>
LogicalResult ConvertOnnxOp<ONNXReluOp>::matchAndRewrite(ONNXReluOp op,
OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const {
Value input = adaptor.X();
auto inputTy = input.getType().dyn_cast<TensorType>();

if (!inputTy)
return op.emitError("Only Tensor types supported in TOSA");

if (!inputTy.getElementType().isa<FloatType>()) {
return op.emitError(
"Only floating-point datatype legalization currently supported");
}
static bool isSignedInt(Type type) {
IntegerType intType = type.dyn_cast<IntegerType>();
std::set<unsigned> intWidth{8, 16, 32, 48, 64};
return intType && intType.isSigned() &&
(intWidth.find(intType.getWidth()) != intWidth.end());
}

// Rescale the clampIn for quantized types. TBD
// Maps to tosa.clamp which has both int and fp limits.
Value clampIn = input;
static bool isFloat(Type type) {
return type.isa<BFloat16Type, Float16Type, Float32Type>();
}

rewriter.replaceOpWithNewOp<tosa::ClampOp>(op, op.getType(), clampIn,
rewriter.getI64IntegerAttr(0),
rewriter.getI64IntegerAttr(std::numeric_limits<int32_t>::max()),
rewriter.getF32FloatAttr(0.0f),
rewriter.getF32FloatAttr(std::numeric_limits<float>::max()));
return success();
void populateONNXToTOSAConversionPattern(ConversionTarget &target,
RewritePatternSet &patterns, TypeConverter &typeConverter,
MLIRContext *ctx) {
// Math
populateLoweringONNXElementwiseOpToTOSAPattern(
target, patterns, typeConverter, ctx);
}

// Performs lowering to TOSA dialect
Expand All @@ -79,24 +54,36 @@ struct FrontendToTosaLoweringPass
};

void FrontendToTosaLoweringPass::runOnOperation() {
ModuleOp module = getOperation();
// Define final conversion target
MLIRContext *context = &getContext();
RewritePatternSet patterns(context);
ConversionTarget target(*context);

// We use the type converter to legalize types before any conversion patterns
// are executed. This ensures that we do not need to trigger separate
// conversion failures. Quantized types are not supported right now.
TypeConverter typeConverter;
typeConverter.addConversion([](Type type) { return type; });

typeConverter.addConversion([](Type type) -> Optional<Type> {
if (isSignedInt(type) || isFloat(type))
return type;
return llvm::None;
});
typeConverter.addConversion([&](TensorType type) -> Optional<Type> {
if (typeConverter.isLegal(type.getElementType()))
return type;
return llvm::None;
});

// Define legal dialects and operations
target.addLegalDialect<tosa::TosaDialect, func::FuncDialect>();

#define INSERT_ONNXOP_PATTERN(OnnxOp) \
target.addIllegalOp<OnnxOp>(); \
patterns.add<ConvertOnnxOp<OnnxOp>>(typeConverter, context);
INSERT_ONNXOP_PATTERN(ONNXReluOp);
#undef INSERT_ONNXOP_PATTERN
// Define patterns
populateONNXToTOSAConversionPattern(target, patterns, typeConverter, context);

if (failed(
applyPartialConversion(getOperation(), target, std::move(patterns))))
if (failed(applyPartialConversion(module, target, std::move(patterns)))) {
signalPassFailure();
}
}

std::unique_ptr<Pass> createConvertONNXToTOSAPass() {
Expand Down
71 changes: 71 additions & 0 deletions src/Conversion/ONNXToTOSA/Math/Elementwise.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
/*
* SPDX-License-Identifier: Apache-2.0
*/

//===---------------- Elementwise.cpp - Elementwise Op --------------------===//
//
// Copyright (c) 2022 Advanced Micro Devices, Inc.
//
// =============================================================================
//
// This file lowers ONNX element-wise operators to TOSA dialect.
//
//===----------------------------------------------------------------------===//

#include "src/Conversion/ONNXToTOSA/ONNXToTOSACommon.hpp"

using namespace mlir;

namespace onnx_mlir {

namespace {

template <typename ONNXOpT, typename TOSAOpT>
class ONNXUnaryOpLoweringToTOSA : public OpConversionPattern<ONNXOpT> {
public:
using OpConversionPattern<ONNXOpT>::OpConversionPattern;
using OpAdaptor = typename ONNXOpT::Adaptor;
LogicalResult matchAndRewrite(ONNXOpT op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
rewriter.replaceOpWithNewOp<TOSAOpT>(op, op.getType(), adaptor.X());
return success();
}
};

class ONNXReluOpLoweringToTOSA : public OpConversionPattern<ONNXReluOp> {
public:
using OpConversionPattern::OpConversionPattern;
LogicalResult matchAndRewrite(ONNXReluOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {

Value input = adaptor.X();

// Quantized types are not supported right now (in type conversion).
// Once they are, the input should be rescaled for quantized types. (TBD)
// Maps to `tosa.clamp` which has both int and fp limits.
rewriter.replaceOpWithNewOp<tosa::ClampOp>(op, op.getType(), input,
rewriter.getI64IntegerAttr(0),
rewriter.getI64IntegerAttr(std::numeric_limits<int32_t>::max()),
rewriter.getF32FloatAttr(0.0f),
rewriter.getF32FloatAttr(std::numeric_limits<float>::max()));
return success();
}
};

} // namespace

void populateLoweringONNXElementwiseOpToTOSAPattern(ConversionTarget &target,
RewritePatternSet &patterns, TypeConverter &typeConverter,
MLIRContext *ctx) {
patterns.insert<ONNXReluOpLoweringToTOSA>(typeConverter, ctx);

#define INSERT_UNARY_PATTERN(ONNXOp, TOSAOp) \
target.addIllegalOp<ONNXOp>(); \
patterns.insert<ONNXUnaryOpLoweringToTOSA<ONNXOp, TOSAOp>>( \
typeConverter, ctx);
INSERT_UNARY_PATTERN(ONNXNegOp, tosa::NegateOp)
INSERT_UNARY_PATTERN(ONNXFloorOp, tosa::FloorOp)
#undef INSERT_UNARY_PATTERN
}

} // namespace onnx_mlir
49 changes: 49 additions & 0 deletions src/Conversion/ONNXToTOSA/ONNXToTOSACommon.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
/*
* SPDX-License-Identifier: Apache-2.0
*/

//====------ ONNXToTOSACommon.hpp - ONNX dialects to TOSA lowering --------===//
//
// Copyright (c) 2022 Advanced Micro Devices, Inc.
//
// =============================================================================
//
// This file contains common code shared by the functions performing the
// lowering to the TOSA dialect.
//
//===----------------------------------------------------------------------===//

#include "mlir/Dialect/Quant/QuantTypes.h"
#include "mlir/Dialect/Tosa/IR/TosaOps.h"

#include "mlir/IR/MLIRContext.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"

#include "src/Dialect/ONNX/DialectBuilder.hpp"
#include "src/Dialect/ONNX/ONNXOps.hpp"
#include "src/Dialect/ONNX/ONNXOpsHelper.hpp"
#include "src/Pass/Passes.hpp"
#include "src/Transform/ONNX/ConstPropHelper.hpp"

//===----------------------------------------------------------------------===//
// Functions to add lowering patterns for frontend operations.
//===----------------------------------------------------------------------===//

namespace onnx_mlir {

//===----------------------------------------------------------------------===//
// This is to get a TOSA operation of a given type for a specific operation.
//===----------------------------------------------------------------------===//
template <typename ONNXOp>
struct TOSADialectOp {
using Op = void;
};

template <typename Op>
using TOSAOp = typename TOSADialectOp<Op>::Op;

// `Math` directory methods:
void populateLoweringONNXElementwiseOpToTOSAPattern(
ConversionTarget &, RewritePatternSet &, TypeConverter &, MLIRContext *);
} // namespace onnx_mlir
37 changes: 37 additions & 0 deletions test/mlir/conversion/onnx_to_tosa/Math/Elementwise.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
// RUN: onnx-mlir-opt --shape-inference --convert-onnx-to-tosa %s -split-input-file | FileCheck %s

func.func @test_relu(%arg0 : tensor<10x10xf32>) -> tensor<10x10xf32> {
%0 = "onnx.Relu"(%arg0) : (tensor<10x10xf32>) -> tensor<10x10xf32>
"func.return"(%0) : (tensor<10x10xf32>) -> ()
// CHECK-LABEL: func @test_relu
// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<10x10xf32>) -> tensor<10x10xf32> {
// CHECK-NEXT: [[VAR_0_:%.+]] = "tosa.clamp"([[PARAM_0_]]) {max_fp = 3.40282347E+38 : f32, max_int = 2147483647 : i64, min_fp = 0.000000e+00 : f32, min_int = 0 : i64} : (tensor<10x10xf32>) -> tensor<10x10xf32>
// CHECK-NEXT: return [[VAR_0_]] : tensor<10x10xf32>
// CHECK-NEXT: }
}

func.func @test_relu_dynamic(%arg0 : tensor<?x10xf32>) -> tensor<*xf32> {
%0 = "onnx.Relu"(%arg0) : (tensor<?x10xf32>) -> tensor<*xf32>
"func.return"(%0) : (tensor<*xf32>) -> ()
// CHECK-LABEL: func @test_relu_dynamic
// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<?x10xf32>) -> tensor<?x10xf32> {
// CHECK-NEXT: [[VAR_0_:%.+]] = "tosa.clamp"([[PARAM_0_]]) {max_fp = 3.40282347E+38 : f32, max_int = 2147483647 : i64, min_fp = 0.000000e+00 : f32, min_int = 0 : i64} : (tensor<?x10xf32>) -> tensor<?x10xf32>
// CHECK-NEXT: return [[VAR_0_]] : tensor<?x10xf32>
// CHECK-NEXT: }
}

func.func @test_neg(%arg0: tensor<10x10xf32>) -> tensor<10x10xf32> {
%0 = "onnx.Neg"(%arg0) : (tensor<10x10xf32>) -> tensor<10x10xf32>
"func.return"(%0) : (tensor<10x10xf32>) -> ()
// CHECK-LABEL: func @test_neg
// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<10x10xf32>) -> tensor<10x10xf32> {
// CHECK-NEXT: [[VAR_0_:%.+]] = "tosa.negate"([[PARAM_0_]]) : (tensor<10x10xf32>) -> tensor<10x10xf32>
}

func.func @test_floor(%arg0: tensor<10x10xf32>) -> tensor<10x10xf32> {
%0 = "onnx.Floor"(%arg0) : (tensor<10x10xf32>) -> tensor<10x10xf32>
"func.return"(%0) : (tensor<10x10xf32>) -> ()
// CHECK-LABEL: func @test_floor
// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<10x10xf32>) -> tensor<10x10xf32> {
// CHECK-NEXT: [[VAR_0_:%.+]] = "tosa.floor"([[PARAM_0_]]) : (tensor<10x10xf32>) -> tensor<10x10xf32>
}
11 changes: 0 additions & 11 deletions test/mlir/tosa/onnx_lowering.mlir

This file was deleted.

0 comments on commit 06a87c4

Please sign in to comment.