Skip to content

Commit

Permalink
Merge pull request #103 from Xilinx/feature/onnx-to-tosa
Browse files Browse the repository at this point in the history
Merge main to release
  • Loading branch information
mgehre-amd authored Jun 13, 2024
2 parents 8de2a7c + 67cb3a6 commit aa3d84b
Show file tree
Hide file tree
Showing 17 changed files with 584 additions and 329 deletions.
1 change: 1 addition & 0 deletions src/Conversion/ONNXToTOSA/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ add_onnx_mlir_library(OMONNXToTOSA
Tensor/Resize.cpp
Tensor/Shrink.cpp
Tensor/Slice.cpp
Tensor/Split.cpp
Tensor/Squeeze.cpp
Tensor/Tile.cpp
Tensor/Transpose.cpp
Expand Down
2 changes: 2 additions & 0 deletions src/Conversion/ONNXToTOSA/ConvertONNXToTOSA.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,8 @@ void populateONNXToTOSAConversionPattern(ConversionTarget &target,
populateLoweringONNXPadOpToTOSAPattern(target, patterns, typeConverter, ctx);
populateLoweringONNXSliceOpToTOSAPattern(
target, patterns, typeConverter, ctx);
populateLoweringONNXSplitOpToTOSAPattern(
target, patterns, typeConverter, ctx);
populateLoweringONNXSqueezeOpToTOSAPattern(
target, patterns, typeConverter, ctx);
populateLoweringONNXTileOpToTOSAPattern(target, patterns, typeConverter, ctx);
Expand Down
71 changes: 18 additions & 53 deletions src/Conversion/ONNXToTOSA/Math/Elementwise.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ struct IsBool {
}
};

template <typename OpAdaptorT, typename TypeChecker>
template <typename OpAdaptorT, typename TypeChecker, typename TosaOpT>
LogicalResult checkBasicTosaRequirementsForBinaryOps(
ConversionPatternRewriter &rewriter, Operation *op, OpAdaptorT adaptor,
Type resultType) {
Expand All @@ -92,50 +92,22 @@ LogicalResult checkBasicTosaRequirementsForBinaryOps(

Type resultElementType = resultTensorType.getElementType();

if (TosaOpT::template hasTrait<
::mlir::OpTrait::SameOperandsAndResultElementType>()) {
if (lhsType.getElementType() != rhsType.getElementType() ||
lhsType.getElementType() != resultElementType) {
return rewriter.notifyMatchFailure(
op, "lhs, rhs and result must have the same type");
}
}

if (failed(TypeChecker::checkType(rewriter, resultElementType, op))) {
return failure();
}

return success();
}

// Element-wise unary ops lowering to custom op of TOSA dialect.
//===----------------------------------------------------------------------===//
template <typename ONNXOp>
class ConvertONNXUnaryOpToTosaCustomOp : public OpConversionPattern<ONNXOp> {
public:
using OpConversionPattern<ONNXOp>::OpConversionPattern;
using OpAdaptor = typename ONNXOp::Adaptor;

ConvertONNXUnaryOpToTosaCustomOp(TypeConverter &typeConverter,
MLIRContext *context, std::string opName,
std::string implementedWithOpAttr = "UNDEF")
: OpConversionPattern<ONNXOp>(typeConverter, context),
opName(std::move(opName)),
implementedWithOpAttr(std::move(implementedWithOpAttr)) {}

LogicalResult matchAndRewrite(ONNXOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {

// Set tosa.custom_op attributes.
// Only identifier needs to be known. Other attributes are not used.
auto *ctx = op->getContext();
auto identifier = StringAttr::get(ctx, opName);
auto implementAttr = StringAttr::get(ctx, implementedWithOpAttr);
auto config = StringAttr::get(ctx, "UNDEF");

rewriter.replaceOpWithNewOp<mlir::tosa::CustomOp>(op,
TypeRange{OpConversionPattern<ONNXOp>::getTypeConverter()->convertType(
op.getType())},
identifier, config, implementAttr, adaptor.getOperands());
return success();
}

private:
std::string opName;
std::string implementedWithOpAttr;
};

// Element-wise unary ops lowering to TOSA dialect.
//===----------------------------------------------------------------------===//
template <typename ElementwiseUnaryOpONNX, typename ElementwiseUnaryOpTOSA,
Expand Down Expand Up @@ -181,8 +153,8 @@ class ONNXBinaryElementwiseOpLoweringToTOSA
LogicalResult matchAndRewrite(ONNXOpT op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {

if (failed(checkBasicTosaRequirementsForBinaryOps<OpAdaptor, TypeChecker>(
rewriter, op, adaptor, op.getResult().getType())))
if (failed(checkBasicTosaRequirementsForBinaryOps<OpAdaptor, TypeChecker,
TosaOpT>(rewriter, op, adaptor, op.getResult().getType())))
return failure();

auto loc = op.getLoc();
Expand Down Expand Up @@ -216,7 +188,8 @@ class ONNXMulOpLoweringToTosa : public OpConversionPattern<ONNXMulOp> {
using OpConversionPattern::OpConversionPattern;
LogicalResult matchAndRewrite(ONNXMulOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
if (failed(checkBasicTosaRequirementsForBinaryOps<OpAdaptor, IsIntOrFloat>(
if (failed(checkBasicTosaRequirementsForBinaryOps<OpAdaptor, IsIntOrFloat,
mlir::tosa::MulOp>(
rewriter, op, adaptor, op.getResult().getType())))
return failure();

Expand Down Expand Up @@ -720,19 +693,11 @@ static void populateLoweringONNXElementwiseUnaryTemplateOpToTOSAPattern(
ONNXElementwiseUnaryOpLoweringToTOSA<ONNXAbsOp, mlir::tosa::AbsOp,
IsIntOrFloat, IsIntOrFloat>,
ONNXElementwiseUnaryOpLoweringToTOSA<ONNXErfOp, mlir::tosa::ErfOp,
IsFloat, IsFloat>,
ONNXElementwiseUnaryOpLoweringToTOSA<ONNXSinOp, mlir::tosa::SinOp,
IsFloat, IsFloat>,
ONNXElementwiseUnaryOpLoweringToTOSA<ONNXCosOp, mlir::tosa::CosOp,
IsFloat, IsFloat>>(typeConverter, ctx);

// Tosa custom ops
#define INSERT_ONNX_UNARY_TO_TOSA_CUSTOMOP_PATTERN( \
ONNXOp, opName, implementedWith) \
patterns.add<ConvertONNXUnaryOpToTosaCustomOp<ONNXOp>>( \
typeConverter, ctx, opName, implementedWith);

INSERT_ONNX_UNARY_TO_TOSA_CUSTOMOP_PATTERN(
ONNXSinOp, "math.sin", "linalg.generic");
INSERT_ONNX_UNARY_TO_TOSA_CUSTOMOP_PATTERN(
ONNXCosOp, "math.cos", "linalg.generic");
#undef INSERT_ONNX_UNARY_TO_TOSA_CUSTOMOP_PATTERN
}

void populateLoweringONNXElementwiseOpToTOSAPattern(ConversionTarget &target,
Expand Down
19 changes: 13 additions & 6 deletions src/Conversion/ONNXToTOSA/NN/DequantizeLinear.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -73,17 +73,24 @@ class ONNXDequantizeLinearOpLoweringToTOSA

// Dequantization formula is (x - zero_point) * scale
// Cast into the destination type first

// Cast the operands of (x - zero_point) to float32 to avoid underflows
Type arithType = rewriter.getF32Type();
Value subOpA = tosaBuilder.castToNewTensorElementType(x, arithType);
Value subOpB = tosaBuilder.castToNewTensorElementType(zpConst, arithType);
Value subOp = tosa::CreateOpAndInfer<mlir::tosa::SubOp>(
rewriter, loc, x.getType(), x, zpConst)
rewriter, loc, subOpA.getType(), subOpA, subOpB)
.getResult();
Value castOp = tosa::CreateOpAndInfer<mlir::tosa::CastOp>(
rewriter, loc, resultType, subOp)
.getResult();
// There are no guarantees about the bitwith of the scale factor
Value scaleFactorCast =
tosaBuilder.castToNewTensorElementType(scaleFactorConst, arithType);
Value mulOp = tosa::CreateOpAndInfer<mlir::tosa::MulOp>(
rewriter, loc, resultType, castOp, scaleFactorConst, 0)
rewriter, loc, subOp.getType(), subOp, scaleFactorCast, 0)
.getResult();
Value castOp = tosaBuilder.castToNewTensorElementType(
mulOp, resultType.getElementType());

rewriter.replaceOp(op, mulOp);
rewriter.replaceOp(op, castOp);
return success();
}
};
Expand Down
20 changes: 13 additions & 7 deletions src/Conversion/ONNXToTOSA/NN/QuantizeLinear.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,6 @@ class ONNXQuantizeLinearOpLoweringToTOSA
LogicalResult matchAndRewrite(ONNXQuantizeLinearOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
Location loc = op->getLoc();
Value x = op.getX();
Type xType = x.getType();
auto resultType = dyn_cast_if_present<ShapedType>(
getTypeConverter()->convertType(op.getResult().getType()));
if (!resultType || !resultType.hasStaticShape()) {
Expand Down Expand Up @@ -78,6 +76,9 @@ class ONNXQuantizeLinearOpLoweringToTOSA
auto scaleFactorConst = tosa::expandShape(
rewriter, loc, adaptor.getYScale(), axis, resultType.getRank());

Value x = adaptor.getX();
Type xType = x.getType();

// Quantization formula is ((x / y_scale) + y_zero_point)
// Replace the division by a reciprocal followed by a mul
Value recOp = tosa::CreateOpAndInfer<mlir::tosa::ReciprocalOp>(
Expand All @@ -86,15 +87,20 @@ class ONNXQuantizeLinearOpLoweringToTOSA
Value mulOp = tosa::CreateOpAndInfer<mlir::tosa::MulOp>(
rewriter, loc, xType, x, recOp, 0)
.getResult();
// Cast into the result type
Value castOp = tosa::CreateOpAndInfer<mlir::tosa::CastOp>(
rewriter, loc, resultType, mulOp)
// zpConst has the same type as the result of QLinear which is always
// smaller than the input type. Cast it to the input type.
Value castZp = tosa::CreateOpAndInfer<mlir::tosa::CastOp>(
rewriter, loc, scaleFactorConst.getType(), zpConst)
.getResult();
Value addOp = tosa::CreateOpAndInfer<mlir::tosa::AddOp>(
rewriter, loc, resultType, castOp, zpConst)
rewriter, loc, xType, mulOp, castZp)
.getResult();
// Cast into the result type
Value castOp = tosa::CreateOpAndInfer<mlir::tosa::CastOp>(
rewriter, loc, resultType, addOp)
.getResult();

rewriter.replaceOp(op, addOp);
rewriter.replaceOp(op, castOp);
return success();
}
};
Expand Down
2 changes: 2 additions & 0 deletions src/Conversion/ONNXToTOSA/ONNXToTOSACommon.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,8 @@ void populateLoweringONNXFlattenOpToTOSAPattern(mlir::ConversionTarget &,
mlir::RewritePatternSet &, mlir::TypeConverter &, mlir::MLIRContext *);
void populateLoweringONNXSliceOpToTOSAPattern(mlir::ConversionTarget &,
mlir::RewritePatternSet &, mlir::TypeConverter &, mlir::MLIRContext *);
void populateLoweringONNXSplitOpToTOSAPattern(mlir::ConversionTarget &,
mlir::RewritePatternSet &, mlir::TypeConverter &, mlir::MLIRContext *);
void populateLoweringONNXSqueezeOpToTOSAPattern(mlir::ConversionTarget &,
mlir::RewritePatternSet &, mlir::TypeConverter &, mlir::MLIRContext *);
void populateLoweringONNXTileOpToTOSAPattern(mlir::ConversionTarget &,
Expand Down
78 changes: 78 additions & 0 deletions src/Conversion/ONNXToTOSA/Tensor/Split.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
/*
* SPDX-License-Identifier: Apache-2.0
*/

//===------------- Split.cpp - Split Op---------===//
//
// Copyright (c) 2023 Advanced Micro Devices, Inc.
//
// =============================================================================
//
// This file lowers ONNX SplitOp operator to TOSA dialect.
//
//===----------------------------------------------------------------------===//

#include "src/Conversion/ONNXToTOSA/DialectBuilder.hpp"
#include "src/Conversion/ONNXToTOSA/ONNXToTOSACommon.hpp"
#include "src/Conversion/ONNXToTOSA/ONNXToTOSALegalizeUtils.hpp"
#include "src/Dialect/ONNX/ONNXOps.hpp"

using namespace mlir;
namespace onnx_mlir {
namespace {
class ONNXSplitOpLoweringToTOSA : public OpConversionPattern<ONNXSplitOp> {
public:
using OpConversionPattern::OpConversionPattern;
using OpAdaptor = typename ONNXSplitOp::Adaptor;
LogicalResult matchAndRewrite(ONNXSplitOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
Value input = adaptor.getInput();
ShapedType inputType = cast<ShapedType>(input.getType());

// tosa.slice does not allow a dynamic entry in the size attribute
if (!hasStaticShape(inputType))
return rewriter.notifyMatchFailure(
op, "only static shapes are supported");

uint64_t rank = inputType.getRank();
int64_t splitAxis = adaptor.getAxis();
if (splitAxis < 0)
splitAxis += rank;

IndexExprBuilderForTosa createTosaIE(rewriter, op->getLoc());
ONNXSplitOpShapeHelper shapeHelper(
op, adaptor.getOperands(), &createTosaIE);

// compute shape
if (failed(shapeHelper.computeShape()))
return rewriter.notifyMatchFailure(op, "could not compute shape.");

TosaBuilder tosaBuilder(rewriter, op->getLoc());
uint64_t outputNum = op.getNumResults();
SmallVector<Value, 4> slices;
slices.reserve(outputNum);

llvm::SmallVector<int64_t, 4> size;
llvm::SmallVector<int64_t, 4> starts(rank, 0);
int64_t start = 0;

for (uint64_t i = 0; i < outputNum; i++) {
DimsExpr outputDim = shapeHelper.getOutputDims(i);
IndexExpr::getShape(outputDim, size);
starts[splitAxis] = start;
slices.push_back(tosaBuilder.slice(input, size, starts));
start += size[splitAxis];
}
rewriter.replaceOp(op, slices);
return success();
}
};
} // namespace

void populateLoweringONNXSplitOpToTOSAPattern(ConversionTarget &target,
RewritePatternSet &patterns, TypeConverter &typeConverter,
MLIRContext *ctx) {
patterns.insert<ONNXSplitOpLoweringToTOSA>(typeConverter, ctx);
}

} // namespace onnx_mlir
5 changes: 5 additions & 0 deletions src/Dialect/ONNX/ONNX.td
Original file line number Diff line number Diff line change
Expand Up @@ -232,6 +232,11 @@ def ONNXConstantOpFromDenseAttr: NativeCodeCall<
class ONNX_Op<string mnemonic, list<Trait> traits = []> :
Op<ONNX_Dialect, mnemonic, traits> ;

// Trait to specify which operation set introduced a revision of an operator.
// For multi-versioned operators, the version also appears in the operator's name.
class OpVersionTrait<int version>
: ParamNativeOpTrait<"OpVersionTrait", !cast<string>(version)>;

// The tablegen code onnxop.in is generated with gen_doc.py
// clone and install onnx
// git clone --recursive https://github.com/onnx/onnx.git
Expand Down
1 change: 1 addition & 0 deletions src/Dialect/ONNX/ONNXOps.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
#include "src/Dialect/ONNX/ONNXAttributes.hpp"
#include "src/Dialect/ONNX/ONNXDialect.hpp"
#include "src/Dialect/ONNX/ONNXOps/ShapeHelper.hpp"
#include "src/Dialect/ONNX/ONNXTraits.hpp"
#include "src/Dialect/ONNX/ONNXTypes.hpp"
#include "src/Interface/HasOnnxSubgraphOpInterface.hpp"
#include "src/Interface/ResultTypeInferenceOpInterface.hpp"
Expand Down
Loading

0 comments on commit aa3d84b

Please sign in to comment.