forked from onnx/onnx-mlir
-
Notifications
You must be signed in to change notification settings - Fork 2
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[TOSA] Update type converter and unary ops
Signed-off-by: Philipp Braun <philipp.braun@amd.com>
- Loading branch information
1 parent
096d449
commit 06a87c4
Showing
6 changed files
with
197 additions
and
61 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,71 @@ | ||
/* | ||
* SPDX-License-Identifier: Apache-2.0 | ||
*/ | ||
|
||
//===---------------- Elementwise.cpp - Elementwise Op --------------------===// | ||
// | ||
// Copyright (c) 2022 Advanced Micro Devices, Inc. | ||
// | ||
// ============================================================================= | ||
// | ||
// This file lowers ONNX element-wise operators to TOSA dialect. | ||
// | ||
//===----------------------------------------------------------------------===// | ||
|
||
#include "src/Conversion/ONNXToTOSA/ONNXToTOSACommon.hpp" | ||
|
||
using namespace mlir; | ||
|
||
namespace onnx_mlir { | ||
|
||
namespace { | ||
|
||
template <typename ONNXOpT, typename TOSAOpT> | ||
class ONNXUnaryOpLoweringToTOSA : public OpConversionPattern<ONNXOpT> { | ||
public: | ||
using OpConversionPattern<ONNXOpT>::OpConversionPattern; | ||
using OpAdaptor = typename ONNXOpT::Adaptor; | ||
LogicalResult matchAndRewrite(ONNXOpT op, OpAdaptor adaptor, | ||
ConversionPatternRewriter &rewriter) const override { | ||
rewriter.replaceOpWithNewOp<TOSAOpT>(op, op.getType(), adaptor.X()); | ||
return success(); | ||
} | ||
}; | ||
|
||
class ONNXReluOpLoweringToTOSA : public OpConversionPattern<ONNXReluOp> { | ||
public: | ||
using OpConversionPattern::OpConversionPattern; | ||
LogicalResult matchAndRewrite(ONNXReluOp op, OpAdaptor adaptor, | ||
ConversionPatternRewriter &rewriter) const override { | ||
|
||
Value input = adaptor.X(); | ||
|
||
// Quantized types are not supported right now (in type conversion). | ||
// Once they are, the input should be rescaled for quantized types. (TBD) | ||
// Maps to `tosa.clamp` which has both int and fp limits. | ||
rewriter.replaceOpWithNewOp<tosa::ClampOp>(op, op.getType(), input, | ||
rewriter.getI64IntegerAttr(0), | ||
rewriter.getI64IntegerAttr(std::numeric_limits<int32_t>::max()), | ||
rewriter.getF32FloatAttr(0.0f), | ||
rewriter.getF32FloatAttr(std::numeric_limits<float>::max())); | ||
return success(); | ||
} | ||
}; | ||
|
||
} // namespace | ||
|
||
void populateLoweringONNXElementwiseOpToTOSAPattern(ConversionTarget &target, | ||
RewritePatternSet &patterns, TypeConverter &typeConverter, | ||
MLIRContext *ctx) { | ||
patterns.insert<ONNXReluOpLoweringToTOSA>(typeConverter, ctx); | ||
|
||
#define INSERT_UNARY_PATTERN(ONNXOp, TOSAOp) \ | ||
target.addIllegalOp<ONNXOp>(); \ | ||
patterns.insert<ONNXUnaryOpLoweringToTOSA<ONNXOp, TOSAOp>>( \ | ||
typeConverter, ctx); | ||
INSERT_UNARY_PATTERN(ONNXNegOp, tosa::NegateOp) | ||
INSERT_UNARY_PATTERN(ONNXFloorOp, tosa::FloorOp) | ||
#undef INSERT_UNARY_PATTERN | ||
} | ||
|
||
} // namespace onnx_mlir |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,49 @@ | ||
/* | ||
* SPDX-License-Identifier: Apache-2.0 | ||
*/ | ||
|
||
//====------ ONNXToTOSACommon.hpp - ONNX dialects to TOSA lowering --------===// | ||
// | ||
// Copyright (c) 2022 Advanced Micro Devices, Inc. | ||
// | ||
// ============================================================================= | ||
// | ||
// This file contains common code shared by the functions performing the | ||
// lowering to the TOSA dialect. | ||
// | ||
//===----------------------------------------------------------------------===// | ||
|
||
#include "mlir/Dialect/Quant/QuantTypes.h" | ||
#include "mlir/Dialect/Tosa/IR/TosaOps.h" | ||
|
||
#include "mlir/IR/MLIRContext.h" | ||
#include "mlir/Pass/Pass.h" | ||
#include "mlir/Transforms/GreedyPatternRewriteDriver.h" | ||
|
||
#include "src/Dialect/ONNX/DialectBuilder.hpp" | ||
#include "src/Dialect/ONNX/ONNXOps.hpp" | ||
#include "src/Dialect/ONNX/ONNXOpsHelper.hpp" | ||
#include "src/Pass/Passes.hpp" | ||
#include "src/Transform/ONNX/ConstPropHelper.hpp" | ||
|
||
//===----------------------------------------------------------------------===// | ||
// Functions to add lowering patterns for frontend operations. | ||
//===----------------------------------------------------------------------===// | ||
|
||
namespace onnx_mlir { | ||
|
||
//===----------------------------------------------------------------------===// | ||
// This is to get a TOSA operation of a given type for a specific operation. | ||
//===----------------------------------------------------------------------===// | ||
template <typename ONNXOp> | ||
struct TOSADialectOp { | ||
using Op = void; | ||
}; | ||
|
||
template <typename Op> | ||
using TOSAOp = typename TOSADialectOp<Op>::Op; | ||
|
||
// `Math` directory methods: | ||
void populateLoweringONNXElementwiseOpToTOSAPattern( | ||
ConversionTarget &, RewritePatternSet &, TypeConverter &, MLIRContext *); | ||
} // namespace onnx_mlir |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,37 @@ | ||
// RUN: onnx-mlir-opt --shape-inference --convert-onnx-to-tosa %s -split-input-file | FileCheck %s | ||
|
||
func.func @test_relu(%arg0 : tensor<10x10xf32>) -> tensor<10x10xf32> { | ||
%0 = "onnx.Relu"(%arg0) : (tensor<10x10xf32>) -> tensor<10x10xf32> | ||
"func.return"(%0) : (tensor<10x10xf32>) -> () | ||
// CHECK-LABEL: func @test_relu | ||
// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<10x10xf32>) -> tensor<10x10xf32> { | ||
// CHECK-NEXT: [[VAR_0_:%.+]] = "tosa.clamp"([[PARAM_0_]]) {max_fp = 3.40282347E+38 : f32, max_int = 2147483647 : i64, min_fp = 0.000000e+00 : f32, min_int = 0 : i64} : (tensor<10x10xf32>) -> tensor<10x10xf32> | ||
// CHECK-NEXT: return [[VAR_0_]] : tensor<10x10xf32> | ||
// CHECK-NEXT: } | ||
} | ||
|
||
func.func @test_relu_dynamic(%arg0 : tensor<?x10xf32>) -> tensor<*xf32> { | ||
%0 = "onnx.Relu"(%arg0) : (tensor<?x10xf32>) -> tensor<*xf32> | ||
"func.return"(%0) : (tensor<*xf32>) -> () | ||
// CHECK-LABEL: func @test_relu_dynamic | ||
// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<?x10xf32>) -> tensor<?x10xf32> { | ||
// CHECK-NEXT: [[VAR_0_:%.+]] = "tosa.clamp"([[PARAM_0_]]) {max_fp = 3.40282347E+38 : f32, max_int = 2147483647 : i64, min_fp = 0.000000e+00 : f32, min_int = 0 : i64} : (tensor<?x10xf32>) -> tensor<?x10xf32> | ||
// CHECK-NEXT: return [[VAR_0_]] : tensor<?x10xf32> | ||
// CHECK-NEXT: } | ||
} | ||
|
||
func.func @test_neg(%arg0: tensor<10x10xf32>) -> tensor<10x10xf32> { | ||
%0 = "onnx.Neg"(%arg0) : (tensor<10x10xf32>) -> tensor<10x10xf32> | ||
"func.return"(%0) : (tensor<10x10xf32>) -> () | ||
// CHECK-LABEL: func @test_neg | ||
// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<10x10xf32>) -> tensor<10x10xf32> { | ||
// CHECK-NEXT: [[VAR_0_:%.+]] = "tosa.negate"([[PARAM_0_]]) : (tensor<10x10xf32>) -> tensor<10x10xf32> | ||
} | ||
|
||
func.func @test_floor(%arg0: tensor<10x10xf32>) -> tensor<10x10xf32> { | ||
%0 = "onnx.Floor"(%arg0) : (tensor<10x10xf32>) -> tensor<10x10xf32> | ||
"func.return"(%0) : (tensor<10x10xf32>) -> () | ||
// CHECK-LABEL: func @test_floor | ||
// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<10x10xf32>) -> tensor<10x10xf32> { | ||
// CHECK-NEXT: [[VAR_0_:%.+]] = "tosa.floor"([[PARAM_0_]]) : (tensor<10x10xf32>) -> tensor<10x10xf32> | ||
} |
This file was deleted.
Oops, something went wrong.