Skip to content

Commit

Permalink
addressed comments
Browse files Browse the repository at this point in the history
  • Loading branch information
p-lanza committed Dec 19, 2024
1 parent 72216f0 commit a8f48db
Showing 1 changed file with 3 additions and 8 deletions.
11 changes: 3 additions & 8 deletions src/Conversion/ONNXToTOSA/Tensor/Where.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,15 +14,14 @@ namespace {
class ONNXWhereLoweringToTOSA : public OpConversionPattern<ONNXWhereOp> {
public:
using OpConversionPattern::OpConversionPattern;
using OpAdaptor = typename ONNXWhereOp::Adaptor;

LogicalResult matchAndRewrite(ONNXWhereOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {

auto loc = op.getLoc();
Value pred = adaptor.getOperands()[0];
Value lhs = adaptor.getOperands()[1];
Value rhs = adaptor.getOperands()[2];
Value pred = adaptor.getCondition();
Value lhs = adaptor.getX();
Value rhs = adaptor.getY();

// Check types are compatible
auto predType = dyn_cast<TensorType>(pred.getType());
Expand All @@ -33,10 +32,6 @@ class ONNXWhereLoweringToTOSA : public OpConversionPattern<ONNXWhereOp> {
if (!predType || !lhsType || !rhsType || !resultType) {
return rewriter.notifyMatchFailure(op, "Tosa only supports TensorTypes");
}
if (!isTOSABool(predType.getElementType())) {
return rewriter.notifyMatchFailure(
op, "Expected bool type for condition to onnx.Where");
}
if (lhsType.getElementType() != rhsType.getElementType() ||
lhsType.getElementType() != resultType.getElementType()) {
return rewriter.notifyMatchFailure(op,
Expand Down

0 comments on commit a8f48db

Please sign in to comment.