Skip to content

Commit

Permalink
[mlir] better handling of function result attributes
Browse files Browse the repository at this point in the history
Specifically, handle the siutation where some alias classes are
associateded with pointer operands that cannot be written into. In such
a case, even if the results of the function may alias the operands, the
alias classes of the operands are known not to point to anything that
the function could have written.

Additionally, don't include results marked as `noalias` into the common
alias class of function results.
  • Loading branch information
ftynse committed Jan 5, 2024
1 parent cfe8cf1 commit d9b0cde
Show file tree
Hide file tree
Showing 4 changed files with 271 additions and 38 deletions.
152 changes: 132 additions & 20 deletions enzyme/Enzyme/MLIR/Analysis/AliasAnalysis.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@

// TODO: remove this once aliasing interface is factored out.
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "llvm/ADT/SetOperations.h"

using namespace mlir;
using namespace mlir::dataflow;
Expand Down Expand Up @@ -498,6 +499,17 @@ getFunctionOtherModRef(FunctionOpInterface func) {
return std::nullopt;
}

/// Returns information indicating whether the function may read or write into
/// memory previously inaccessible in the calling context. When unknown, returns
/// `nullopt`.
static std::optional<LLVM::ModRefInfo>
getFunctionInaccessibleModRef(FunctionOpInterface func) {
if (auto memoryAttr =
func->getAttrOfType<LLVM::MemoryEffectsAttr>(kLLVMMemoryAttrName))
return memoryAttr.getInaccessibleMem();
return std::nullopt;
}

void enzyme::PointsToPointerAnalysis::visitCallControlFlowTransfer(
CallOpInterface call, CallControlFlowAction action,
const PointsToSets &before, PointsToSets *after) {
Expand Down Expand Up @@ -577,22 +589,25 @@ void enzyme::PointsToPointerAnalysis::visitCallControlFlowTransfer(

// For each alias class the function may write to, indicate potentially
// stored classes. Keep the set of writable alias classes for future.
AliasClassSet pointerOperandClasses = AliasClassSet::getUndefined();
AliasClassSet writableClasses = AliasClassSet::getUndefined();
AliasClassSet nonWritableOperandClasses = AliasClassSet::getUndefined();
ChangeResult changed = ChangeResult::NoChange;
for (int pointerOperand : pointerLikeOperands) {
auto *destClasses = getOrCreateFor<AliasClassLattice>(
call, call.getArgOperands()[pointerOperand]);
pointerOperandClasses.join(destClasses->getAliasClassesObject());

// If the argument cannot be stored into, just preserve it as is.
if (!mayWriteArg(callee, pointerOperand, argModRef))
if (!mayWriteArg(callee, pointerOperand, argModRef)) {
nonWritableOperandClasses.join(destClasses->getAliasClassesObject());
continue;
}
writableClasses.join(destClasses->getAliasClassesObject());

// If the destination class is unknown, mark all known classes
// pessimistic (alias classes that have not beed analyzed and thus are
// absent from pointsTo are treated as "undefined" at this point).
if (destClasses->isUnknown()) {
pointerOperandClasses.markUnknown();
writableClasses.markUnknown();
changed |= after->markAllPointToUnknown();
break;
}
Expand All @@ -616,11 +631,11 @@ void enzyme::PointsToPointerAnalysis::visitCallControlFlowTransfer(
// lattice after joining it with `before` are marked as pointing to
// "unknown", except the classes that are associated with operands for
// which we have more specific information. Classes that haven't been
// analyzed, and therefore absent in the `after` lattice, are left
// analyzed, and therefore absent from the `after` lattice, are left
// unmodified and thus assumed to be "undefined". This makes this
// transfer function monotonic as opposed to marking the latter classes
// as "unknown" eagerly, which would require rolling that marking back.
changed |= after->markAllExceptPointToUnknown(pointerOperandClasses);
changed |= after->markAllExceptPointToUnknown(writableClasses);
}

// Pointer-typed results may be pointing to any other pointer. The
Expand All @@ -641,21 +656,49 @@ void enzyme::PointsToPointerAnalysis::visitCallControlFlowTransfer(
if (!isPointerLike(result.getType()))
continue;

// Result alias classes may contain operand alias classes because
// results may alias with those operands. However, if the operands are
// not writable, they cannot be updated to point to other classes
// even though results can be. To handle this, only update the alias
// classes associated with the results that are not also associated
// with non-writable operands.
//
// This logic is a bit more conservative than the theoretical optimum to
// ensure monotonicity of the transfer function: if additional alias
// classes are discovered for non-writable operands at a later stage
// after these classes have already been associated with the result and
// marked as potentially pointing to some other classes, this marking
// is *not* rolled back. Since points-to-pointer analysis is a may-
// analysis, this is not problematic.
const auto *destClasses =
getOrCreateFor<AliasClassLattice>(call, result);
AliasClassSet resultWithoutNonWritableOperands =
AliasClassSet::getUndefined();
if (destClasses->isUnknown() || nonWritableOperandClasses.isUnknown()) {
resultWithoutNonWritableOperands.markUnknown();
} else if (!destClasses->isUndefined() &&
!nonWritableOperandClasses.isUndefined()) {
DenseSet<DistinctAttr> nonOperandClasses =
llvm::set_difference(destClasses->getAliasClasses(),
nonWritableOperandClasses.getAliasClasses());
resultWithoutNonWritableOperands.insert(nonOperandClasses);
} else {
resultWithoutNonWritableOperands.join(
destClasses->getAliasClassesObject());
}

// If reading from other memory, the results may point to anything.
if (funcMayReadOther) {
propagateIfChanged(after, after->markPointToUnknown(
destClasses->getAliasClassesObject()));
resultWithoutNonWritableOperands));
continue;
}

for (int operandNo : pointerLikeOperands) {
const auto *srcClasses = getOrCreateFor<AliasClassLattice>(
call, call.getArgOperands()[operandNo]);
if (mayReadArg(callee, operandNo, argModRef)) {
changed |= after->addSetsFrom(destClasses->getAliasClassesObject(),
changed |= after->addSetsFrom(resultWithoutNonWritableOperands,
srcClasses->getAliasClassesObject());
}

Expand All @@ -665,7 +708,7 @@ void enzyme::PointsToPointerAnalysis::visitCallControlFlowTransfer(
LLVM::LLVMDialect::getNoCaptureAttrName()));
if (isNoCapture)
continue;
changed |= after->insert(destClasses->getAliasClassesObject(),
changed |= after->insert(resultWithoutNonWritableOperands,
srcClasses->getAliasClassesObject());
}
}
Expand Down Expand Up @@ -711,6 +754,9 @@ void enzyme::AliasClassLattice::print(raw_ostream &os) const {
AliasResult
enzyme::AliasClassLattice::alias(const AbstractSparseLattice &other) const {
const auto *rhs = reinterpret_cast<const AliasClassLattice *>(&other);

assert(!isUndefined() && !rhs->isUndefined() && "incomplete alias analysis");

if (getPoint() == rhs->getPoint())
return AliasResult::MustAlias;

Expand Down Expand Up @@ -960,6 +1006,8 @@ void enzyme::AliasAnalysis::visitExternalCall(
// Even if a function is marked as not reading from memory or arguments, it
// may still create pointers "out of the thin air", e.g., by "ptrtoint" from a
// constant or an argument.
// TODO: consider "ptrtoint" here, for now assuming it is covered by
// inaccessible and other mem.
auto symbol = dyn_cast<SymbolRefAttr>(call.getCallableForCallee());
if (!symbol)
return markResultsUnknown();
Expand All @@ -968,20 +1016,84 @@ void enzyme::AliasAnalysis::visitExternalCall(
if (!callee)
return markResultsUnknown();

// Collect alias classes that can be read through the arguments.
std::optional<LLVM::ModRefInfo> argModRef = getFunctionArgModRef(callee);
std::optional<LLVM::ModRefInfo> otherModRef = getFunctionOtherModRef(callee);
std::optional<LLVM::ModRefInfo> inaccessibleModRef =
getFunctionInaccessibleModRef(callee);
auto operandAliasClasses = AliasClassSet::getEmpty();
for (auto [operandNo, operand] : llvm::enumerate(call.getArgOperands())) {
if (!isPointerLike(operand.getType()))
continue;

const AliasClassLattice *srcClasses = operands[operandNo];
operandAliasClasses.join(srcClasses->getAliasClassesObject());

if (!mayReadArg(callee, operandNo, argModRef))
continue;

// If can read from argument, collect the alias classes that can this
// argument may be pointing to.
const auto *pointsToLattice = getOrCreateFor<PointsToSets>(call, call);
srcClasses->getAliasClassesObject().foreachClass(
[&](DistinctAttr srcClass, AliasClassSet::State state) {
// Nothing to do in top/bottom case. In the top case, we have already
// set `operandAliasClasses` to top above.
if (srcClass == nullptr)
return ChangeResult::NoChange;
operandAliasClasses.join(pointsToLattice->getPointsTo(srcClass));
return ChangeResult::NoChange;
});
}

auto debugLabel = call->getAttrOfType<StringAttr>("tag");
DistinctAttr commonResultAttr = nullptr;

// Collect all results that are not marked noalias so we can put them in a
// common alias group.
SmallVector<Value> aliasGroupResults;
for (OpResult result : call->getResults()) {
if (!callee.getResultAttr(result.getResultNumber(),
LLVM::LLVMDialect::getNoAliasAttrName()))
aliasGroupResults.push_back(result);
}

for (OpResult result : call->getResults()) {
AliasClassLattice *resultLattice = results[result.getResultNumber()];
if (callee.getResultAttr(result.getResultNumber(),
LLVM::LLVMDialect::getNoAliasAttrName())) {
Attribute debugLabel = call->getAttrOfType<StringAttr>("tag");
auto individualAlloc =
AliasClassLattice::single(resultLattice->getPoint(),
originalClasses.getOriginalClass(
resultLattice->getPoint(), debugLabel));
if (!llvm::is_contained(aliasGroupResults, result)) {
Attribute individualDebugLabel =
debugLabel
? StringAttr::get(debugLabel.getContext(),
debugLabel.getValue().str() +
std::to_string(result.getResultNumber()))
: nullptr;
auto individualAlloc = AliasClassLattice::single(
resultLattice->getPoint(),
originalClasses.getOriginalClass(resultLattice->getPoint(),
individualDebugLabel));
propagateIfChanged(resultLattice, resultLattice->join(individualAlloc));
// TODO(zinenko): if the function is known not to read other (or
// inaccessible mem), its results may only alias what we know it can read,
// e.g. other arguments (unless they are marked noalias) or anything
// stored in those arguments.
} else if (!modRefMayRef(otherModRef) &&
!modRefMayRef(inaccessibleModRef)) {
// Put results that are not marked as noalias into one common group.
if (!commonResultAttr) {
std::string label = !debugLabel
? "func-result-common"
: debugLabel.getValue().str() + "-common";
commonResultAttr =
originalClasses.getSameOriginalClass(aliasGroupResults, label);
}
AliasClassSet commonClass(commonResultAttr);
ChangeResult changed = resultLattice->join(
AliasClassLattice(resultLattice->getPoint(), std::move(commonClass)));

// If the function is known not to read other (or inaccessible mem), its
// results may only alias what we know it can read, e.g. other arguments
// or anything stored in those arguments.
// FIXME: note the explicit copy, we need to simplify the relation between
// AliasClassSet and AliasClassLattice.
changed |= resultLattice->join(AliasClassLattice(
resultLattice->getPoint(), AliasClassSet(operandAliasClasses)));
propagateIfChanged(resultLattice, changed);
} else {
propagateIfChanged(resultLattice, resultLattice->markUnknown());
}
Expand Down
9 changes: 3 additions & 6 deletions enzyme/Enzyme/MLIR/Analysis/AliasAnalysis.h
Original file line number Diff line number Diff line change
Expand Up @@ -301,6 +301,9 @@ class PointsToPointerAnalysis
class AliasClassLattice : public dataflow::AbstractSparseLattice {
public:
using AbstractSparseLattice::AbstractSparseLattice;
AliasClassLattice(Value value, AliasClassSet &&classes)
: dataflow::AbstractSparseLattice(value),
aliasClasses(std::move(classes)) {}

void print(raw_ostream &os) const override;

Expand Down Expand Up @@ -332,12 +335,6 @@ class AliasClassLattice : public dataflow::AbstractSparseLattice {
const AliasClassSet &getAliasClassesObject() const { return aliasClasses; }

private:
struct UndefinedState {};

AliasClassLattice(Value value, AliasClassSet &&classes)
: dataflow::AbstractSparseLattice(value),
aliasClasses(std::move(classes)) {}

AliasClassSet aliasClasses;
};

Expand Down
3 changes: 3 additions & 0 deletions enzyme/Enzyme/MLIR/Passes/PrintAliasAnalysis.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,9 @@ struct PrintAliasAnalysisPass
if (state->isUnknown()) {
op->setAttr("ac",
StringAttr::get(result.getContext(), "<unknown>"));
} else if (state->isUndefined()) {
op->setAttr("ac",
StringAttr::get(result.getContext(), "<undefined>"));
} else {
for (auto aliasClass : state->getAliasClasses()) {
op->setAttr("ac", aliasClass);
Expand Down
Loading

0 comments on commit d9b0cde

Please sign in to comment.