Skip to content

Commit

Permalink
Fix pad axes
Browse files Browse the repository at this point in the history
  • Loading branch information
josel-amd committed Jun 12, 2024
1 parent 55d5d4b commit 1541fcd
Showing 1 changed file with 19 additions and 5 deletions.
24 changes: 19 additions & 5 deletions src/Dialect/ONNX/ONNXOps/Tensor/Pad.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
//===----------------------------------------------------------------------===//

#include "src/Dialect/ONNX/ONNXOps/OpHelper.hpp"
#include "llvm/ADT/STLExtras.h"

using namespace mlir;
using namespace mlir::OpTrait::util;
Expand All @@ -28,11 +29,25 @@ LogicalResult ONNXPadOpShapeHelper::computeShape() {
ONNXPadOpAdaptor operandAdaptor(operands);
Value dataOperand = operandAdaptor.getData();
Value padsOperand = operandAdaptor.getPads();
Value axesOperand = operandAdaptor.getAxes();
DimsExpr outputDims;

// Get info about input data operand.
uint64_t dataRank = createIE->getShapedTypeRank(dataOperand);

// If the axes operand is provided, the output shape is at least guaranteed to
// keep the same rank as the input. But nothing can be said about the actual
// size of each dimension
if (!isNoneValue(axesOperand)) {
bool isFloat = isa<FloatType>(getElementType(dataOperand.getType()));
llvm::for_each(llvm::iota_range<int64_t>(0, dataRank, /*Inclusive=*/false),
[&outputDims, isFloat](const auto /*idx*/) {
outputDims.push_back(QuestionmarkIndexExpr(/*IsFloat=*/isFloat));
});
setOutputDims(outputDims);
return success();
}

// Initialize context and results (pads & output)
pads.resize(2 * dataRank); // pads two sides of each axis.
outputDims.resize(dataRank);
Expand All @@ -47,14 +62,17 @@ LogicalResult ONNXPadOpShapeHelper::computeShape() {
// Get begin/end pads.
SymbolIndexExpr padBegin(createIE->getIntFromArrayAsSymbol(padsOperand, i));
SymbolIndexExpr padEnd(
createIE->getIntFromArrayAsSymbol(padsOperand, i + dataRank));
createIE->getIntFromArrayAsSymbol(padsOperand, i + dataRank - 1));
if (padBegin.isUndefined() || padEnd.isUndefined())
return op->emitError("pad parameter could not be processed");
// Get input dim.
DimIndexExpr dimInput(createIE->getShapeAsDim(dataOperand, i));

// Calculation for output size.
IndexExpr dimOutputFinal = (padBegin + dimInput) + padEnd;
std::string debug;
dimOutputFinal.debugPrint(debug);
llvm::errs() << debug << "\n";

// Save results.
pads[i] = padBegin;
Expand Down Expand Up @@ -86,10 +104,6 @@ LogicalResult ONNXPadOp::verify() {
}
}

if (!isNoneValue(getAxes())) {
return emitOpError("Axes input is not currently supported");
}

return success();
}

Expand Down

0 comments on commit 1541fcd

Please sign in to comment.