-
Notifications
You must be signed in to change notification settings - Fork 108
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
18 changed files
with
4,575 additions
and
337 deletions.
There are no files selected for viewing
1,278 changes: 1,278 additions & 0 deletions
1,278
enzyme/Enzyme/MLIR/Analysis/ActivityAnnotations.cpp
Large diffs are not rendered by default.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,234 @@ | ||
#ifndef ENZYME_MLIR_ANALYSIS_ACTIVITYANNOTATIONS_H | ||
#define ENZYME_MLIR_ANALYSIS_ACTIVITYANNOTATIONS_H | ||
|
||
#include "AliasAnalysis.h" | ||
#include "Dialect/Ops.h" | ||
#include "Lattice.h" | ||
|
||
#include "mlir/Analysis/DataFlow/DenseAnalysis.h" | ||
#include "mlir/Analysis/DataFlow/SparseAnalysis.h" | ||
#include "mlir/Analysis/DataFlowFramework.h" | ||
|
||
namespace mlir { | ||
class FunctionOpInterface; | ||
|
||
namespace enzyme { | ||
|
||
using ValueOriginSet = SetLattice<OriginAttr>; | ||
|
||
//===----------------------------------------------------------------------===// | ||
// ForwardOriginsLattice | ||
//===----------------------------------------------------------------------===// | ||
|
||
// TODO: specialize this to only arguments | ||
class ForwardOriginsLattice : public SparseSetLattice<OriginAttr> { | ||
public: | ||
using SparseSetLattice::SparseSetLattice; | ||
|
||
static ForwardOriginsLattice single(Value point, OriginAttr value) { | ||
return ForwardOriginsLattice(point, SetLattice<OriginAttr>(value)); | ||
} | ||
|
||
void print(raw_ostream &os) const override; | ||
|
||
ChangeResult join(const AbstractSparseLattice &other) override; | ||
|
||
const DenseSet<OriginAttr> &getOrigins() const { | ||
return elements.getElements(); | ||
} | ||
|
||
const SetLattice<OriginAttr> &getOriginsObject() const { return elements; } | ||
}; | ||
|
||
class BackwardOriginsLattice : public SparseSetLattice<OriginAttr> { | ||
public: | ||
using SparseSetLattice::SparseSetLattice; | ||
|
||
static BackwardOriginsLattice single(Value point, OriginAttr value) { | ||
return BackwardOriginsLattice(point, SetLattice<OriginAttr>(value)); | ||
} | ||
|
||
void print(raw_ostream &os) const override; | ||
|
||
ChangeResult meet(const AbstractSparseLattice &other) override { | ||
// MLIR framework again misusing terminology | ||
const auto *otherValueOrigins = | ||
static_cast<const BackwardOriginsLattice *>(&other); | ||
return elements.join(otherValueOrigins->elements); | ||
} | ||
|
||
const DenseSet<OriginAttr> &getOrigins() const { | ||
return elements.getElements(); | ||
} | ||
|
||
const SetLattice<OriginAttr> &getOriginsObject() const { return elements; } | ||
}; | ||
|
||
class ForwardActivityAnnotationAnalysis | ||
: public dataflow::SparseForwardDataFlowAnalysis<ForwardOriginsLattice> { | ||
public: | ||
ForwardActivityAnnotationAnalysis(DataFlowSolver &solver) | ||
: SparseForwardDataFlowAnalysis(solver) { | ||
assert(!solver.getConfig().isInterprocedural()); | ||
} | ||
|
||
void setToEntryState(ForwardOriginsLattice *lattice) override; | ||
|
||
void visitOperation(Operation *op, | ||
ArrayRef<const ForwardOriginsLattice *> operands, | ||
ArrayRef<ForwardOriginsLattice *> results) override; | ||
|
||
void visitExternalCall(CallOpInterface call, | ||
ArrayRef<const ForwardOriginsLattice *> operands, | ||
ArrayRef<ForwardOriginsLattice *> results) override; | ||
|
||
private: | ||
void processMemoryRead(Operation *op, Value address, | ||
ArrayRef<ForwardOriginsLattice *> results); | ||
|
||
void | ||
processCallToSummarizedFunc(CallOpInterface call, | ||
ArrayRef<ValueOriginSet> summary, | ||
ArrayRef<const ForwardOriginsLattice *> operands, | ||
ArrayRef<ForwardOriginsLattice *> results); | ||
}; | ||
|
||
class BackwardActivityAnnotationAnalysis | ||
: public dataflow::SparseBackwardDataFlowAnalysis<BackwardOriginsLattice> { | ||
public: | ||
BackwardActivityAnnotationAnalysis(DataFlowSolver &solver, | ||
SymbolTableCollection &symbolTable) | ||
: SparseBackwardDataFlowAnalysis(solver, symbolTable) { | ||
assert(!solver.getConfig().isInterprocedural()); | ||
} | ||
|
||
void visitBranchOperand(OpOperand &operand) override {} | ||
|
||
void visitCallOperand(OpOperand &operand) override {} | ||
|
||
void setToExitState(BackwardOriginsLattice *lattice) override; | ||
|
||
void | ||
visitOperation(Operation *op, ArrayRef<BackwardOriginsLattice *> operands, | ||
ArrayRef<const BackwardOriginsLattice *> results) override; | ||
|
||
void | ||
visitExternalCall(CallOpInterface call, | ||
ArrayRef<BackwardOriginsLattice *> operands, | ||
ArrayRef<const BackwardOriginsLattice *> results) override; | ||
|
||
private: | ||
void | ||
processCallToSummarizedFunc(CallOpInterface call, | ||
ArrayRef<ValueOriginSet> summary, | ||
ArrayRef<BackwardOriginsLattice *> operands, | ||
ArrayRef<const BackwardOriginsLattice *> results); | ||
}; | ||
|
||
//===----------------------------------------------------------------------===// | ||
// ForwardOriginsMap | ||
//===----------------------------------------------------------------------===// | ||
|
||
class ForwardOriginsMap : public MapOfSetsLattice<DistinctAttr, OriginAttr> { | ||
public: | ||
using MapOfSetsLattice::MapOfSetsLattice; | ||
|
||
void print(raw_ostream &os) const override; | ||
|
||
ChangeResult markAllOriginsUnknown() { return markAllUnknown(); } | ||
|
||
const ValueOriginSet &getOrigins(DistinctAttr id) const { return lookup(id); } | ||
}; | ||
|
||
class BackwardOriginsMap : public MapOfSetsLattice<DistinctAttr, OriginAttr> { | ||
public: | ||
using MapOfSetsLattice::MapOfSetsLattice; | ||
|
||
void print(raw_ostream &os) const override; | ||
|
||
ChangeResult markAllOriginsUnknown() { return markAllUnknown(); } | ||
|
||
const ValueOriginSet &getOrigins(DistinctAttr id) const { return lookup(id); } | ||
|
||
ChangeResult meet(const AbstractDenseLattice &other) override { | ||
return join(other); | ||
} | ||
}; | ||
|
||
//===----------------------------------------------------------------------===// | ||
// DenseActivityAnnotationAnalysis | ||
//===----------------------------------------------------------------------===// | ||
|
||
class DenseActivityAnnotationAnalysis | ||
: public dataflow::DenseForwardDataFlowAnalysis<ForwardOriginsMap> { | ||
public: | ||
using DenseForwardDataFlowAnalysis::DenseForwardDataFlowAnalysis; | ||
|
||
void setToEntryState(ForwardOriginsMap *lattice) override; | ||
|
||
void visitOperation(Operation *op, const ForwardOriginsMap &before, | ||
ForwardOriginsMap *after) override; | ||
|
||
void visitCallControlFlowTransfer(CallOpInterface call, | ||
dataflow::CallControlFlowAction action, | ||
const ForwardOriginsMap &before, | ||
ForwardOriginsMap *after) override; | ||
|
||
private: | ||
void processCallToSummarizedFunc( | ||
CallOpInterface call, | ||
const DenseMap<DistinctAttr, ValueOriginSet> &summary, | ||
const ForwardOriginsMap &before, ForwardOriginsMap *after); | ||
|
||
void processCopy(Operation *op, Value copySource, Value copyDest, | ||
const ForwardOriginsMap &before, ForwardOriginsMap *after); | ||
|
||
OriginalClasses originalClasses; | ||
}; | ||
|
||
class DenseBackwardActivityAnnotationAnalysis | ||
: public dataflow::DenseBackwardDataFlowAnalysis<BackwardOriginsMap> { | ||
public: | ||
using DenseBackwardDataFlowAnalysis::DenseBackwardDataFlowAnalysis; | ||
|
||
void visitOperation(Operation *op, const BackwardOriginsMap &after, | ||
BackwardOriginsMap *before) override; | ||
|
||
void visitCallControlFlowTransfer(CallOpInterface call, | ||
dataflow::CallControlFlowAction action, | ||
const BackwardOriginsMap &after, | ||
BackwardOriginsMap *before) override; | ||
|
||
void setToExitState(BackwardOriginsMap *lattice) override; | ||
|
||
private: | ||
void processCallToSummarizedFunc( | ||
CallOpInterface call, | ||
const DenseMap<DistinctAttr, ValueOriginSet> &summary, | ||
const BackwardOriginsMap &after, BackwardOriginsMap *before); | ||
|
||
void processCopy(Operation *op, Value copySource, Value copyDest, | ||
const BackwardOriginsMap &after, BackwardOriginsMap *before); | ||
}; | ||
|
||
class ActivityPrinterConfig { | ||
public: | ||
ActivityPrinterConfig() = default; | ||
|
||
/// Output extra information for debugging | ||
bool verbose = false; | ||
/// Annotate the IR with activity information for every operation. Currently | ||
/// only supports the LLVM dialect. | ||
bool annotate = false; | ||
/// Infer the starting argument state from an __enzyme_autodiff call. | ||
bool inferFromAutodiff = false; | ||
}; | ||
|
||
void runActivityAnnotations( | ||
FunctionOpInterface callee, ArrayRef<enzyme::Activity> argActivities = {}, | ||
const ActivityPrinterConfig &config = ActivityPrinterConfig()); | ||
|
||
} // namespace enzyme | ||
} // namespace mlir | ||
|
||
#endif // ENZYME_MLIR_ANALYSIS_ACTIVITYANNOTATIONS_H |
Oops, something went wrong.