Skip to content

Commit

Permalink
[mlir] correctly handle "unknown" state in activity analysis
Browse files Browse the repository at this point in the history
This requires injecting ModRef information about library functions.
  • Loading branch information
ftynse committed Jan 5, 2024
1 parent f928881 commit 34fbff8
Show file tree
Hide file tree
Showing 2 changed files with 93 additions and 23 deletions.
54 changes: 48 additions & 6 deletions enzyme/Enzyme/MLIR/Analysis/AliasAnalysis.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -456,6 +456,51 @@ static bool mayWriteArg(FunctionOpInterface callee, unsigned argNo,
return !hasReadOnlyAttr && !hasReadNoneAttr && funcMayWrite;
}

/// Returns information indicating whether the function may read or write into
/// the memory pointed to by its arguments. When unknown, returns `nullopt`.
static std::optional<LLVM::ModRefInfo>
getFunctionArgModRef(FunctionOpInterface func) {
// First, handle some library functions with statically known behavior.
StringRef name = cast<SymbolOpInterface>(func.getOperation()).getName();
auto hardcoded = llvm::StringSwitch<std::optional<LLVM::ModRefInfo>>(name)
// printf: only reads from arguments.
.Case("printf", LLVM::ModRefInfo::Ref)
// operator delete(void *) doesn't read from arguments.
.Case("_ZdlPv", LLVM::ModRefInfo::NoModRef)
.Default(std::nullopt);
if (hardcoded)
return hardcoded;

if (auto memoryAttr =
func->getAttrOfType<LLVM::MemoryEffectsAttr>(kLLVMMemoryAttrName))
return memoryAttr.getArgMem();
return std::nullopt;
}

/// Returns information indicating whether the function may read or write into
/// the memory other than that pointed to by its arguments, though still
/// accessible from (any) calling context. When unknown, returns `nullopt`.
static std::optional<LLVM::ModRefInfo>
getFunctionOtherModRef(FunctionOpInterface func) {
// First, handle some library functions with statically known behavior.
StringRef name = cast<SymbolOpInterface>(func.getOperation()).getName();
auto hardcoded =
llvm::StringSwitch<std::optional<LLVM::ModRefInfo>>(name)
// printf: doesn't access other (technically, stdout is pointer-like,
// but we cannot flow information through it since it is write-only.
.Case("printf", LLVM::ModRefInfo::NoModRef)
// operator delete(void *) doesn't access other.
.Case("_ZdlPv", LLVM::ModRefInfo::NoModRef)
.Default(std::nullopt);
if (hardcoded)
return hardcoded;

if (auto memoryAttr =
func->getAttrOfType<LLVM::MemoryEffectsAttr>(kLLVMMemoryAttrName))
return memoryAttr.getOther();
return std::nullopt;
}

void enzyme::PointsToPointerAnalysis::visitCallControlFlowTransfer(
CallOpInterface call, CallControlFlowAction action,
const PointsToSets &before, PointsToSets *after) {
Expand Down Expand Up @@ -497,13 +542,10 @@ void enzyme::PointsToPointerAnalysis::visitCallControlFlowTransfer(
// into pointers that are non-arguments.
if (auto callee = SymbolTable::lookupNearestSymbolFrom<FunctionOpInterface>(
call, symbol.getLeafReference())) {
auto memoryAttr =
callee->getAttrOfType<LLVM::MemoryEffectsAttr>(kLLVMMemoryAttrName);
std::optional<LLVM::ModRefInfo> argModRef =
memoryAttr ? std::make_optional(memoryAttr.getArgMem())
: std::nullopt;
std::optional<LLVM::ModRefInfo> argModRef = getFunctionArgModRef(callee);
std::optional<LLVM::ModRefInfo> otherModRef =
memoryAttr ? std::make_optional(memoryAttr.getOther()) : std::nullopt;
getFunctionOtherModRef(callee);

SmallVector<int> pointerLikeOperands;
for (auto &&[i, operand] : llvm::enumerate(call.getArgOperands())) {
if (isPointerLike(operand.getType()))
Expand Down
62 changes: 45 additions & 17 deletions enzyme/Enzyme/MLIR/Analysis/DataFlowActivityAnalysis.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -794,6 +794,20 @@ void traverseCallGraph(FunctionOpInterface root,
}
}

static const enzyme::AliasClassSet &
getDefaultPointsTo(const enzyme::PointsToSets &pointsToSets) {
// Get the default points-to alias class set, which is where the
// "unknown" and any other unlisted class set points to.
const enzyme::AliasClassSet &defaultPointsTo =
pointsToSets.getPointsTo(nullptr);
// Unknown class can point to unknown or nothing, unless further
// refined.
assert((defaultPointsTo.isUnknown() ||
defaultPointsTo.getAliasClasses().empty()) &&
"new case introduced for AliasClassSet?");
return defaultPointsTo;
}

void printActivityAnalysisResults(const DataFlowSolver &solver,
FunctionOpInterface callee,
const SmallPtrSet<Operation *, 2> &returnOps,
Expand All @@ -815,19 +829,33 @@ void printActivityAnalysisResults(const DataFlowSolver &solver,
auto *bma = solver.lookupState<BackwardMemoryActivity>(
&callee.getFunctionBody().front().front());

auto *pointsToSets =
const enzyme::PointsToSets *pointsToSets =
solver.lookupState<enzyme::PointsToSets>(*returnOps.begin());
auto *aliasClassLattice = solver.lookupState<AliasClassLattice>(value);
// Traverse the points-to sets in a simple BFS
std::deque<DistinctAttr> frontier;
DenseSet<DistinctAttr> visited;
// TODO(zinenko): FIXME, handle unknown...
if (!aliasClassLattice->isUnknown()) {
auto scheduleVisit = [&](auto range) {
for (DistinctAttr neighbor : range) {
if (!visited.contains(neighbor)) {
visited.insert(neighbor);
frontier.push_back(neighbor);
}
}
};

if (aliasClassLattice->isUnknown()) {
// If this pointer is in unknown alias class, it may point to active
// data if the unknown alias class is known to point to something and
// may not point to active data if the unknown alias class is known not
// to point to anything.
auto &defaultPointsTo = getDefaultPointsTo(*pointsToSets);
return !defaultPointsTo.isUnknown() &&
defaultPointsTo.getAliasClasses().empty();
} else {
const DenseSet<DistinctAttr> &aliasClasses =
aliasClassLattice->getAliasClasses();
frontier.insert(frontier.end(), aliasClasses.begin(),
aliasClasses.end());
visited.insert(aliasClasses.begin(), aliasClasses.end());
scheduleVisit(aliasClasses);
}
while (!frontier.empty()) {
DistinctAttr aliasClass = frontier.front();
Expand All @@ -841,19 +869,19 @@ void printActivityAnalysisResults(const DataFlowSolver &solver,

// Or if it points to a pointer that points to active data.
if (pointsToSets->getPointsTo(aliasClass).isUnknown()) {
// TODO(zinenko): FIXME handle unknown. Conservative assumption here
// is to assume the value is active (or unknown if we can return
// that). Is there a less conservative option?
// If a pointer points to an unknown alias set, query the default
// points-to alias set (which also applies to the unknown alias set).
auto &defaultPointsTo = getDefaultPointsTo(*pointsToSets);
// If it is in turn unknown, conservatively assume the pointer may be
// pointing to some active data.
if (defaultPointsTo.isUnknown())
return false;
// Otherwise look at classes pointed to by unknown (which can only be
// an empty set as of time of writing).
scheduleVisit(defaultPointsTo.getAliasClasses());
continue;
}
const DenseSet<DistinctAttr> &neighbors =
pointsToSets->getPointsTo(aliasClass).getAliasClasses();
for (DistinctAttr neighbor : neighbors) {
if (!visited.contains(neighbor)) {
visited.insert(neighbor);
frontier.push_back(neighbor);
}
}
scheduleVisit(pointsToSets->getPointsTo(aliasClass).getAliasClasses());
}
// Otherwise, it's constant
return true;
Expand Down

0 comments on commit 34fbff8

Please sign in to comment.