Skip to content

Commit

Permalink
Address reviewer comments
Browse files Browse the repository at this point in the history
  • Loading branch information
philippb-amd committed Jul 27, 2022
1 parent 06a87c4 commit f8bdb7a
Show file tree
Hide file tree
Showing 3 changed files with 41 additions and 20 deletions.
13 changes: 1 addition & 12 deletions src/Conversion/ONNXToTOSA/ConvertONNXToTOSA.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,17 +18,6 @@ using namespace mlir;

namespace onnx_mlir {

static bool isSignedInt(Type type) {
IntegerType intType = type.dyn_cast<IntegerType>();
std::set<unsigned> intWidth{8, 16, 32, 48, 64};
return intType && intType.isSigned() &&
(intWidth.find(intType.getWidth()) != intWidth.end());
}

static bool isFloat(Type type) {
return type.isa<BFloat16Type, Float16Type, Float32Type>();
}

void populateONNXToTOSAConversionPattern(ConversionTarget &target,
RewritePatternSet &patterns, TypeConverter &typeConverter,
MLIRContext *ctx) {
Expand Down Expand Up @@ -65,7 +54,7 @@ void FrontendToTosaLoweringPass::runOnOperation() {
// conversion failures. Quantized types are not supported right now.
TypeConverter typeConverter;
typeConverter.addConversion([](Type type) -> Optional<Type> {
if (isSignedInt(type) || isFloat(type))
if (isTOSASignedInt(type) || isTOSAFloat(type))
return type;
return llvm::None;
});
Expand Down
33 changes: 25 additions & 8 deletions src/Conversion/ONNXToTOSA/Math/Elementwise.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
// This file lowers ONNX element-wise operators to TOSA dialect.
//
//===----------------------------------------------------------------------===//

#include "mlir/IR/TypeUtilities.h"
#include "src/Conversion/ONNXToTOSA/ONNXToTOSACommon.hpp"

using namespace mlir;
Expand All @@ -32,6 +32,25 @@ class ONNXUnaryOpLoweringToTOSA : public OpConversionPattern<ONNXOpT> {
}
};

template <typename ONNXOpT>
class ONNXUnaryOpLoweringToTOSA<ONNXOpT, tosa::FloorOp>
: public OpConversionPattern<ONNXOpT> {
public:
using OpConversionPattern<ONNXOpT>::OpConversionPattern;
using OpAdaptor = typename ONNXOpT::Adaptor;
LogicalResult matchAndRewrite(ONNXOpT op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {

auto scalarType = getElementTypeOrSelf(adaptor.X());
if (!isTOSAFloat(scalarType))
return rewriter.notifyMatchFailure(
op, "`tosa.floor` only supports float types");

rewriter.replaceOpWithNewOp<tosa::FloorOp>(op, op.getType(), adaptor.X());
return success();
}
};

class ONNXReluOpLoweringToTOSA : public OpConversionPattern<ONNXReluOp> {
public:
using OpConversionPattern::OpConversionPattern;
Expand All @@ -58,14 +77,12 @@ void populateLoweringONNXElementwiseOpToTOSAPattern(ConversionTarget &target,
RewritePatternSet &patterns, TypeConverter &typeConverter,
MLIRContext *ctx) {
patterns.insert<ONNXReluOpLoweringToTOSA>(typeConverter, ctx);

#define INSERT_UNARY_PATTERN(ONNXOp, TOSAOp) \
target.addIllegalOp<ONNXOp>(); \
patterns.insert<ONNXUnaryOpLoweringToTOSA<ONNXOp, TOSAOp>>( \
target.addIllegalOp<ONNXNegOp>();
patterns.insert<ONNXUnaryOpLoweringToTOSA<ONNXNegOp, tosa::NegateOp>>(
typeConverter, ctx);
target.addIllegalOp<ONNXFloorOp>();
patterns.insert<ONNXUnaryOpLoweringToTOSA<ONNXFloorOp, tosa::FloorOp>>(
typeConverter, ctx);
INSERT_UNARY_PATTERN(ONNXNegOp, tosa::NegateOp)
INSERT_UNARY_PATTERN(ONNXFloorOp, tosa::FloorOp)
#undef INSERT_UNARY_PATTERN
}

} // namespace onnx_mlir
15 changes: 15 additions & 0 deletions src/Conversion/ONNXToTOSA/ONNXToTOSACommon.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,21 @@

namespace onnx_mlir {

//===----------------------------------------------------------------------===//
// Check for valid TOSA types.
//===----------------------------------------------------------------------===//

inline bool isTOSASignedInt(Type type) {
IntegerType intType = type.dyn_cast<IntegerType>();
std::set<unsigned> intWidth{8, 16, 32, 48, 64};
return intType && intType.isSigned() &&
(intWidth.find(intType.getWidth()) != intWidth.end());
}

inline bool isTOSAFloat(Type type) {
return type.isa<BFloat16Type, Float16Type, Float32Type>();
}

//===----------------------------------------------------------------------===//
// This is to get a TOSA operation of a given type for a specific operation.
//===----------------------------------------------------------------------===//
Expand Down

0 comments on commit f8bdb7a

Please sign in to comment.