From 01a458e1dba8cb28d00d8ff42ebbaaf71f23c913 Mon Sep 17 00:00:00 2001 From: Jacob Peng Date: Mon, 26 Feb 2024 17:50:42 -0500 Subject: [PATCH] Add summary processing for sparse backwards annotations --- .../MLIR/Analysis/ActivityAnnotations.cpp | 183 ++++++++++++++++-- .../MLIR/Analysis/ActivityAnnotations.h | 13 +- .../ActivityAnalysis/Summaries/basic.mlir | 70 +++++++ 3 files changed, 247 insertions(+), 19 deletions(-) diff --git a/enzyme/Enzyme/MLIR/Analysis/ActivityAnnotations.cpp b/enzyme/Enzyme/MLIR/Analysis/ActivityAnnotations.cpp index 59b596d270fd..dd860cf02ec3 100644 --- a/enzyme/Enzyme/MLIR/Analysis/ActivityAnnotations.cpp +++ b/enzyme/Enzyme/MLIR/Analysis/ActivityAnnotations.cpp @@ -19,6 +19,7 @@ using namespace mlir; static StringRef getActivityAnnotationAttrName() { return "activedeps"; } static StringRef getPointerSummaryAttrName() { return "p2psummary"; } +static StringRef getReturnOriginsAttrName() { return "returnorigins"; } template void printSetLattice(const enzyme::SparseSetLattice &setLattice, @@ -147,9 +148,77 @@ void enzyme::ForwardActivityAnnotationAnalysis::processMemoryRead( } } +void deserializeReturnOrigins(ArrayAttr returnOrigins, + SmallVectorImpl &out) { + for (auto &&[resultIdx, argOrigins] : llvm::enumerate(returnOrigins)) { + enzyme::ValueOriginSet origins; + if (auto strAttr = dyn_cast(argOrigins)) { + if (strAttr.getValue() == "") { + (void)origins.markUnknown(); + } else { + // Leave origins undefined + } + } else { + for (enzyme::ArgumentOriginAttr originAttr : + cast(argOrigins) + .getAsRange()) { + (void)origins.insert({originAttr}); + } + } + + out.push_back(origins); + } +} + void enzyme::ForwardActivityAnnotationAnalysis::visitExternalCall( CallOpInterface call, ArrayRef operands, - ArrayRef results) {} + ArrayRef results) { + auto symbol = dyn_cast(call.getCallableForCallee()); + auto markAllResultsUnknown = [&]() { + for (ForwardOriginsLattice *result : results) { + propagateIfChanged(result, result->markUnknown()); + } + }; + if (!symbol) + return markAllResultsUnknown(); + + if (auto callee = SymbolTable::lookupNearestSymbolFrom( + call, symbol.getLeafReference())) { + if (auto returnOriginsAttr = + callee->getAttrOfType(getReturnOriginsAttrName())) { + SmallVector returnOrigins; + deserializeReturnOrigins(returnOriginsAttr, returnOrigins); + return processCallToSummarizedFunc(call, returnOrigins, operands, + results); + } + } +} + +void enzyme::ForwardActivityAnnotationAnalysis::processCallToSummarizedFunc( + CallOpInterface call, ArrayRef summary, + ArrayRef operands, + ArrayRef results) { + for (const auto &[result, returnOrigin] : llvm::zip(results, summary)) { + // Convert the origins relative to the callee to relative to the caller + ValueOriginSet callerOrigins; + if (returnOrigin.isUndefined()) + continue; + + if (returnOrigin.isUnknown()) { + (void)callerOrigins.markUnknown(); + } else { + (void)returnOrigin.foreachElement( + [&](OriginAttr calleeOrigin, ValueOriginSet::State state) { + assert(state == ValueOriginSet::State::Defined && + "undefined and unknown must have been handled above"); + auto calleeArgOrigin = cast(calleeOrigin); + return callerOrigins.join( + operands[calleeArgOrigin.getArgNumber()]->getOriginsObject()); + }); + } + propagateIfChanged(result, result->merge(callerOrigins)); + } +} void enzyme::BackwardActivityAnnotationAnalysis::setToExitState( BackwardOriginsLattice *lattice) { @@ -205,7 +274,51 @@ void enzyme::BackwardActivityAnnotationAnalysis::visitOperation( void enzyme::BackwardActivityAnnotationAnalysis::visitExternalCall( CallOpInterface call, ArrayRef operands, - ArrayRef results) {} + ArrayRef results) { + auto symbol = dyn_cast(call.getCallableForCallee()); + auto markAllOperandsUnknown = [&]() { + for (BackwardOriginsLattice *operand : operands) { + propagateIfChanged(operand, operand->markUnknown()); + } + }; + if (!symbol) + return markAllOperandsUnknown(); + + if (auto callee = SymbolTable::lookupNearestSymbolFrom( + call, symbol.getLeafReference())) { + if (auto returnOriginsAttr = + callee->getAttrOfType(getReturnOriginsAttrName())) { + SmallVector returnOrigins; + deserializeReturnOrigins(returnOriginsAttr, returnOrigins); + return processCallToSummarizedFunc(call, returnOrigins, operands, + results); + } + } +} + +void enzyme::BackwardActivityAnnotationAnalysis::processCallToSummarizedFunc( + CallOpInterface call, ArrayRef summary, + ArrayRef operands, + ArrayRef results) { + // collect the result origins, propagate them to the operands. + for (const auto &[result, calleeOrigins] : llvm::zip(results, summary)) { + ValueOriginSet resultOrigins = result->getOriginsObject(); + if (calleeOrigins.isUndefined()) + continue; + if (calleeOrigins.isUnknown()) + (void)resultOrigins.markUnknown(); + else { + (void)calleeOrigins.foreachElement( + [&](OriginAttr calleeOrigin, ValueOriginSet::State state) { + auto calleeArgOrigin = cast(calleeOrigin); + BackwardOriginsLattice *operand = + operands[calleeArgOrigin.getArgNumber()]; + propagateIfChanged(operand, operand->merge(resultOrigins)); + return ChangeResult::NoChange; + }); + } + } +} template void printMapOfSetsLattice( @@ -300,7 +413,6 @@ void enzyme::DenseActivityAnnotationAnalysis::visitOperation( memory.getEffects(effects); for (const auto &effect : effects) { Value value = effect.getValue(); - // TODO: may be too pessimistic if (!value) return propagateIfChanged(after, after->markAllOriginsUnknown()); @@ -628,17 +740,6 @@ void enzyme::DenseBackwardActivityAnnotationAnalysis:: } for (const auto &[destClass, sourceOrigins] : summary) { - // Get the source alias classes - AliasClassSet callerSourceClasses; - for (Attribute sourceOrigin : sourceOrigins.getElements()) { - unsigned argNumber = - cast(sourceOrigin).getArgNumber(); - traversePointsToSets(argumentClasses[argNumber], *p2sets, - [&](DistinctAttr aliasClass) { - (void)callerSourceClasses.insert({aliasClass}); - }); - } - // Get the destination origins ValueOriginSet destOrigins; if (auto pseudoClass = dyn_cast_if_present( @@ -649,6 +750,37 @@ void enzyme::DenseBackwardActivityAnnotationAnalysis:: after.getOrigins(aliasClass)); }); } + + if (destOrigins.isUndefined()) + continue; + + // Get the source alias classes + AliasClassSet callerSourceClasses; + for (Attribute sourceOrigin : sourceOrigins.getElements()) { + unsigned argNumber = + cast(sourceOrigin).getArgNumber(); + + if (argumentClasses[argNumber].isUndefined()) { + // Not a pointer, do a sparse update + raw_ostream &os = llvm::outs(); + os << "sparse update dest origins: "; + destOrigins.print(os); + os << "\n"; + auto *backwardLattice = getOrCreate( + call.getArgOperands()[argNumber]); + if (destOrigins.isUnknown()) { + propagateIfChanged(backwardLattice, backwardLattice->markUnknown()); + continue; + } + propagateIfChanged(backwardLattice, + backwardLattice->insert(destOrigins.getElements())); + } else { + traversePointsToSets(argumentClasses[argNumber], *p2sets, + [&](DistinctAttr aliasClass) { + (void)callerSourceClasses.insert({aliasClass}); + }); + } + } changed |= before->insert(callerSourceClasses, destOrigins); } propagateIfChanged(before, changed); @@ -715,8 +847,7 @@ void initializeSparseBackwardActivityAnnotations(FunctionOpInterface func, solver.getOrCreateState(returnOperand.get()); auto origin = ReturnOriginAttr::get(FlatSymbolRefAttr::get(func), returnOperand.getOperandNumber()); - (void)lattice->join( - BackwardOriginsLattice::single(returnOperand.get(), origin)); + (void)lattice->insert({origin}); } } } @@ -834,8 +965,24 @@ void enzyme::runActivityAnnotations(FunctionOpInterface callee) { return lattice.serialize(ctx); }); node->setAttr( - "returnorigins", + getReturnOriginsAttrName(), ArrayAttr::get(node.getContext(), serializedReturnOperandOrigins)); - os << "[ata] return origins: " << node->getAttr("returnorigins") << "\n"; + os << "[ata] return origins: " << node->getAttr(getReturnOriginsAttrName()) + << "\n"; + + node.getCallableRegion()->walk([&](Operation *op) { + if (op->hasAttr("tag")) { + for (OpResult result : op->getResults()) { + auto *sources = + solver.getOrCreateState(result); + auto *sinks = + solver.getOrCreateState(result); + os << op->getAttr("tag") << "(#" << result.getResultNumber() + << ") sources:\n" + << *sources << "sinks:\n" + << *sinks << "\n"; + } + } + }); } } diff --git a/enzyme/Enzyme/MLIR/Analysis/ActivityAnnotations.h b/enzyme/Enzyme/MLIR/Analysis/ActivityAnnotations.h index 25acc1af7106..933137a96f02 100644 --- a/enzyme/Enzyme/MLIR/Analysis/ActivityAnnotations.h +++ b/enzyme/Enzyme/MLIR/Analysis/ActivityAnnotations.h @@ -86,7 +86,11 @@ class ForwardActivityAnnotationAnalysis void processMemoryRead(Operation *op, Value address, ArrayRef results); - OriginalClasses originalClasses; + void + processCallToSummarizedFunc(CallOpInterface call, + ArrayRef summary, + ArrayRef operands, + ArrayRef results); }; class BackwardActivityAnnotationAnalysis @@ -112,6 +116,13 @@ class BackwardActivityAnnotationAnalysis visitExternalCall(CallOpInterface call, ArrayRef operands, ArrayRef results) override; + +private: + void + processCallToSummarizedFunc(CallOpInterface call, + ArrayRef summary, + ArrayRef operands, + ArrayRef results); }; //===----------------------------------------------------------------------===// diff --git a/enzyme/test/MLIR/ActivityAnalysis/Summaries/basic.mlir b/enzyme/test/MLIR/ActivityAnalysis/Summaries/basic.mlir index 1f26ad481d7c..231a77a2f07b 100644 --- a/enzyme/test/MLIR/ActivityAnalysis/Summaries/basic.mlir +++ b/enzyme/test/MLIR/ActivityAnalysis/Summaries/basic.mlir @@ -1,5 +1,62 @@ // RUN: %eopt --print-activity-analysis='use-annotations' --split-input-file %s | FileCheck %s +// CHECK-LABEL: processing function @sparse_callee +// CHECK: "fadd"(#0) sources: +// CHECK: size: 1: +// CHECK: #enzyme.argorigin<@sparse_callee(0)> +// CHECK: sinks: +// CHECK: size: 1: +// CHECK: #enzyme.retorigin<@sparse_callee(1)> +func.func @sparse_callee(%arg0: f64) -> (f64, f64) { + %zero = llvm.mlir.constant (0.0) : f64 + %0 = llvm.fadd %arg0, %arg0 {tag = "fadd"} : f64 + return %zero, %0 : f64, f64 +} + +// CHECK-LABEL: processing function @sparse_caller +// CHECK: "fmul"(#0) sources: +// CHECK: size: 1: +// CHECK: #enzyme.argorigin<@sparse_caller(1)> +// CHECK: sinks: +// CHECK: size: 1: +// CHECK: #enzyme.retorigin<@sparse_caller(0)> +func.func @sparse_caller(%unused: i64, %arg0: f64) -> f64 { + %0 = llvm.fmul %arg0, %arg0 {tag = "fmul"} : f64 + %zero, %1 = call @sparse_callee(%0) : (f64) -> (f64, f64) + return %1 : f64 +} + +// ----- + +func.func @aliased_callee(%arg0: !llvm.ptr) -> !llvm.ptr { + %c0 = llvm.mlir.constant (0) : i64 + %0 = llvm.getelementptr inbounds %arg0[%c0] : (!llvm.ptr, i64) -> !llvm.ptr, f64 + return %arg0 : !llvm.ptr +} + +// Test propagation of aliasing through function calls +func.func @loadstore(%arg0: f64) -> f64 { + %c1 = llvm.mlir.constant (1) : i64 + %ptr = llvm.alloca %c1 x f64 : (i64) -> !llvm.ptr + %ptr2 = call @aliased_callee(%ptr) : (!llvm.ptr) -> !llvm.ptr + llvm.store %arg0, %ptr2 : f64, !llvm.ptr + %0 = llvm.load %ptr : !llvm.ptr -> f64 + return %0 : f64 +} + +// ----- + +llvm.func local_unnamed_addr @malloc(i64 {llvm.noundef}) -> (!llvm.ptr {llvm.noalias, llvm.noundef}) attributes {frame_pointer = #llvm.framePointerKind<"non-leaf">, memory = #llvm.memory_effects, passthrough = ["mustprogress", "nofree", "nounwind", "willreturn", ["allockind", "9"], ["allocsize", "4294967295"], ["alloc-family", "malloc"], ["approx-func-fp-math", "true"], ["no-infs-fp-math", "true"], ["no-nans-fp-math", "true"], ["no-signed-zeros-fp-math", "true"], ["no-trapping-math", "true"], ["stack-protector-buffer-size", "8"], ["target-cpu", "apple-m1"], ["unsafe-fp-math", "true"]], sym_visibility = "private", target_cpu = "apple-m1", target_features = #llvm.target_features<["+aes", "+complxnum", "+crc", "+dotprod", "+fp-armv8", "+fp16fml", "+fullfp16", "+jsconv", "+lse", "+neon", "+ras", "+rcpc", "+rdm", "+sha2", "+sha3", "+v8.1a", "+v8.2a", "+v8.3a", "+v8.4a", "+v8.5a", "+v8a", "+zcm", "+zcz"]>} + +func.func @returnptr(%arg0: f64) -> !llvm.ptr { + %c8 = llvm.mlir.constant (8) : i64 + %ptr = llvm.call @malloc(%c8) {tag = "malloc"} : (i64) -> !llvm.ptr + llvm.store %arg0, %ptr : f64, !llvm.ptr + return %ptr : !llvm.ptr +} + +// ----- + // CHECK-LABEL: processing function @load_nested // CHECK: forward value origins: // CHECK: distinct[0]<#enzyme.pseudoclass<@load_nested(1, 0)>> originates from [#enzyme.argorigin<@load_nested(0)>, #enzyme.argorigin<@load_nested(1)>] @@ -26,6 +83,19 @@ func.func @pass_pointer_to(%arg0: f64, %alloc: !llvm.ptr, %out: !llvm.ptr) { // ----- +func.func @callee(%val: f64, %out: !llvm.ptr) { + llvm.store %val, %out : f64, !llvm.ptr + return +} + +// Test backward summary propagation to scalar arguments +func.func @caller(%unused: i32, %val: f64, %out: !llvm.ptr) { + call @callee(%val, %out) : (f64, !llvm.ptr) -> () + return +} + +// ----- + func.func @load_double_nested(%arg0: !llvm.ptr, %arg1: !llvm.ptr) { %data = llvm.load %arg0 : !llvm.ptr -> !llvm.ptr %val = llvm.load %data : !llvm.ptr -> f64