Skip to content

Commit

Permalink
TypeAnalysis permit passing info as fn parm attrs
Browse files Browse the repository at this point in the history
  • Loading branch information
wsmoses committed Feb 20, 2024
1 parent 1ea2e0b commit ef18744
Show file tree
Hide file tree
Showing 5 changed files with 213 additions and 34 deletions.
53 changes: 29 additions & 24 deletions enzyme/Enzyme/AdjointGenerator.h
Original file line number Diff line number Diff line change
Expand Up @@ -4757,10 +4757,11 @@ class AdjointGenerator : public llvm::InstVisitor<AdjointGenerator> {
// call.getParamAttr(i, Attribute::StructRet).getValueAsType()));
#endif
}
if (call.getAttributes().hasParamAttr(i, "enzymejl_returnRoots")) {
structAttrs[args.size()].push_back(
call.getParamAttr(i, "enzymejl_returnRoots"));
}
for (auto attr : {"enzymejl_returnRoots", "enzymejl_parmtype",
"enzymejl_parmtype_ref", "enzyme_type"})
if (call.getAttributes().hasParamAttr(i, attr)) {
structAttrs[args.size()].push_back(call.getParamAttr(i, attr));
}
for (auto ty : PrimalParamAttrsToPreserve)
if (call.getAttributes().hasParamAttr(i, ty)) {
auto attr = call.getAttributes().getParamAttr(i, ty);
Expand Down Expand Up @@ -4815,15 +4816,16 @@ class AdjointGenerator : public llvm::InstVisitor<AdjointGenerator> {
structAttrs[args.size()].push_back(attr);
}

if (call.getAttributes().hasParamAttr(i, "enzymejl_returnRoots")) {
if (gutils->getWidth() == 1) {
structAttrs[args.size()].push_back(
call.getParamAttr(i, "enzymejl_returnRoots"));
} else {
structAttrs[args.size()].push_back(
Attribute::get(call.getContext(), "enzyme_sret_v"));
for (auto attr : {"enzymejl_returnRoots", "enzymejl_parmtype",
"enzymejl_parmtype_ref", "enzyme_type"})
if (call.getAttributes().hasParamAttr(i, attr)) {
if (gutils->getWidth() == 1) {
structAttrs[args.size()].push_back(call.getParamAttr(i, attr));
} else if (attr == "enzymejl_returnRoots") {
structAttrs[args.size()].push_back(
Attribute::get(call.getContext(), "enzymejl_returnRoots_v"));
}
}
}
if (call.paramHasAttr(i, Attribute::StructRet)) {
if (gutils->getWidth() == 1) {
structAttrs[args.size()].push_back(
Expand Down Expand Up @@ -5050,10 +5052,11 @@ class AdjointGenerator : public llvm::InstVisitor<AdjointGenerator> {
if (call.isByValArgument(i)) {
preByVal[pre_args.size()] = call.getParamByValType(i);
}
if (call.getAttributes().hasParamAttr(i, "enzymejl_returnRoots")) {
structAttrs[pre_args.size()].push_back(
call.getParamAttr(i, "enzymejl_returnRoots"));
}
for (auto attr : {"enzymejl_returnRoots", "enzymejl_parmtype",
"enzymejl_parmtype_ref", "enzyme_type"})
if (call.getAttributes().hasParamAttr(i, attr)) {
structAttrs[pre_args.size()].push_back(call.getParamAttr(i, attr));
}
if (call.paramHasAttr(i, Attribute::StructRet)) {
structAttrs[pre_args.size()].push_back(
#if LLVM_VERSION_MAJOR >= 12
Expand Down Expand Up @@ -5146,15 +5149,17 @@ class AdjointGenerator : public llvm::InstVisitor<AdjointGenerator> {
structAttrs[pre_args.size()].push_back(attr);
}

if (call.getAttributes().hasParamAttr(i, "enzymejl_returnRoots")) {
if (gutils->getWidth() == 1) {
structAttrs[pre_args.size()].push_back(
call.getParamAttr(i, "enzymejl_returnRoots"));
} else {
structAttrs[pre_args.size()].push_back(
Attribute::get(call.getContext(), "enzymejl_returnRoots_v"));
for (auto attr : {"enzymejl_returnRoots", "enzymejl_parmtype",
"enzymejl_parmtype_ref", "enzyme_type"})
if (call.getAttributes().hasParamAttr(i, attr)) {
if (gutils->getWidth() == 1) {
structAttrs[pre_args.size()].push_back(
call.getParamAttr(i, attr));
} else if (attr == "enzymejl_returnRoots") {
structAttrs[pre_args.size()].push_back(
Attribute::get(call.getContext(), "enzymejl_returnRoots_v"));
}
}
}
if (call.paramHasAttr(i, Attribute::StructRet)) {
if (gutils->getWidth() == 1) {
structAttrs[pre_args.size()].push_back(
Expand Down
18 changes: 17 additions & 1 deletion enzyme/Enzyme/EnzymeLogic.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,7 @@

#if LLVM_VERSION_MAJOR >= 14
#define addAttribute addAttributeAtIndex
#define getAttribute getAttributeAtIndex
#define removeAttribute removeAttributeAtIndex
#endif

Expand Down Expand Up @@ -2784,7 +2785,8 @@ const AugmentedReturn &EnzymeLogic::CreateAugmentedPrimal(
NewF->addParamAttr(attrIndex, Attribute::NoAlias);
}
for (auto name : {"enzyme_sret", "enzyme_sret_v", "enzymejl_returnRoots",
"enzymejl_returnRoots_v"})
"enzymejl_returnRoots_v", "enzymejl_parmtype",
"enzymejl_parmtype_ref", "enzyme_type"})
if (nf->getAttributes().hasParamAttr(attrIndex, name)) {
NewF->addParamAttr(attrIndex,
nf->getAttributes().getParamAttr(attrIndex, name));
Expand All @@ -2796,6 +2798,20 @@ const AugmentedReturn &EnzymeLogic::CreateAugmentedPrimal(
++attrIndex;
}

for (auto attr : {"enzyme_ta_norecur"})
if (nf->getAttributes().hasAttribute(AttributeList::FunctionIndex, attr)) {
NewF->addFnAttr(
nf->getAttributes().getAttribute(AttributeList::FunctionIndex, attr));
}

for (auto attr :
{"enzyme_type", "enzymejl_parmtype", "enzymejl_parmtype_ref"})
if (nf->getAttributes().hasAttribute(AttributeList::ReturnIndex, attr)) {
NewF->addAttribute(
AttributeList::ReturnIndex,
nf->getAttributes().getAttribute(AttributeList::ReturnIndex, attr));
}

SmallVector<ReturnInst *, 4> Returns;
#if LLVM_VERSION_MAJOR >= 13
CloneFunctionInto(NewF, nf, VMap, CloneFunctionChangeType::LocalChangesOnly,
Expand Down
22 changes: 20 additions & 2 deletions enzyme/Enzyme/FunctionUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2180,8 +2180,24 @@ Function *PreProcessCache::CloneFunctionWithReturns(
VMapO->getMDMap() = VMap.getMDMap();
}

for (auto attr : {"enzyme_ta_norecur"})
if (F->getAttributes().hasAttribute(AttributeList::FunctionIndex, attr)) {
NewF->addAttribute(
AttributeList::FunctionIndex,
F->getAttributes().getAttribute(AttributeList::FunctionIndex, attr));
}

for (auto attr :
{"enzyme_type", "enzymejl_parmtype", "enzymejl_parmtype_ref"})
if (F->getAttributes().hasAttribute(AttributeList::ReturnIndex, attr)) {
NewF->addAttribute(
AttributeList::ReturnIndex,
F->getAttributes().getAttribute(AttributeList::ReturnIndex, attr));
}

bool hasPtrInput = false;
unsigned ii = 0, jj = 0;

for (auto i = F->arg_begin(), j = NewF->arg_begin(); i != F->arg_end();) {
if (F->hasParamAttribute(ii, Attribute::StructRet)) {
NewF->addParamAttr(jj, Attribute::get(F->getContext(), "enzyme_sret"));
Expand All @@ -2204,7 +2220,8 @@ Function *PreProcessCache::CloneFunctionWithReturns(
// Attribute::ElementType));
#endif
}
for (auto attr : {"enzymejl_parmtype", "enzymejl_parmtype_ref"})
for (auto attr :
{"enzymejl_parmtype", "enzymejl_parmtype_ref", "enzyme_type"})
if (F->getAttributes().hasParamAttr(ii, attr)) {
NewF->addParamAttr(jj, F->getAttributes().getParamAttr(ii, attr));
for (auto ty : PrimalParamAttrsToPreserve)
Expand Down Expand Up @@ -2250,7 +2267,8 @@ Function *PreProcessCache::CloneFunctionWithReturns(
NewF->addParamAttr(jj + 1, attr);
}

for (auto attr : {"enzymejl_parmtype", "enzymejl_parmtype_ref"})
for (auto attr :
{"enzymejl_parmtype", "enzymejl_parmtype_ref", "enzyme_type"})
if (F->getAttributes().hasParamAttr(ii, attr)) {
if (width == 1)
NewF->addParamAttr(jj + 1,
Expand Down
67 changes: 61 additions & 6 deletions enzyme/Enzyme/TypeAnalysis/TypeAnalysis.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,12 @@

#include <math.h>

#if LLVM_VERSION_MAJOR >= 14
#define getAttribute getAttributeAtIndex
#define hasAttribute hasAttributeAtIndex
#define addAttribute addAttributeAtIndex
#endif

using namespace llvm;

extern "C" {
Expand Down Expand Up @@ -1207,18 +1213,57 @@ void TypeAnalyzer::considerTBAA() {
}

if (CallBase *call = dyn_cast<CallBase>(&I)) {
#if LLVM_VERSION_MAJOR >= 14
size_t num_args = call->arg_size();
#else
size_t num_args = call->getNumArgOperands();
#endif

if (call->getAttributes().hasAttribute(AttributeList::ReturnIndex,
"enzyme_type")) {
auto attr = call->getAttributes().getAttribute(
AttributeList::ReturnIndex, "enzyme_type");
auto TT =
TypeTree::parse(attr.getValueAsString(), call->getContext());
updateAnalysis(call, TT, call);
}
for (size_t i = 0; i < num_args; i++) {
if (call->getAttributes().hasParamAttr(i, "enzyme_type")) {
auto attr = call->getAttributes().getParamAttr(i, "enzyme_type");
auto TT =
TypeTree::parse(attr.getValueAsString(), call->getContext());
updateAnalysis(call, TT, call);
}
}

Function *F = call->getCalledFunction();

if (F) {
if (F->getAttributes().hasAttribute(AttributeList::ReturnIndex,
"enzyme_type")) {
auto attr = F->getAttributes().getAttribute(
AttributeList::ReturnIndex, "enzyme_type");
auto TT =
TypeTree::parse(attr.getValueAsString(), call->getContext());
updateAnalysis(call, TT, call);
}
size_t f_num_args = F->arg_size();
for (size_t i = 0; i < f_num_args; i++) {
if (F->getAttributes().hasParamAttr(i, "enzyme_type")) {
auto attr = F->getAttributes().getParamAttr(i, "enzyme_type");
auto TT =
TypeTree::parse(attr.getValueAsString(), call->getContext());
updateAnalysis(call, TT, call);
}
}
}

if (auto castinst = dyn_cast<ConstantExpr>(call->getCalledOperand())) {
if (castinst->isCast())
if (auto fn = dyn_cast<Function>(castinst->getOperand(0))) {
F = fn;
}
}
#if LLVM_VERSION_MAJOR >= 14
size_t num_args = call->arg_size();
#else
size_t num_args = call->getNumArgOperands();
#endif
if (F && F->getName().contains("__enzyme_float")) {
assert(num_args == 1 || num_args == 2);
assert(call->getArgOperand(0)->getType()->isPointerTy());
Expand Down Expand Up @@ -4356,9 +4401,16 @@ void TypeAnalyzer::visitCallBase(CallBase &call) {
}
}

if (call.hasFnAttr("enzyme_ta_norecur"))
return;

Function *ci = getFunctionFromCall(&call);

if (ci) {
if (ci->getAttributes().hasAttribute(AttributeList::FunctionIndex,
"enzyme_ta_norecur"))
return;

StringRef funcName = getFuncNameFromCall(&call);

auto blasMetaData = extractBLAS(funcName);
Expand All @@ -4367,6 +4419,9 @@ void TypeAnalyzer::visitCallBase(CallBase &call) {
#include "BlasTA.inc"
}

// Manual TT specification is non-interprocedural and already handled once
// at the start.

// When compiling Enzyme against standard LLVM, and not Intel's
// modified version of LLVM, the intrinsic `llvm.intel.subscript` is
// not fully understood by LLVM. One of the results of this is that the
Expand Down Expand Up @@ -5596,7 +5651,7 @@ bool TypeAnalyzer::mustRemainInteger(Value *val, bool *returned) {
FnTypeInfo TypeAnalyzer::getCallInfo(CallBase &call, Function &fn) {
FnTypeInfo typeInfo(&fn);

int argnum = 0;
size_t argnum = 0;
for (auto &arg : fn.args()) {
if (argnum >= call.arg_size()) {
typeInfo.Arguments.insert(
Expand Down
87 changes: 86 additions & 1 deletion enzyme/Enzyme/TypeAnalysis/TypeTree.h
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,92 @@ class TypeTree : public std::enable_shared_from_this<TypeTree> {
}
}

static TypeTree parse(llvm::StringRef str, llvm::LLVMContext &ctx) {
using namespace llvm;
assert(str[0] == '{');
str = str.substr(1);

TypeTree Result;
while (true) {
while (str[0] == ' ')
str = str.substr(1);
if (str[0] == '}')
break;

assert(str[0] == '[');
str = str.substr(1);

std::vector<int> idxs;
while (true) {
while (str[0] == ' ')
str = str.substr(1);
if (str[0] == ']') {
str = str.substr(1);
break;
}

int idx;
bool failed = str.consumeInteger(10, idx);
(void)failed;
assert(!failed);
idxs.push_back(idx);

while (str[0] == ' ')
str = str.substr(1);

if (str[0] == ',') {
str = str.substr(1);
}
}

while (str[0] == ' ')
str = str.substr(1);

assert(str[0] == ':');
str = str.substr(1);

while (str[0] == ' ')
str = str.substr(1);

auto endval = str.find(',');
auto endval2 = str.find('}');
auto endval3 = str.find(' ');

if (endval2 != StringRef::npos &&
(endval == StringRef::npos || endval2 < endval))
endval = endval2;
if (endval3 != StringRef::npos &&
(endval == StringRef::npos || endval3 < endval))
endval = endval3;
assert(endval != StringRef::npos);

auto tystr = str.substr(0, endval);
str = str.substr(endval);

ConcreteType CT(tystr, ctx);
Result.mapping.emplace(idxs, CT);
if (Result.minIndices.size() < idxs.size()) {
for (size_t i = Result.minIndices.size(), end = idxs.size(); i < end;
++i) {
Result.minIndices.push_back(idxs[i]);
}
}
for (size_t i = 0, end = idxs.size(); i < end; ++i) {
if (idxs[i] < Result.minIndices[i])
Result.minIndices[i] = idxs[i];
}

while (str[0] == ' ')
str = str.substr(1);

if (str[0] == ',') {
str = str.substr(1);
}
}

return Result;
}

/// Utility helper to lookup the mapping
const ConcreteTypeMapType &getMapping() const { return mapping; }

Expand Down Expand Up @@ -1228,7 +1314,6 @@ class TypeTree : public std::enable_shared_from_this<TypeTree> {
if (found != RHS.mapping.end()) {
RightCT = found->second;
}

bool SubLegal = true;
changed |= CT.binopIn(SubLegal, RightCT, Op);
if (!SubLegal) {
Expand Down

0 comments on commit ef18744

Please sign in to comment.