Skip to content

Commit

Permalink
MLIR handle multiple distinct return activities
Browse files Browse the repository at this point in the history
  • Loading branch information
wsmoses committed Mar 9, 2024
1 parent 73d862a commit 954a209
Show file tree
Hide file tree
Showing 17 changed files with 314 additions and 342 deletions.
25 changes: 23 additions & 2 deletions enzyme/Enzyme/MLIR/Analysis/ActivityAnalysis.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3073,7 +3073,18 @@ bool mlir::enzyme::ActivityAnalyzer::isValueInactiveFromUsers(
continue;
}
if (isFunctionReturn(a)) {
if (ActiveReturns == DIFFE_TYPE::CONSTANT) {
bool allConstant = true;
assert(a->getOperands().size() == ActiveReturns.size());
for (auto &&[retval, act] :
llvm::zip(a->getOperands(), ActiveReturns)) {
if (retval != parent)
continue;
if (act == DIFFE_TYPE::CONSTANT)
continue;
allConstant = false;
break;
}
if (allConstant) {
continue;
} else {
return false;
Expand Down Expand Up @@ -3424,7 +3435,17 @@ bool mlir::enzyme::ActivityAnalyzer::isValueActivelyStoredOrReturned(
return false;
}
if (isFunctionReturn(a)) {
if (ActiveReturns == DIFFE_TYPE::CONSTANT)
bool allConstant = true;
assert(a->getOperands().size() == ActiveReturns.size());
for (auto &&[retval, act] : llvm::zip(a->getOperands(), ActiveReturns)) {
if (retval != val)
continue;
if (act == DIFFE_TYPE::CONSTANT)
continue;
allConstant = false;
break;
}
if (allConstant)
continue;

if (EnzymePrintActivity)
Expand Down
4 changes: 2 additions & 2 deletions enzyme/Enzyme/MLIR/Analysis/ActivityAnalysis.h
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ class ActivityAnalyzer {

public:
/// Whether the returns of the function being analyzed are active
const DIFFE_TYPE ActiveReturns;
const llvm::ArrayRef<DIFFE_TYPE> ActiveReturns;

private:
/// Direction of current analysis
Expand Down Expand Up @@ -71,7 +71,7 @@ class ActivityAnalyzer {
// llvm::TargetLibraryInfo &TLI_,
const llvm::SmallPtrSetImpl<Value> &ConstantValues,
const llvm::SmallPtrSetImpl<Value> &ActiveValues,
DIFFE_TYPE ActiveReturns)
llvm::ArrayRef<DIFFE_TYPE> ActiveReturns)
: notForAnalysis(notForAnalysis_), ActiveReturns(ActiveReturns),
directions(UP | DOWN),
ConstantValues(ConstantValues.begin(), ConstantValues.end()),
Expand Down
136 changes: 53 additions & 83 deletions enzyme/Enzyme/MLIR/Interfaces/CloneFunction.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,73 +12,50 @@ Type getShadowType(Type type, unsigned width) {
exit(1);
}

mlir::FunctionType getFunctionTypeForClone(
mlir::FunctionType FTy, DerivativeMode mode, unsigned width,
mlir::Type additionalArg, llvm::ArrayRef<DIFFE_TYPE> constant_args,
bool diffeReturnArg, ReturnType returnValue, DIFFE_TYPE ReturnType) {
mlir::FunctionType
getFunctionTypeForClone(mlir::FunctionType FTy, DerivativeMode mode,
unsigned width, mlir::Type additionalArg,
const std::vector<bool> &returnPrimals,
const std::vector<bool> &returnShadows,
llvm::ArrayRef<DIFFE_TYPE> ReturnActivity,
llvm::ArrayRef<DIFFE_TYPE> ArgActivity) {

SmallVector<mlir::Type, 4> RetTypes;
if (returnValue == ReturnType::ArgsWithReturn ||
returnValue == ReturnType::Return) {
assert(FTy.getNumResults() >= 1);
for (size_t i = 0; i < FTy.getNumResults(); i++) {
if (ReturnType != DIFFE_TYPE::CONSTANT &&
ReturnType != DIFFE_TYPE::OUT_DIFF) {
RetTypes.push_back(getShadowType(FTy.getResult(i), width));
} else {
RetTypes.push_back(FTy.getResult(i));
}
}
} else if (returnValue == ReturnType::ArgsWithTwoReturns ||
returnValue == ReturnType::TwoReturns) {
assert(FTy.getNumResults() >= 1);
for (size_t i = 0; i < FTy.getNumResults(); i++) {
RetTypes.push_back(FTy.getResult(i));
if (ReturnType != DIFFE_TYPE::CONSTANT &&
ReturnType != DIFFE_TYPE::OUT_DIFF) {
RetTypes.push_back(getShadowType(FTy.getResult(i), width));
} else {
RetTypes.push_back(FTy.getResult(i));
}

for (auto &&[Ty, returnPrimal, returnShadow, activity] : llvm::zip(
FTy.getResults(), returnPrimals, returnShadows, ReturnActivity)) {
if (returnPrimal)
RetTypes.push_back(Ty);
if (returnShadow) {
assert(activity != DIFFE_TYPE::CONSTANT);
assert(activity != DIFFE_TYPE::OUT_DIFF);
RetTypes.push_back(getShadowType(Ty, width));
}
}

SmallVector<mlir::Type, 4> ArgTypes;

// The user might be deleting arguments to the function by specifying them in
// the VMap. If so, we need to not add the arguments to the arg ty vector
unsigned argno = 0;

for (auto I : FTy.getInputs()) {
ArgTypes.push_back(I);
if (constant_args[argno] == DIFFE_TYPE::DUP_ARG ||
constant_args[argno] == DIFFE_TYPE::DUP_NONEED) {
ArgTypes.push_back(getShadowType(I, width));
} else if (constant_args[argno] == DIFFE_TYPE::OUT_DIFF) {
RetTypes.push_back(getShadowType(I, width));
for (auto &&[ITy, act] : llvm::zip(FTy.getInputs(), ArgActivity)) {
ArgTypes.push_back(ITy);
if (act == DIFFE_TYPE::DUP_ARG || act == DIFFE_TYPE::DUP_NONEED) {
ArgTypes.push_back(getShadowType(ITy, width));
} else if (act == DIFFE_TYPE::OUT_DIFF) {
RetTypes.push_back(getShadowType(ITy, width));
}
++argno;
}

// TODO: Expand for multiple returns
if (diffeReturnArg) {
ArgTypes.push_back(getShadowType(FTy.getResult(0), width));
for (auto &&[Ty, activity] : llvm::zip(FTy.getResults(), ReturnActivity)) {
if (activity == DIFFE_TYPE::OUT_DIFF) {
ArgTypes.push_back(getShadowType(Ty, width));
}
}

if (additionalArg) {
ArgTypes.push_back(additionalArg);
}

OpBuilder builder(FTy.getContext());
if (returnValue == ReturnType::TapeAndTwoReturns ||
returnValue == ReturnType::TapeAndReturn) {
RetTypes.insert(RetTypes.begin(),
LLVM::LLVMPointerType::get(FTy.getContext()));
} else if (returnValue == ReturnType::Tape) {
for (auto I : FTy.getInputs()) {
RetTypes.push_back(I);
}
}

// Create a new function type...
OpBuilder builder(FTy.getContext());
return builder.getFunctionType(ArgTypes, RetTypes);
}

Expand Down Expand Up @@ -205,19 +182,20 @@ void cloneInto(Region *src, Region *dest, Region::iterator destPos,

FunctionOpInterface CloneFunctionWithReturns(
DerivativeMode mode, unsigned width, FunctionOpInterface F,
IRMapping &ptrInputs, ArrayRef<DIFFE_TYPE> constant_args,
IRMapping &ptrInputs, ArrayRef<DIFFE_TYPE> ArgActivity,
SmallPtrSetImpl<mlir::Value> &constants,
SmallPtrSetImpl<mlir::Value> &nonconstants,
SmallPtrSetImpl<mlir::Value> &returnvals, ReturnType returnValue,
DIFFE_TYPE DReturnType, Twine name, IRMapping &VMap,
std::map<Operation *, Operation *> &OpMap, bool diffeReturnArg,
SmallPtrSetImpl<mlir::Value> &returnvals,
const std::vector<bool> &returnPrimals,
const std::vector<bool> &returnShadows, ArrayRef<DIFFE_TYPE> RetActivity,
Twine name, IRMapping &VMap, std::map<Operation *, Operation *> &OpMap,
mlir::Type additionalArg) {
assert(!F.getFunctionBody().empty());
// F = preprocessForClone(F, mode);
// llvm::ValueToValueMapTy VMap;
auto FTy = getFunctionTypeForClone(
F.getFunctionType().cast<mlir::FunctionType>(), mode, width,
additionalArg, constant_args, diffeReturnArg, returnValue, DReturnType);
additionalArg, returnPrimals, returnShadows, RetActivity, ArgActivity);

/*
for (Block &BB : F.getFunctionBody().getBlocks()) {
Expand All @@ -244,32 +222,33 @@ FunctionOpInterface CloneFunctionWithReturns(

{
auto &blk = NewF.getFunctionBody().front();
assert(F.getFunctionBody().front().getNumArguments() ==
constant_args.size());
for (ssize_t i = constant_args.size() - 1; i >= 0; i--) {
assert(F.getFunctionBody().front().getNumArguments() == ArgActivity.size());
for (ssize_t i = ArgActivity.size() - 1; i >= 0; i--) {
mlir::Value oval = F.getFunctionBody().front().getArgument(i);
if (constant_args[i] == DIFFE_TYPE::CONSTANT)
if (ArgActivity[i] == DIFFE_TYPE::CONSTANT)
constants.insert(oval);
else if (constant_args[i] == DIFFE_TYPE::OUT_DIFF)
else if (ArgActivity[i] == DIFFE_TYPE::OUT_DIFF)
nonconstants.insert(oval);
else if (constant_args[i] == DIFFE_TYPE::DUP_ARG ||
constant_args[i] == DIFFE_TYPE::DUP_NONEED) {
else if (ArgActivity[i] == DIFFE_TYPE::DUP_ARG ||
ArgActivity[i] == DIFFE_TYPE::DUP_NONEED) {
nonconstants.insert(oval);
mlir::Value val = blk.getArgument(i);
mlir::Value dval;
if (i == constant_args.size() - 1)
if (i == ArgActivity.size() - 1)
dval = blk.addArgument(val.getType(), val.getLoc());
else
dval = blk.insertArgument(blk.args_begin() + i + 1, val.getType(),
val.getLoc());
ptrInputs.map(oval, dval);
}
}
// TODO: Add support for mulitple outputs?
if (diffeReturnArg) {
auto location = blk.getArgument(blk.getNumArguments() - 1).getLoc();
auto val = F.getFunctionType().cast<mlir::FunctionType>().getResult(0);
blk.addArgument(val, location);
for (auto &&[Ty, activity] :
llvm::zip(F.getFunctionType().cast<mlir::FunctionType>().getResults(),
RetActivity)) {
if (activity == DIFFE_TYPE::OUT_DIFF) {
auto location = blk.getArgument(blk.getNumArguments() - 1).getLoc();
blk.addArgument(getShadowType(Ty, width), location);
}
}
}

Expand All @@ -285,15 +264,7 @@ FunctionOpInterface CloneFunctionWithReturns(
size_t oldi = 0;
size_t newi = 0;
while (oldi < F.getNumResults()) {
bool primalReturn = returnValue == ReturnType::ArgsWithReturn ||
returnValue == ReturnType::ArgsWithTwoReturns ||
(returnValue == ReturnType::TapeAndReturn &&
DReturnType == DIFFE_TYPE::CONSTANT) ||
returnValue == ReturnType::TapeAndTwoReturns ||
returnValue == ReturnType::TwoReturns ||
(returnValue == ReturnType::Return &&
DReturnType == DIFFE_TYPE::CONSTANT);
if (primalReturn) {
if (returnPrimals[oldi]) {
for (auto attrName : ToClone) {
auto attrNameS = StringAttr::get(F->getContext(), attrName);
NewF.removeResultAttr(newi, attrNameS);
Expand All @@ -310,8 +281,7 @@ FunctionOpInterface CloneFunctionWithReturns(
}
newi++;
}
if (DReturnType == DIFFE_TYPE::DUP_ARG ||
DReturnType == DIFFE_TYPE::DUP_NONEED) {
if (returnShadows[oldi]) {
for (auto attrName : ToClone) {
auto attrNameS = StringAttr::get(F->getContext(), attrName);
NewF.removeResultAttr(newi, attrNameS);
Expand Down Expand Up @@ -350,8 +320,8 @@ FunctionOpInterface CloneFunctionWithReturns(
}

newi++;
if (constant_args[oldi] == DIFFE_TYPE::DUP_ARG ||
constant_args[oldi] == DIFFE_TYPE::DUP_NONEED) {
if (ArgActivity[oldi] == DIFFE_TYPE::DUP_ARG ||
ArgActivity[oldi] == DIFFE_TYPE::DUP_NONEED) {

for (auto attrName : ToClone) {
NewF.removeArgAttr(newi, attrName);
Expand Down
22 changes: 13 additions & 9 deletions enzyme/Enzyme/MLIR/Interfaces/CloneFunction.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,13 @@ using namespace mlir::enzyme;

Type getShadowType(Type type, unsigned width = 1);

mlir::FunctionType getFunctionTypeForClone(
mlir::FunctionType FTy, DerivativeMode mode, unsigned width,
mlir::Type additionalArg, llvm::ArrayRef<DIFFE_TYPE> constant_args,
bool diffeReturnArg, ReturnType returnValue, DIFFE_TYPE ReturnType);
mlir::FunctionType
getFunctionTypeForClone(mlir::FunctionType FTy, DerivativeMode mode,
unsigned width, mlir::Type additionalArg,
llvm::ArrayRef<bool> returnPrimals,
llvm::ArrayRef<bool> returnShadows,
llvm::ArrayRef<DIFFE_TYPE> ReturnActivity,
llvm::ArrayRef<DIFFE_TYPE> ArgActivity);

void cloneInto(Region *src, Region *dest, Region::iterator destPos,
IRMapping &mapper, std::map<Operation *, Operation *> &opMap);
Expand All @@ -41,10 +44,11 @@ Operation *clone(Operation *src, IRMapping &mapper,

FunctionOpInterface CloneFunctionWithReturns(
DerivativeMode mode, unsigned width, FunctionOpInterface F,
IRMapping &ptrInputs, ArrayRef<DIFFE_TYPE> constant_args,
IRMapping &ptrInputs, ArrayRef<DIFFE_TYPE> ArgActivity,
SmallPtrSetImpl<mlir::Value> &constants,
SmallPtrSetImpl<mlir::Value> &nonconstants,
SmallPtrSetImpl<mlir::Value> &returnvals, ReturnType returnValue,
DIFFE_TYPE ReturnType, Twine name, IRMapping &VMap,
std::map<Operation *, Operation *> &OpMap, bool diffeReturnArg,
mlir::Type additionalArg);
SmallPtrSetImpl<mlir::Value> &returnvals,
const std::vector<bool> &returnPrimals,
const std::vector<bool> &returnShadows, ArrayRef<DIFFE_TYPE> ReturnActivity,
Twine name, IRMapping &VMap, std::map<Operation *, Operation *> &OpMap,
mlir::Type additionalArg);
Loading

0 comments on commit 954a209

Please sign in to comment.