Skip to content

Commit

Permalink
MLIR: improve dataflow interface (#2110)
Browse files Browse the repository at this point in the history
* MLIR: improve dataflow interface

* Fixup

* fix

* fix

* fix

* fix

* fix
  • Loading branch information
wsmoses authored Oct 7, 2024
1 parent 1c9aff0 commit 71dbc56
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 8 deletions.
18 changes: 10 additions & 8 deletions enzyme/Enzyme/MLIR/Analysis/ActivityAnalysis.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -941,9 +941,9 @@ getPotentialTerminatorUsers(Operation *op, Value parent) {
if (isFunctionReturn(op))
return {};

SmallVector<Value> results;

if (isa<RegionBranchOpInterface>(op->getParentOp()))
if (auto termIface = dyn_cast<ADDataFlowOpInterface>(op->getParentOp())) {
return termIface.getPotentialTerminatorUsers(op, parent);
} else if (isa<RegionBranchOpInterface>(op->getParentOp())) {
if (auto termIface = dyn_cast<RegionBranchTerminatorOpInterface>(op)) {
SmallVector<RegionSuccessor> successors;
termIface.getSuccessorRegions(
Expand All @@ -966,6 +966,8 @@ getPotentialTerminatorUsers(Operation *op, Value parent) {
}
return std::move(results);
}
}
SmallVector<Value> results;
if (auto iface = dyn_cast<BranchOpInterface>(op)) {
for (auto &operand : op->getOpOperands())
if (operand.get() == parent)
Expand Down Expand Up @@ -1002,7 +1004,11 @@ static SmallVector<Value> getPotentialIncomingValues(OpResult res) {

auto resultNo = res.getResultNumber();

if (auto iface = dyn_cast<RegionBranchOpInterface>(owner)) {
if (auto iface = dyn_cast<ADDataFlowOpInterface>(owner)) {
for (auto val : iface.getPotentialIncomingValuesRes(res))
potentialSources.push_back(val);
return potentialSources;
} else if (auto iface = dyn_cast<RegionBranchOpInterface>(owner)) {
SmallVector<RegionSuccessor> successors;
iface.getSuccessorRegions(RegionBranchPoint::parent(), successors);
for (auto &succ : successors) {
Expand All @@ -1019,10 +1025,6 @@ static SmallVector<Value> getPotentialIncomingValues(OpResult res) {

potentialSources.push_back(successorOperands[resultNo]);
}
} else if (auto iface = dyn_cast<ADDataFlowOpInterface>(owner)) {
for (auto val : iface.getPotentialIncomingValuesRes(res))
potentialSources.push_back(val);
return potentialSources;
} else {
// assume all inputs potentially flow into all op results
for (auto operand : owner->getOperands()) {
Expand Down
7 changes: 7 additions & 0 deletions enzyme/Enzyme/MLIR/Interfaces/AutoDiffOpInterface.td
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,13 @@ def ADDataFlowOpInterface
/*retTy=*/"SmallVector<Value>",
/*methodName=*/"getPotentialIncomingValuesArg",
/*args=*/(ins "::mlir::BlockArgument":$v)
>,
InterfaceMethod<
/*desc=*/[{
}],
/*retTy=*/"SmallVector<Value>",
/*methodName=*/"getPotentialTerminatorUsers",
/*args=*/(ins "::mlir::Operation*":$terminator, "::mlir::Value":$val)
>
];
}
Expand Down

0 comments on commit 71dbc56

Please sign in to comment.