Skip to content

Commit

Permalink
Preserve primal shadow info
Browse files Browse the repository at this point in the history
  • Loading branch information
wsmoses committed Oct 9, 2024
1 parent 152441a commit 46bb5c1
Show file tree
Hide file tree
Showing 3 changed files with 17 additions and 7 deletions.
4 changes: 3 additions & 1 deletion enzyme/Enzyme/MLIR/Interfaces/GradientUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,8 @@ using namespace mlir::enzyme;
mlir::enzyme::MGradientUtils::MGradientUtils(
MEnzymeLogic &Logic, FunctionOpInterface newFunc_,
FunctionOpInterface oldFunc_, MTypeAnalysis &TA_, MTypeResults TR_,
IRMapping &invertedPointers_,
IRMapping &invertedPointers_, llvm::ArrayRef<bool> returnPrimals,
llvm::ArrayRef<bool> returnShadows,
const SmallPtrSetImpl<mlir::Value> &constantvalues_,
const SmallPtrSetImpl<mlir::Value> &activevals_,
ArrayRef<DIFFE_TYPE> ReturnActivity, ArrayRef<DIFFE_TYPE> ArgDiffeTypes_,
Expand All @@ -40,6 +41,7 @@ mlir::enzyme::MGradientUtils::MGradientUtils(
: newFunc(newFunc_), Logic(Logic), mode(mode), oldFunc(oldFunc_),
invertedPointers(invertedPointers_), originalToNewFn(originalToNewFn_),
originalToNewFnOps(originalToNewFnOps_), blocksNotForAnalysis(),
returnPrimals(returnPrimals), returnShadows(returnShadows),
activityAnalyzer(std::make_unique<enzyme::ActivityAnalyzer>(
blocksNotForAnalysis, constantvalues_, activevals_, ReturnActivity)),
TA(TA_), TR(TR_), omp(omp), width(width), ArgDiffeTypes(ArgDiffeTypes_),
Expand Down
18 changes: 12 additions & 6 deletions enzyme/Enzyme/MLIR/Interfaces/GradientUtils.h
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,9 @@ class MGradientUtils {
MTypeAnalysis &TA;
MTypeResults TR;
bool omp;
llvm::ArrayRef<bool> returnPrimals, llvm::ArrayRef<bool> returnShadows,

unsigned width;
unsigned width;
ArrayRef<DIFFE_TYPE> ArgDiffeTypes;
ArrayRef<DIFFE_TYPE> RetDiffeTypes;

Expand All @@ -48,6 +49,8 @@ class MGradientUtils {
MGradientUtils(MEnzymeLogic &Logic, FunctionOpInterface newFunc_,
FunctionOpInterface oldFunc_, MTypeAnalysis &TA_,
MTypeResults TR_, IRMapping &invertedPointers_,
llvm::ArrayRef<bool> returnPrimals,
llvm::ArrayRef<bool> returnShadows,
const SmallPtrSetImpl<mlir::Value> &constantvalues_,
const SmallPtrSetImpl<mlir::Value> &activevals_,
ArrayRef<DIFFE_TYPE> ReturnActivities,
Expand Down Expand Up @@ -102,15 +105,18 @@ class MDiffeGradientUtils : public MGradientUtils {
MDiffeGradientUtils(MEnzymeLogic &Logic, FunctionOpInterface newFunc_,
FunctionOpInterface oldFunc_, MTypeAnalysis &TA,
MTypeResults TR, IRMapping &invertedPointers_,
const std::vector<bool> &returnPrimals,
const std::vector<bool> &returnShadows,
const SmallPtrSetImpl<mlir::Value> &constantvalues_,
const SmallPtrSetImpl<mlir::Value> &activevals_,
ArrayRef<DIFFE_TYPE> RetActivity,
ArrayRef<DIFFE_TYPE> ArgActivity, IRMapping &origToNew_,
std::map<Operation *, Operation *> &origToNewOps_,
DerivativeMode mode, unsigned width, bool omp)
: MGradientUtils(Logic, newFunc_, oldFunc_, TA, TR, invertedPointers_,
constantvalues_, activevals_, RetActivity, ArgActivity,
origToNew_, origToNewOps_, mode, width, omp),
returnPrimals, returnShadows, constantvalues_,
activevals_, RetActivity, ArgActivity, origToNew_,
origToNewOps_, mode, width, omp),
initializationBlock(&*(newFunc.getFunctionBody().begin())) {}

// Technically diffe constructor
Expand Down Expand Up @@ -153,9 +159,9 @@ class MDiffeGradientUtils : public MGradientUtils {
additionalArg);
MTypeResults TR; // TODO
return new MDiffeGradientUtils(
Logic, newFunc, todiff, TA, TR, invertedPointers, constant_values,
nonconstant_values, RetActivity, ArgActivity, originalToNew,
originalToNewOps, mode, width, omp);
Logic, newFunc, todiff, TA, TR, invertedPointers, returnPrimals,
returnShadows, constant_values, nonconstant_values, RetActivity,
ArgActivity, originalToNew, originalToNewOps, mode, width, omp);
}
};

Expand Down
2 changes: 2 additions & 0 deletions enzyme/Enzyme/MLIR/Passes/EnzymeWrapPass.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,8 @@ using namespace mlir::enzyme;
using namespace enzyme;

std::vector<DIFFE_TYPE> parseActivityString(StringRef inp) {
if (inp.size() == 0)
return {};
std::vector<DIFFE_TYPE> ArgActivity;
SmallVector<StringRef, 1> split;
StringRef(inp.data(), inp.size()).split(split, ',');
Expand Down

0 comments on commit 46bb5c1

Please sign in to comment.