forked from onnx/onnx-mlir
-
Notifications
You must be signed in to change notification settings - Fork 2
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[TOSA] Add ONNX to TOSA ArgMax conversion pass #45
Open
ghost
wants to merge
14
commits into
feature/onnx_to_torch
Choose a base branch
from
philippb.tosa_argmax
base: feature/onnx_to_torch
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Changes from all commits
Commits
Show all changes
14 commits
Select commit
Hold shift + click to select a range
a5b801e
Update type converter for ONNXToTOSA conversion passes
733cf9c
Add unary op conversion passes
f9b4e48
Add ONNXToTOSA unary lit tests
fe8b4dc
Add ONNX to TOSA ArgMax conversion pass
c070a5e
Update ONNXToTOSA unary ops lit test
b1bb575
Merge remote-tracking branch 'origin/philippb.refactor_torch_to_tosa'…
c10d29b
Update ArgMax title
e86ace9
Mark unary ONNX ops illegal & fail on quantized types
de1f661
Merge unary op changes
dd933f5
Align TOSA lit tests
c4cd0f4
Merge remote-tracking branch 'origin/philippb.refactor_torch_to_tosa'…
73bbeac
Remove quantization check
5064875
Merge branch 'philippb.refactor_torch_to_tosa' into philippb.tosa_argmax
105cdc2
Update comments
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,71 @@ | ||
/* | ||
* SPDX-License-Identifier: Apache-2.0 | ||
*/ | ||
|
||
//===---------------- Elementwise.cpp - Elementwise Op --------------------===// | ||
// | ||
// Copyright (c) 2022 Advanced Micro Devices, Inc. | ||
// | ||
// ============================================================================= | ||
// | ||
// This file lowers ONNX element-wise operators to TOSA dialect. | ||
// | ||
//===----------------------------------------------------------------------===// | ||
|
||
#include "src/Conversion/ONNXToTOSA/ONNXToTOSACommon.hpp" | ||
|
||
using namespace mlir; | ||
|
||
namespace onnx_mlir { | ||
|
||
namespace { | ||
|
||
template <typename ONNXOpT, typename TOSAOpT> | ||
class ONNXUnaryOpLoweringToTOSA : public OpConversionPattern<ONNXOpT> { | ||
public: | ||
using OpConversionPattern<ONNXOpT>::OpConversionPattern; | ||
using OpAdaptor = typename ONNXOpT::Adaptor; | ||
LogicalResult matchAndRewrite(ONNXOpT op, OpAdaptor adaptor, | ||
ConversionPatternRewriter &rewriter) const override { | ||
rewriter.replaceOpWithNewOp<TOSAOpT>(op, op.getType(), adaptor.X()); | ||
return success(); | ||
} | ||
}; | ||
|
||
class ONNXReluOpLoweringToTOSA : public OpConversionPattern<ONNXReluOp> { | ||
public: | ||
using OpConversionPattern::OpConversionPattern; | ||
LogicalResult matchAndRewrite(ONNXReluOp op, OpAdaptor adaptor, | ||
ConversionPatternRewriter &rewriter) const override { | ||
|
||
Value input = adaptor.X(); | ||
|
||
// Quantized types are not supported right now (in type conversion). | ||
// Once they are, the input should be rescaled for quantized types. (TBD) | ||
// Maps to `tosa.clamp` which has both int and fp limits. | ||
rewriter.replaceOpWithNewOp<tosa::ClampOp>(op, op.getType(), input, | ||
rewriter.getI64IntegerAttr(0), | ||
rewriter.getI64IntegerAttr(std::numeric_limits<int32_t>::max()), | ||
rewriter.getF32FloatAttr(0.0f), | ||
rewriter.getF32FloatAttr(std::numeric_limits<float>::max())); | ||
return success(); | ||
} | ||
}; | ||
|
||
} // namespace | ||
|
||
void populateLoweringONNXElementwiseOpToTOSAPattern(ConversionTarget &target, | ||
RewritePatternSet &patterns, TypeConverter &typeConverter, | ||
MLIRContext *ctx) { | ||
patterns.insert<ONNXReluOpLoweringToTOSA>(typeConverter, ctx); | ||
|
||
#define INSERT_UNARY_PATTERN(ONNXOp, TOSAOp) \ | ||
target.addIllegalOp<ONNXOp>(); \ | ||
patterns.insert<ONNXUnaryOpLoweringToTOSA<ONNXOp, TOSAOp>>( \ | ||
typeConverter, ctx); | ||
INSERT_UNARY_PATTERN(ONNXNegOp, tosa::NegateOp) | ||
INSERT_UNARY_PATTERN(ONNXFloorOp, tosa::FloorOp) | ||
#undef INSERT_UNARY_PATTERN | ||
} | ||
|
||
} // namespace onnx_mlir |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,53 @@ | ||
/* | ||
* SPDX-License-Identifier: Apache-2.0 | ||
*/ | ||
|
||
//====------ ONNXToTOSACommon.hpp - ONNX dialects to TOSA lowering --------===// | ||
// | ||
// Copyright (c) 2022 Advanced Micro Devices, Inc. | ||
// | ||
// ============================================================================= | ||
// | ||
// This file contains common code shared by the functions performing the | ||
// lowering to the TOSA dialect. | ||
// | ||
//===----------------------------------------------------------------------===// | ||
|
||
#include "mlir/Dialect/Quant/QuantTypes.h" | ||
#include "mlir/Dialect/Tosa/IR/TosaOps.h" | ||
|
||
#include "mlir/IR/MLIRContext.h" | ||
#include "mlir/Pass/Pass.h" | ||
#include "mlir/Transforms/GreedyPatternRewriteDriver.h" | ||
|
||
#include "src/Dialect/ONNX/DialectBuilder.hpp" | ||
#include "src/Dialect/ONNX/ONNXOps.hpp" | ||
#include "src/Dialect/ONNX/ONNXOpsHelper.hpp" | ||
#include "src/Pass/Passes.hpp" | ||
#include "src/Transform/ONNX/ConstPropHelper.hpp" | ||
|
||
//===----------------------------------------------------------------------===// | ||
// Functions to add lowering patterns for frontend operations. | ||
//===----------------------------------------------------------------------===// | ||
|
||
namespace onnx_mlir { | ||
|
||
//===----------------------------------------------------------------------===// | ||
// This is to get a TOSA operation of a given type for a specific operation. | ||
//===----------------------------------------------------------------------===// | ||
template <typename ONNXOp> | ||
struct TOSADialectOp { | ||
using Op = void; | ||
}; | ||
|
||
template <typename Op> | ||
using TOSAOp = typename TOSADialectOp<Op>::Op; | ||
|
||
// `Math` directory methods: | ||
void populateLoweringONNXElementwiseOpToTOSAPattern( | ||
ConversionTarget &, RewritePatternSet &, TypeConverter &, MLIRContext *); | ||
// `Tensor` directory methods: | ||
void populateLoweringONNXArgMaxOpToTOSAPattern( | ||
RewritePatternSet &, TypeConverter &, MLIRContext *); | ||
|
||
} // namespace onnx_mlir |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,50 @@ | ||
/* | ||
* SPDX-License-Identifier: Apache-2.0 | ||
*/ | ||
|
||
//===------------------- ArgMax.cpp - ArgMax Op ---------------------------===// | ||
// | ||
// Copyright (c) 2022 Advanced Micro Devices, Inc. | ||
// | ||
// ============================================================================= | ||
// | ||
// This file lowers ONNX ArgMax operator to TOSA dialect. | ||
// | ||
//===----------------------------------------------------------------------===// | ||
|
||
#include "src/Conversion/ONNXToTOSA/ONNXToTOSACommon.hpp" | ||
|
||
using namespace mlir; | ||
|
||
namespace onnx_mlir { | ||
|
||
namespace { | ||
|
||
class ONNXArgMaxOpLoweringToTOSA : public OpConversionPattern<ONNXArgMaxOp> { | ||
public: | ||
using OpConversionPattern::OpConversionPattern; | ||
LogicalResult matchAndRewrite(ONNXArgMaxOp op, OpAdaptor adaptor, | ||
ConversionPatternRewriter &rewriter) const override { | ||
|
||
if (adaptor.keepdims() != 1) | ||
return rewriter.notifyMatchFailure(op, "keepdims != 1 is not supported"); | ||
|
||
if (adaptor.select_last_index() != 0) | ||
return rewriter.notifyMatchFailure( | ||
op, "select_last_index != 0 is not supported"); | ||
|
||
IntegerAttr axis = rewriter.getI64IntegerAttr(adaptor.axis()); | ||
rewriter.replaceOpWithNewOp<tosa::ArgMaxOp>( | ||
op, op.getType(), adaptor.data(), axis); | ||
return success(); | ||
} | ||
}; | ||
|
||
} // namespace | ||
|
||
void populateLoweringONNXArgMaxOpToTOSAPattern(RewritePatternSet &patterns, | ||
TypeConverter &typeConverter, MLIRContext *ctx) { | ||
patterns.insert<ONNXArgMaxOpLoweringToTOSA>(typeConverter, ctx); | ||
} | ||
|
||
} // namespace onnx_mlir |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,37 @@ | ||
// RUN: onnx-mlir-opt --shape-inference --convert-onnx-to-tosa %s -split-input-file | FileCheck %s | ||
|
||
func.func @test_relu(%arg0 : tensor<10x10xf32>) -> tensor<10x10xf32> { | ||
%0 = "onnx.Relu"(%arg0) : (tensor<10x10xf32>) -> tensor<10x10xf32> | ||
"func.return"(%0) : (tensor<10x10xf32>) -> () | ||
// CHECK-LABEL: func @test_relu | ||
// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<10x10xf32>) -> tensor<10x10xf32> { | ||
// CHECK-NEXT: [[VAR_0_:%.+]] = "tosa.clamp"([[PARAM_0_]]) {max_fp = 3.40282347E+38 : f32, max_int = 2147483647 : i64, min_fp = 0.000000e+00 : f32, min_int = 0 : i64} : (tensor<10x10xf32>) -> tensor<10x10xf32> | ||
// CHECK-NEXT: return [[VAR_0_]] : tensor<10x10xf32> | ||
// CHECK-NEXT: } | ||
} | ||
|
||
func.func @test_relu_dynamic(%arg0 : tensor<?x10xf32>) -> tensor<*xf32> { | ||
%0 = "onnx.Relu"(%arg0) : (tensor<?x10xf32>) -> tensor<*xf32> | ||
"func.return"(%0) : (tensor<*xf32>) -> () | ||
// CHECK-LABEL: func @test_relu_dynamic | ||
// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<?x10xf32>) -> tensor<?x10xf32> { | ||
// CHECK-NEXT: [[VAR_0_:%.+]] = "tosa.clamp"([[PARAM_0_]]) {max_fp = 3.40282347E+38 : f32, max_int = 2147483647 : i64, min_fp = 0.000000e+00 : f32, min_int = 0 : i64} : (tensor<?x10xf32>) -> tensor<?x10xf32> | ||
// CHECK-NEXT: return [[VAR_0_]] : tensor<?x10xf32> | ||
// CHECK-NEXT: } | ||
} | ||
|
||
func.func @test_neg(%arg0: tensor<10x10xf32>) -> tensor<10x10xf32> { | ||
%0 = "onnx.Neg"(%arg0) : (tensor<10x10xf32>) -> tensor<10x10xf32> | ||
"func.return"(%0) : (tensor<10x10xf32>) -> () | ||
// CHECK-LABEL: func @test_neg | ||
// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<10x10xf32>) -> tensor<10x10xf32> { | ||
// CHECK-NEXT: [[VAR_0_:%.+]] = "tosa.negate"([[PARAM_0_]]) : (tensor<10x10xf32>) -> tensor<10x10xf32> | ||
} | ||
|
||
func.func @test_floor(%arg0: tensor<10x10xf32>) -> tensor<10x10xf32> { | ||
%0 = "onnx.Floor"(%arg0) : (tensor<10x10xf32>) -> tensor<10x10xf32> | ||
"func.return"(%0) : (tensor<10x10xf32>) -> () | ||
// CHECK-LABEL: func @test_floor | ||
// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<10x10xf32>) -> tensor<10x10xf32> { | ||
// CHECK-NEXT: [[VAR_0_:%.+]] = "tosa.floor"([[PARAM_0_]]) : (tensor<10x10xf32>) -> tensor<10x10xf32> | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,11 @@ | ||
// RUN: onnx-mlir-opt --convert-onnx-to-tosa %s -split-input-file | FileCheck %s | ||
|
||
func.func @test_argmax(%arg0: tensor<8x16x32xf32>) -> tensor<8x16x32xi64> { | ||
%0 = "onnx.ArgMax"(%arg0) {axis = 2 : si64, keepdims = 1 : si64, onnx_node_name = "ArgMax_0"} : (tensor<8x16x32xf32>) -> tensor<8x16x32xi64> | ||
return %0 : tensor<8x16x32xi64> | ||
// CHECK-LABEL: func @test_argmax | ||
// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<8x16x32xf32>) -> tensor<8x16x32xi64> { | ||
// CHECK-NEXT: [[VAR_0_:%.+]] = "tosa.argmax"([[PARAM_0_]]) {axis = 2 : i64} : (tensor<8x16x32xf32>) -> tensor<8x16x32xi64> | ||
// CHECK-NEXT: return [[VAR_0_]] : tensor<8x16x32xi64> | ||
// CHECK-NEXT: } | ||
} |
This file was deleted.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Wondering if this should be named
isTosaSignedInt
- I know it is in a convert-to-tosa file, but I still had to read the contents carefully to understand it was limiting to TOSA supported types. SImilar comment forisFloat
.Also wondering if there is any way to reuse
Tosa_SignedInt
frommlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td
directly? (I'm guessing not, but wondered if you knew for certain.)There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah, there isn't. They turn into anonymous constraint functions in the .cpp files where they are used. That's the buf thing that my concepts proposal was to solve.