Skip to content

Commit

Permalink
Add summary processing to backward dense annotations
Browse files Browse the repository at this point in the history
  • Loading branch information
pengmai committed Feb 26, 2024
1 parent 8d6280d commit 99f5c3b
Show file tree
Hide file tree
Showing 6 changed files with 156 additions and 21 deletions.
105 changes: 88 additions & 17 deletions enzyme/Enzyme/MLIR/Analysis/ActivityAnnotations.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
using namespace mlir;

static StringRef getActivityAnnotationAttrName() { return "activedeps"; }
static StringRef getPointerSummaryAttrName() { return "p2psummary"; }

template <typename ValueT>
void printSetLattice(const enzyme::SparseSetLattice<ValueT> &setLattice,
Expand Down Expand Up @@ -514,7 +515,19 @@ void enzyme::DenseBackwardActivityAnnotationAnalysis::
BackwardOriginsMap *before) {
meet(before, after);
if (action == dataflow::CallControlFlowAction::ExternalCallee) {
// TODO: deserialize state, infer transfer
auto symbol = dyn_cast<SymbolRefAttr>(call.getCallableForCallee());
if (!symbol)
return propagateIfChanged(before, before->markAllOriginsUnknown());

if (auto callee = SymbolTable::lookupNearestSymbolFrom<FunctionOpInterface>(
call, symbol.getLeafReference())) {
if (auto summaryAttr = callee->getAttrOfType<ArrayAttr>(
getActivityAnnotationAttrName())) {
DenseMap<DistinctAttr, ValueOriginSet> summary;
deserializePointsTo(summaryAttr, summary);
return processCallToSummarizedFunc(call, summary, after, before);
}
}
}
}

Expand Down Expand Up @@ -593,8 +606,70 @@ void enzyme::DenseBackwardActivityAnnotationAnalysis::visitOperation(
// Capturing stores are handled via the points-to relationship in
// setToExitState.
}
} else if (std::optional<Value> copySource = getCopySource(op)) {
processCopy(op, *copySource, value, after, before);
}
}
}

void enzyme::DenseBackwardActivityAnnotationAnalysis::
processCallToSummarizedFunc(
CallOpInterface call,
const DenseMap<DistinctAttr, ValueOriginSet> &summary,
const BackwardOriginsMap &after, BackwardOriginsMap *before) {
ChangeResult changed = ChangeResult::NoChange;
// Unify the value origin summary with the actual lattices of function
// arguments
auto *p2sets = getOrCreateFor<PointsToSets>(call, call);
SmallVector<AliasClassSet> argumentClasses;
for (Value argOperand : call.getArgOperands()) {
auto *argClasses = getOrCreateFor<AliasClassLattice>(call, argOperand);
argumentClasses.push_back(argClasses->getAliasClassesObject());
}

for (const auto &[destClass, sourceOrigins] : summary) {
// Get the source alias classes
AliasClassSet callerSourceClasses;
for (Attribute sourceOrigin : sourceOrigins.getElements()) {
unsigned argNumber =
cast<ArgumentOriginAttr>(sourceOrigin).getArgNumber();
traversePointsToSets(argumentClasses[argNumber], *p2sets,
[&](DistinctAttr aliasClass) {
(void)callerSourceClasses.insert({aliasClass});
});
}

// Get the destination origins
ValueOriginSet destOrigins;
if (auto pseudoClass = dyn_cast_if_present<PseudoAliasClassAttr>(
destClass.getReferencedAttr())) {
traversePointsToSets(argumentClasses[pseudoClass.getArgNumber()], *p2sets,
[&](DistinctAttr aliasClass) {
(void)destOrigins.join(
after.getOrigins(aliasClass));
});
}
changed |= before->insert(callerSourceClasses, destOrigins);
}
propagateIfChanged(before, changed);
}

void enzyme::DenseBackwardActivityAnnotationAnalysis::processCopy(
Operation *op, Value copySource, Value copyDest,
const BackwardOriginsMap &after, BackwardOriginsMap *before) {
auto *dest = getOrCreateFor<AliasClassLattice>(op, copyDest);
ValueOriginSet destOrigins;
if (dest->isUndefined())
return;
if (dest->isUnknown())
(void)destOrigins.markUnknown();

for (DistinctAttr destClass : dest->getAliasClasses())
(void)destOrigins.join(after.getOrigins(destClass));

auto *src = getOrCreateFor<AliasClassLattice>(op, copySource);
propagateIfChanged(before,
before->insert(src->getAliasClassesObject(), destOrigins));
}

namespace {
Expand Down Expand Up @@ -654,7 +729,7 @@ void enzyme::runActivityAnnotations(FunctionOpInterface callee) {
raw_ostream &os = llvm::outs();

for (CallableOpInterface node : sorted) {
if (!node.getCallableRegion() || node->hasAttr("p2psummary"))
if (!node.getCallableRegion() || node->hasAttr(getPointerSummaryAttrName()))
continue;
auto funcOp = cast<FunctionOpInterface>(node.getOperation());
os << "[ata] processing function @" << funcOp.getName() << "\n";
Expand Down Expand Up @@ -717,13 +792,16 @@ void enzyme::runActivityAnnotations(FunctionOpInterface callee) {
os << "[debug] return alias class: " << lattice << "\n";
}

node->setAttr("p2psummary", p2sets.serialize(node.getContext()));
node->setAttr(getPointerSummaryAttrName(),
p2sets.serialize(node.getContext()));
os << "[ata] p2p summary:\n";
if (node->getAttrOfType<ArrayAttr>("p2psummary").size() == 0) {
if (node->getAttrOfType<ArrayAttr>(getPointerSummaryAttrName()).size() ==
0) {
os << " <empty>\n";
}
for (ArrayAttr pair :
node->getAttrOfType<ArrayAttr>("p2psummary").getAsRange<ArrayAttr>()) {
node->getAttrOfType<ArrayAttr>(getPointerSummaryAttrName())
.getAsRange<ArrayAttr>()) {
os << " " << pair[0] << " -> " << pair[1] << "\n";
}

Expand All @@ -750,18 +828,11 @@ void enzyme::runActivityAnnotations(FunctionOpInterface callee) {
MLIRContext *ctx = node.getContext();
SmallVector<Attribute> serializedReturnOperandOrigins(
returnOperandOrigins.size());
llvm::transform(
returnOperandOrigins, serializedReturnOperandOrigins.begin(),
[ctx](enzyme::ForwardOriginsLattice lattice) -> Attribute {
if (lattice.isUndefined())
return StringAttr::get(ctx, "<undefined>");
if (lattice.isUnknown())
return StringAttr::get(ctx, "<unknown>");
SmallVector<Attribute> originsVector(lattice.getOrigins().begin(),
lattice.getOrigins().end());
llvm::sort(originsVector, enzyme::sortAttributes);
return ArrayAttr::get(ctx, originsVector);
});
llvm::transform(returnOperandOrigins,
serializedReturnOperandOrigins.begin(),
[ctx](enzyme::ForwardOriginsLattice lattice) -> Attribute {
return lattice.serialize(ctx);
});
node->setAttr(
"returnorigins",
ArrayAttr::get(node.getContext(), serializedReturnOperandOrigins));
Expand Down
9 changes: 9 additions & 0 deletions enzyme/Enzyme/MLIR/Analysis/ActivityAnnotations.h
Original file line number Diff line number Diff line change
Expand Up @@ -189,6 +189,15 @@ class DenseBackwardActivityAnnotationAnalysis
BackwardOriginsMap *before) override;

void setToExitState(BackwardOriginsMap *lattice) override;

private:
void processCallToSummarizedFunc(
CallOpInterface call,
const DenseMap<DistinctAttr, ValueOriginSet> &summary,
const BackwardOriginsMap &after, BackwardOriginsMap *before);

void processCopy(Operation *op, Value copySource, Value copyDest,
const BackwardOriginsMap &after, BackwardOriginsMap *before);
};

void runActivityAnnotations(FunctionOpInterface callee);
Expand Down
1 change: 0 additions & 1 deletion enzyme/Enzyme/MLIR/Analysis/AliasAnalysis.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -890,7 +890,6 @@ void enzyme::AliasAnalysis::setToEntryState(AliasClassLattice *lattice) {
}
DistinctAttr argClass =
originalClasses.getOriginalClass(lattice->getPoint(), debugLabel);
funcOp.setArgAttr(arg.getArgNumber(), "enzyme.origin", argClass);
return propagateIfChanged(lattice,
lattice->join(AliasClassLattice::single(
lattice->getPoint(), argClass)));
Expand Down
22 changes: 19 additions & 3 deletions enzyme/Enzyme/MLIR/Analysis/Lattice.h
Original file line number Diff line number Diff line change
Expand Up @@ -178,6 +178,9 @@ template <typename ValueT>
const SetLattice<ValueT> SetLattice<ValueT>::undefinedSet =
SetLattice<ValueT>(SetLattice<ValueT>::State::Undefined);

/// Used when serializing to ensure a consistent order.
bool sortAttributes(Attribute a, Attribute b);

//===----------------------------------------------------------------------===//
// SparseSetLattice
//
Expand All @@ -191,6 +194,8 @@ class SparseSetLattice : public dataflow::AbstractSparseLattice {
SparseSetLattice(Value value, SetLattice<ValueT> &&elements)
: dataflow::AbstractSparseLattice(value), elements(std::move(elements)) {}

Attribute serialize(MLIRContext *ctx) { return serializeSetNaive(ctx); }

ChangeResult merge(const SetLattice<ValueT> &other) {
return elements.join(other);
}
Expand All @@ -209,15 +214,26 @@ class SparseSetLattice : public dataflow::AbstractSparseLattice {

protected:
SetLattice<ValueT> elements;

private:
Attribute serializeSetNaive(MLIRContext *ctx) {
if (elements.isUndefined())
return StringAttr::get(ctx, "<undefined>");
if (elements.isUnknown())
return StringAttr::get(ctx, "<unknown>");
SmallVector<Attribute> elementsVec;
for (Attribute element : elements.getElements()) {
elementsVec.push_back(element);
}
llvm::sort(elementsVec, sortAttributes);
return ArrayAttr::get(ctx, elementsVec);
}
};

//===----------------------------------------------------------------------===//
// MapOfSetsLattice
//===----------------------------------------------------------------------===//

/// Used when serializing to ensure a consistent order.
bool sortAttributes(Attribute a, Attribute b);

template <typename KeyT, typename ElementT>
class MapOfSetsLattice : public dataflow::AbstractDenseLattice {
public:
Expand Down
20 changes: 20 additions & 0 deletions enzyme/test/MLIR/ActivityAnalysis/Summaries/bude.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,16 @@
// CHECK-NEXT: distinct[0]<"fresh-transform"> -> []
// CHECK: forward value origins:
// CHECK: distinct[0]<#enzyme.pseudoclass<@fasten_main(10, 0)>> originates from [#enzyme.argorigin<@fasten_main(2)>, #enzyme.argorigin<@fasten_main(3)>, #enzyme.argorigin<@fasten_main(4)>, #enzyme.argorigin<@fasten_main(5)>, #enzyme.argorigin<@fasten_main(6)>, #enzyme.argorigin<@fasten_main(7)>, #enzyme.argorigin<@fasten_main(8)>, #enzyme.argorigin<@fasten_main(9)>, #enzyme.argorigin<@fasten_main(10)>, #enzyme.argorigin<@fasten_main(11)>]
// CHECK: backward value origins:
// CHECK: distinct[0]<#enzyme.pseudoclass<@fasten_main(2, 0)>> goes to [#enzyme.argorigin<@fasten_main(2)>, #enzyme.argorigin<@fasten_main(10)>]
// CHECK: distinct[0]<#enzyme.pseudoclass<@fasten_main(3, 0)>> goes to [#enzyme.argorigin<@fasten_main(3)>, #enzyme.argorigin<@fasten_main(10)>]
// CHECK: distinct[0]<#enzyme.pseudoclass<@fasten_main(4, 0)>> goes to [#enzyme.argorigin<@fasten_main(4)>, #enzyme.argorigin<@fasten_main(10)>]
// CHECK: distinct[0]<#enzyme.pseudoclass<@fasten_main(5, 0)>> goes to [#enzyme.argorigin<@fasten_main(5)>, #enzyme.argorigin<@fasten_main(10)>]
// CHECK: distinct[0]<#enzyme.pseudoclass<@fasten_main(6, 0)>> goes to [#enzyme.argorigin<@fasten_main(6)>, #enzyme.argorigin<@fasten_main(10)>]
// CHECK: distinct[0]<#enzyme.pseudoclass<@fasten_main(7, 0)>> goes to [#enzyme.argorigin<@fasten_main(7)>, #enzyme.argorigin<@fasten_main(10)>]
// CHECK: distinct[0]<#enzyme.pseudoclass<@fasten_main(8, 0)>> goes to [#enzyme.argorigin<@fasten_main(8)>, #enzyme.argorigin<@fasten_main(10)>]
// CHECK: distinct[0]<#enzyme.pseudoclass<@fasten_main(9, 0)>> goes to [#enzyme.argorigin<@fasten_main(9)>, #enzyme.argorigin<@fasten_main(10)>]
// CHECK: distinct[0]<#enzyme.pseudoclass<@fasten_main(11, 0)>> goes to [#enzyme.argorigin<@fasten_main(10)>, #enzyme.argorigin<@fasten_main(11)>]
llvm.func local_unnamed_addr @fasten_main(%arg0: i32 {llvm.noundef}, %arg1: i32 {llvm.noundef}, %arg2: !llvm.ptr {llvm.noalias, llvm.nocapture, llvm.noundef, llvm.readonly}, %arg3: !llvm.ptr {llvm.noalias, llvm.nocapture, llvm.noundef, llvm.readonly}, %arg4: !llvm.ptr {llvm.noalias, llvm.nocapture, llvm.noundef, llvm.readonly}, %arg5: !llvm.ptr {llvm.noalias, llvm.nocapture, llvm.noundef, llvm.readonly}, %arg6: !llvm.ptr {llvm.noalias, llvm.nocapture, llvm.noundef, llvm.readonly}, %arg7: !llvm.ptr {llvm.noalias, llvm.nocapture, llvm.noundef, llvm.readonly}, %arg8: !llvm.ptr {llvm.noalias, llvm.nocapture, llvm.noundef, llvm.readonly}, %arg9: !llvm.ptr {llvm.noalias, llvm.nocapture, llvm.noundef, llvm.readonly}, %arg10: !llvm.ptr {llvm.noalias, llvm.nocapture, llvm.noundef, llvm.writeonly}, %arg11: !llvm.ptr {llvm.noalias, llvm.nocapture, llvm.noundef, llvm.readonly}, %arg12: i32 {llvm.noundef}) attributes {frame_pointer = #llvm.framePointerKind<"non-leaf">, memory = #llvm.memory_effects<other = none, argMem = readwrite, inaccessibleMem = readwrite>, passthrough = ["nofree", "norecurse", "nosync", "nounwind", "ssp", ["uwtable", "1"], ["approx-func-fp-math", "true"], ["no-infs-fp-math", "true"], ["no-nans-fp-math", "true"], ["no-signed-zeros-fp-math", "true"], ["no-trapping-math", "true"], ["stack-protector-buffer-size", "8"], ["target-cpu", "apple-m1"], ["unsafe-fp-math", "true"]], target_cpu = "apple-m1", target_features = #llvm.target_features<["+aes", "+complxnum", "+crc", "+dotprod", "+fp-armv8", "+fp16fml", "+fullfp16", "+jsconv", "+lse", "+neon", "+ras", "+rcpc", "+rdm", "+sha2", "+sha3", "+v8.1a", "+v8.2a", "+v8.3a", "+v8.4a", "+v8.5a", "+v8a", "+zcm", "+zcz"]>} {
%0 = llvm.mlir.constant(1 : i32) : i32
%1 = llvm.mlir.constant(8 : i32) : i32
Expand Down Expand Up @@ -388,6 +398,16 @@ llvm.mlir.global internal unnamed_addr @params() {addr_space = 0 : i32, alignmen
// CHECK-NEXT: <empty>
// CHECK: forward value origins:
// CHECK: distinct[0]<#enzyme.pseudoclass<@onecompute(8, 0)>> originates from [#enzyme.argorigin<@onecompute(0)>, #enzyme.argorigin<@onecompute(1)>, #enzyme.argorigin<@onecompute(2)>, #enzyme.argorigin<@onecompute(3)>, #enzyme.argorigin<@onecompute(4)>, #enzyme.argorigin<@onecompute(5)>, #enzyme.argorigin<@onecompute(6)>, #enzyme.argorigin<@onecompute(7)>, #enzyme.argorigin<@onecompute(8)>, #enzyme.argorigin<@onecompute(9)>]
// CHECK: backward value origins:
// CHECK: distinct[0]<#enzyme.pseudoclass<@onecompute(0, 0)>> goes to [#enzyme.argorigin<@onecompute(0)>, #enzyme.argorigin<@onecompute(8)>]
// CHECK: distinct[0]<#enzyme.pseudoclass<@onecompute(1, 0)>> goes to [#enzyme.argorigin<@onecompute(1)>, #enzyme.argorigin<@onecompute(8)>]
// CHECK: distinct[0]<#enzyme.pseudoclass<@onecompute(2, 0)>> goes to [#enzyme.argorigin<@onecompute(2)>, #enzyme.argorigin<@onecompute(8)>]
// CHECK: distinct[0]<#enzyme.pseudoclass<@onecompute(3, 0)>> goes to [#enzyme.argorigin<@onecompute(3)>, #enzyme.argorigin<@onecompute(8)>]
// CHECK: distinct[0]<#enzyme.pseudoclass<@onecompute(4, 0)>> goes to [#enzyme.argorigin<@onecompute(4)>, #enzyme.argorigin<@onecompute(8)>]
// CHECK: distinct[0]<#enzyme.pseudoclass<@onecompute(5, 0)>> goes to [#enzyme.argorigin<@onecompute(5)>, #enzyme.argorigin<@onecompute(8)>]
// CHECK: distinct[0]<#enzyme.pseudoclass<@onecompute(6, 0)>> goes to [#enzyme.argorigin<@onecompute(6)>, #enzyme.argorigin<@onecompute(8)>]
// CHECK: distinct[0]<#enzyme.pseudoclass<@onecompute(7, 0)>> goes to [#enzyme.argorigin<@onecompute(7)>, #enzyme.argorigin<@onecompute(8)>]
// CHECK: distinct[0]<#enzyme.pseudoclass<@onecompute(9, 0)>> goes to [#enzyme.argorigin<@onecompute(8)>, #enzyme.argorigin<@onecompute(9)>]
llvm.func @onecompute(%arg0: !llvm.ptr {llvm.noalias, llvm.nocapture, llvm.noundef, llvm.readonly}, %arg1: !llvm.ptr {llvm.noalias, llvm.nocapture, llvm.noundef, llvm.readonly}, %arg2: !llvm.ptr {llvm.noalias, llvm.nocapture, llvm.noundef, llvm.readonly}, %arg3: !llvm.ptr {llvm.noalias, llvm.nocapture, llvm.noundef, llvm.readonly}, %arg4: !llvm.ptr {llvm.noalias, llvm.nocapture, llvm.noundef, llvm.readonly}, %arg5: !llvm.ptr {llvm.noalias, llvm.nocapture, llvm.noundef, llvm.readonly}, %arg6: !llvm.ptr {llvm.noalias, llvm.nocapture, llvm.noundef, llvm.readonly}, %arg7: !llvm.ptr {llvm.noalias, llvm.nocapture, llvm.noundef, llvm.readonly}, %arg8: !llvm.ptr {llvm.noalias, llvm.nocapture, llvm.noundef, llvm.writeonly}, %arg9: !llvm.ptr {llvm.noalias, llvm.nocapture, llvm.noundef, llvm.readonly}) attributes {frame_pointer = #llvm.framePointerKind<"non-leaf">, memory = #llvm.memory_effects<other = read, argMem = readwrite, inaccessibleMem = readwrite>, passthrough = ["nofree", "norecurse", "nosync", "nounwind", "ssp", ["uwtable", "1"], ["approx-func-fp-math", "true"], ["no-infs-fp-math", "true"], ["no-nans-fp-math", "true"], ["no-signed-zeros-fp-math", "true"], ["no-trapping-math", "true"], ["stack-protector-buffer-size", "8"], ["target-cpu", "apple-m1"], ["unsafe-fp-math", "true"]], target_cpu = "apple-m1", target_features = #llvm.target_features<["+aes", "+complxnum", "+crc", "+dotprod", "+fp-armv8", "+fp16fml", "+fullfp16", "+jsconv", "+lse", "+neon", "+ras", "+rcpc", "+rdm", "+sha2", "+sha3", "+v8.1a", "+v8.2a", "+v8.3a", "+v8.4a", "+v8.5a", "+v8a", "+zcm", "+zcz"]>} {
%0 = llvm.mlir.constant(3 : i32) : i32
%1 = llvm.mlir.constant(0 : i64) : i64
Expand Down
Loading

0 comments on commit 99f5c3b

Please sign in to comment.