Skip to content

Commit

Permalink
fix
Browse files Browse the repository at this point in the history
  • Loading branch information
wsmoses committed Oct 9, 2024
1 parent 949bca3 commit 1b34c40
Show file tree
Hide file tree
Showing 4 changed files with 21 additions and 15 deletions.
8 changes: 5 additions & 3 deletions enzyme/Enzyme/MLIR/Interfaces/EnzymeLogic.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,8 @@ using namespace mlir;
using namespace mlir::enzyme;

void createTerminator(MGradientUtils *gutils, mlir::Block *oBB,
const std::vector<bool> &returnPrimals,
const std::vector<bool> &returnShadows) {
const ArrayRef<bool> returnPrimals,
const ArrayRef<bool> returnShadows) {
auto inst = oBB->getTerminator();

mlir::Block *nBB = gutils->getNewFromOriginal(inst->getBlock());
Expand Down Expand Up @@ -100,8 +100,10 @@ FunctionOpInterface mlir::enzyme::MEnzymeLogic::CreateForwardDiff(
for (auto act : RetActivity) {
returnShadows.push_back(act != DIFFE_TYPE::CONSTANT);
}
SmallVector<bool> returnPrimalsP(returnPrimals);
SmallVector<bool> returnShadowsP(returnShadows);
auto gutils = MDiffeGradientUtils::CreateFromClone(
*this, mode, width, fn, TA, type_args, returnPrimals, returnShadows,
*this, mode, width, fn, TA, type_args, returnPrimalsP, returnShadowsP,
RetActivity, ArgActivity, addedType,
/*omp*/ false);
ForwardCachedFunctions[tup] = gutils->newFunc;
Expand Down
5 changes: 4 additions & 1 deletion enzyme/Enzyme/MLIR/Interfaces/EnzymeLogicReverse.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -189,8 +189,11 @@ FunctionOpInterface MEnzymeLogic::CreateReverseDiff(
llvm_unreachable("Differentiating empty function");
}

SmallVector<bool> returnPrimalsP(returnPrimals);
SmallVector<bool> returnShadowsP(returnShadows);

MGradientUtilsReverse *gutils = MGradientUtilsReverse::CreateFromClone(
*this, mode, width, fn, TA, type_args, returnPrimals, returnShadows,
*this, mode, width, fn, TA, type_args, returnPrimalsP, returnShadowsP,
retType, constants, addedType);

Region &oldRegion = gutils->oldFunc.getFunctionBody();
Expand Down
4 changes: 2 additions & 2 deletions enzyme/Enzyme/MLIR/Interfaces/GradientUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,8 @@ using namespace mlir::enzyme;
mlir::enzyme::MGradientUtils::MGradientUtils(
MEnzymeLogic &Logic, FunctionOpInterface newFunc_,
FunctionOpInterface oldFunc_, MTypeAnalysis &TA_, MTypeResults TR_,
IRMapping &invertedPointers_, llvm::ArrayRef<bool> returnPrimals,
llvm::ArrayRef<bool> returnShadows,
IRMapping &invertedPointers_, const llvm::ArrayRef<bool> returnPrimals,
const llvm::ArrayRef<bool> returnShadows,
const SmallPtrSetImpl<mlir::Value> &constantvalues_,
const SmallPtrSetImpl<mlir::Value> &activevals_,
ArrayRef<DIFFE_TYPE> ReturnActivity, ArrayRef<DIFFE_TYPE> ArgDiffeTypes_,
Expand Down
19 changes: 10 additions & 9 deletions enzyme/Enzyme/MLIR/Interfaces/GradientUtils.h
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,8 @@ class MGradientUtils {
MTypeAnalysis &TA;
MTypeResults TR;
bool omp;
llvm::ArrayRef<bool> returnPrimals;
llvm::ArrayRef<bool> returnShadows;
const llvm::ArrayRef<bool> returnPrimals;
const llvm::ArrayRef<bool> returnShadows;

unsigned width;
ArrayRef<DIFFE_TYPE> ArgDiffeTypes;
Expand All @@ -50,8 +50,8 @@ class MGradientUtils {
MGradientUtils(MEnzymeLogic &Logic, FunctionOpInterface newFunc_,
FunctionOpInterface oldFunc_, MTypeAnalysis &TA_,
MTypeResults TR_, IRMapping &invertedPointers_,
llvm::ArrayRef<bool> returnPrimals,
llvm::ArrayRef<bool> returnShadows,
const llvm::ArrayRef<bool> returnPrimals,
const llvm::ArrayRef<bool> returnShadows,
const SmallPtrSetImpl<mlir::Value> &constantvalues_,
const SmallPtrSetImpl<mlir::Value> &activevals_,
ArrayRef<DIFFE_TYPE> ReturnActivities,
Expand Down Expand Up @@ -106,8 +106,8 @@ class MDiffeGradientUtils : public MGradientUtils {
MDiffeGradientUtils(MEnzymeLogic &Logic, FunctionOpInterface newFunc_,
FunctionOpInterface oldFunc_, MTypeAnalysis &TA,
MTypeResults TR, IRMapping &invertedPointers_,
const std::vector<bool> &returnPrimals,
const std::vector<bool> &returnShadows,
const llvm::ArrayRef<bool> returnPrimals,
const llvm::ArrayRef<bool> returnShadows,
const SmallPtrSetImpl<mlir::Value> &constantvalues_,
const SmallPtrSetImpl<mlir::Value> &activevals_,
ArrayRef<DIFFE_TYPE> RetActivity,
Expand All @@ -124,9 +124,10 @@ class MDiffeGradientUtils : public MGradientUtils {
static MDiffeGradientUtils *CreateFromClone(
MEnzymeLogic &Logic, DerivativeMode mode, unsigned width,
FunctionOpInterface todiff, MTypeAnalysis &TA, MFnTypeInfo &oldTypeInfo,
const std::vector<bool> &returnPrimals,
const std::vector<bool> &returnShadows, ArrayRef<DIFFE_TYPE> RetActivity,
ArrayRef<DIFFE_TYPE> ArgActivity, mlir::Type additionalArg, bool omp) {
const llvm::ArrayRef<bool> returnPrimals,
const llvm::ArrayRef<bool> returnShadows,
ArrayRef<DIFFE_TYPE> RetActivity, ArrayRef<DIFFE_TYPE> ArgActivity,
mlir::Type additionalArg, bool omp) {
std::string prefix;

switch (mode) {
Expand Down

0 comments on commit 1b34c40

Please sign in to comment.