diff --git a/src/Conversion/ONNXToTOSA/CMakeLists.txt b/src/Conversion/ONNXToTOSA/CMakeLists.txt index 80e61c5038..f0082e59f2 100644 --- a/src/Conversion/ONNXToTOSA/CMakeLists.txt +++ b/src/Conversion/ONNXToTOSA/CMakeLists.txt @@ -28,6 +28,7 @@ add_onnx_mlir_library(OMONNXToTOSA Tensor/Resize.cpp Tensor/Shrink.cpp Tensor/Slice.cpp + Tensor/Split.cpp Tensor/Squeeze.cpp Tensor/Tile.cpp Tensor/Transpose.cpp diff --git a/src/Conversion/ONNXToTOSA/ConvertONNXToTOSA.cpp b/src/Conversion/ONNXToTOSA/ConvertONNXToTOSA.cpp index 19396270db..f0f7ffaa9d 100644 --- a/src/Conversion/ONNXToTOSA/ConvertONNXToTOSA.cpp +++ b/src/Conversion/ONNXToTOSA/ConvertONNXToTOSA.cpp @@ -55,6 +55,8 @@ void populateONNXToTOSAConversionPattern(ConversionTarget &target, populateLoweringONNXPadOpToTOSAPattern(target, patterns, typeConverter, ctx); populateLoweringONNXSliceOpToTOSAPattern( target, patterns, typeConverter, ctx); + populateLoweringONNXSplitOpToTOSAPattern( + target, patterns, typeConverter, ctx); populateLoweringONNXSqueezeOpToTOSAPattern( target, patterns, typeConverter, ctx); populateLoweringONNXTileOpToTOSAPattern(target, patterns, typeConverter, ctx); diff --git a/src/Conversion/ONNXToTOSA/Math/Elementwise.cpp b/src/Conversion/ONNXToTOSA/Math/Elementwise.cpp index 52ab249930..c78aa99ae5 100644 --- a/src/Conversion/ONNXToTOSA/Math/Elementwise.cpp +++ b/src/Conversion/ONNXToTOSA/Math/Elementwise.cpp @@ -75,7 +75,7 @@ struct IsBool { } }; -template +template LogicalResult checkBasicTosaRequirementsForBinaryOps( ConversionPatternRewriter &rewriter, Operation *op, OpAdaptorT adaptor, Type resultType) { @@ -92,6 +92,15 @@ LogicalResult checkBasicTosaRequirementsForBinaryOps( Type resultElementType = resultTensorType.getElementType(); + if (TosaOpT::template hasTrait< + ::mlir::OpTrait::SameOperandsAndResultElementType>()) { + if (lhsType.getElementType() != rhsType.getElementType() || + lhsType.getElementType() != resultElementType) { + return rewriter.notifyMatchFailure( + op, "lhs, rhs and result must have the same type"); + } + } + if (failed(TypeChecker::checkType(rewriter, resultElementType, op))) { return failure(); } @@ -99,43 +108,6 @@ LogicalResult checkBasicTosaRequirementsForBinaryOps( return success(); } -// Element-wise unary ops lowering to custom op of TOSA dialect. -//===----------------------------------------------------------------------===// -template -class ConvertONNXUnaryOpToTosaCustomOp : public OpConversionPattern { -public: - using OpConversionPattern::OpConversionPattern; - using OpAdaptor = typename ONNXOp::Adaptor; - - ConvertONNXUnaryOpToTosaCustomOp(TypeConverter &typeConverter, - MLIRContext *context, std::string opName, - std::string implementedWithOpAttr = "UNDEF") - : OpConversionPattern(typeConverter, context), - opName(std::move(opName)), - implementedWithOpAttr(std::move(implementedWithOpAttr)) {} - - LogicalResult matchAndRewrite(ONNXOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - - // Set tosa.custom_op attributes. - // Only identifier needs to be known. Other attributes are not used. - auto *ctx = op->getContext(); - auto identifier = StringAttr::get(ctx, opName); - auto implementAttr = StringAttr::get(ctx, implementedWithOpAttr); - auto config = StringAttr::get(ctx, "UNDEF"); - - rewriter.replaceOpWithNewOp(op, - TypeRange{OpConversionPattern::getTypeConverter()->convertType( - op.getType())}, - identifier, config, implementAttr, adaptor.getOperands()); - return success(); - } - -private: - std::string opName; - std::string implementedWithOpAttr; -}; - // Element-wise unary ops lowering to TOSA dialect. //===----------------------------------------------------------------------===// template ( - rewriter, op, adaptor, op.getResult().getType()))) + if (failed(checkBasicTosaRequirementsForBinaryOps(rewriter, op, adaptor, op.getResult().getType()))) return failure(); auto loc = op.getLoc(); @@ -216,7 +188,8 @@ class ONNXMulOpLoweringToTosa : public OpConversionPattern { using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(ONNXMulOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { - if (failed(checkBasicTosaRequirementsForBinaryOps( + if (failed(checkBasicTosaRequirementsForBinaryOps( rewriter, op, adaptor, op.getResult().getType()))) return failure(); @@ -720,19 +693,11 @@ static void populateLoweringONNXElementwiseUnaryTemplateOpToTOSAPattern( ONNXElementwiseUnaryOpLoweringToTOSA, ONNXElementwiseUnaryOpLoweringToTOSA, + ONNXElementwiseUnaryOpLoweringToTOSA, + ONNXElementwiseUnaryOpLoweringToTOSA>(typeConverter, ctx); - -// Tosa custom ops -#define INSERT_ONNX_UNARY_TO_TOSA_CUSTOMOP_PATTERN( \ - ONNXOp, opName, implementedWith) \ - patterns.add>( \ - typeConverter, ctx, opName, implementedWith); - - INSERT_ONNX_UNARY_TO_TOSA_CUSTOMOP_PATTERN( - ONNXSinOp, "math.sin", "linalg.generic"); - INSERT_ONNX_UNARY_TO_TOSA_CUSTOMOP_PATTERN( - ONNXCosOp, "math.cos", "linalg.generic"); -#undef INSERT_ONNX_UNARY_TO_TOSA_CUSTOMOP_PATTERN } void populateLoweringONNXElementwiseOpToTOSAPattern(ConversionTarget &target, diff --git a/src/Conversion/ONNXToTOSA/NN/DequantizeLinear.cpp b/src/Conversion/ONNXToTOSA/NN/DequantizeLinear.cpp index 26302f75a7..0897deb37c 100644 --- a/src/Conversion/ONNXToTOSA/NN/DequantizeLinear.cpp +++ b/src/Conversion/ONNXToTOSA/NN/DequantizeLinear.cpp @@ -73,17 +73,24 @@ class ONNXDequantizeLinearOpLoweringToTOSA // Dequantization formula is (x - zero_point) * scale // Cast into the destination type first + + // Cast the operands of (x - zero_point) to float32 to avoid underflows + Type arithType = rewriter.getF32Type(); + Value subOpA = tosaBuilder.castToNewTensorElementType(x, arithType); + Value subOpB = tosaBuilder.castToNewTensorElementType(zpConst, arithType); Value subOp = tosa::CreateOpAndInfer( - rewriter, loc, x.getType(), x, zpConst) + rewriter, loc, subOpA.getType(), subOpA, subOpB) .getResult(); - Value castOp = tosa::CreateOpAndInfer( - rewriter, loc, resultType, subOp) - .getResult(); + // There are no guarantees about the bitwith of the scale factor + Value scaleFactorCast = + tosaBuilder.castToNewTensorElementType(scaleFactorConst, arithType); Value mulOp = tosa::CreateOpAndInfer( - rewriter, loc, resultType, castOp, scaleFactorConst, 0) + rewriter, loc, subOp.getType(), subOp, scaleFactorCast, 0) .getResult(); + Value castOp = tosaBuilder.castToNewTensorElementType( + mulOp, resultType.getElementType()); - rewriter.replaceOp(op, mulOp); + rewriter.replaceOp(op, castOp); return success(); } }; diff --git a/src/Conversion/ONNXToTOSA/NN/QuantizeLinear.cpp b/src/Conversion/ONNXToTOSA/NN/QuantizeLinear.cpp index 1e90e984cf..acc6497483 100644 --- a/src/Conversion/ONNXToTOSA/NN/QuantizeLinear.cpp +++ b/src/Conversion/ONNXToTOSA/NN/QuantizeLinear.cpp @@ -35,8 +35,6 @@ class ONNXQuantizeLinearOpLoweringToTOSA LogicalResult matchAndRewrite(ONNXQuantizeLinearOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { Location loc = op->getLoc(); - Value x = op.getX(); - Type xType = x.getType(); auto resultType = dyn_cast_if_present( getTypeConverter()->convertType(op.getResult().getType())); if (!resultType || !resultType.hasStaticShape()) { @@ -78,6 +76,9 @@ class ONNXQuantizeLinearOpLoweringToTOSA auto scaleFactorConst = tosa::expandShape( rewriter, loc, adaptor.getYScale(), axis, resultType.getRank()); + Value x = adaptor.getX(); + Type xType = x.getType(); + // Quantization formula is ((x / y_scale) + y_zero_point) // Replace the division by a reciprocal followed by a mul Value recOp = tosa::CreateOpAndInfer( @@ -86,15 +87,20 @@ class ONNXQuantizeLinearOpLoweringToTOSA Value mulOp = tosa::CreateOpAndInfer( rewriter, loc, xType, x, recOp, 0) .getResult(); - // Cast into the result type - Value castOp = tosa::CreateOpAndInfer( - rewriter, loc, resultType, mulOp) + // zpConst has the same type as the result of QLinear which is always + // smaller than the input type. Cast it to the input type. + Value castZp = tosa::CreateOpAndInfer( + rewriter, loc, scaleFactorConst.getType(), zpConst) .getResult(); Value addOp = tosa::CreateOpAndInfer( - rewriter, loc, resultType, castOp, zpConst) + rewriter, loc, xType, mulOp, castZp) .getResult(); + // Cast into the result type + Value castOp = tosa::CreateOpAndInfer( + rewriter, loc, resultType, addOp) + .getResult(); - rewriter.replaceOp(op, addOp); + rewriter.replaceOp(op, castOp); return success(); } }; diff --git a/src/Conversion/ONNXToTOSA/ONNXToTOSACommon.hpp b/src/Conversion/ONNXToTOSA/ONNXToTOSACommon.hpp index 4c2f504b75..084ab4cf30 100644 --- a/src/Conversion/ONNXToTOSA/ONNXToTOSACommon.hpp +++ b/src/Conversion/ONNXToTOSA/ONNXToTOSACommon.hpp @@ -159,6 +159,8 @@ void populateLoweringONNXFlattenOpToTOSAPattern(mlir::ConversionTarget &, mlir::RewritePatternSet &, mlir::TypeConverter &, mlir::MLIRContext *); void populateLoweringONNXSliceOpToTOSAPattern(mlir::ConversionTarget &, mlir::RewritePatternSet &, mlir::TypeConverter &, mlir::MLIRContext *); +void populateLoweringONNXSplitOpToTOSAPattern(mlir::ConversionTarget &, + mlir::RewritePatternSet &, mlir::TypeConverter &, mlir::MLIRContext *); void populateLoweringONNXSqueezeOpToTOSAPattern(mlir::ConversionTarget &, mlir::RewritePatternSet &, mlir::TypeConverter &, mlir::MLIRContext *); void populateLoweringONNXTileOpToTOSAPattern(mlir::ConversionTarget &, diff --git a/src/Conversion/ONNXToTOSA/Tensor/Split.cpp b/src/Conversion/ONNXToTOSA/Tensor/Split.cpp new file mode 100644 index 0000000000..1de7dd9263 --- /dev/null +++ b/src/Conversion/ONNXToTOSA/Tensor/Split.cpp @@ -0,0 +1,78 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + */ + +//===------------- Split.cpp - Split Op---------===// +// +// Copyright (c) 2023 Advanced Micro Devices, Inc. +// +// ============================================================================= +// +// This file lowers ONNX SplitOp operator to TOSA dialect. +// +//===----------------------------------------------------------------------===// + +#include "src/Conversion/ONNXToTOSA/DialectBuilder.hpp" +#include "src/Conversion/ONNXToTOSA/ONNXToTOSACommon.hpp" +#include "src/Conversion/ONNXToTOSA/ONNXToTOSALegalizeUtils.hpp" +#include "src/Dialect/ONNX/ONNXOps.hpp" + +using namespace mlir; +namespace onnx_mlir { +namespace { +class ONNXSplitOpLoweringToTOSA : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + using OpAdaptor = typename ONNXSplitOp::Adaptor; + LogicalResult matchAndRewrite(ONNXSplitOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Value input = adaptor.getInput(); + ShapedType inputType = cast(input.getType()); + + // tosa.slice does not allow a dynamic entry in the size attribute + if (!hasStaticShape(inputType)) + return rewriter.notifyMatchFailure( + op, "only static shapes are supported"); + + uint64_t rank = inputType.getRank(); + int64_t splitAxis = adaptor.getAxis(); + if (splitAxis < 0) + splitAxis += rank; + + IndexExprBuilderForTosa createTosaIE(rewriter, op->getLoc()); + ONNXSplitOpShapeHelper shapeHelper( + op, adaptor.getOperands(), &createTosaIE); + + // compute shape + if (failed(shapeHelper.computeShape())) + return rewriter.notifyMatchFailure(op, "could not compute shape."); + + TosaBuilder tosaBuilder(rewriter, op->getLoc()); + uint64_t outputNum = op.getNumResults(); + SmallVector slices; + slices.reserve(outputNum); + + llvm::SmallVector size; + llvm::SmallVector starts(rank, 0); + int64_t start = 0; + + for (uint64_t i = 0; i < outputNum; i++) { + DimsExpr outputDim = shapeHelper.getOutputDims(i); + IndexExpr::getShape(outputDim, size); + starts[splitAxis] = start; + slices.push_back(tosaBuilder.slice(input, size, starts)); + start += size[splitAxis]; + } + rewriter.replaceOp(op, slices); + return success(); + } +}; +} // namespace + +void populateLoweringONNXSplitOpToTOSAPattern(ConversionTarget &target, + RewritePatternSet &patterns, TypeConverter &typeConverter, + MLIRContext *ctx) { + patterns.insert(typeConverter, ctx); +} + +} // namespace onnx_mlir \ No newline at end of file diff --git a/src/Dialect/ONNX/ONNX.td b/src/Dialect/ONNX/ONNX.td index d9f7259ee1..1f231b1a11 100644 --- a/src/Dialect/ONNX/ONNX.td +++ b/src/Dialect/ONNX/ONNX.td @@ -232,6 +232,11 @@ def ONNXConstantOpFromDenseAttr: NativeCodeCall< class ONNX_Op traits = []> : Op ; +// Trait to specify which operation set introduced a revision of an operator. +// For multi-versioned operators, the version also appears in the operator's name. +class OpVersionTrait + : ParamNativeOpTrait<"OpVersionTrait", !cast(version)>; + // The tablegen code onnxop.in is generated with gen_doc.py // clone and install onnx // git clone --recursive https://github.com/onnx/onnx.git diff --git a/src/Dialect/ONNX/ONNXOps.hpp b/src/Dialect/ONNX/ONNXOps.hpp index f016557608..2552279f27 100644 --- a/src/Dialect/ONNX/ONNXOps.hpp +++ b/src/Dialect/ONNX/ONNXOps.hpp @@ -18,6 +18,7 @@ #include "src/Dialect/ONNX/ONNXAttributes.hpp" #include "src/Dialect/ONNX/ONNXDialect.hpp" #include "src/Dialect/ONNX/ONNXOps/ShapeHelper.hpp" +#include "src/Dialect/ONNX/ONNXTraits.hpp" #include "src/Dialect/ONNX/ONNXTypes.hpp" #include "src/Interface/HasOnnxSubgraphOpInterface.hpp" #include "src/Interface/ResultTypeInferenceOpInterface.hpp" diff --git a/src/Dialect/ONNX/ONNXOps.td.inc b/src/Dialect/ONNX/ONNXOps.td.inc index 35512ffe73..a9e484c620 100644 --- a/src/Dialect/ONNX/ONNXOps.td.inc +++ b/src/Dialect/ONNX/ONNXOps.td.inc @@ -5,7 +5,7 @@ //******************************************************** def ONNXAbsOp:ONNX_Op<"Abs", - [Pure, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { + [Pure, OpVersionTrait<13>, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { let summary = "ONNX Abs operation"; let description = [{ Absolute takes one input data (Tensor) and produces one output data @@ -46,7 +46,7 @@ def ONNXAbsOp:ONNX_Op<"Abs", } def ONNXAcosOp:ONNX_Op<"Acos", - [Pure, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { + [Pure, OpVersionTrait<7>, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { let summary = "ONNX Acos operation"; let description = [{ Calculates the arccosine (inverse of cosine) of the given input tensor, element-wise. @@ -75,7 +75,7 @@ def ONNXAcosOp:ONNX_Op<"Acos", } def ONNXAcoshOp:ONNX_Op<"Acosh", - [Pure, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { + [Pure, OpVersionTrait<9>, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { let summary = "ONNX Acosh operation"; let description = [{ Calculates the hyperbolic arccosine of the given input tensor element-wise. @@ -104,7 +104,7 @@ def ONNXAcoshOp:ONNX_Op<"Acosh", } def ONNXAddOp:ONNX_Op<"Add", - [Pure, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods, SameOperandsAndResultElementType]> { + [Pure, OpVersionTrait<14>, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods, SameOperandsAndResultElementType]> { let hasCanonicalizer = 1; let summary = "ONNX Add operation"; let description = [{ @@ -160,7 +160,7 @@ def ONNXAddOp:ONNX_Op<"Add", } def ONNXAndOp:ONNX_Op<"And", - [Pure, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { + [Pure, OpVersionTrait<7>, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { let hasCanonicalizer = 1; let summary = "ONNX And operation"; let description = [{ @@ -215,7 +215,7 @@ def ONNXAndOp:ONNX_Op<"And", } def ONNXArgMaxOp:ONNX_Op<"ArgMax", - [Pure, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { + [Pure, OpVersionTrait<13>, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { let summary = "ONNX ArgMax operation"; let description = [{ Computes the indices of the max elements of the input tensor's element along the @@ -254,7 +254,7 @@ def ONNXArgMaxOp:ONNX_Op<"ArgMax", } def ONNXArgMinOp:ONNX_Op<"ArgMin", - [Pure, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { + [Pure, OpVersionTrait<13>, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { let summary = "ONNX ArgMin operation"; let description = [{ Computes the indices of the min elements of the input tensor's element along the @@ -293,7 +293,7 @@ def ONNXArgMinOp:ONNX_Op<"ArgMin", } def ONNXAsinOp:ONNX_Op<"Asin", - [Pure, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { + [Pure, OpVersionTrait<7>, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { let summary = "ONNX Asin operation"; let description = [{ Calculates the arcsine (inverse of sine) of the given input tensor, element-wise. @@ -322,7 +322,7 @@ def ONNXAsinOp:ONNX_Op<"Asin", } def ONNXAsinhOp:ONNX_Op<"Asinh", - [Pure, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { + [Pure, OpVersionTrait<9>, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { let summary = "ONNX Asinh operation"; let description = [{ Calculates the hyperbolic arcsine of the given input tensor element-wise. @@ -351,7 +351,7 @@ def ONNXAsinhOp:ONNX_Op<"Asinh", } def ONNXAtanOp:ONNX_Op<"Atan", - [Pure, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { + [Pure, OpVersionTrait<7>, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { let summary = "ONNX Atan operation"; let description = [{ Calculates the arctangent (inverse of tangent) of the given input tensor, element-wise. @@ -380,7 +380,7 @@ def ONNXAtanOp:ONNX_Op<"Atan", } def ONNXAtanhOp:ONNX_Op<"Atanh", - [Pure, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { + [Pure, OpVersionTrait<9>, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { let summary = "ONNX Atanh operation"; let description = [{ Calculates the hyperbolic arctangent of the given input tensor element-wise. @@ -409,7 +409,7 @@ def ONNXAtanhOp:ONNX_Op<"Atanh", } def ONNXAveragePoolOp:ONNX_Op<"AveragePool", - [Pure, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { + [Pure, OpVersionTrait<19>, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { let summary = "ONNX AveragePool operation"; let description = [{ AveragePool consumes an input tensor X and applies average pooling across @@ -478,7 +478,7 @@ def ONNXAveragePoolOp:ONNX_Op<"AveragePool", } def ONNXBatchNormalizationOp:ONNX_Op<"BatchNormalization", - [Pure, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { + [Pure, OpVersionTrait<15>, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { let summary = "ONNX BatchNormalization operation"; let description = [{ Carries out batch normalization as described in the paper @@ -554,7 +554,7 @@ def ONNXBatchNormalizationOp:ONNX_Op<"BatchNormalization", } def ONNXBernoulliOp:ONNX_Op<"Bernoulli", - [Pure, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { + [Pure, OpVersionTrait<15>, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { let summary = "ONNX Bernoulli operation"; let description = [{ Draws binary random numbers (0 or 1) from a Bernoulli distribution. The input tensor should be a tensor @@ -590,7 +590,7 @@ def ONNXBernoulliOp:ONNX_Op<"Bernoulli", } def ONNXBitShiftOp:ONNX_Op<"BitShift", - [Pure, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { + [Pure, OpVersionTrait<11>, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { let summary = "ONNX BitShift operation"; let description = [{ Bitwise shift operator performs element-wise operation. For each input element, if the @@ -633,7 +633,7 @@ def ONNXBitShiftOp:ONNX_Op<"BitShift", } def ONNXBitwiseAndOp:ONNX_Op<"BitwiseAnd", - [Pure, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { + [Pure, OpVersionTrait<18>, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { let summary = "ONNX BitwiseAnd operation"; let description = [{ Returns the tensor resulting from performing the bitwise `and` operation @@ -667,7 +667,7 @@ def ONNXBitwiseAndOp:ONNX_Op<"BitwiseAnd", } def ONNXBitwiseNotOp:ONNX_Op<"BitwiseNot", - [Pure, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { + [Pure, OpVersionTrait<18>, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { let summary = "ONNX BitwiseNot operation"; let description = [{ Returns the bitwise not of the input tensor element-wise. @@ -696,7 +696,7 @@ def ONNXBitwiseNotOp:ONNX_Op<"BitwiseNot", } def ONNXBitwiseOrOp:ONNX_Op<"BitwiseOr", - [Pure, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { + [Pure, OpVersionTrait<18>, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { let summary = "ONNX BitwiseOr operation"; let description = [{ Returns the tensor resulting from performing the bitwise `or` operation @@ -730,7 +730,7 @@ def ONNXBitwiseOrOp:ONNX_Op<"BitwiseOr", } def ONNXBitwiseXorOp:ONNX_Op<"BitwiseXor", - [Pure, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { + [Pure, OpVersionTrait<18>, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { let summary = "ONNX BitwiseXor operation"; let description = [{ Returns the tensor resulting from performing the bitwise `xor` operation @@ -764,7 +764,7 @@ def ONNXBitwiseXorOp:ONNX_Op<"BitwiseXor", } def ONNXBlackmanWindowOp:ONNX_Op<"BlackmanWindow", - [Pure, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { + [Pure, OpVersionTrait<17>, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { let summary = "ONNX BlackmanWindow operation"; let description = [{ Generates a Blackman window as described in the paper https://ieeexplore.ieee.org/document/1455106. @@ -795,7 +795,7 @@ def ONNXBlackmanWindowOp:ONNX_Op<"BlackmanWindow", } def ONNXCastOp:ONNX_Op<"Cast", - [Pure, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { + [Pure, OpVersionTrait<19>, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { let hasCanonicalizer = 1; let summary = "ONNX Cast operation"; let description = [{ @@ -896,7 +896,7 @@ def ONNXCastOp:ONNX_Op<"Cast", } def ONNXCastLikeOp:ONNX_Op<"CastLike", - [Pure, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { + [Pure, OpVersionTrait<19>, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { let summary = "ONNX CastLike operation"; let description = [{ The operator casts the elements of a given input tensor (the first input) to @@ -929,7 +929,7 @@ def ONNXCastLikeOp:ONNX_Op<"CastLike", } def ONNXCeilOp:ONNX_Op<"Ceil", - [Pure, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { + [Pure, OpVersionTrait<13>, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { let summary = "ONNX Ceil operation"; let description = [{ Ceil takes one input data (Tensor) and produces one output data @@ -960,7 +960,7 @@ def ONNXCeilOp:ONNX_Op<"Ceil", } def ONNXCeluOp:ONNX_Op<"Celu", - [Pure, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { + [Pure, OpVersionTrait<12>, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { let summary = "ONNX Celu operation"; let description = [{ Continuously Differentiable Exponential Linear Units: @@ -997,7 +997,7 @@ def ONNXCeluOp:ONNX_Op<"Celu", } def ONNXCenterCropPadOp:ONNX_Op<"CenterCropPad", - [Pure, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { + [Pure, OpVersionTrait<18>, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { let summary = "ONNX CenterCropPad operation"; let description = [{ Center crop or pad an input to given dimensions. @@ -1035,7 +1035,7 @@ def ONNXCenterCropPadOp:ONNX_Op<"CenterCropPad", } def ONNXClipOp:ONNX_Op<"Clip", - [Pure, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { + [Pure, OpVersionTrait<13>, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { let summary = "ONNX Clip operation"; let description = [{ Clip operator limits the given input within an interval. The interval is @@ -1068,7 +1068,7 @@ def ONNXClipOp:ONNX_Op<"Clip", } def ONNXClipV12Op:ONNX_Op<"ClipV12", - [Pure, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { + [Pure, OpVersionTrait<12>, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { let summary = "ONNX Clip operation"; let description = [{ Clip operator limits the given input within an interval. The interval is @@ -1101,7 +1101,7 @@ def ONNXClipV12Op:ONNX_Op<"ClipV12", } def ONNXClipV11Op:ONNX_Op<"ClipV11", - [Pure, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { + [Pure, OpVersionTrait<11>, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { let summary = "ONNX Clip operation"; let description = [{ Clip operator limits the given input within an interval. The interval is @@ -1134,7 +1134,7 @@ def ONNXClipV11Op:ONNX_Op<"ClipV11", } def ONNXClipV6Op:ONNX_Op<"ClipV6", - [Pure, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { + [Pure, OpVersionTrait<6>, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { let summary = "ONNX Clip operation"; let description = [{ Clip operator limits the given input within an interval. The interval is @@ -1167,7 +1167,7 @@ def ONNXClipV6Op:ONNX_Op<"ClipV6", } def ONNXCol2ImOp:ONNX_Op<"Col2Im", - [Pure, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { + [Pure, OpVersionTrait<18>, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { let summary = "ONNX Col2Im operation"; let description = [{ The operator rearranges column blocks back into a multidimensional image @@ -1210,7 +1210,7 @@ def ONNXCol2ImOp:ONNX_Op<"Col2Im", } def ONNXCompressOp:ONNX_Op<"Compress", - [Pure, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { + [Pure, OpVersionTrait<11>, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { let summary = "ONNX Compress operation"; let description = [{ Selects slices from an input tensor along a given axis where condition evaluates to True for each axis index. @@ -1245,7 +1245,7 @@ def ONNXCompressOp:ONNX_Op<"Compress", } def ONNXConcatOp:ONNX_Op<"Concat", - [Pure, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { + [Pure, OpVersionTrait<13>, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { let summary = "ONNX Concat operation"; let description = [{ Concatenate a list of tensors into a single tensor. All input tensors must have the same shape, except for the dimension size of the axis to concatenate on. @@ -1276,7 +1276,7 @@ def ONNXConcatOp:ONNX_Op<"Concat", } def ONNXConcatFromSequenceOp:ONNX_Op<"ConcatFromSequence", - [Pure, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { + [Pure, OpVersionTrait<11>, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { let summary = "ONNX ConcatFromSequence operation"; let description = [{ Concatenate a sequence of tensors into a single tensor. @@ -1311,7 +1311,7 @@ def ONNXConcatFromSequenceOp:ONNX_Op<"ConcatFromSequence", } def ONNXConstantOp:ONNX_Op<"Constant", - [Pure, ConstantLike, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { + [Pure, OpVersionTrait<19>, ConstantLike, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { let hasCustomAssemblyFormat = 1; let hasCanonicalizer = 1; let summary = "ONNX Constant operation"; @@ -1364,7 +1364,7 @@ def ONNXConstantOp:ONNX_Op<"Constant", } def ONNXConstantOfShapeOp:ONNX_Op<"ConstantOfShape", - [Pure, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { + [Pure, OpVersionTrait<20>, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { let hasCustomAssemblyFormat = 1; let summary = "ONNX ConstantOfShape operation"; let description = [{ @@ -1396,7 +1396,7 @@ def ONNXConstantOfShapeOp:ONNX_Op<"ConstantOfShape", } def ONNXConvOp:ONNX_Op<"Conv", - [Pure, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { + [Pure, OpVersionTrait<11>, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { let summary = "ONNX Conv operation"; let description = [{ The convolution operator consumes an input tensor and a filter, and @@ -1446,7 +1446,7 @@ def ONNXConvOp:ONNX_Op<"Conv", } def ONNXConvIntegerOp:ONNX_Op<"ConvInteger", - [Pure, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { + [Pure, OpVersionTrait<10>, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { let summary = "ONNX ConvInteger operation"; let description = [{ The integer convolution operator consumes an input tensor, its zero-point, a filter, and its zero-point, @@ -1485,7 +1485,7 @@ def ONNXConvIntegerOp:ONNX_Op<"ConvInteger", } def ONNXConvTransposeOp:ONNX_Op<"ConvTranspose", - [Pure, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { + [Pure, OpVersionTrait<11>, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { let summary = "ONNX ConvTranspose operation"; let description = [{ The convolution transpose operator consumes an input tensor and a filter, @@ -1538,7 +1538,7 @@ def ONNXConvTransposeOp:ONNX_Op<"ConvTranspose", } def ONNXCosOp:ONNX_Op<"Cos", - [Pure, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { + [Pure, OpVersionTrait<7>, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { let summary = "ONNX Cos operation"; let description = [{ Calculates the cosine of the given input tensor, element-wise. @@ -1567,7 +1567,7 @@ def ONNXCosOp:ONNX_Op<"Cos", } def ONNXCoshOp:ONNX_Op<"Cosh", - [Pure, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { + [Pure, OpVersionTrait<9>, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { let summary = "ONNX Cosh operation"; let description = [{ Calculates the hyperbolic cosine of the given input tensor element-wise. @@ -1596,7 +1596,7 @@ def ONNXCoshOp:ONNX_Op<"Cosh", } def ONNXCumSumOp:ONNX_Op<"CumSum", - [Pure, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { + [Pure, OpVersionTrait<14>, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { let summary = "ONNX CumSum operation"; let description = [{ Performs cumulative sum of the input elements along the given axis. @@ -1647,7 +1647,7 @@ def ONNXCumSumOp:ONNX_Op<"CumSum", } def ONNXDFTOp:ONNX_Op<"DFT", - [Pure, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { + [Pure, OpVersionTrait<17>, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { let summary = "ONNX DFT operation"; let description = [{ Computes the discrete Fourier transform of input. @@ -1680,7 +1680,7 @@ def ONNXDFTOp:ONNX_Op<"DFT", } def ONNXDeformConvOp:ONNX_Op<"DeformConv", - [Pure, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { + [Pure, OpVersionTrait<19>, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { let summary = "ONNX DeformConv operation"; let description = [{ Performs deformable convolution as described in https://arxiv.org/abs/1703.06211 and https://arxiv.org/abs/1811.11168. @@ -1720,7 +1720,7 @@ def ONNXDeformConvOp:ONNX_Op<"DeformConv", } def ONNXDepthToSpaceOp:ONNX_Op<"DepthToSpace", - [Pure, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { + [Pure, OpVersionTrait<13>, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { let hasCanonicalizer = 1; let summary = "ONNX DepthToSpace operation"; let description = [{ @@ -1775,7 +1775,7 @@ def ONNXDepthToSpaceOp:ONNX_Op<"DepthToSpace", } def ONNXDequantizeLinearOp:ONNX_Op<"DequantizeLinear", - [Pure, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { + [Pure, OpVersionTrait<19>, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { let summary = "ONNX DequantizeLinear operation"; let description = [{ The linear dequantization operator. It consumes a quantized tensor, a scale, and a zero point to compute the full precision tensor. @@ -1814,7 +1814,7 @@ def ONNXDequantizeLinearOp:ONNX_Op<"DequantizeLinear", } def ONNXDetOp:ONNX_Op<"Det", - [Pure, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { + [Pure, OpVersionTrait<11>, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { let summary = "ONNX Det operation"; let description = [{ Det calculates determinant of a square matrix or batches of square matrices. @@ -1847,7 +1847,7 @@ def ONNXDetOp:ONNX_Op<"Det", } def ONNXDivOp:ONNX_Op<"Div", - [Pure, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods, SameOperandsAndResultElementType]> { + [Pure, OpVersionTrait<14>, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods, SameOperandsAndResultElementType]> { let hasCanonicalizer = 1; let summary = "ONNX Div operation"; let description = [{ @@ -1903,7 +1903,7 @@ def ONNXDivOp:ONNX_Op<"Div", } def ONNXDropoutOp:ONNX_Op<"Dropout", - [Pure, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { + [Pure, OpVersionTrait<13>, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { let hasCanonicalizer = 1; let summary = "ONNX Dropout operation"; let description = [{ @@ -1948,7 +1948,7 @@ def ONNXDropoutOp:ONNX_Op<"Dropout", } def ONNXDynamicQuantizeLinearOp:ONNX_Op<"DynamicQuantizeLinear", - [Pure, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { + [Pure, OpVersionTrait<11>, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { let summary = "ONNX DynamicQuantizeLinear operation"; let description = [{ A Function to fuse calculation for Scale, Zero Point and FP32->8Bit conversion of FP32 Input data. @@ -2005,7 +2005,7 @@ def ONNXDynamicQuantizeLinearOp:ONNX_Op<"DynamicQuantizeLinear", } def ONNXEinsumOp:ONNX_Op<"Einsum", - [Pure, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { + [Pure, OpVersionTrait<12>, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { let summary = "ONNX Einsum operation"; let description = [{ An einsum of the form `term1, term2 -> output-term` produces an output tensor using the following equation @@ -2060,7 +2060,7 @@ def ONNXEinsumOp:ONNX_Op<"Einsum", } def ONNXEluOp:ONNX_Op<"Elu", - [Pure, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { + [Pure, OpVersionTrait<6>, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { let summary = "ONNX Elu operation"; let description = [{ Elu takes one input data (Tensor) and produces one output data @@ -2094,7 +2094,7 @@ def ONNXEluOp:ONNX_Op<"Elu", } def ONNXEqualOp:ONNX_Op<"Equal", - [Pure, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods, SameOperandsElementType]> { + [Pure, OpVersionTrait<19>, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods, SameOperandsElementType]> { let hasCanonicalizer = 1; let summary = "ONNX Equal operation"; let description = [{ @@ -2151,7 +2151,7 @@ def ONNXEqualOp:ONNX_Op<"Equal", } def ONNXErfOp:ONNX_Op<"Erf", - [Pure, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { + [Pure, OpVersionTrait<13>, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { let summary = "ONNX Erf operation"; let description = [{ Computes the error function of the given input tensor element-wise. @@ -2180,7 +2180,7 @@ def ONNXErfOp:ONNX_Op<"Erf", } def ONNXExpOp:ONNX_Op<"Exp", - [Pure, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { + [Pure, OpVersionTrait<13>, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { let summary = "ONNX Exp operation"; let description = [{ Calculates the exponential of the given input tensor, element-wise. @@ -2219,7 +2219,7 @@ def ONNXExpOp:ONNX_Op<"Exp", } def ONNXExpandOp:ONNX_Op<"Expand", - [Pure, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { + [Pure, OpVersionTrait<13>, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { let summary = "ONNX Expand operation"; let description = [{ Broadcast the input tensor following the given shape and the broadcast rule. @@ -2257,7 +2257,7 @@ def ONNXExpandOp:ONNX_Op<"Expand", } def ONNXEyeLikeOp:ONNX_Op<"EyeLike", - [Pure, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { + [Pure, OpVersionTrait<9>, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { let summary = "ONNX EyeLike operation"; let description = [{ Generate a 2D tensor (matrix) with ones on the diagonal and zeros everywhere else. Only 2D @@ -2294,7 +2294,7 @@ def ONNXEyeLikeOp:ONNX_Op<"EyeLike", } def ONNXFlattenOp:ONNX_Op<"Flatten", - [Pure, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { + [Pure, OpVersionTrait<13>, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { let summary = "ONNX Flatten operation"; let description = [{ Flattens the input tensor into a 2D matrix. If input tensor has shape @@ -2327,7 +2327,7 @@ def ONNXFlattenOp:ONNX_Op<"Flatten", } def ONNXFloorOp:ONNX_Op<"Floor", - [Pure, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { + [Pure, OpVersionTrait<13>, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { let summary = "ONNX Floor operation"; let description = [{ Floor takes one input data (Tensor) and produces one output data @@ -2358,7 +2358,7 @@ def ONNXFloorOp:ONNX_Op<"Floor", } def ONNXGRUOp:ONNX_Op<"GRU", - [Pure, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { + [Pure, OpVersionTrait<14>, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { let hasCanonicalizer = 1; let summary = "ONNX GRU operation"; let description = [{ @@ -2448,7 +2448,7 @@ def ONNXGRUOp:ONNX_Op<"GRU", } def ONNXGatherOp:ONNX_Op<"Gather", - [Pure, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { + [Pure, OpVersionTrait<13>, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { let summary = "ONNX Gather operation"; let description = [{ Given `data` tensor of rank r >= 1, and `indices` tensor of rank q, gather @@ -2527,7 +2527,7 @@ def ONNXGatherOp:ONNX_Op<"Gather", } def ONNXGatherElementsOp:ONNX_Op<"GatherElements", - [Pure, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { + [Pure, OpVersionTrait<13>, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { let summary = "ONNX GatherElements operation"; let description = [{ GatherElements takes two inputs `data` and `indices` of the same rank r >= 1 @@ -2609,7 +2609,7 @@ def ONNXGatherElementsOp:ONNX_Op<"GatherElements", } def ONNXGatherNDOp:ONNX_Op<"GatherND", - [Pure, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { + [Pure, OpVersionTrait<13>, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { let summary = "ONNX GatherND operation"; let description = [{ Given `data` tensor of rank `r` >= 1, `indices` tensor of rank `q` >= 1, and `batch_dims` integer `b`, this operator gathers @@ -2724,7 +2724,7 @@ def ONNXGatherNDOp:ONNX_Op<"GatherND", } def ONNXGeluOp:ONNX_Op<"Gelu", - [Pure, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { + [Pure, OpVersionTrait<20>, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { let summary = "ONNX Gelu operation"; let description = [{ Gelu takes one input data (Tensor) and produces one @@ -2761,7 +2761,7 @@ def ONNXGeluOp:ONNX_Op<"Gelu", } def ONNXGemmOp:ONNX_Op<"Gemm", - [Pure, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { + [Pure, OpVersionTrait<13>, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { let summary = "ONNX Gemm operation"; let description = [{ General Matrix multiplication: @@ -2807,7 +2807,7 @@ def ONNXGemmOp:ONNX_Op<"Gemm", } def ONNXGlobalAveragePoolOp:ONNX_Op<"GlobalAveragePool", - [Pure, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { + [Pure, OpVersionTrait<1>, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { let hasCanonicalizer = 1; let summary = "ONNX GlobalAveragePool operation"; let description = [{ @@ -2840,7 +2840,7 @@ def ONNXGlobalAveragePoolOp:ONNX_Op<"GlobalAveragePool", } def ONNXGlobalLpPoolOp:ONNX_Op<"GlobalLpPool", - [Pure, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { + [Pure, OpVersionTrait<2>, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { let summary = "ONNX GlobalLpPool operation"; let description = [{ GlobalLpPool consumes an input tensor X and applies lp pool pooling across @@ -2872,7 +2872,7 @@ def ONNXGlobalLpPoolOp:ONNX_Op<"GlobalLpPool", } def ONNXGlobalMaxPoolOp:ONNX_Op<"GlobalMaxPool", - [Pure, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { + [Pure, OpVersionTrait<1>, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { let hasCanonicalizer = 1; let summary = "ONNX GlobalMaxPool operation"; let description = [{ @@ -2904,7 +2904,7 @@ def ONNXGlobalMaxPoolOp:ONNX_Op<"GlobalMaxPool", } def ONNXGreaterOp:ONNX_Op<"Greater", - [Pure, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods, SameOperandsElementType]> { + [Pure, OpVersionTrait<13>, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods, SameOperandsElementType]> { let hasCanonicalizer = 1; let summary = "ONNX Greater operation"; let description = [{ @@ -2961,7 +2961,7 @@ def ONNXGreaterOp:ONNX_Op<"Greater", } def ONNXGreaterOrEqualOp:ONNX_Op<"GreaterOrEqual", - [Pure, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods, SameOperandsElementType]> { + [Pure, OpVersionTrait<16>, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods, SameOperandsElementType]> { let summary = "ONNX GreaterOrEqual operation"; let description = [{ Returns the tensor resulted from performing the `greater_equal` logical operation @@ -3017,7 +3017,7 @@ def ONNXGreaterOrEqualOp:ONNX_Op<"GreaterOrEqual", } def ONNXGridSampleOp:ONNX_Op<"GridSample", - [Pure, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { + [Pure, OpVersionTrait<16>, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { let summary = "ONNX GridSample operation"; let description = [{ Given an input `X` and a flow-field `grid`, computes the output `Y` using `X` values and pixel locations from `grid`. @@ -3063,7 +3063,7 @@ def ONNXGridSampleOp:ONNX_Op<"GridSample", } def ONNXGroupNormalizationOp:ONNX_Op<"GroupNormalization", - [Pure, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { + [Pure, OpVersionTrait<18>, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { let summary = "ONNX GroupNormalization operation"; let description = [{ A GroupNormalization function. Carries out group normalization as described in @@ -3110,7 +3110,7 @@ def ONNXGroupNormalizationOp:ONNX_Op<"GroupNormalization", } def ONNXHammingWindowOp:ONNX_Op<"HammingWindow", - [Pure, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { + [Pure, OpVersionTrait<17>, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { let summary = "ONNX HammingWindow operation"; let description = [{ Generates a Hamming window as described in the paper https://ieeexplore.ieee.org/document/1455106. @@ -3141,7 +3141,7 @@ def ONNXHammingWindowOp:ONNX_Op<"HammingWindow", } def ONNXHannWindowOp:ONNX_Op<"HannWindow", - [Pure, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { + [Pure, OpVersionTrait<17>, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { let summary = "ONNX HannWindow operation"; let description = [{ Generates a Hann window as described in the paper https://ieeexplore.ieee.org/document/1455106. @@ -3172,7 +3172,7 @@ def ONNXHannWindowOp:ONNX_Op<"HannWindow", } def ONNXHardSigmoidOp:ONNX_Op<"HardSigmoid", - [Pure, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { + [Pure, OpVersionTrait<6>, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { let summary = "ONNX HardSigmoid operation"; let description = [{ HardSigmoid takes one input data (Tensor) and produces one output data @@ -3206,7 +3206,7 @@ def ONNXHardSigmoidOp:ONNX_Op<"HardSigmoid", } def ONNXHardSwishOp:ONNX_Op<"HardSwish", - [Pure, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { + [Pure, OpVersionTrait<14>, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { let summary = "ONNX HardSwish operation"; let description = [{ HardSwish takes one input data (Tensor) and produces one output data (Tensor) where @@ -3237,7 +3237,7 @@ def ONNXHardSwishOp:ONNX_Op<"HardSwish", } def ONNXHardmaxOp:ONNX_Op<"Hardmax", - [Pure, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { + [Pure, OpVersionTrait<13>, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { let summary = "ONNX Hardmax operation"; let description = [{ The operator computes the hardmax values for the given input: @@ -3274,7 +3274,7 @@ def ONNXHardmaxOp:ONNX_Op<"Hardmax", } def ONNXIdentityOp:ONNX_Op<"Identity", - [Pure, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { + [Pure, OpVersionTrait<19>, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { let hasCanonicalizer = 1; let summary = "ONNX Identity operation"; let description = [{ @@ -3314,7 +3314,7 @@ def ONNXIdentityOp:ONNX_Op<"Identity", } def ONNXIfOp:ONNX_Op<"If", - [Pure, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods, OpInterface<"HasOnnxSubgraphOpInterface">]> { + [Pure, OpVersionTrait<19>, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods, OpInterface<"HasOnnxSubgraphOpInterface">]> { let summary = "ONNX If operation"; let description = [{ If conditional @@ -3351,7 +3351,7 @@ def ONNXIfOp:ONNX_Op<"If", } def ONNXInstanceNormalizationOp:ONNX_Op<"InstanceNormalization", - [Pure, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { + [Pure, OpVersionTrait<6>, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { let summary = "ONNX InstanceNormalization operation"; let description = [{ Carries out instance normalization as described in the paper @@ -3389,7 +3389,7 @@ def ONNXInstanceNormalizationOp:ONNX_Op<"InstanceNormalization", } def ONNXIsInfOp:ONNX_Op<"IsInf", - [Pure, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { + [Pure, OpVersionTrait<20>, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { let summary = "ONNX IsInf operation"; let description = [{ Map infinity to true and other values to false. @@ -3421,7 +3421,7 @@ def ONNXIsInfOp:ONNX_Op<"IsInf", } def ONNXIsNaNOp:ONNX_Op<"IsNaN", - [Pure, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { + [Pure, OpVersionTrait<20>, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { let summary = "ONNX IsNaN operation"; let description = [{ Returns which elements of the input are NaN. @@ -3450,7 +3450,7 @@ def ONNXIsNaNOp:ONNX_Op<"IsNaN", } def ONNXLRNOp:ONNX_Op<"LRN", - [Pure, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { + [Pure, OpVersionTrait<13>, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { let summary = "ONNX LRN operation"; let description = [{ Local Response Normalization proposed in the [AlexNet paper](https://papers.nips.cc/paper/4824-imagenet-classification-with-deep-convolutional-neural-networks.pdf). @@ -3492,7 +3492,7 @@ def ONNXLRNOp:ONNX_Op<"LRN", } def ONNXLSTMOp:ONNX_Op<"LSTM", - [Pure, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { + [Pure, OpVersionTrait<14>, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { let hasCanonicalizer = 1; let summary = "ONNX LSTM operation"; let description = [{ @@ -3588,7 +3588,7 @@ def ONNXLSTMOp:ONNX_Op<"LSTM", } def ONNXLayerNormalizationOp:ONNX_Op<"LayerNormalization", - [Pure, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { + [Pure, OpVersionTrait<17>, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { let summary = "ONNX LayerNormalization operation"; let description = [{ This is layer normalization defined in ONNX as function. @@ -3663,7 +3663,7 @@ def ONNXLayerNormalizationOp:ONNX_Op<"LayerNormalization", } def ONNXLeakyReluOp:ONNX_Op<"LeakyRelu", - [Pure, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { + [Pure, OpVersionTrait<16>, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { let summary = "ONNX LeakyRelu operation"; let description = [{ LeakyRelu takes input data (Tensor) and an argument alpha, and produces one @@ -3695,7 +3695,7 @@ def ONNXLeakyReluOp:ONNX_Op<"LeakyRelu", } def ONNXLessOp:ONNX_Op<"Less", - [Pure, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods, SameOperandsElementType]> { + [Pure, OpVersionTrait<13>, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods, SameOperandsElementType]> { let hasCanonicalizer = 1; let summary = "ONNX Less operation"; let description = [{ @@ -3752,7 +3752,7 @@ def ONNXLessOp:ONNX_Op<"Less", } def ONNXLessOrEqualOp:ONNX_Op<"LessOrEqual", - [Pure, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods, SameOperandsElementType]> { + [Pure, OpVersionTrait<16>, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods, SameOperandsElementType]> { let summary = "ONNX LessOrEqual operation"; let description = [{ Returns the tensor resulted from performing the `less_equal` logical operation @@ -3808,7 +3808,7 @@ def ONNXLessOrEqualOp:ONNX_Op<"LessOrEqual", } def ONNXLogOp:ONNX_Op<"Log", - [Pure, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { + [Pure, OpVersionTrait<13>, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { let summary = "ONNX Log operation"; let description = [{ Calculates the natural log of the given input tensor, element-wise. @@ -3837,7 +3837,7 @@ def ONNXLogOp:ONNX_Op<"Log", } def ONNXLogSoftmaxOp:ONNX_Op<"LogSoftmax", - [Pure, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { + [Pure, OpVersionTrait<13>, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { let summary = "ONNX LogSoftmax operation"; let description = [{ The operator computes the log of softmax values for the given input: @@ -3874,7 +3874,7 @@ def ONNXLogSoftmaxOp:ONNX_Op<"LogSoftmax", } def ONNXLoopOp:ONNX_Op<"Loop", - [Pure, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods, OpInterface<"HasOnnxSubgraphOpInterface">]> { + [Pure, OpVersionTrait<19>, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods, OpInterface<"HasOnnxSubgraphOpInterface">]> { let hasCanonicalizer = 1; let summary = "ONNX Loop operation"; let description = [{ @@ -4048,7 +4048,7 @@ def ONNXLoopOp:ONNX_Op<"Loop", } def ONNXLpNormalizationOp:ONNX_Op<"LpNormalization", - [Pure, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { + [Pure, OpVersionTrait<1>, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { let summary = "ONNX LpNormalization operation"; let description = [{ Given a matrix, apply Lp-normalization along the provided axis. @@ -4079,7 +4079,7 @@ def ONNXLpNormalizationOp:ONNX_Op<"LpNormalization", } def ONNXLpPoolOp:ONNX_Op<"LpPool", - [Pure, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { + [Pure, OpVersionTrait<18>, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { let summary = "ONNX LpPool operation"; let description = [{ LpPool consumes an input tensor X and applies Lp pooling across @@ -4137,7 +4137,7 @@ def ONNXLpPoolOp:ONNX_Op<"LpPool", } def ONNXMatMulOp:ONNX_Op<"MatMul", - [Pure, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { + [Pure, OpVersionTrait<13>, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { let summary = "ONNX MatMul operation"; let description = [{ Matrix product that behaves like numpy.matmul: https://docs.scipy.org/doc/numpy-1.13.0/reference/generated/numpy.matmul.html @@ -4167,7 +4167,7 @@ def ONNXMatMulOp:ONNX_Op<"MatMul", } def ONNXMatMulIntegerOp:ONNX_Op<"MatMulInteger", - [Pure, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { + [Pure, OpVersionTrait<10>, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { let summary = "ONNX MatMulInteger operation"; let description = [{ Matrix product that behaves like numpy.matmul: https://docs.scipy.org/doc/numpy-1.13.0/reference/generated/numpy.matmul.html. @@ -4201,7 +4201,7 @@ def ONNXMatMulIntegerOp:ONNX_Op<"MatMulInteger", } def ONNXMaxOp:ONNX_Op<"Max", - [Pure, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods, SameOperandsAndResultElementType]> { + [Pure, OpVersionTrait<13>, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods, SameOperandsAndResultElementType]> { let summary = "ONNX Max operation"; let description = [{ Element-wise max of each of the input tensors (with Numpy-style broadcasting support). @@ -4233,7 +4233,7 @@ def ONNXMaxOp:ONNX_Op<"Max", } def ONNXMaxPoolOp:ONNX_Op<"MaxPool", - [Pure, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { + [Pure, OpVersionTrait<12>, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { let summary = "ONNX MaxPool operation"; let description = [{ MaxPool consumes an input tensor X and applies max pooling across @@ -4302,7 +4302,7 @@ def ONNXMaxPoolOp:ONNX_Op<"MaxPool", } def ONNXMaxRoiPoolOp:ONNX_Op<"MaxRoiPool", - [Pure, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { + [Pure, OpVersionTrait<1>, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { let summary = "ONNX MaxRoiPool operation"; let description = [{ ROI max pool consumes an input tensor X and region of interests (RoIs) to @@ -4336,7 +4336,7 @@ def ONNXMaxRoiPoolOp:ONNX_Op<"MaxRoiPool", } def ONNXMaxUnpoolOp:ONNX_Op<"MaxUnpool", - [Pure, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { + [Pure, OpVersionTrait<11>, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { let summary = "ONNX MaxUnpool operation"; let description = [{ MaxUnpool essentially computes the partial inverse of the MaxPool op. @@ -4387,7 +4387,7 @@ def ONNXMaxUnpoolOp:ONNX_Op<"MaxUnpool", } def ONNXMeanOp:ONNX_Op<"Mean", - [Pure, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { + [Pure, OpVersionTrait<13>, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { let summary = "ONNX Mean operation"; let description = [{ Element-wise mean of each of the input tensors (with Numpy-style broadcasting support). @@ -4419,7 +4419,7 @@ def ONNXMeanOp:ONNX_Op<"Mean", } def ONNXMeanVarianceNormalizationOp:ONNX_Op<"MeanVarianceNormalization", - [Pure, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { + [Pure, OpVersionTrait<13>, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { let summary = "ONNX MeanVarianceNormalization operation"; let description = [{ A MeanVarianceNormalization Function: Perform mean variance normalization @@ -4450,7 +4450,7 @@ def ONNXMeanVarianceNormalizationOp:ONNX_Op<"MeanVarianceNormalization", } def ONNXMelWeightMatrixOp:ONNX_Op<"MelWeightMatrix", - [Pure, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { + [Pure, OpVersionTrait<17>, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { let summary = "ONNX MelWeightMatrix operation"; let description = [{ Generate a MelWeightMatrix that can be used to re-weight a Tensor containing a linearly sampled frequency spectra (from DFT or STFT) into num_mel_bins frequency information based on the [lower_edge_hertz, upper_edge_hertz] range on the mel scale. @@ -4491,7 +4491,7 @@ def ONNXMelWeightMatrixOp:ONNX_Op<"MelWeightMatrix", } def ONNXMinOp:ONNX_Op<"Min", - [Pure, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods, SameOperandsAndResultElementType]> { + [Pure, OpVersionTrait<13>, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods, SameOperandsAndResultElementType]> { let summary = "ONNX Min operation"; let description = [{ Element-wise min of each of the input tensors (with Numpy-style broadcasting support). @@ -4523,7 +4523,7 @@ def ONNXMinOp:ONNX_Op<"Min", } def ONNXMishOp:ONNX_Op<"Mish", - [Pure, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods, SameOperandsAndResultElementType]> { + [Pure, OpVersionTrait<18>, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods, SameOperandsAndResultElementType]> { let summary = "ONNX Mish operation"; let description = [{ Mish: A Self Regularized Non-Monotonic Neural Activation Function. @@ -4558,7 +4558,7 @@ def ONNXMishOp:ONNX_Op<"Mish", } def ONNXModOp:ONNX_Op<"Mod", - [Pure, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods, SameOperandsAndResultElementType]> { + [Pure, OpVersionTrait<13>, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods, SameOperandsAndResultElementType]> { let summary = "ONNX Mod operation"; let description = [{ Performs element-wise binary modulus (with Numpy-style broadcasting support). @@ -4602,7 +4602,7 @@ def ONNXModOp:ONNX_Op<"Mod", } def ONNXMulOp:ONNX_Op<"Mul", - [Pure, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods, SameOperandsAndResultElementType]> { + [Pure, OpVersionTrait<14>, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods, SameOperandsAndResultElementType]> { let hasCanonicalizer = 1; let summary = "ONNX Mul operation"; let description = [{ @@ -4658,7 +4658,7 @@ def ONNXMulOp:ONNX_Op<"Mul", } def ONNXMultinomialOp:ONNX_Op<"Multinomial", - [Pure, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { + [Pure, OpVersionTrait<7>, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { let summary = "ONNX Multinomial operation"; let description = [{ Generate a tensor of samples from a multinomial distribution according to the probabilities @@ -4691,7 +4691,7 @@ def ONNXMultinomialOp:ONNX_Op<"Multinomial", } def ONNXNegOp:ONNX_Op<"Neg", - [Pure, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { + [Pure, OpVersionTrait<13>, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { let summary = "ONNX Neg operation"; let description = [{ Neg takes one input data (Tensor) and produces one output data @@ -4732,7 +4732,7 @@ def ONNXNegOp:ONNX_Op<"Neg", } def ONNXNegativeLogLikelihoodLossOp:ONNX_Op<"NegativeLogLikelihoodLoss", - [Pure, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { + [Pure, OpVersionTrait<13>, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { let summary = "ONNX NegativeLogLikelihoodLoss operation"; let description = [{ A NegativeLogLikelihoodLoss operator computes (weighted) negative log likelihood loss. @@ -4865,7 +4865,7 @@ def ONNXNegativeLogLikelihoodLossOp:ONNX_Op<"NegativeLogLikelihoodLoss", } def ONNXNonMaxSuppressionOp:ONNX_Op<"NonMaxSuppression", - [Pure, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { + [Pure, OpVersionTrait<11>, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { let summary = "ONNX NonMaxSuppression operation"; let description = [{ Filter out boxes that have high intersection-over-union (IOU) overlap with previously selected boxes. @@ -4906,7 +4906,7 @@ def ONNXNonMaxSuppressionOp:ONNX_Op<"NonMaxSuppression", } def ONNXNonZeroOp:ONNX_Op<"NonZero", - [Pure, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { + [Pure, OpVersionTrait<13>, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { let summary = "ONNX NonZero operation"; let description = [{ Returns the indices of the elements that are non-zero @@ -4939,7 +4939,7 @@ def ONNXNonZeroOp:ONNX_Op<"NonZero", } def ONNXNotOp:ONNX_Op<"Not", - [Pure, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { + [Pure, OpVersionTrait<1>, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { let summary = "ONNX Not operation"; let description = [{ Returns the negation of the input tensor element-wise. @@ -4968,7 +4968,7 @@ def ONNXNotOp:ONNX_Op<"Not", } def ONNXOneHotOp:ONNX_Op<"OneHot", - [Pure, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { + [Pure, OpVersionTrait<11>, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { let summary = "ONNX OneHot operation"; let description = [{ Produces a one-hot tensor based on inputs. @@ -5019,7 +5019,7 @@ def ONNXOneHotOp:ONNX_Op<"OneHot", } def ONNXOptionalOp:ONNX_Op<"Optional", - [Pure, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { + [Pure, OpVersionTrait<15>, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { let summary = "ONNX Optional operation"; let description = [{ Constructs an optional-type value containing either an empty optional of a certain type specified by the attribute, @@ -5051,7 +5051,7 @@ def ONNXOptionalOp:ONNX_Op<"Optional", } def ONNXOptionalGetElementOp:ONNX_Op<"OptionalGetElement", - [Pure, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { + [Pure, OpVersionTrait<18>, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { let summary = "ONNX OptionalGetElement operation"; let description = [{ If the input is a tensor or sequence type, it returns the input. @@ -5083,7 +5083,7 @@ def ONNXOptionalGetElementOp:ONNX_Op<"OptionalGetElement", } def ONNXOptionalHasElementOp:ONNX_Op<"OptionalHasElement", - [Pure, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { + [Pure, OpVersionTrait<18>, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { let summary = "ONNX OptionalHasElement operation"; let description = [{ Returns true if (1) the input is an optional-type and contains an element, @@ -5115,7 +5115,7 @@ def ONNXOptionalHasElementOp:ONNX_Op<"OptionalHasElement", } def ONNXOrOp:ONNX_Op<"Or", - [Pure, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { + [Pure, OpVersionTrait<7>, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { let hasCanonicalizer = 1; let summary = "ONNX Or operation"; let description = [{ @@ -5170,7 +5170,7 @@ def ONNXOrOp:ONNX_Op<"Or", } def ONNXPReluOp:ONNX_Op<"PRelu", - [Pure, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { + [Pure, OpVersionTrait<16>, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { let summary = "ONNX PRelu operation"; let description = [{ PRelu takes input data (Tensor) and slope tensor as input, and produces one @@ -5204,7 +5204,7 @@ def ONNXPReluOp:ONNX_Op<"PRelu", } def ONNXPadOp:ONNX_Op<"Pad", - [Pure, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { + [Pure, OpVersionTrait<19>, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { let summary = "ONNX Pad operation"; let description = [{ Given a tensor containing the data to be padded (`data`), a tensor containing the number of start and end pad values for axis (`pads`), (optionally) a `mode`, and (optionally) `constant_value`, @@ -5347,7 +5347,7 @@ def ONNXPadOp:ONNX_Op<"Pad", } def ONNXPadV18Op:ONNX_Op<"PadV18", - [Pure, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { + [Pure, OpVersionTrait<18>, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { let summary = "ONNX Pad operation"; let description = [{ Given a tensor containing the data to be padded (`data`), a tensor containing the number of start and end pad values for axis (`pads`), (optionally) a `mode`, and (optionally) `constant_value`, @@ -5454,7 +5454,7 @@ def ONNXPadV18Op:ONNX_Op<"PadV18", } def ONNXPadV13Op:ONNX_Op<"PadV13", - [Pure, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { + [Pure, OpVersionTrait<13>, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { let summary = "ONNX Pad operation"; let description = [{ Given a tensor containing the data to be padded (`data`), a tensor containing the number of start and end pad values for axis (`pads`), (optionally) a `mode`, and (optionally) `constant_value`, @@ -5560,7 +5560,7 @@ def ONNXPadV13Op:ONNX_Op<"PadV13", } def ONNXPadV11Op:ONNX_Op<"PadV11", - [Pure, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { + [Pure, OpVersionTrait<11>, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { let summary = "ONNX Pad operation"; let description = [{ Given a tensor containing the data to be padded (`data`), a tensor containing the number of start and end pad values for axis (`pads`), (optionally) a `mode`, and (optionally) `constant_value`, @@ -5666,7 +5666,7 @@ def ONNXPadV11Op:ONNX_Op<"PadV11", } def ONNXPadV2Op:ONNX_Op<"PadV2", - [Pure, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { + [Pure, OpVersionTrait<2>, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { let summary = "ONNX Pad operation"; let description = [{ Given `data` tensor, pads, mode, and value. @@ -5713,7 +5713,7 @@ def ONNXPadV2Op:ONNX_Op<"PadV2", } def ONNXPowOp:ONNX_Op<"Pow", - [Pure, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { + [Pure, OpVersionTrait<15>, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { let hasCanonicalizer = 1; let summary = "ONNX Pow operation"; let description = [{ @@ -5768,7 +5768,7 @@ def ONNXPowOp:ONNX_Op<"Pow", } def ONNXQLinearConvOp:ONNX_Op<"QLinearConv", - [Pure, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { + [Pure, OpVersionTrait<10>, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { let summary = "ONNX QLinearConv operation"; let description = [{ The convolution operator consumes a quantized input tensor, its scale and zero point, @@ -5817,7 +5817,7 @@ def ONNXQLinearConvOp:ONNX_Op<"QLinearConv", } def ONNXQLinearMatMulOp:ONNX_Op<"QLinearMatMul", - [Pure, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { + [Pure, OpVersionTrait<10>, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { let summary = "ONNX QLinearMatMul operation"; let description = [{ Matrix product that behaves like numpy.matmul: https://docs.scipy.org/doc/numpy-1.13.0/reference/generated/numpy.matmul.html. @@ -5863,7 +5863,7 @@ def ONNXQLinearMatMulOp:ONNX_Op<"QLinearMatMul", } def ONNXQuantizeLinearOp:ONNX_Op<"QuantizeLinear", - [Pure, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { + [Pure, OpVersionTrait<19>, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { let summary = "ONNX QuantizeLinear operation"; let description = [{ The linear quantization operator. It consumes a high precision tensor, a scale, and a zero point to compute the low precision / quantized tensor. @@ -5904,7 +5904,7 @@ def ONNXQuantizeLinearOp:ONNX_Op<"QuantizeLinear", } def ONNXRNNOp:ONNX_Op<"RNN", - [Pure, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { + [Pure, OpVersionTrait<14>, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { let hasCanonicalizer = 1; let summary = "ONNX RNN operation"; let description = [{ @@ -5986,7 +5986,7 @@ def ONNXRNNOp:ONNX_Op<"RNN", } def ONNXRandomNormalOp:ONNX_Op<"RandomNormal", - [Pure, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { + [Pure, OpVersionTrait<1>, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { let summary = "ONNX RandomNormal operation"; let description = [{ Generate a tensor with random values drawn from a normal distribution. The shape @@ -6025,7 +6025,7 @@ def ONNXRandomNormalOp:ONNX_Op<"RandomNormal", } def ONNXRandomNormalLikeOp:ONNX_Op<"RandomNormalLike", - [Pure, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { + [Pure, OpVersionTrait<1>, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { let summary = "ONNX RandomNormalLike operation"; let description = [{ Generate a tensor with random values drawn from a normal distribution. @@ -6065,7 +6065,7 @@ def ONNXRandomNormalLikeOp:ONNX_Op<"RandomNormalLike", } def ONNXRandomUniformOp:ONNX_Op<"RandomUniform", - [Pure, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { + [Pure, OpVersionTrait<1>, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { let summary = "ONNX RandomUniform operation"; let description = [{ Generate a tensor with random values drawn from a uniform distribution. The shape @@ -6103,7 +6103,7 @@ def ONNXRandomUniformOp:ONNX_Op<"RandomUniform", } def ONNXRandomUniformLikeOp:ONNX_Op<"RandomUniformLike", - [Pure, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { + [Pure, OpVersionTrait<1>, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { let summary = "ONNX RandomUniformLike operation"; let description = [{ Generate a tensor with random values drawn from a uniform distribution. @@ -6142,7 +6142,7 @@ def ONNXRandomUniformLikeOp:ONNX_Op<"RandomUniformLike", } def ONNXRangeOp:ONNX_Op<"Range", - [Pure, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { + [Pure, OpVersionTrait<11>, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { let summary = "ONNX Range operation"; let description = [{ Generate a tensor containing a sequence of numbers that begin at `start` and extends by increments of `delta` @@ -6203,7 +6203,7 @@ def ONNXRangeOp:ONNX_Op<"Range", } def ONNXReciprocalOp:ONNX_Op<"Reciprocal", - [Pure, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { + [Pure, OpVersionTrait<13>, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { let summary = "ONNX Reciprocal operation"; let description = [{ Reciprocal takes one input data (Tensor) and produces one output data @@ -6234,7 +6234,7 @@ def ONNXReciprocalOp:ONNX_Op<"Reciprocal", } def ONNXReduceL1Op:ONNX_Op<"ReduceL1", - [Pure, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { + [Pure, OpVersionTrait<18>, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { let summary = "ONNX ReduceL1 operation"; let description = [{ Computes the L1 norm of the input tensor's elements along the provided axes. The resulting @@ -6273,7 +6273,7 @@ def ONNXReduceL1Op:ONNX_Op<"ReduceL1", } def ONNXReduceL1V13Op:ONNX_Op<"ReduceL1V13", - [Pure, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { + [Pure, OpVersionTrait<13>, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { let summary = "ONNX ReduceL1 operation"; let description = [{ Computes the L1 norm of the input tensor's elements along the provided axes. The resulting @@ -6311,7 +6311,7 @@ def ONNXReduceL1V13Op:ONNX_Op<"ReduceL1V13", } def ONNXReduceL2Op:ONNX_Op<"ReduceL2", - [Pure, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { + [Pure, OpVersionTrait<18>, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { let summary = "ONNX ReduceL2 operation"; let description = [{ Computes the L2 norm of the input tensor's elements along the provided axes. The resulting @@ -6350,7 +6350,7 @@ def ONNXReduceL2Op:ONNX_Op<"ReduceL2", } def ONNXReduceL2V13Op:ONNX_Op<"ReduceL2V13", - [Pure, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { + [Pure, OpVersionTrait<13>, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { let summary = "ONNX ReduceL2 operation"; let description = [{ Computes the L2 norm of the input tensor's elements along the provided axes. The resulting @@ -6388,7 +6388,7 @@ def ONNXReduceL2V13Op:ONNX_Op<"ReduceL2V13", } def ONNXReduceLogSumOp:ONNX_Op<"ReduceLogSum", - [Pure, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { + [Pure, OpVersionTrait<18>, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { let summary = "ONNX ReduceLogSum operation"; let description = [{ Computes the log sum of the input tensor's elements along the provided axes. The resulting @@ -6437,7 +6437,7 @@ def ONNXReduceLogSumOp:ONNX_Op<"ReduceLogSum", } def ONNXReduceLogSumV13Op:ONNX_Op<"ReduceLogSumV13", - [Pure, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { + [Pure, OpVersionTrait<13>, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { let summary = "ONNX ReduceLogSum operation"; let description = [{ Computes the log sum of the input tensor's elements along the provided axes. The resulting @@ -6475,7 +6475,7 @@ def ONNXReduceLogSumV13Op:ONNX_Op<"ReduceLogSumV13", } def ONNXReduceLogSumExpOp:ONNX_Op<"ReduceLogSumExp", - [Pure, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { + [Pure, OpVersionTrait<18>, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { let summary = "ONNX ReduceLogSumExp operation"; let description = [{ Computes the log sum exponent of the input tensor's elements along the provided axes. The resulting @@ -6514,7 +6514,7 @@ def ONNXReduceLogSumExpOp:ONNX_Op<"ReduceLogSumExp", } def ONNXReduceLogSumExpV13Op:ONNX_Op<"ReduceLogSumExpV13", - [Pure, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { + [Pure, OpVersionTrait<13>, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { let summary = "ONNX ReduceLogSumExp operation"; let description = [{ Computes the log sum exponent of the input tensor's elements along the provided axes. The resulting @@ -6552,7 +6552,7 @@ def ONNXReduceLogSumExpV13Op:ONNX_Op<"ReduceLogSumExpV13", } def ONNXReduceMaxOp:ONNX_Op<"ReduceMax", - [Pure, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { + [Pure, OpVersionTrait<20>, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { let summary = "ONNX ReduceMax operation"; let description = [{ Computes the max of the input tensor's elements along the provided axes. The resulting @@ -6603,7 +6603,7 @@ def ONNXReduceMaxOp:ONNX_Op<"ReduceMax", } def ONNXReduceMaxV18Op:ONNX_Op<"ReduceMaxV18", - [Pure, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { + [Pure, OpVersionTrait<18>, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { let summary = "ONNX ReduceMax operation"; let description = [{ Computes the max of the input tensor's elements along the provided axes. The resulting @@ -6652,7 +6652,7 @@ def ONNXReduceMaxV18Op:ONNX_Op<"ReduceMaxV18", } def ONNXReduceMaxV13Op:ONNX_Op<"ReduceMaxV13", - [Pure, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { + [Pure, OpVersionTrait<13>, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { let summary = "ONNX ReduceMax operation"; let description = [{ Computes the max of the input tensor's elements along the provided axes. The resulting @@ -6700,7 +6700,7 @@ def ONNXReduceMaxV13Op:ONNX_Op<"ReduceMaxV13", } def ONNXReduceMeanOp:ONNX_Op<"ReduceMean", - [Pure, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { + [Pure, OpVersionTrait<18>, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { let summary = "ONNX ReduceMean operation"; let description = [{ Computes the mean of the input tensor's elements along the provided axes. The resulting @@ -6739,7 +6739,7 @@ def ONNXReduceMeanOp:ONNX_Op<"ReduceMean", } def ONNXReduceMeanV13Op:ONNX_Op<"ReduceMeanV13", - [Pure, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { + [Pure, OpVersionTrait<13>, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { let summary = "ONNX ReduceMean operation"; let description = [{ Computes the mean of the input tensor's elements along the provided axes. The resulting @@ -6777,7 +6777,7 @@ def ONNXReduceMeanV13Op:ONNX_Op<"ReduceMeanV13", } def ONNXReduceMinOp:ONNX_Op<"ReduceMin", - [Pure, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { + [Pure, OpVersionTrait<20>, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { let summary = "ONNX ReduceMin operation"; let description = [{ Computes the min of the input tensor's elements along the provided axes. The resulting @@ -6818,7 +6818,7 @@ def ONNXReduceMinOp:ONNX_Op<"ReduceMin", } def ONNXReduceMinV18Op:ONNX_Op<"ReduceMinV18", - [Pure, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { + [Pure, OpVersionTrait<18>, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { let summary = "ONNX ReduceMin operation"; let description = [{ Computes the min of the input tensor's elements along the provided axes. The resulting @@ -6857,7 +6857,7 @@ def ONNXReduceMinV18Op:ONNX_Op<"ReduceMinV18", } def ONNXReduceMinV13Op:ONNX_Op<"ReduceMinV13", - [Pure, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { + [Pure, OpVersionTrait<13>, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { let summary = "ONNX ReduceMin operation"; let description = [{ Computes the min of the input tensor's elements along the provided axes. The resulting @@ -6895,7 +6895,7 @@ def ONNXReduceMinV13Op:ONNX_Op<"ReduceMinV13", } def ONNXReduceProdOp:ONNX_Op<"ReduceProd", - [Pure, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { + [Pure, OpVersionTrait<18>, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { let summary = "ONNX ReduceProd operation"; let description = [{ Computes the product of the input tensor's elements along the provided axes. The resulting @@ -6934,7 +6934,7 @@ def ONNXReduceProdOp:ONNX_Op<"ReduceProd", } def ONNXReduceProdV13Op:ONNX_Op<"ReduceProdV13", - [Pure, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { + [Pure, OpVersionTrait<13>, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { let summary = "ONNX ReduceProd operation"; let description = [{ Computes the product of the input tensor's elements along the provided axes. The resulting @@ -6972,7 +6972,7 @@ def ONNXReduceProdV13Op:ONNX_Op<"ReduceProdV13", } def ONNXReduceSumOp:ONNX_Op<"ReduceSum", - [Pure, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { + [Pure, OpVersionTrait<13>, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { let summary = "ONNX ReduceSum operation"; let description = [{ Computes the sum of the input tensor's elements along the provided axes. The resulting @@ -7021,7 +7021,7 @@ def ONNXReduceSumOp:ONNX_Op<"ReduceSum", } def ONNXReduceSumV11Op:ONNX_Op<"ReduceSumV11", - [Pure, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { + [Pure, OpVersionTrait<11>, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { let summary = "ONNX ReduceSum operation"; let description = [{ Computes the sum of the input tensor's element along the provided axes. The resulting @@ -7067,7 +7067,7 @@ def ONNXReduceSumV11Op:ONNX_Op<"ReduceSumV11", } def ONNXReduceSumSquareOp:ONNX_Op<"ReduceSumSquare", - [Pure, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { + [Pure, OpVersionTrait<18>, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { let summary = "ONNX ReduceSumSquare operation"; let description = [{ Computes the sum square of the input tensor's elements along the provided axes. The resulting @@ -7116,7 +7116,7 @@ def ONNXReduceSumSquareOp:ONNX_Op<"ReduceSumSquare", } def ONNXReduceSumSquareV13Op:ONNX_Op<"ReduceSumSquareV13", - [Pure, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { + [Pure, OpVersionTrait<13>, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { let summary = "ONNX ReduceSumSquare operation"; let description = [{ Computes the sum square of the input tensor's elements along the provided axes. The resulting @@ -7154,7 +7154,7 @@ def ONNXReduceSumSquareV13Op:ONNX_Op<"ReduceSumSquareV13", } def ONNXReluOp:ONNX_Op<"Relu", - [Pure, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { + [Pure, OpVersionTrait<14>, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { let summary = "ONNX Relu operation"; let description = [{ Relu takes one input data (Tensor) and produces one output data @@ -7185,7 +7185,7 @@ def ONNXReluOp:ONNX_Op<"Relu", } def ONNXReshapeOp:ONNX_Op<"Reshape", - [Pure, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { + [Pure, OpVersionTrait<19>, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { let hasCanonicalizer = 1; let summary = "ONNX Reshape operation"; let description = [{ @@ -7230,7 +7230,7 @@ def ONNXReshapeOp:ONNX_Op<"Reshape", } def ONNXResizeOp:ONNX_Op<"Resize", - [Pure, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { + [Pure, OpVersionTrait<19>, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { let hasCanonicalizer = 1; let summary = "ONNX Resize operation"; let description = [{ @@ -7278,7 +7278,7 @@ def ONNXResizeOp:ONNX_Op<"Resize", } def ONNXResizeV18Op:ONNX_Op<"ResizeV18", - [Pure, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { + [Pure, OpVersionTrait<18>, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { let summary = "ONNX Resize operation"; let description = [{ Resize the input tensor. In general, it calculates every value in the output tensor as a weighted average of neighborhood (a.k.a. sampling locations) in the input tensor. @@ -7322,7 +7322,7 @@ def ONNXResizeV18Op:ONNX_Op<"ResizeV18", } def ONNXResizeV13Op:ONNX_Op<"ResizeV13", - [Pure, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { + [Pure, OpVersionTrait<13>, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { let summary = "ONNX Resize operation"; let description = [{ Resize the input tensor. In general, it calculates every value in the output tensor as a weighted average of neighborhood (a.k.a. sampling locations) in the input tensor. @@ -7362,7 +7362,7 @@ def ONNXResizeV13Op:ONNX_Op<"ResizeV13", } def ONNXResizeV11Op:ONNX_Op<"ResizeV11", - [Pure, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { + [Pure, OpVersionTrait<11>, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { let summary = "ONNX Resize operation"; let description = [{ Resize the input tensor. In general, it calculates every value in the output tensor as a weighted average of neighborhood (a.k.a. sampling locations) in the input tensor. @@ -7402,7 +7402,7 @@ def ONNXResizeV11Op:ONNX_Op<"ResizeV11", } def ONNXResizeV10Op:ONNX_Op<"ResizeV10", - [Pure, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { + [Pure, OpVersionTrait<10>, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { let summary = "ONNX Resize operation"; let description = [{ Resize the input tensor. @@ -7435,7 +7435,7 @@ def ONNXResizeV10Op:ONNX_Op<"ResizeV10", } def ONNXReverseSequenceOp:ONNX_Op<"ReverseSequence", - [Pure, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { + [Pure, OpVersionTrait<10>, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { let summary = "ONNX ReverseSequence operation"; let description = [{ Reverse batch of sequences having different lengths specified by `sequence_lens`. @@ -7500,7 +7500,7 @@ def ONNXReverseSequenceOp:ONNX_Op<"ReverseSequence", } def ONNXRoiAlignOp:ONNX_Op<"RoiAlign", - [Pure, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { + [Pure, OpVersionTrait<16>, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { let summary = "ONNX RoiAlign operation"; let description = [{ Region of Interest (RoI) align operation described in the @@ -7548,7 +7548,7 @@ def ONNXRoiAlignOp:ONNX_Op<"RoiAlign", } def ONNXRoundOp:ONNX_Op<"Round", - [Pure, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { + [Pure, OpVersionTrait<11>, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { let summary = "ONNX Round operation"; let description = [{ Round takes one input Tensor and rounds the values, element-wise, meaning @@ -7591,7 +7591,7 @@ def ONNXRoundOp:ONNX_Op<"Round", } def ONNXSTFTOp:ONNX_Op<"STFT", - [Pure, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { + [Pure, OpVersionTrait<17>, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { let summary = "ONNX STFT operation"; let description = [{ Computes the Short-time Fourier Transform of the signal. @@ -7624,7 +7624,7 @@ def ONNXSTFTOp:ONNX_Op<"STFT", } def ONNXScanOp:ONNX_Op<"Scan", - [Pure, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods, OpInterface<"HasOnnxSubgraphOpInterface">]> { + [Pure, OpVersionTrait<19>, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods, OpInterface<"HasOnnxSubgraphOpInterface">]> { let summary = "ONNX Scan operation"; let description = [{ Scan can be used to iterate over one or more scan_input tensors, @@ -7788,7 +7788,7 @@ def ONNXScanOp:ONNX_Op<"Scan", } def ONNXScatterOp:ONNX_Op<"Scatter", - [Pure, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { + [Pure, OpVersionTrait<11>, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { let summary = "ONNX Scatter operation"; let description = [{ This operator is deprecated. Please use ScatterElements, which provides the same functionality. @@ -7872,7 +7872,7 @@ def ONNXScatterOp:ONNX_Op<"Scatter", } def ONNXScatterElementsOp:ONNX_Op<"ScatterElements", - [Pure, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { + [Pure, OpVersionTrait<18>, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { let summary = "ONNX ScatterElements operation"; let description = [{ ScatterElements takes three inputs `data`, `updates`, and `indices` of the same @@ -7967,7 +7967,7 @@ def ONNXScatterElementsOp:ONNX_Op<"ScatterElements", } def ONNXScatterNDOp:ONNX_Op<"ScatterND", - [Pure, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { + [Pure, OpVersionTrait<18>, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { let summary = "ONNX ScatterND operation"; let description = [{ ScatterND takes three inputs `data` tensor of rank r >= 1, `indices` tensor of rank q >= 1, @@ -8074,7 +8074,7 @@ def ONNXScatterNDOp:ONNX_Op<"ScatterND", } def ONNXSeluOp:ONNX_Op<"Selu", - [Pure, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { + [Pure, OpVersionTrait<6>, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { let summary = "ONNX Selu operation"; let description = [{ Selu takes one input data (Tensor) and produces one output data @@ -8109,7 +8109,7 @@ def ONNXSeluOp:ONNX_Op<"Selu", } def ONNXSequenceAtOp:ONNX_Op<"SequenceAt", - [Pure, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { + [Pure, OpVersionTrait<11>, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { let summary = "ONNX SequenceAt operation"; let description = [{ Outputs a tensor copy from the tensor at 'position' in 'input_sequence'. @@ -8141,7 +8141,7 @@ def ONNXSequenceAtOp:ONNX_Op<"SequenceAt", } def ONNXSequenceConstructOp:ONNX_Op<"SequenceConstruct", - [Pure, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { + [Pure, OpVersionTrait<11>, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { let summary = "ONNX SequenceConstruct operation"; let description = [{ Construct a tensor sequence containing 'inputs' tensors. @@ -8171,7 +8171,7 @@ def ONNXSequenceConstructOp:ONNX_Op<"SequenceConstruct", } def ONNXSequenceEmptyOp:ONNX_Op<"SequenceEmpty", - [Pure, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { + [Pure, OpVersionTrait<11>, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { let summary = "ONNX SequenceEmpty operation"; let description = [{ Construct an empty tensor sequence, with given data type. @@ -8201,7 +8201,7 @@ def ONNXSequenceEmptyOp:ONNX_Op<"SequenceEmpty", } def ONNXSequenceEraseOp:ONNX_Op<"SequenceErase", - [Pure, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { + [Pure, OpVersionTrait<11>, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { let summary = "ONNX SequenceErase operation"; let description = [{ Outputs a tensor sequence that removes the tensor at 'position' from 'input_sequence'. @@ -8234,7 +8234,7 @@ def ONNXSequenceEraseOp:ONNX_Op<"SequenceErase", } def ONNXSequenceInsertOp:ONNX_Op<"SequenceInsert", - [Pure, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { + [Pure, OpVersionTrait<11>, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { let summary = "ONNX SequenceInsert operation"; let description = [{ Outputs a tensor sequence that inserts 'tensor' into 'input_sequence' at 'position'. @@ -8270,7 +8270,7 @@ def ONNXSequenceInsertOp:ONNX_Op<"SequenceInsert", } def ONNXSequenceLengthOp:ONNX_Op<"SequenceLength", - [Pure, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { + [Pure, OpVersionTrait<11>, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { let summary = "ONNX SequenceLength operation"; let description = [{ Produces a scalar(tensor of empty shape) containing the number of tensors in 'input_sequence'. @@ -8299,7 +8299,7 @@ def ONNXSequenceLengthOp:ONNX_Op<"SequenceLength", } def ONNXSequenceMapOp:ONNX_Op<"SequenceMap", - [Pure, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods, OpInterface<"HasOnnxSubgraphOpInterface">]> { + [Pure, OpVersionTrait<17>, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods, OpInterface<"HasOnnxSubgraphOpInterface">]> { let summary = "ONNX SequenceMap operation"; let description = [{ Applies a sub-graph to each sample in the input sequence(s). @@ -8347,7 +8347,7 @@ def ONNXSequenceMapOp:ONNX_Op<"SequenceMap", } def ONNXShapeOp:ONNX_Op<"Shape", - [Pure, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { + [Pure, OpVersionTrait<19>, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { let hasCanonicalizer = 1; let summary = "ONNX Shape operation"; let description = [{ @@ -8417,7 +8417,7 @@ def ONNXShapeOp:ONNX_Op<"Shape", } def ONNXShrinkOp:ONNX_Op<"Shrink", - [Pure, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { + [Pure, OpVersionTrait<9>, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { let summary = "ONNX Shrink operation"; let description = [{ Shrink takes one input data (Tensor) and produces one Tensor output, @@ -8451,7 +8451,7 @@ def ONNXShrinkOp:ONNX_Op<"Shrink", } def ONNXSigmoidOp:ONNX_Op<"Sigmoid", - [Pure, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { + [Pure, OpVersionTrait<13>, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { let summary = "ONNX Sigmoid operation"; let description = [{ Sigmoid takes one input data (Tensor) and produces one output data @@ -8482,7 +8482,7 @@ def ONNXSigmoidOp:ONNX_Op<"Sigmoid", } def ONNXSignOp:ONNX_Op<"Sign", - [Pure, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { + [Pure, OpVersionTrait<13>, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { let summary = "ONNX Sign operation"; let description = [{ Calculate the sign of the given input tensor element-wise. @@ -8512,7 +8512,7 @@ def ONNXSignOp:ONNX_Op<"Sign", } def ONNXSinOp:ONNX_Op<"Sin", - [Pure, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { + [Pure, OpVersionTrait<7>, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { let summary = "ONNX Sin operation"; let description = [{ Calculates the sine of the given input tensor, element-wise. @@ -8542,7 +8542,7 @@ def ONNXSinOp:ONNX_Op<"Sin", } def ONNXSinhOp:ONNX_Op<"Sinh", - [Pure, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { + [Pure, OpVersionTrait<9>, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { let summary = "ONNX Sinh operation"; let description = [{ Calculates the hyperbolic sine of the given input tensor element-wise. @@ -8571,7 +8571,7 @@ def ONNXSinhOp:ONNX_Op<"Sinh", } def ONNXSizeOp:ONNX_Op<"Size", - [Pure, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { + [Pure, OpVersionTrait<19>, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { let hasCanonicalizer = 1; let summary = "ONNX Size operation"; let description = [{ @@ -8601,7 +8601,7 @@ def ONNXSizeOp:ONNX_Op<"Size", } def ONNXSliceOp:ONNX_Op<"Slice", - [Pure, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { + [Pure, OpVersionTrait<13>, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { let summary = "ONNX Slice operation"; let description = [{ Produces a slice of the input tensor along multiple axes. Similar to numpy: @@ -8695,7 +8695,7 @@ def ONNXSliceOp:ONNX_Op<"Slice", } def ONNXSoftmaxOp:ONNX_Op<"Softmax", - [Pure, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { + [Pure, OpVersionTrait<13>, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { let summary = "ONNX Softmax operation"; let description = [{ The operator computes the normalized exponential values for the given input: @@ -8741,7 +8741,7 @@ def ONNXSoftmaxOp:ONNX_Op<"Softmax", } def ONNXSoftmaxV11Op:ONNX_Op<"SoftmaxV11", - [Pure, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { + [Pure, OpVersionTrait<11>, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { let hasCanonicalizer = 1; let summary = "ONNX Softmax operation"; let description = [{ @@ -8785,7 +8785,7 @@ def ONNXSoftmaxV11Op:ONNX_Op<"SoftmaxV11", } def ONNXSoftmaxCrossEntropyLossOp:ONNX_Op<"SoftmaxCrossEntropyLoss", - [Pure, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { + [Pure, OpVersionTrait<13>, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { let summary = "ONNX SoftmaxCrossEntropyLoss operation"; let description = [{ Loss function that measures the softmax cross entropy @@ -8858,7 +8858,7 @@ def ONNXSoftmaxCrossEntropyLossOp:ONNX_Op<"SoftmaxCrossEntropyLoss", } def ONNXSoftplusOp:ONNX_Op<"Softplus", - [Pure, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { + [Pure, OpVersionTrait<1>, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { let summary = "ONNX Softplus operation"; let description = [{ Softplus takes one input data (Tensor) and produces one output data @@ -8890,7 +8890,7 @@ def ONNXSoftplusOp:ONNX_Op<"Softplus", } def ONNXSoftsignOp:ONNX_Op<"Softsign", - [Pure, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { + [Pure, OpVersionTrait<1>, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { let summary = "ONNX Softsign operation"; let description = [{ Calculates the softsign (x/(1+|x|)) of the given input tensor element-wise. @@ -8919,7 +8919,7 @@ def ONNXSoftsignOp:ONNX_Op<"Softsign", } def ONNXSpaceToDepthOp:ONNX_Op<"SpaceToDepth", - [Pure, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { + [Pure, OpVersionTrait<13>, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { let hasCanonicalizer = 1; let summary = "ONNX SpaceToDepth operation"; let description = [{ @@ -8953,7 +8953,7 @@ def ONNXSpaceToDepthOp:ONNX_Op<"SpaceToDepth", } def ONNXSplitOp:ONNX_Op<"Split", - [Pure, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { + [Pure, OpVersionTrait<18>, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { let summary = "ONNX Split operation"; let description = [{ Split a tensor into a list of tensors, along the specified 'axis'. @@ -9000,7 +9000,7 @@ def ONNXSplitOp:ONNX_Op<"Split", } def ONNXSplitV13Op:ONNX_Op<"SplitV13", - [Pure, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { + [Pure, OpVersionTrait<13>, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { let summary = "ONNX Split operation"; let description = [{ Split a tensor into a list of tensors, along the specified @@ -9043,7 +9043,7 @@ def ONNXSplitV13Op:ONNX_Op<"SplitV13", } def ONNXSplitV11Op:ONNX_Op<"SplitV11", - [Pure, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { + [Pure, OpVersionTrait<11>, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { let summary = "ONNX Split operation"; let description = [{ Split a tensor into a list of tensors, along the specified @@ -9076,7 +9076,7 @@ def ONNXSplitV11Op:ONNX_Op<"SplitV11", } def ONNXSplitToSequenceOp:ONNX_Op<"SplitToSequence", - [Pure, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { + [Pure, OpVersionTrait<11>, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { let summary = "ONNX SplitToSequence operation"; let description = [{ Split a tensor into a sequence of tensors, along the specified 'axis'. @@ -9120,7 +9120,7 @@ def ONNXSplitToSequenceOp:ONNX_Op<"SplitToSequence", } def ONNXSqrtOp:ONNX_Op<"Sqrt", - [Pure, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { + [Pure, OpVersionTrait<13>, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { let summary = "ONNX Sqrt operation"; let description = [{ Square root takes one input data (Tensor) and produces one output data @@ -9161,7 +9161,7 @@ def ONNXSqrtOp:ONNX_Op<"Sqrt", } def ONNXSqueezeOp:ONNX_Op<"Squeeze", - [Pure, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { + [Pure, OpVersionTrait<13>, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { let hasCanonicalizer = 1; let summary = "ONNX Squeeze operation"; let description = [{ @@ -9206,7 +9206,7 @@ def ONNXSqueezeOp:ONNX_Op<"Squeeze", } def ONNXSqueezeV11Op:ONNX_Op<"SqueezeV11", - [Pure, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { + [Pure, OpVersionTrait<11>, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { let hasCanonicalizer = 1; let summary = "ONNX Squeeze operation"; let description = [{ @@ -9251,7 +9251,7 @@ def ONNXSqueezeV11Op:ONNX_Op<"SqueezeV11", } def ONNXStringNormalizerOp:ONNX_Op<"StringNormalizer", - [Pure, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { + [Pure, OpVersionTrait<10>, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { let summary = "ONNX StringNormalizer operation"; let description = [{ StringNormalization performs string operations for basic cleaning. @@ -9292,7 +9292,7 @@ def ONNXStringNormalizerOp:ONNX_Op<"StringNormalizer", } def ONNXSubOp:ONNX_Op<"Sub", - [Pure, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods, SameOperandsAndResultElementType]> { + [Pure, OpVersionTrait<14>, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods, SameOperandsAndResultElementType]> { let hasCanonicalizer = 1; let summary = "ONNX Sub operation"; let description = [{ @@ -9348,7 +9348,7 @@ def ONNXSubOp:ONNX_Op<"Sub", } def ONNXSumOp:ONNX_Op<"Sum", - [Pure, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods, SameOperandsAndResultElementType]> { + [Pure, OpVersionTrait<13>, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods, SameOperandsAndResultElementType]> { let summary = "ONNX Sum operation"; let description = [{ Element-wise sum of each of the input tensors (with Numpy-style broadcasting support). @@ -9380,7 +9380,7 @@ def ONNXSumOp:ONNX_Op<"Sum", } def ONNXTanOp:ONNX_Op<"Tan", - [Pure, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { + [Pure, OpVersionTrait<7>, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { let summary = "ONNX Tan operation"; let description = [{ Calculates the tangent of the given input tensor, element-wise. @@ -9409,7 +9409,7 @@ def ONNXTanOp:ONNX_Op<"Tan", } def ONNXTanhOp:ONNX_Op<"Tanh", - [Pure, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { + [Pure, OpVersionTrait<13>, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { let summary = "ONNX Tanh operation"; let description = [{ Calculates the hyperbolic tangent of the given input tensor element-wise. @@ -9438,7 +9438,7 @@ def ONNXTanhOp:ONNX_Op<"Tanh", } def ONNXTfIdfVectorizerOp:ONNX_Op<"TfIdfVectorizer", - [Pure, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { + [Pure, OpVersionTrait<9>, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { let summary = "ONNX TfIdfVectorizer operation"; let description = [{ This transform extracts n-grams from the input sequence and save them as a vector. Input can @@ -9502,7 +9502,7 @@ def ONNXTfIdfVectorizerOp:ONNX_Op<"TfIdfVectorizer", } def ONNXThresholdedReluOp:ONNX_Op<"ThresholdedRelu", - [Pure, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { + [Pure, OpVersionTrait<10>, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { let summary = "ONNX ThresholdedRelu operation"; let description = [{ ThresholdedRelu takes one input data (Tensor) and produces one output data @@ -9534,7 +9534,7 @@ def ONNXThresholdedReluOp:ONNX_Op<"ThresholdedRelu", } def ONNXTileOp:ONNX_Op<"Tile", - [Pure, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { + [Pure, OpVersionTrait<13>, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { let summary = "ONNX Tile operation"; let description = [{ Constructs a tensor by tiling a given tensor. @@ -9566,7 +9566,7 @@ def ONNXTileOp:ONNX_Op<"Tile", } def ONNXTopKOp:ONNX_Op<"TopK", - [Pure, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { + [Pure, OpVersionTrait<11>, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { let summary = "ONNX TopK operation"; let description = [{ Retrieve the top-K largest or smallest elements along a specified axis. Given an input tensor of @@ -9615,7 +9615,7 @@ def ONNXTopKOp:ONNX_Op<"TopK", } def ONNXTransposeOp:ONNX_Op<"Transpose", - [Pure, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { + [Pure, OpVersionTrait<13>, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { let hasCanonicalizer = 1; let summary = "ONNX Transpose operation"; let description = [{ @@ -9648,7 +9648,7 @@ def ONNXTransposeOp:ONNX_Op<"Transpose", } def ONNXTriluOp:ONNX_Op<"Trilu", - [Pure, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { + [Pure, OpVersionTrait<14>, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { let summary = "ONNX Trilu operation"; let description = [{ Given a 2-D matrix or batches of 2-D matrices, returns the upper or lower triangular part of the tensor(s). @@ -9690,7 +9690,7 @@ def ONNXTriluOp:ONNX_Op<"Trilu", } def ONNXUniqueOp:ONNX_Op<"Unique", - [Pure, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { + [Pure, OpVersionTrait<11>, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { let summary = "ONNX Unique operation"; let description = [{ Find the unique elements of a tensor. When an optional attribute 'axis' is provided, unique subtensors sliced along the 'axis' are returned. @@ -9821,7 +9821,7 @@ def ONNXUniqueOp:ONNX_Op<"Unique", } def ONNXUnsqueezeOp:ONNX_Op<"Unsqueeze", - [Pure, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { + [Pure, OpVersionTrait<13>, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { let hasCanonicalizer = 1; let summary = "ONNX Unsqueeze operation"; let description = [{ @@ -9871,7 +9871,7 @@ def ONNXUnsqueezeOp:ONNX_Op<"Unsqueeze", } def ONNXUnsqueezeV11Op:ONNX_Op<"UnsqueezeV11", - [Pure, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { + [Pure, OpVersionTrait<11>, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { let hasCanonicalizer = 1; let summary = "ONNX Unsqueeze operation"; let description = [{ @@ -9923,7 +9923,7 @@ def ONNXUnsqueezeV11Op:ONNX_Op<"UnsqueezeV11", } def ONNXUpsampleOp:ONNX_Op<"Upsample", - [Pure, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { + [Pure, OpVersionTrait<9>, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { let summary = "ONNX Upsample operation"; let description = [{ Upsample the input tensor. @@ -9957,7 +9957,7 @@ def ONNXUpsampleOp:ONNX_Op<"Upsample", } def ONNXUpsampleV7Op:ONNX_Op<"UpsampleV7", - [Pure, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { + [Pure, OpVersionTrait<7>, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { let summary = "ONNX Upsample operation"; let description = [{ Upsample the input tensor. @@ -9990,7 +9990,7 @@ def ONNXUpsampleV7Op:ONNX_Op<"UpsampleV7", } def ONNXWhereOp:ONNX_Op<"Where", - [Pure, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { + [Pure, OpVersionTrait<16>, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { let summary = "ONNX Where operation"; let description = [{ Return elements, either from X or Y, depending on condition. @@ -10027,7 +10027,7 @@ def ONNXWhereOp:ONNX_Op<"Where", } def ONNXXorOp:ONNX_Op<"Xor", - [Pure, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { + [Pure, OpVersionTrait<7>, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { let hasCanonicalizer = 1; let summary = "ONNX Xor operation"; let description = [{ @@ -10082,7 +10082,7 @@ def ONNXXorOp:ONNX_Op<"Xor", } def ONNXArrayFeatureExtractorOp:ONNX_Op<"ArrayFeatureExtractor", - [Pure, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { + [Pure, OpVersionTrait<1>, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { let summary = "ONNX ArrayFeatureExtractor operation"; let description = [{ Select elements of the input tensor based on the indices passed.
@@ -10113,7 +10113,7 @@ def ONNXArrayFeatureExtractorOp:ONNX_Op<"ArrayFeatureExtractor", } def ONNXBinarizerOp:ONNX_Op<"Binarizer", - [Pure, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { + [Pure, OpVersionTrait<1>, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { let summary = "ONNX Binarizer operation"; let description = [{ Maps the values of the input tensor to either 0 or 1, element-wise, based on the outcome of a comparison against a threshold value. @@ -10143,7 +10143,7 @@ def ONNXBinarizerOp:ONNX_Op<"Binarizer", } def ONNXCastMapOp:ONNX_Op<"CastMap", - [Pure, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { + [Pure, OpVersionTrait<1>, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { let summary = "ONNX CastMap operation"; let description = [{ Converts a map to a tensor.
The map key must be an int64 and the values will be ordered @@ -10177,7 +10177,7 @@ def ONNXCastMapOp:ONNX_Op<"CastMap", } def ONNXCategoryMapperOp:ONNX_Op<"CategoryMapper", - [Pure, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { + [Pure, OpVersionTrait<1>, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { let summary = "ONNX CategoryMapper operation"; let description = [{ Converts strings to integers and vice versa.
@@ -10218,7 +10218,7 @@ def ONNXCategoryMapperOp:ONNX_Op<"CategoryMapper", } def ONNXDictVectorizerOp:ONNX_Op<"DictVectorizer", - [Pure, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { + [Pure, OpVersionTrait<1>, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { let summary = "ONNX DictVectorizer operation"; let description = [{ Uses an index mapping to convert a dictionary to an array.
@@ -10260,7 +10260,7 @@ def ONNXDictVectorizerOp:ONNX_Op<"DictVectorizer", } def ONNXFeatureVectorizerOp:ONNX_Op<"FeatureVectorizer", - [Pure, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { + [Pure, OpVersionTrait<1>, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { let summary = "ONNX FeatureVectorizer operation"; let description = [{ Concatenates input tensors into one continuous output.
@@ -10293,7 +10293,7 @@ def ONNXFeatureVectorizerOp:ONNX_Op<"FeatureVectorizer", } def ONNXImputerOp:ONNX_Op<"Imputer", - [Pure, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { + [Pure, OpVersionTrait<1>, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { let summary = "ONNX Imputer operation"; let description = [{ Replaces inputs that equal one value with another, leaving all other elements alone.
@@ -10333,7 +10333,7 @@ def ONNXImputerOp:ONNX_Op<"Imputer", } def ONNXLabelEncoderOp:ONNX_Op<"LabelEncoder", - [Pure, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { + [Pure, OpVersionTrait<2>, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { let summary = "ONNX LabelEncoder operation"; let description = [{ Maps each element in the input tensor to another value.
@@ -10387,7 +10387,7 @@ def ONNXLabelEncoderOp:ONNX_Op<"LabelEncoder", } def ONNXLinearClassifierOp:ONNX_Op<"LinearClassifier", - [Pure, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { + [Pure, OpVersionTrait<1>, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { let summary = "ONNX LinearClassifier operation"; let description = [{ Linear classifier @@ -10423,7 +10423,7 @@ def ONNXLinearClassifierOp:ONNX_Op<"LinearClassifier", } def ONNXLinearRegressorOp:ONNX_Op<"LinearRegressor", - [Pure, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { + [Pure, OpVersionTrait<1>, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { let summary = "ONNX LinearRegressor operation"; let description = [{ Generalized linear regression evaluation.
@@ -10461,7 +10461,7 @@ def ONNXLinearRegressorOp:ONNX_Op<"LinearRegressor", } def ONNXNormalizerOp:ONNX_Op<"Normalizer", - [Pure, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { + [Pure, OpVersionTrait<1>, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { let summary = "ONNX Normalizer operation"; let description = [{ Normalize the input. There are three normalization modes, which have the corresponding formulas, @@ -10500,7 +10500,7 @@ def ONNXNormalizerOp:ONNX_Op<"Normalizer", } def ONNXOneHotEncoderOp:ONNX_Op<"OneHotEncoder", - [Pure, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { + [Pure, OpVersionTrait<1>, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { let summary = "ONNX OneHotEncoder operation"; let description = [{ Replace each input element with an array of ones and zeros, where a single @@ -10540,7 +10540,7 @@ def ONNXOneHotEncoderOp:ONNX_Op<"OneHotEncoder", } def ONNXSVMClassifierOp:ONNX_Op<"SVMClassifier", - [Pure, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { + [Pure, OpVersionTrait<1>, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { let summary = "ONNX SVMClassifier operation"; let description = [{ Support Vector Machine classifier @@ -10581,7 +10581,7 @@ def ONNXSVMClassifierOp:ONNX_Op<"SVMClassifier", } def ONNXSVMRegressorOp:ONNX_Op<"SVMRegressor", - [Pure, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { + [Pure, OpVersionTrait<1>, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { let summary = "ONNX SVMRegressor operation"; let description = [{ Support Vector Machine regression prediction and one-class SVM anomaly detection. @@ -10618,7 +10618,7 @@ def ONNXSVMRegressorOp:ONNX_Op<"SVMRegressor", } def ONNXScalerOp:ONNX_Op<"Scaler", - [Pure, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { + [Pure, OpVersionTrait<1>, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { let summary = "ONNX Scaler operation"; let description = [{ Rescale input data, for example to standardize features by removing the mean and scaling to unit variance. @@ -10649,7 +10649,7 @@ def ONNXScalerOp:ONNX_Op<"Scaler", } def ONNXTreeEnsembleClassifierOp:ONNX_Op<"TreeEnsembleClassifier", - [Pure, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { + [Pure, OpVersionTrait<1>, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { let summary = "ONNX TreeEnsembleClassifier operation"; let description = [{ Tree Ensemble classifier. Returns the top class for each of N inputs.
@@ -10704,7 +10704,7 @@ def ONNXTreeEnsembleClassifierOp:ONNX_Op<"TreeEnsembleClassifier", } def ONNXTreeEnsembleRegressorOp:ONNX_Op<"TreeEnsembleRegressor", - [Pure, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { + [Pure, OpVersionTrait<1>, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { let summary = "ONNX TreeEnsembleRegressor operation"; let description = [{ Tree Ensemble regressor. Returns the regressed values for each input in N.
@@ -10759,7 +10759,7 @@ def ONNXTreeEnsembleRegressorOp:ONNX_Op<"TreeEnsembleRegressor", } def ONNXZipMapOp:ONNX_Op<"ZipMap", - [Pure, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { + [Pure, OpVersionTrait<1>, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { let summary = "ONNX ZipMap operation"; let description = [{ Creates a map from the input and the attributes.
@@ -10793,7 +10793,7 @@ def ONNXZipMapOp:ONNX_Op<"ZipMap", } def ONNXAdagradOp:ONNX_Op<"Adagrad", - [Pure, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { + [Pure, OpVersionTrait<1>, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { let summary = "ONNX Adagrad operation"; let description = [{ Compute one iteration of ADAGRAD, a stochastic gradient based optimization @@ -10876,7 +10876,7 @@ def ONNXAdagradOp:ONNX_Op<"Adagrad", } def ONNXAdamOp:ONNX_Op<"Adam", - [Pure, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { + [Pure, OpVersionTrait<1>, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { let summary = "ONNX Adam operation"; let description = [{ Compute one iteration of Adam, a stochastic gradient based optimization @@ -10972,7 +10972,7 @@ def ONNXAdamOp:ONNX_Op<"Adam", } def ONNXGradientOp:ONNX_Op<"Gradient", - [Pure, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { + [Pure, OpVersionTrait<1>, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { let summary = "ONNX Gradient operation"; let description = [{ Gradient operator computes the partial derivatives of a specific tensor w.r.t. @@ -11126,7 +11126,7 @@ def ONNXGradientOp:ONNX_Op<"Gradient", } def ONNXMomentumOp:ONNX_Op<"Momentum", - [Pure, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { + [Pure, OpVersionTrait<1>, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { let summary = "ONNX Momentum operation"; let description = [{ Compute one iteration of stochastic gradient update with momentum. diff --git a/src/Dialect/ONNX/ONNXOps/Tensor/Split.cpp b/src/Dialect/ONNX/ONNXOps/Tensor/Split.cpp index 708503fa90..ee36c01afa 100644 --- a/src/Dialect/ONNX/ONNXOps/Tensor/Split.cpp +++ b/src/Dialect/ONNX/ONNXOps/Tensor/Split.cpp @@ -108,8 +108,8 @@ LogicalResult ONNXSplitOpShapeHelper::computeShape() { // None is fine, indexExprArray will be empty. } else { createIE->getIntFromArrayAsSymbols(split, indexExprArray); - assert(IndexExpr::isLiteral(indexExprArray) && - "dynamic split not yet supported"); + if (!IndexExpr::isLiteral(indexExprArray)) + return failure(); } return customComputeShape(indexExprArray); } @@ -124,8 +124,8 @@ LogicalResult ONNXSplitV13OpShapeHelper::computeShape() { // None is fine, indexExprArray will be empty. } else { createIE->getIntFromArrayAsSymbols(split, indexExprArray); - assert(IndexExpr::isLiteral(indexExprArray) && - "dynamic split not yet supported"); + if (!IndexExpr::isLiteral(indexExprArray)) + return failure(); } return customComputeShape(indexExprArray); } diff --git a/src/Dialect/ONNX/ONNXTraits.hpp b/src/Dialect/ONNX/ONNXTraits.hpp new file mode 100644 index 0000000000..cf04bd37e8 --- /dev/null +++ b/src/Dialect/ONNX/ONNXTraits.hpp @@ -0,0 +1,33 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + */ + +//===----------------- ONNXTraits.hpp - ONNX Op Traits --------------------===// +// +// Copyright (C) 2024, Advanced Micro Devices, Inc. +// +// ============================================================================= +// +// This file defines traits of ONNX ops. +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "mlir/IR/OpDefinition.h" + +namespace mlir { +namespace OpTrait { + +template +class OpVersionTrait { +public: + template + class Impl : public OpTrait::TraitBase { + public: + int getOpVersion() { return version; } + }; +}; + +} // namespace OpTrait +} // namespace mlir diff --git a/test/mlir/conversion/onnx_to_tosa/Math/Elementwise.mlir b/test/mlir/conversion/onnx_to_tosa/Math/Elementwise.mlir index 25f643f419..39f701a9d7 100644 --- a/test/mlir/conversion/onnx_to_tosa/Math/Elementwise.mlir +++ b/test/mlir/conversion/onnx_to_tosa/Math/Elementwise.mlir @@ -513,6 +513,16 @@ func.func @test_pow_f64(%arg0: tensor<13x21x1xf64>, %arg1: tensor<13x21x1xf64>) // ----- +func.func @test_pow_mixed_types(%arg0: tensor<3xf32>, %arg1: tensor<3xi32>) -> (tensor<3xf32>) { + // CHECK-LABEL: func @test_pow_mixed_types + // CHECK-SAME: ([[PARAM_0:%.*]]: tensor<3xf32>, [[PARAM_1:%.*]]: tensor<3xi32>) -> tensor<3xf32> + // CHECK: "onnx.Pow"([[PARAM_0]], [[PARAM_1]]) {onnx_node_name = "onnx.Pow_0"} : (tensor<3xf32>, tensor<3xi32>) -> tensor<3xf32> + %0 = "onnx.Pow"(%arg0, %arg1) {onnx_node_name = "onnx.Pow_0"} : (tensor<3xf32>, tensor<3xi32>) -> tensor<3xf32> + return %0 : tensor<3xf32> +} + +// ----- + func.func @test_sqrt(%arg0: tensor<3xf32>) -> tensor<3xf32> { %0 = "onnx.Sqrt"(%arg0) : (tensor<3xf32>) -> tensor<3xf32> return %0 : tensor<3xf32> @@ -870,7 +880,7 @@ func.func @test_sin(%arg0 : tensor<10x10xf32>) -> tensor<10x10xf32> { "func.return"(%0) : (tensor<10x10xf32>) -> () // CHECK-LABEL: func @test_sin // CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<10x10xf32>) -> tensor<10x10xf32> { -// CHECK-NEXT: [[VAR_0_:%.+]] = tosa.custom [[PARAM_0_]] {domain_name = "UNDEF", implementation_attrs = "linalg.generic", operator_name = "math.sin"} : (tensor<10x10xf32>) -> tensor<10x10xf32> +// CHECK-NEXT: [[VAR_0_:%.+]] = tosa.sin [[PARAM_0_]] : (tensor<10x10xf32>) -> tensor<10x10xf32> // CHECK-NEXT: return [[VAR_0_]] : tensor<10x10xf32> // CHECK-NEXT: } } @@ -882,7 +892,7 @@ func.func @test_sin_dynamic(%arg0 : tensor) -> tensor<*xf32> { "func.return"(%0) : (tensor<*xf32>) -> () // CHECK-LABEL: func @test_sin_dynamic // CHECK-SAME: ([[PARAM_0_:%.+]]: tensor) -> tensor { -// CHECK-NEXT: [[VAR_0_:%.+]] = tosa.custom [[PARAM_0_]] {domain_name = "UNDEF", implementation_attrs = "linalg.generic", operator_name = "math.sin"} : (tensor) -> tensor +// CHECK-NEXT: [[VAR_0_:%.+]] = tosa.sin [[PARAM_0_]] : (tensor) -> tensor // CHECK-NEXT: return [[VAR_0_]] : tensor // CHECK-NEXT: } } @@ -894,7 +904,7 @@ func.func @test_cos(%arg0 : tensor<10x10xf32>) -> tensor<10x10xf32> { "func.return"(%0) : (tensor<10x10xf32>) -> () // CHECK-LABEL: func @test_cos // CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<10x10xf32>) -> tensor<10x10xf32> { -// CHECK-NEXT: [[VAR_0_:%.+]] = tosa.custom [[PARAM_0_]] {domain_name = "UNDEF", implementation_attrs = "linalg.generic", operator_name = "math.cos"} : (tensor<10x10xf32>) -> tensor<10x10xf32> +// CHECK-NEXT: [[VAR_0_:%.+]] = tosa.cos [[PARAM_0_]] : (tensor<10x10xf32>) -> tensor<10x10xf32> // CHECK-NEXT: return [[VAR_0_]] : tensor<10x10xf32> // CHECK-NEXT: } } @@ -906,7 +916,7 @@ func.func @test_cos_dynamic(%arg0 : tensor) -> tensor<*xf32> { "func.return"(%0) : (tensor<*xf32>) -> () // CHECK-LABEL: func @test_cos_dynamic // CHECK-SAME: ([[PARAM_0_:%.+]]: tensor) -> tensor { -// CHECK-NEXT: [[VAR_0_:%.+]] = tosa.custom [[PARAM_0_]] {domain_name = "UNDEF", implementation_attrs = "linalg.generic", operator_name = "math.cos"} : (tensor) -> tensor +// CHECK-NEXT: [[VAR_0_:%.+]] = tosa.cos [[PARAM_0_]] : (tensor) -> tensor // CHECK-NEXT: return [[VAR_0_]] : tensor // CHECK-NEXT: } } diff --git a/test/mlir/conversion/onnx_to_tosa/NN/DequantizeLinear.mlir b/test/mlir/conversion/onnx_to_tosa/NN/DequantizeLinear.mlir index 775bdca6cd..746c3a6989 100644 --- a/test/mlir/conversion/onnx_to_tosa/NN/DequantizeLinear.mlir +++ b/test/mlir/conversion/onnx_to_tosa/NN/DequantizeLinear.mlir @@ -6,16 +6,39 @@ func.func @test_dequantizeLinear(%arg0 : tensor<32x3x224x224xi8>) -> tensor<32x3 %2 = "onnx.DequantizeLinear"(%arg0, %0, %1) {axis = 1 : si64} : (tensor<32x3x224x224xi8>, tensor, tensor) -> tensor<32x3x224x224xf32> "func.return"(%2) : (tensor<32x3x224x224xf32>) -> () } -// CHECK-LABEL: @test_dequantizeLinear(%arg0: tensor<32x3x224x224xi8>) -> tensor<32x3x224x224xf32> +// CHECK-LABEL: @test_dequantizeLinear +// CHECK-SAME: (%[[ARG_0:.*]]: tensor<32x3x224x224xi8>) -> tensor<32x3x224x224xf32> // CHECK-DAG: %[[ZP:.*]] = "tosa.const"() <{value = dense<0> : tensor<1x1x1x1xi8>}> : () -> tensor<1x1x1x1xi8> // CHECK-DAG: %[[SCALE:.*]] = "tosa.const"() <{value = dense<3.125000e-02> : tensor<1x1x1x1xf32>}> : () -> tensor<1x1x1x1xf32> -// CHECK-DAG: %[[SUB:.*]] = tosa.sub %arg0, %[[ZP]] : (tensor<32x3x224x224xi8>, tensor<1x1x1x1xi8>) -> tensor<32x3x224x224xi8> -// CHECK-DAG: %[[CAST:.*]] = tosa.cast %[[SUB]] : (tensor<32x3x224x224xi8>) -> tensor<32x3x224x224xf32> -// CHECK-DAG: %[[MUL:.*]] = tosa.mul %[[CAST]], %[[SCALE]] {shift = 0 : i8} : (tensor<32x3x224x224xf32>, tensor<1x1x1x1xf32>) -> tensor<32x3x224x224xf32> +// CHECK-DAG: %[[CAST_0:.*]] = tosa.cast %[[ARG_0]] : (tensor<32x3x224x224xi8>) -> tensor<32x3x224x224xf32> +// CHECK-DAG: %[[CASTZP:.*]] = tosa.cast %[[ZP]] : (tensor<1x1x1x1xi8>) -> tensor<1x1x1x1xf32> +// CHECK-DAG: %[[SUB:.*]] = tosa.sub %[[CAST_0]], %[[CASTZP]] : (tensor<32x3x224x224xf32>, tensor<1x1x1x1xf32>) -> tensor<32x3x224x224xf32> +// CHECK-DAG: %[[MUL:.*]] = tosa.mul %[[SUB]], %[[SCALE]] {shift = 0 : i8} : (tensor<32x3x224x224xf32>, tensor<1x1x1x1xf32>) -> tensor<32x3x224x224xf32> // CHECK-DAG: return %[[MUL]] : tensor<32x3x224x224xf32> // ----- +func.func @test_dequantizeLinear_f16(%arg0 : tensor<32x3x224x224xi8>) -> tensor<32x3x224x224xf16> { + %0 = onnx.Constant dense<3.125000e-02> : tensor + %1 = onnx.Constant dense<0> : tensor + %2 = "onnx.DequantizeLinear"(%arg0, %0, %1) {axis = 1 : si64} : (tensor<32x3x224x224xi8>, tensor, tensor) -> tensor<32x3x224x224xf16> + "func.return"(%2) : (tensor<32x3x224x224xf16>) -> () +} + +// CHECK-LABEL: @test_dequantizeLinear_f16 +// CHECK-SAME: (%[[ARG_0:.*]]: tensor<32x3x224x224xi8>) -> tensor<32x3x224x224xf16> +// CHECK-DAG: %[[ZP:.*]] = "tosa.const"() <{value = dense<0> : tensor<1x1x1x1xi8>}> : () -> tensor<1x1x1x1xi8> +// CHECK-DAG: %[[SCALE:.*]] = "tosa.const"() <{value = dense<3.125000e-02> : tensor<1x1x1x1xf16>}> : () -> tensor<1x1x1x1xf16> +// CHECK-DAG: %[[CAST_0:.*]] = tosa.cast %[[ARG_0]] : (tensor<32x3x224x224xi8>) -> tensor<32x3x224x224xf32> +// CHECK-DAG: %[[CASTZP:.*]] = tosa.cast %[[ZP]] : (tensor<1x1x1x1xi8>) -> tensor<1x1x1x1xf32> +// CHECK-DAG: %[[SUB:.*]] = tosa.sub %[[CAST_0]], %[[CASTZP]] : (tensor<32x3x224x224xf32>, tensor<1x1x1x1xf32>) -> tensor<32x3x224x224xf32> +// CHECK-DAG: %[[CASTSCALE:.*]] = tosa.cast %[[SCALE]] : (tensor<1x1x1x1xf16>) -> tensor<1x1x1x1xf32> +// CHECK-DAG: %[[MUL:.*]] = tosa.mul %[[SUB]], %[[CASTSCALE]] {shift = 0 : i8} : (tensor<32x3x224x224xf32>, tensor<1x1x1x1xf32>) -> tensor<32x3x224x224xf32> +// CHECK-DAG: %[[CAST:.*]] = tosa.cast %[[MUL]] : (tensor<32x3x224x224xf32>) -> tensor<32x3x224x224xf16> +// CHECK-DAG: return %[[CAST]] : tensor<32x3x224x224xf16> + +// ----- + func.func @per_axis(%arg0: tensor<8x2xi8>) -> tensor<8x2xf32> { %0 = onnx.Constant dense<[1.000000e+00, 2.000000e+00]> : tensor<2xf32> %1 = onnx.Constant dense<[0, 1]> : tensor<2xi8> diff --git a/test/mlir/conversion/onnx_to_tosa/NN/QuantizeLinear.mlir b/test/mlir/conversion/onnx_to_tosa/NN/QuantizeLinear.mlir index 0f7c8a9106..32ebf768fd 100644 --- a/test/mlir/conversion/onnx_to_tosa/NN/QuantizeLinear.mlir +++ b/test/mlir/conversion/onnx_to_tosa/NN/QuantizeLinear.mlir @@ -6,14 +6,16 @@ func.func @test_quantizeLinear(%arg0 : tensor<32x3x224x224xf32>) -> tensor<32x3x %2 = "onnx.QuantizeLinear"(%arg0, %0, %1) {axis = 1 : si64} : (tensor<32x3x224x224xf32>, tensor, tensor) -> tensor<32x3x224x224xi8> "func.return"(%2) : (tensor<32x3x224x224xi8>) -> () } -// CHECK-LABEL: @test_quantizeLinear(%arg0: tensor<32x3x224x224xf32>) -> tensor<32x3x224x224xi8> -// CHECK-DAG: %[[SCALE:.*]] = "tosa.const"() <{value = dense<3.125000e-02> : tensor<1x1x1x1xf32>}> : () -> tensor<1x1x1x1xf32> +// CHECK-LABEL: @test_quantizeLinear +// CHECK-SAME: (%[[ARG_0:.*]]: tensor<32x3x224x224xf32>) -> tensor<32x3x224x224xi8> // CHECK-DAG: %[[ZP:.*]] = "tosa.const"() <{value = dense<0> : tensor<1x1x1x1xi8>}> : () -> tensor<1x1x1x1xi8> +// CHECK-DAG: %[[SCALE:.*]] = "tosa.const"() <{value = dense<3.125000e-02> : tensor<1x1x1x1xf32>}> : () -> tensor<1x1x1x1xf32> // CHECK-DAG: %[[REC:.*]] = tosa.reciprocal %[[SCALE]] : (tensor<1x1x1x1xf32>) -> tensor<1x1x1x1xf32> -// CHECK-DAG: %[[MUL:.*]] = tosa.mul %arg0, %[[REC]] {shift = 0 : i8} : (tensor<32x3x224x224xf32>, tensor<1x1x1x1xf32>) -> tensor<32x3x224x224xf32> -// CHECK-DAG: %[[CAST:.*]] = tosa.cast %[[MUL]] : (tensor<32x3x224x224xf32>) -> tensor<32x3x224x224xi8> -// CHECK-DAG: %[[ADD:.*]] = tosa.add %[[CAST]], %[[ZP]] : (tensor<32x3x224x224xi8>, tensor<1x1x1x1xi8>) -> tensor<32x3x224x224xi8> -// CHECK-DAG: return %[[ADD]] : tensor<32x3x224x224xi8> +// CHECK-DAG: %[[MUL:.*]] = tosa.mul %[[ARG_0]], %[[REC]] {shift = 0 : i8} : (tensor<32x3x224x224xf32>, tensor<1x1x1x1xf32>) -> tensor<32x3x224x224xf32> +// CHECK-DAG: %[[ZPCAST:.*]] = tosa.cast %[[ZP]] : (tensor<1x1x1x1xi8>) -> tensor<1x1x1x1xf32> +// CHECK-DAG: %[[ADD:.*]] = tosa.add %3, %[[ZPCAST]] : (tensor<32x3x224x224xf32>, tensor<1x1x1x1xf32>) -> tensor<32x3x224x224xf32> +// CHECK-DAG: %[[CAST:.*]] = tosa.cast %[[ADD]] : (tensor<32x3x224x224xf32>) -> tensor<32x3x224x224xi8> +// CHECK-DAG: return %[[CAST]] : tensor<32x3x224x224xi8> // ----- @@ -29,11 +31,12 @@ func.func @test_quantizeLinear_per_axis(%arg0: tensor<8x2xf32>) -> tensor<8x2xi8 // CHECK-SAME: %[[VAL_0:.*]]: tensor<8x2xf32>) -> tensor<8x2xi8> { // CHECK: %[[VAL_1:.*]] = "tosa.const"() <{value = dense<{{\[\[}}0, 1]]> : tensor<1x2xi8>}> : () -> tensor<1x2xi8> // CHECK: %[[VAL_2:.*]] = "tosa.const"() <{value = dense<{{\[\[}}1.000000e+00, 2.000000e+00]]> : tensor<1x2xf32>}> : () -> tensor<1x2xf32> -// CHECK: %[[VAL_3:.*]] = tosa.reciprocal %[[VAL_2]] : (tensor<1x2xf32>) -> tensor<1x2xf32> -// CHECK: %[[VAL_4:.*]] = tosa.mul %[[VAL_0]], %[[VAL_3]] {shift = 0 : i8} : (tensor<8x2xf32>, tensor<1x2xf32>) -> tensor<8x2xf32> -// CHECK: %[[VAL_5:.*]] = tosa.cast %[[VAL_4]] : (tensor<8x2xf32>) -> tensor<8x2xi8> -// CHECK: %[[VAL_6:.*]] = tosa.add %[[VAL_5]], %[[VAL_1]] : (tensor<8x2xi8>, tensor<1x2xi8>) -> tensor<8x2xi8> -// CHECK: return %[[VAL_6]] : tensor<8x2xi8> +// CHECK: %[[REC:.*]] = tosa.reciprocal %[[VAL_2]] : (tensor<1x2xf32>) -> tensor<1x2xf32> +// CHECK: %[[MUL:.*]] = tosa.mul %[[VAL_0]], %[[REC]] {shift = 0 : i8} : (tensor<8x2xf32>, tensor<1x2xf32>) -> tensor<8x2xf32> +// CHECK: %[[ZPCAST:.*]] = tosa.cast %[[VAL_1]] : (tensor<1x2xi8>) -> tensor<1x2xf32> +// CHECK: %[[ADD:.*]] = tosa.add %[[MUL]], %[[ZPCAST]] : (tensor<8x2xf32>, tensor<1x2xf32>) -> tensor<8x2xf32> +// CHECK: %[[CAST:.*]] = tosa.cast %[[ADD]] : (tensor<8x2xf32>) -> tensor<8x2xi8> +// CHECK: return %[[CAST]] : tensor<8x2xi8> // CHECK: } // ----- diff --git a/test/mlir/conversion/onnx_to_tosa/Tensor/Split.mlir b/test/mlir/conversion/onnx_to_tosa/Tensor/Split.mlir new file mode 100644 index 0000000000..367d3ce988 --- /dev/null +++ b/test/mlir/conversion/onnx_to_tosa/Tensor/Split.mlir @@ -0,0 +1,119 @@ +// RUN: onnx-mlir-opt --convert-onnx-to-tosa -cse %s -split-input-file | FileCheck %s + +func.func @test_split_equal(%arg0 : tensor<16x32x64xf32>) -> (tensor<8x32x64xf32>, tensor<8x32x64xf32>) { + %cst = "onnx.NoValue"() {value} : () -> none + %0, %1 = "onnx.Split"(%arg0, %cst) { axis = 0 : si64} : (tensor<16x32x64xf32>, none) -> (tensor<8x32x64xf32>, tensor<8x32x64xf32>) + return %0, %1 : tensor<8x32x64xf32>, tensor<8x32x64xf32> +} + +// CHECK-LABEL: func.func @test_split_equal +// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<16x32x64xf32>) -> (tensor<8x32x64xf32>, tensor<8x32x64xf32>) { +// CHECK-DAG: [[VAR_0_:%.+]] = tosa.slice [[PARAM_0_]] {size = array, start = array} : (tensor<16x32x64xf32>) -> tensor<8x32x64xf32> +// CHECK-DAG: [[VAR_1_:%.+]] = tosa.slice [[PARAM_0_]] {size = array, start = array} : (tensor<16x32x64xf32>) -> tensor<8x32x64xf32> +// CHECK: return [[VAR_0_]], [[VAR_1_]] : tensor<8x32x64xf32>, tensor<8x32x64xf32> + +// ----- + +func.func @test_split_variable(%arg0 : tensor<16x32x64xf16>) -> (tensor<16x2x64xf16>, tensor<16x30x64xf16>) { + %split = "onnx.Constant"() {value = dense<[2, 30]> : tensor<2xi64>} : () -> tensor<2xi64> + %0, %1 = "onnx.Split"(%arg0, %split) {axis = 1 : si64} : (tensor<16x32x64xf16>, tensor<2xi64>) -> (tensor<16x2x64xf16>, tensor<16x30x64xf16>) + "func.return"(%0, %1) : (tensor<16x2x64xf16>, tensor<16x30x64xf16>) -> () +} + +// CHECK-LABEL: func.func @test_split_variable +// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<16x32x64xf16>) -> (tensor<16x2x64xf16>, tensor<16x30x64xf16>) { +// CHECK-DAG: [[VAR_0_:%.+]] = tosa.slice [[PARAM_0_]] {size = array, start = array} : (tensor<16x32x64xf16>) -> tensor<16x2x64xf16> +// CHECK-DAG: [[VAR_1_:%.+]] = tosa.slice [[PARAM_0_]] {size = array, start = array} : (tensor<16x32x64xf16>) -> tensor<16x30x64xf16> +// CHECK: return [[VAR_0_]], [[VAR_1_]] : tensor<16x2x64xf16>, tensor<16x30x64xf16> + +// ----- + +func.func @test_split_multiple(%arg0 : tensor<16x32x64xf16>) -> (tensor<16x4x64xf16>, tensor<16x8x64xf16>, tensor<16x20x64xf16>) { + %split = "onnx.Constant"() {value = dense<[4, 8, 20]> : tensor<3xi64>} : () -> tensor<3xi64> + %0, %1, %2 = "onnx.Split"(%arg0, %split) {axis = 1 : si64} : (tensor<16x32x64xf16>, tensor<3xi64>) -> (tensor<16x4x64xf16>, tensor<16x8x64xf16>, tensor<16x20x64xf16>) + "func.return"(%0, %1, %2) : (tensor<16x4x64xf16>, tensor<16x8x64xf16>, tensor<16x20x64xf16>) -> () +} + +// CHECK-LABEL: func.func @test_split_multiple +// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<16x32x64xf16>) -> (tensor<16x4x64xf16>, tensor<16x8x64xf16>, tensor<16x20x64xf16>) { +// CHECK-DAG: [[VAR_0_:%.+]] = tosa.slice [[PARAM_0_]] {size = array, start = array} : (tensor<16x32x64xf16>) -> tensor<16x4x64xf16> +// CHECK-DAG: [[VAR_1_:%.+]] = tosa.slice [[PARAM_0_]] {size = array, start = array} : (tensor<16x32x64xf16>) -> tensor<16x8x64xf16> +// CHECK-DAG: [[VAR_2_:%.+]] = tosa.slice [[PARAM_0_]] {size = array, start = array} : (tensor<16x32x64xf16>) -> tensor<16x20x64xf16> +// CHECK: return [[VAR_0_]], [[VAR_1_]], [[VAR_2_]] : tensor<16x4x64xf16>, tensor<16x8x64xf16>, tensor<16x20x64xf16> + + +// ----- + +func.func @test_no_split(%arg0 : tensor<16x32x64xi32>) -> tensor<16x16x64xi32> { + %cst = "onnx.NoValue"() {value} : () -> none + %0, %1 = "onnx.Split"(%arg0, %cst) { axis = 1 : si64} : (tensor<16x32x64xi32>, none) -> (tensor<16x16x64xi32>, tensor<16x16x64xi32>) + "func.return"(%0) : (tensor<16x16x64xi32>) -> () +} + +// CHECK-LABEL: func.func @test_no_split +// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<16x32x64xi32>) -> tensor<16x16x64xi32> { +// CHECK: [[VAR_0_:%.+]] = tosa.slice [[PARAM_0_]] {size = array, start = array} : (tensor<16x32x64xi32>) -> tensor<16x16x64xi32> +// CHECK: return [[VAR_0_]] : tensor<16x16x64xi32> + + +// ----- + +func.func @test_split_negative_axis(%arg0 : tensor<16x32x64xbf16>) -> (tensor<16x16x64xbf16>, tensor<16x16x64xbf16>) { + %cst = "onnx.NoValue"() {value} : () -> none + %0, %1 = "onnx.Split"(%arg0, %cst) { axis = -2 : si64} : (tensor<16x32x64xbf16>, none) -> (tensor<16x16x64xbf16>, tensor<16x16x64xbf16>) + "func.return"(%0, %1) : (tensor<16x16x64xbf16>, tensor<16x16x64xbf16>) -> () +} + +// CHECK-LABEL: func.func @test_split_negative_axis +// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<16x32x64xbf16>) -> (tensor<16x16x64xbf16>, tensor<16x16x64xbf16>) { +// CHECK: [[VAR_0_:%.+]] = tosa.slice [[PARAM_0_]] {size = array, start = array} : (tensor<16x32x64xbf16>) -> tensor<16x16x64xbf16> +// CHECK: [[VAR_1_:%.+]] = tosa.slice [[PARAM_0_]] {size = array, start = array} : (tensor<16x32x64xbf16>) -> tensor<16x16x64xbf16> +// CHECK: return [[VAR_0_]], [[VAR_1_]] : tensor<16x16x64xbf16>, tensor<16x16x64xbf16> + +// ----- + +func.func @test_non_constant_split(%arg0 : tensor<16x32x64xi16>, %arg1 : tensor<2xi64>) -> tensor<16x?x64xi16> { + %0, %1 = "onnx.Split"(%arg0, %arg1) {axis = 1 : si64} : (tensor<16x32x64xi16>, tensor<2xi64>) -> (tensor<16x?x64xi16>, tensor<16x?x64xi16>) + "func.return"(%0) : (tensor<16x?x64xi16>) -> () +} + +// CHECK-LABEL: func.func @test_non_constant_split +// CHECK-NOT: tosa.slice + +// ----- + +func.func @test_zero_split(%arg0 : tensor<16x32x64xi16>) -> tensor<16x0x64xi16> { + %split = "onnx.Constant"() {value = dense<[32, 0]> : tensor<2xi64>} : () -> tensor<2xi64> + %0, %1 = "onnx.Split"(%arg0, %split) {axis = 1 : si64} : (tensor<16x32x64xi16>, tensor<2xi64>) -> (tensor<16x32x64xi16>, tensor<16x0x64xi16>) + "func.return"(%1) : (tensor<16x0x64xi16>) -> () +} + +// CHECK-LABEL: func.func @test_zero_split +// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<16x32x64xi16>) -> tensor<16x0x64xi16> { +// CHECK: [[VAR_0_:%.+]] = tosa.slice [[PARAM_0_]] {size = array, start = array} : (tensor<16x32x64xi16>) -> tensor<16x0x64xi16> +// CHECK: return [[VAR_0_]] : tensor<16x0x64xi16> + +// ----- +// Legalization won't happen since tosa.slice doesn't +// allow dynamic entry in 'size' attribute +func.func @test_dynamic_shapes(%arg0 : tensor<16x32x?xf32>) -> tensor<16x2x?xf32> { + %split = "onnx.Constant"() {value = dense<[2, 30]> : tensor<2xi64>} : () -> tensor<2xi64> + %0, %1 = "onnx.Split"(%arg0, %split) {axis = 1 : si64} : (tensor<16x32x?xf32>, tensor<2xi64>) -> (tensor<16x2x?xf32>, tensor<16x30x?xf32>) + return %0 : tensor<16x2x?xf32> +} + +// CHECK-LABEL: func.func @test_dynamic_shapes +// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<16x32x?xf32>) -> tensor<16x2x?xf32> { +// CHECK-NOT: tosa.slice + +// ----- +func.func @test_num_outputs(%arg0 : tensor<16x32x64xf32>) -> tensor<8x32x64xf32> { + %cst = "onnx.NoValue"() {value} : () -> none + %0, %1 = "onnx.Split"(%arg0, %cst) {axis = 0 : si64, num_outputs = 2 : si64} : (tensor<16x32x64xf32>, none) -> (tensor<8x32x64xf32>, tensor<8x32x64xf32>) + return %0 : tensor<8x32x64xf32> +} + +// CHECK-LABEL: func.func @test_num_outputs +// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<16x32x64xf32>) -> tensor<8x32x64xf32> { +// CHECK-DAG: [[VAR_0_:%.+]] = tosa.slice [[PARAM_0_]] {size = array, start = array} : (tensor<16x32x64xf32>) -> tensor<8x32x64xf32> +// CHECK: return [[VAR_0_]] : tensor<8x32x64xf32> diff --git a/utils/gen_onnx_mlir.py b/utils/gen_onnx_mlir.py index f2eb538bea..fbcd2714f6 100755 --- a/utils/gen_onnx_mlir.py +++ b/utils/gen_onnx_mlir.py @@ -1166,7 +1166,7 @@ def gen_op_def(schema, with_version=False): regions[attr.name] = "AnyRegion" # Generate decl for op traits. - traits = ["Pure"] + traits = ["Pure", f"OpVersionTrait<{schema.since_version}>"] # Generate ConstantLike traits. if opName in OpsWithConstantLike: