Skip to content

Commit

Permalink
Skip external consts when walking operands
Browse files Browse the repository at this point in the history
  • Loading branch information
jorickert committed Oct 4, 2024
1 parent 5793d27 commit a760d41
Showing 1 changed file with 16 additions and 3 deletions.
19 changes: 16 additions & 3 deletions lib/Transform/XTenMinimizeLiveTensors.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -192,14 +192,25 @@ FailureOr<SmallVector<Value>> getFmOperands(Operation *op) {
if (isa<func::FuncOp>(op))
return {{}};

const auto filterOutExternalConst = [](SmallVector<Value> operands) {
SmallVector<Value> filteredOperands;
for (const auto operand : operands) {
if (!isa_and_nonnull<amd::xten_nn::LoadExternalConstOp>(
operand.getDefiningOp())) {
filteredOperands.push_back(operand);
}
}
return filteredOperands;
};

if (isInCoreChain(op))
return {getSubgraphIFMs(op)};

if (isConcatSubgraph(op))
return {getSubgraphIFMs(op)};

if (isTemplatedGraph(op))
return {op->getOperands()};
return {filterOutExternalConst(op->getOperands())};

// Otherwise, this is a PseudoOp and IFM is the first operand.
if (!(isAnyPseudoOp(op) || isInterfaceOp(op))) {
Expand All @@ -225,7 +236,8 @@ size_t getSize(Value val) {

if (auto complexType = elementType.dyn_cast<ComplexType>()) {
elementType = complexType.getElementType();
return (elementType.getIntOrFloatBitWidth() * type.getNumElements() * 2) / 8;
return (elementType.getIntOrFloatBitWidth() * type.getNumElements() * 2) /
8;
}
llvm_unreachable("Does not know how to compute size");
}
Expand Down Expand Up @@ -299,7 +311,8 @@ class XTenMinimizeLiveTensorsPass
} else {
fmResults = SmallVector<Value>(currFn.getBody().front().getArguments());
}
std::optional<Value> const sharesResultMemory = sharesMemoryWithResult(defOp);
std::optional<Value> const sharesResultMemory =
sharesMemoryWithResult(defOp);
OpInfo info = {.op = defOp,
.operands = *fmOperands,
.results = fmResults,
Expand Down

0 comments on commit a760d41

Please sign in to comment.