From 71dbc567eca1cc905dd29ca544d523738aa40663 Mon Sep 17 00:00:00 2001 From: William Moses Date: Mon, 7 Oct 2024 16:03:44 -0500 Subject: [PATCH] MLIR: improve dataflow interface (#2110) * MLIR: improve dataflow interface * Fixup * fix * fix * fix * fix * fix --- .../Enzyme/MLIR/Analysis/ActivityAnalysis.cpp | 18 ++++++++++-------- .../MLIR/Interfaces/AutoDiffOpInterface.td | 7 +++++++ 2 files changed, 17 insertions(+), 8 deletions(-) diff --git a/enzyme/Enzyme/MLIR/Analysis/ActivityAnalysis.cpp b/enzyme/Enzyme/MLIR/Analysis/ActivityAnalysis.cpp index ef54f3c385a..c7b7a256fb9 100644 --- a/enzyme/Enzyme/MLIR/Analysis/ActivityAnalysis.cpp +++ b/enzyme/Enzyme/MLIR/Analysis/ActivityAnalysis.cpp @@ -941,9 +941,9 @@ getPotentialTerminatorUsers(Operation *op, Value parent) { if (isFunctionReturn(op)) return {}; - SmallVector results; - - if (isa(op->getParentOp())) + if (auto termIface = dyn_cast(op->getParentOp())) { + return termIface.getPotentialTerminatorUsers(op, parent); + } else if (isa(op->getParentOp())) { if (auto termIface = dyn_cast(op)) { SmallVector successors; termIface.getSuccessorRegions( @@ -966,6 +966,8 @@ getPotentialTerminatorUsers(Operation *op, Value parent) { } return std::move(results); } + } + SmallVector results; if (auto iface = dyn_cast(op)) { for (auto &operand : op->getOpOperands()) if (operand.get() == parent) @@ -1002,7 +1004,11 @@ static SmallVector getPotentialIncomingValues(OpResult res) { auto resultNo = res.getResultNumber(); - if (auto iface = dyn_cast(owner)) { + if (auto iface = dyn_cast(owner)) { + for (auto val : iface.getPotentialIncomingValuesRes(res)) + potentialSources.push_back(val); + return potentialSources; + } else if (auto iface = dyn_cast(owner)) { SmallVector successors; iface.getSuccessorRegions(RegionBranchPoint::parent(), successors); for (auto &succ : successors) { @@ -1019,10 +1025,6 @@ static SmallVector getPotentialIncomingValues(OpResult res) { potentialSources.push_back(successorOperands[resultNo]); } - } else if (auto iface = dyn_cast(owner)) { - for (auto val : iface.getPotentialIncomingValuesRes(res)) - potentialSources.push_back(val); - return potentialSources; } else { // assume all inputs potentially flow into all op results for (auto operand : owner->getOperands()) { diff --git a/enzyme/Enzyme/MLIR/Interfaces/AutoDiffOpInterface.td b/enzyme/Enzyme/MLIR/Interfaces/AutoDiffOpInterface.td index 82a4e4e168c..8804580fbd8 100644 --- a/enzyme/Enzyme/MLIR/Interfaces/AutoDiffOpInterface.td +++ b/enzyme/Enzyme/MLIR/Interfaces/AutoDiffOpInterface.td @@ -151,6 +151,13 @@ def ADDataFlowOpInterface /*retTy=*/"SmallVector", /*methodName=*/"getPotentialIncomingValuesArg", /*args=*/(ins "::mlir::BlockArgument":$v) + >, + InterfaceMethod< + /*desc=*/[{ + }], + /*retTy=*/"SmallVector", + /*methodName=*/"getPotentialTerminatorUsers", + /*args=*/(ins "::mlir::Operation*":$terminator, "::mlir::Value":$val) > ]; }