Skip to content

Commit

Permalink
Add summary processing for sparse backwards annotations
Browse files Browse the repository at this point in the history
  • Loading branch information
pengmai committed Feb 26, 2024
1 parent 99f5c3b commit 01a458e
Show file tree
Hide file tree
Showing 3 changed files with 247 additions and 19 deletions.
183 changes: 165 additions & 18 deletions enzyme/Enzyme/MLIR/Analysis/ActivityAnnotations.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ using namespace mlir;

static StringRef getActivityAnnotationAttrName() { return "activedeps"; }
static StringRef getPointerSummaryAttrName() { return "p2psummary"; }
static StringRef getReturnOriginsAttrName() { return "returnorigins"; }

template <typename ValueT>
void printSetLattice(const enzyme::SparseSetLattice<ValueT> &setLattice,
Expand Down Expand Up @@ -147,9 +148,77 @@ void enzyme::ForwardActivityAnnotationAnalysis::processMemoryRead(
}
}

void deserializeReturnOrigins(ArrayAttr returnOrigins,
SmallVectorImpl<enzyme::ValueOriginSet> &out) {
for (auto &&[resultIdx, argOrigins] : llvm::enumerate(returnOrigins)) {
enzyme::ValueOriginSet origins;
if (auto strAttr = dyn_cast<StringAttr>(argOrigins)) {
if (strAttr.getValue() == "<unknown>") {
(void)origins.markUnknown();
} else {
// Leave origins undefined
}
} else {
for (enzyme::ArgumentOriginAttr originAttr :
cast<ArrayAttr>(argOrigins)
.getAsRange<enzyme::ArgumentOriginAttr>()) {
(void)origins.insert({originAttr});
}
}

out.push_back(origins);
}
}

void enzyme::ForwardActivityAnnotationAnalysis::visitExternalCall(
CallOpInterface call, ArrayRef<const ForwardOriginsLattice *> operands,
ArrayRef<ForwardOriginsLattice *> results) {}
ArrayRef<ForwardOriginsLattice *> results) {
auto symbol = dyn_cast<SymbolRefAttr>(call.getCallableForCallee());
auto markAllResultsUnknown = [&]() {
for (ForwardOriginsLattice *result : results) {
propagateIfChanged(result, result->markUnknown());
}
};
if (!symbol)
return markAllResultsUnknown();

if (auto callee = SymbolTable::lookupNearestSymbolFrom<FunctionOpInterface>(
call, symbol.getLeafReference())) {
if (auto returnOriginsAttr =
callee->getAttrOfType<ArrayAttr>(getReturnOriginsAttrName())) {
SmallVector<ValueOriginSet> returnOrigins;
deserializeReturnOrigins(returnOriginsAttr, returnOrigins);
return processCallToSummarizedFunc(call, returnOrigins, operands,
results);
}
}
}

void enzyme::ForwardActivityAnnotationAnalysis::processCallToSummarizedFunc(
CallOpInterface call, ArrayRef<ValueOriginSet> summary,
ArrayRef<const ForwardOriginsLattice *> operands,
ArrayRef<ForwardOriginsLattice *> results) {
for (const auto &[result, returnOrigin] : llvm::zip(results, summary)) {
// Convert the origins relative to the callee to relative to the caller
ValueOriginSet callerOrigins;
if (returnOrigin.isUndefined())
continue;

if (returnOrigin.isUnknown()) {
(void)callerOrigins.markUnknown();
} else {
(void)returnOrigin.foreachElement(
[&](OriginAttr calleeOrigin, ValueOriginSet::State state) {
assert(state == ValueOriginSet::State::Defined &&
"undefined and unknown must have been handled above");
auto calleeArgOrigin = cast<ArgumentOriginAttr>(calleeOrigin);
return callerOrigins.join(
operands[calleeArgOrigin.getArgNumber()]->getOriginsObject());
});
}
propagateIfChanged(result, result->merge(callerOrigins));
}
}

void enzyme::BackwardActivityAnnotationAnalysis::setToExitState(
BackwardOriginsLattice *lattice) {
Expand Down Expand Up @@ -205,7 +274,51 @@ void enzyme::BackwardActivityAnnotationAnalysis::visitOperation(

void enzyme::BackwardActivityAnnotationAnalysis::visitExternalCall(
CallOpInterface call, ArrayRef<BackwardOriginsLattice *> operands,
ArrayRef<const BackwardOriginsLattice *> results) {}
ArrayRef<const BackwardOriginsLattice *> results) {
auto symbol = dyn_cast<SymbolRefAttr>(call.getCallableForCallee());
auto markAllOperandsUnknown = [&]() {
for (BackwardOriginsLattice *operand : operands) {
propagateIfChanged(operand, operand->markUnknown());
}
};
if (!symbol)
return markAllOperandsUnknown();

if (auto callee = SymbolTable::lookupNearestSymbolFrom<FunctionOpInterface>(
call, symbol.getLeafReference())) {
if (auto returnOriginsAttr =
callee->getAttrOfType<ArrayAttr>(getReturnOriginsAttrName())) {
SmallVector<ValueOriginSet> returnOrigins;
deserializeReturnOrigins(returnOriginsAttr, returnOrigins);
return processCallToSummarizedFunc(call, returnOrigins, operands,
results);
}
}
}

void enzyme::BackwardActivityAnnotationAnalysis::processCallToSummarizedFunc(
CallOpInterface call, ArrayRef<ValueOriginSet> summary,
ArrayRef<BackwardOriginsLattice *> operands,
ArrayRef<const BackwardOriginsLattice *> results) {
// collect the result origins, propagate them to the operands.
for (const auto &[result, calleeOrigins] : llvm::zip(results, summary)) {
ValueOriginSet resultOrigins = result->getOriginsObject();
if (calleeOrigins.isUndefined())
continue;
if (calleeOrigins.isUnknown())
(void)resultOrigins.markUnknown();
else {
(void)calleeOrigins.foreachElement(
[&](OriginAttr calleeOrigin, ValueOriginSet::State state) {
auto calleeArgOrigin = cast<ArgumentOriginAttr>(calleeOrigin);
BackwardOriginsLattice *operand =
operands[calleeArgOrigin.getArgNumber()];
propagateIfChanged(operand, operand->merge(resultOrigins));
return ChangeResult::NoChange;
});
}
}
}

template <typename KeyT, typename ElementT>
void printMapOfSetsLattice(
Expand Down Expand Up @@ -300,7 +413,6 @@ void enzyme::DenseActivityAnnotationAnalysis::visitOperation(
memory.getEffects(effects);
for (const auto &effect : effects) {
Value value = effect.getValue();
// TODO: may be too pessimistic
if (!value)
return propagateIfChanged(after, after->markAllOriginsUnknown());

Expand Down Expand Up @@ -628,17 +740,6 @@ void enzyme::DenseBackwardActivityAnnotationAnalysis::
}

for (const auto &[destClass, sourceOrigins] : summary) {
// Get the source alias classes
AliasClassSet callerSourceClasses;
for (Attribute sourceOrigin : sourceOrigins.getElements()) {
unsigned argNumber =
cast<ArgumentOriginAttr>(sourceOrigin).getArgNumber();
traversePointsToSets(argumentClasses[argNumber], *p2sets,
[&](DistinctAttr aliasClass) {
(void)callerSourceClasses.insert({aliasClass});
});
}

// Get the destination origins
ValueOriginSet destOrigins;
if (auto pseudoClass = dyn_cast_if_present<PseudoAliasClassAttr>(
Expand All @@ -649,6 +750,37 @@ void enzyme::DenseBackwardActivityAnnotationAnalysis::
after.getOrigins(aliasClass));
});
}

if (destOrigins.isUndefined())
continue;

// Get the source alias classes
AliasClassSet callerSourceClasses;
for (Attribute sourceOrigin : sourceOrigins.getElements()) {
unsigned argNumber =
cast<ArgumentOriginAttr>(sourceOrigin).getArgNumber();

if (argumentClasses[argNumber].isUndefined()) {
// Not a pointer, do a sparse update
raw_ostream &os = llvm::outs();
os << "sparse update dest origins: ";
destOrigins.print(os);
os << "\n";
auto *backwardLattice = getOrCreate<BackwardOriginsLattice>(
call.getArgOperands()[argNumber]);
if (destOrigins.isUnknown()) {
propagateIfChanged(backwardLattice, backwardLattice->markUnknown());
continue;
}
propagateIfChanged(backwardLattice,
backwardLattice->insert(destOrigins.getElements()));
} else {
traversePointsToSets(argumentClasses[argNumber], *p2sets,
[&](DistinctAttr aliasClass) {
(void)callerSourceClasses.insert({aliasClass});
});
}
}
changed |= before->insert(callerSourceClasses, destOrigins);
}
propagateIfChanged(before, changed);
Expand Down Expand Up @@ -715,8 +847,7 @@ void initializeSparseBackwardActivityAnnotations(FunctionOpInterface func,
solver.getOrCreateState<BackwardOriginsLattice>(returnOperand.get());
auto origin = ReturnOriginAttr::get(FlatSymbolRefAttr::get(func),
returnOperand.getOperandNumber());
(void)lattice->join(
BackwardOriginsLattice::single(returnOperand.get(), origin));
(void)lattice->insert({origin});
}
}
}
Expand Down Expand Up @@ -834,8 +965,24 @@ void enzyme::runActivityAnnotations(FunctionOpInterface callee) {
return lattice.serialize(ctx);
});
node->setAttr(
"returnorigins",
getReturnOriginsAttrName(),
ArrayAttr::get(node.getContext(), serializedReturnOperandOrigins));
os << "[ata] return origins: " << node->getAttr("returnorigins") << "\n";
os << "[ata] return origins: " << node->getAttr(getReturnOriginsAttrName())
<< "\n";

node.getCallableRegion()->walk([&](Operation *op) {
if (op->hasAttr("tag")) {
for (OpResult result : op->getResults()) {
auto *sources =
solver.getOrCreateState<enzyme::ForwardOriginsLattice>(result);
auto *sinks =
solver.getOrCreateState<enzyme::BackwardOriginsLattice>(result);
os << op->getAttr("tag") << "(#" << result.getResultNumber()
<< ") sources:\n"
<< *sources << "sinks:\n"
<< *sinks << "\n";
}
}
});
}
}
13 changes: 12 additions & 1 deletion enzyme/Enzyme/MLIR/Analysis/ActivityAnnotations.h
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,11 @@ class ForwardActivityAnnotationAnalysis
void processMemoryRead(Operation *op, Value address,
ArrayRef<ForwardOriginsLattice *> results);

OriginalClasses originalClasses;
void
processCallToSummarizedFunc(CallOpInterface call,
ArrayRef<ValueOriginSet> summary,
ArrayRef<const ForwardOriginsLattice *> operands,
ArrayRef<ForwardOriginsLattice *> results);
};

class BackwardActivityAnnotationAnalysis
Expand All @@ -112,6 +116,13 @@ class BackwardActivityAnnotationAnalysis
visitExternalCall(CallOpInterface call,
ArrayRef<BackwardOriginsLattice *> operands,
ArrayRef<const BackwardOriginsLattice *> results) override;

private:
void
processCallToSummarizedFunc(CallOpInterface call,
ArrayRef<ValueOriginSet> summary,
ArrayRef<BackwardOriginsLattice *> operands,
ArrayRef<const BackwardOriginsLattice *> results);
};

//===----------------------------------------------------------------------===//
Expand Down
70 changes: 70 additions & 0 deletions enzyme/test/MLIR/ActivityAnalysis/Summaries/basic.mlir
Original file line number Diff line number Diff line change
@@ -1,5 +1,62 @@
// RUN: %eopt --print-activity-analysis='use-annotations' --split-input-file %s | FileCheck %s

// CHECK-LABEL: processing function @sparse_callee
// CHECK: "fadd"(#0) sources:
// CHECK: size: 1:
// CHECK: #enzyme.argorigin<@sparse_callee(0)>
// CHECK: sinks:
// CHECK: size: 1:
// CHECK: #enzyme.retorigin<@sparse_callee(1)>
func.func @sparse_callee(%arg0: f64) -> (f64, f64) {
%zero = llvm.mlir.constant (0.0) : f64
%0 = llvm.fadd %arg0, %arg0 {tag = "fadd"} : f64
return %zero, %0 : f64, f64
}

// CHECK-LABEL: processing function @sparse_caller
// CHECK: "fmul"(#0) sources:
// CHECK: size: 1:
// CHECK: #enzyme.argorigin<@sparse_caller(1)>
// CHECK: sinks:
// CHECK: size: 1:
// CHECK: #enzyme.retorigin<@sparse_caller(0)>
func.func @sparse_caller(%unused: i64, %arg0: f64) -> f64 {
%0 = llvm.fmul %arg0, %arg0 {tag = "fmul"} : f64
%zero, %1 = call @sparse_callee(%0) : (f64) -> (f64, f64)
return %1 : f64
}

// -----

func.func @aliased_callee(%arg0: !llvm.ptr) -> !llvm.ptr {
%c0 = llvm.mlir.constant (0) : i64
%0 = llvm.getelementptr inbounds %arg0[%c0] : (!llvm.ptr, i64) -> !llvm.ptr, f64
return %arg0 : !llvm.ptr
}

// Test propagation of aliasing through function calls
func.func @loadstore(%arg0: f64) -> f64 {
%c1 = llvm.mlir.constant (1) : i64
%ptr = llvm.alloca %c1 x f64 : (i64) -> !llvm.ptr
%ptr2 = call @aliased_callee(%ptr) : (!llvm.ptr) -> !llvm.ptr
llvm.store %arg0, %ptr2 : f64, !llvm.ptr
%0 = llvm.load %ptr : !llvm.ptr -> f64
return %0 : f64
}

// -----

llvm.func local_unnamed_addr @malloc(i64 {llvm.noundef}) -> (!llvm.ptr {llvm.noalias, llvm.noundef}) attributes {frame_pointer = #llvm.framePointerKind<"non-leaf">, memory = #llvm.memory_effects<other = none, argMem = none, inaccessibleMem = readwrite>, passthrough = ["mustprogress", "nofree", "nounwind", "willreturn", ["allockind", "9"], ["allocsize", "4294967295"], ["alloc-family", "malloc"], ["approx-func-fp-math", "true"], ["no-infs-fp-math", "true"], ["no-nans-fp-math", "true"], ["no-signed-zeros-fp-math", "true"], ["no-trapping-math", "true"], ["stack-protector-buffer-size", "8"], ["target-cpu", "apple-m1"], ["unsafe-fp-math", "true"]], sym_visibility = "private", target_cpu = "apple-m1", target_features = #llvm.target_features<["+aes", "+complxnum", "+crc", "+dotprod", "+fp-armv8", "+fp16fml", "+fullfp16", "+jsconv", "+lse", "+neon", "+ras", "+rcpc", "+rdm", "+sha2", "+sha3", "+v8.1a", "+v8.2a", "+v8.3a", "+v8.4a", "+v8.5a", "+v8a", "+zcm", "+zcz"]>}

func.func @returnptr(%arg0: f64) -> !llvm.ptr {
%c8 = llvm.mlir.constant (8) : i64
%ptr = llvm.call @malloc(%c8) {tag = "malloc"} : (i64) -> !llvm.ptr
llvm.store %arg0, %ptr : f64, !llvm.ptr
return %ptr : !llvm.ptr
}

// -----

// CHECK-LABEL: processing function @load_nested
// CHECK: forward value origins:
// CHECK: distinct[0]<#enzyme.pseudoclass<@load_nested(1, 0)>> originates from [#enzyme.argorigin<@load_nested(0)>, #enzyme.argorigin<@load_nested(1)>]
Expand All @@ -26,6 +83,19 @@ func.func @pass_pointer_to(%arg0: f64, %alloc: !llvm.ptr, %out: !llvm.ptr) {

// -----

func.func @callee(%val: f64, %out: !llvm.ptr) {
llvm.store %val, %out : f64, !llvm.ptr
return
}

// Test backward summary propagation to scalar arguments
func.func @caller(%unused: i32, %val: f64, %out: !llvm.ptr) {
call @callee(%val, %out) : (f64, !llvm.ptr) -> ()
return
}

// -----

func.func @load_double_nested(%arg0: !llvm.ptr, %arg1: !llvm.ptr) {
%data = llvm.load %arg0 : !llvm.ptr -> !llvm.ptr
%val = llvm.load %data : !llvm.ptr -> f64
Expand Down

0 comments on commit 01a458e

Please sign in to comment.