Skip to content

Commit

Permalink
Handle extractions in tablegen
Browse files Browse the repository at this point in the history
  • Loading branch information
wsmoses committed Apr 1, 2024
1 parent 381ceed commit 611ca37
Show file tree
Hide file tree
Showing 5 changed files with 155 additions and 30 deletions.
10 changes: 10 additions & 0 deletions enzyme/Enzyme/Enzyme.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1789,6 +1789,16 @@ class EnzymeBase {
csts.push_back(ConstantFP::get(e, 1.0));
}
args.push_back(ConstantStruct::get(ST, csts));
} else if (auto AT = dyn_cast<ArrayType>(fn->getReturnType())) {
SmallVector<Constant *, 2> csts(
AT->getNumElements(), ConstantFP::get(AT->getElementType(), 1.0));
args.push_back(ConstantArray::get(AT, csts));
} else {
auto RT = fn->getReturnType();
EmitFailure("EnzymeCallingError", CI->getDebugLoc(), CI,
"Differential return required for call ", *CI,
" but one of type ", *RT, " could not be auto deduced");
return false;
}
}

Expand Down
15 changes: 15 additions & 0 deletions enzyme/Enzyme/GradientUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4907,6 +4907,21 @@ Type *GradientUtils::getShadowType(Type *ty) {
return getShadowType(ty, width);
}

Type *GradientUtils::extractMeta(Type *T, ArrayRef<unsigned> off) {
for (auto idx : off) {
if (auto AT = dyn_cast<ArrayType>(T)) {
T = AT->getElementType();
continue;
}
if (auto ST = dyn_cast<StructType>(T)) {
T = ST->getElementType(idx);
continue;
}
assert(false && "could not sub index into type");
}
return T;
}

Value *GradientUtils::extractMeta(IRBuilder<> &Builder, Value *Agg,
unsigned off, const Twine &name) {
return extractMeta(Builder, Agg, ArrayRef<unsigned>({off}), name);
Expand Down
3 changes: 3 additions & 0 deletions enzyme/Enzyme/GradientUtils.h
Original file line number Diff line number Diff line change
Expand Up @@ -536,6 +536,9 @@ class GradientUtils : public CacheUtility {
const llvm::Twine &name = "",
bool fallback = true);

//! Helper routine to get the type of an extraction
static llvm::Type *extractMeta(llvm::Type *T, llvm::ArrayRef<unsigned> off);

static llvm::Value *recursiveFAdd(llvm::IRBuilder<> &B, llvm::Value *lhs,
llvm::Value *rhs,
llvm::ArrayRef<unsigned> lhs_off = {},
Expand Down
20 changes: 20 additions & 0 deletions enzyme/test/Enzyme/ReverseMode/cmplx_sqrt.ll
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
; RUN: if [ %llvmver -lt 16 ]; then %opt < %s %loadEnzyme -enzyme-preopt=false -enzyme -mem2reg -instsimplify -simplifycfg -S | FileCheck %s; fi
; RUN: %opt < %s %newLoadEnzyme -enzyme-preopt=false -passes="enzyme,function(mem2reg,instsimplify,%simplifycfg)" -S | FileCheck %s

declare [2 x double] @cmplx_sqrt([2 x double] %x)

define [2 x double] @tester([2 x double] %x) {
entry:
%y = call [2 x double] @cmplx_sqrt([2 x double] %x)
ret [2 x double] %y
}

define [2 x double] @test_derivative([2 x double] %x) {
entry:
%0 = tail call [2 x double] (...) @__enzyme_autodiff([2 x double] ([2 x double])* nonnull @tester, metadata !"enzyme_active_return", [2 x double] %x)
ret [2 x double] %0
}

declare [2 x double] @__enzyme_autodiff(...)

; CHECK: define internal { [2 x double] } @diffetester([2 x double] %x, [2 x double] %differeturn)
137 changes: 107 additions & 30 deletions enzyme/tools/enzyme-tblgen/enzyme-tblgen.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -191,22 +191,27 @@ void initializeNames(const Twine &curIndent, raw_ostream &os, Init *resultTree,
struct VariableSetting {
StringMap<std::string> nameToOrdinal;
StringMap<bool> isVector;
StringMap<std::vector<int>> extractions;

std::pair<std::string, bool> lookup(StringRef name, Record *pattern,
Init *resultRoot) {
std::tuple<std::string, bool, std::vector<int>>
lookup(StringRef name, Record *pattern, Init *resultRoot) {
auto ord = nameToOrdinal.find(name);
if (ord == nameToOrdinal.end())
PrintFatalError(pattern->getLoc(), Twine("unknown named operand '") +
name + "'" +
resultRoot->getAsString());
auto iv = isVector.find(name);
assert(iv != isVector.end());
return std::make_pair(ord->getValue(), iv->getValue());

auto ext = extractions.find(name);
assert(ext != extractions.end());
return std::make_tuple(ord->getValue(), iv->getValue(), ext->getValue());
}

void insert(StringRef name, StringRef value, bool vec) {
void insert(StringRef name, StringRef value, bool vec, std::vector<int> ext) {
nameToOrdinal[name] = value;
isVector[name] = vec;
extractions[name] = ext;
}
};

Expand All @@ -231,11 +236,17 @@ SmallVector<bool, 1> prepareArgs(const Twine &curIndent, raw_ostream &os,
os << curIndent << "auto " << argName << "_" << idx << " = ";
idx++;
if (isa<UnsetInit>(args) && names) {
auto [ord, vecValue] =
auto [ord, vecValue, ext] =
nameToOrdinal.lookup(names->getValue(), pattern, resultRoot);
if (!vecValue && !startsWith(ord, "local")) {

if (ext.size()) {
os << "gutils->extractMeta(" << builder << ", ";
}

if (lookup && intrinsic != MLIRDerivatives)
os << "lookup(";

if (newFromOriginal && (!lookup || intrinsic != MLIRDerivatives))
os << "gutils->getNewFromOriginal(";
}
Expand All @@ -250,8 +261,19 @@ SmallVector<bool, 1> prepareArgs(const Twine &curIndent, raw_ostream &os,
if (!vecValue && !startsWith(ord, "local")) {
if (newFromOriginal && (!lookup || intrinsic != MLIRDerivatives))
os << ")";

if (lookup && intrinsic != MLIRDerivatives)
os << ", " << builder << ")";

if (ext.size()) {
os << ", ArrayRef<unsigned>({";
for (unsigned i = 0; i < ext.size(); i++) {
if (i != 0)
os << ", ";
os << std::to_string(ext[i]);
}
os << "}))";
}
}
os << ";\n";
vectorValued.push_back(vecValue);
Expand All @@ -263,7 +285,7 @@ SmallVector<bool, 1> prepareArgs(const Twine &curIndent, raw_ostream &os,
os << ";\n";
if (names) {
auto name = names->getAsUnquotedString();
nameToOrdinal.insert(name, "local_" + name, vectorValued.back());
nameToOrdinal.insert(name, "local_" + name, vectorValued.back(), {});
os << curIndent << "local_" << name << " = " << argName << "_"
<< (idx - 1) << ";\n";
}
Expand Down Expand Up @@ -347,8 +369,10 @@ bool handle(const Twine &curIndent, const Twine &argPattern, raw_ostream &os,

if (resultRoot->getArgName(0)) {
auto name = resultRoot->getArgName(0)->getAsUnquotedString();
auto [ord, isVec] = nameToOrdinal.lookup(name, pattern, resultRoot);
auto [ord, isVec, ext] =
nameToOrdinal.lookup(name, pattern, resultRoot);
assert(!isVec);
assert(ext.size() == 0);
os << ord;
} else
PrintFatalError(pattern->getLoc(),
Expand All @@ -368,8 +392,10 @@ bool handle(const Twine &curIndent, const Twine &argPattern, raw_ostream &os,

if (isa<UnsetInit>(resultRoot->getArg(0)) && resultRoot->getArgName(0)) {
auto name = resultRoot->getArgName(0)->getAsUnquotedString();
auto [ord, isVec] = nameToOrdinal.lookup(name, pattern, resultRoot);
auto [ord, isVec, ext] =
nameToOrdinal.lookup(name, pattern, resultRoot);
assert(!isVec);
assert(!ext.size());
os << ord;
} else
handle(curIndent + INDENT, argPattern + "_vs", os, pattern,
Expand Down Expand Up @@ -399,8 +425,11 @@ bool handle(const Twine &curIndent, const Twine &argPattern, raw_ostream &os,

if (isa<UnsetInit>(resultRoot->getArg(0)) && resultRoot->getArgName(0)) {
auto name = resultRoot->getArgName(0)->getAsUnquotedString();
auto [ord, isVec] = nameToOrdinal.lookup(name, pattern, resultRoot);
auto [ord, isVec, ext] =
nameToOrdinal.lookup(name, pattern, resultRoot);
assert(!isVec);
// This assumes that activity of inner extractions are the same as
// outer. assert(!ext.size());
os << ord;
} else
assert("Requires name for arg");
Expand All @@ -413,8 +442,10 @@ bool handle(const Twine &curIndent, const Twine &argPattern, raw_ostream &os,
if (isa<UnsetInit>(resultRoot->getArg(i)) &&
resultRoot->getArgName(i)) {
auto name = resultRoot->getArgName(i)->getAsUnquotedString();
auto [ord, isVec] = nameToOrdinal.lookup(name, pattern, resultRoot);
auto [ord, isVec, ext] =
nameToOrdinal.lookup(name, pattern, resultRoot);
vector = isVec;
assert(!ext.size());
os << ord;
} else
vector = handle(curIndent + INDENT + INDENT,
Expand Down Expand Up @@ -466,8 +497,10 @@ bool handle(const Twine &curIndent, const Twine &argPattern, raw_ostream &os,
ord = "op->getResult(0)";
} else {
auto name = resultRoot->getArgName(0)->getAsUnquotedString();
auto [ord1, isVec] = nameToOrdinal.lookup(name, pattern, resultTree);
auto [ord1, isVec, ext] =
nameToOrdinal.lookup(name, pattern, resultTree);
assert(!isVec);
assert(!ext.size());
ord = ord1;
}
os << ord << ".getType(), ";
Expand All @@ -485,30 +518,57 @@ bool handle(const Twine &curIndent, const Twine &argPattern, raw_ostream &os,
os << "ConstantFP::get(";
if (resultRoot->getArgName(0)) {
auto name = resultRoot->getArgName(0)->getAsUnquotedString();
auto [ord, isVec] = nameToOrdinal.lookup(name, pattern, resultTree);
auto [ord, isVec, ext] =
nameToOrdinal.lookup(name, pattern, resultTree);
assert(!isVec);
os << ord;
if (ext.size())
os << "gutils->extractMeta(";
os << ord << "->getType()";
if (ext.size()) {
os << ", ArrayRef<unsigned>({";
for (unsigned i = 0; i < ext.size(); i++) {
if (i != 0)
os << ", ";
os << std::to_string(ext[i]);
}
os << "}))";
}
} else
PrintFatalError(pattern->getLoc(),
Twine("unknown named operand in constantfp") +
resultTree->getAsString());
os << "->getType(), \"" << value->getValue() << "\")";
os << ", \"" << value->getValue() << "\")";
}
return false;
} else if (opName == "Zero" || Def->isSubClassOf("Zero")) {
if (resultRoot->getNumArgs() != 1)
PrintFatalError(pattern->getLoc(), "only single op Zero supported");
os << "Constant::getNullValue(";
std::vector<int> exto;
if (resultRoot->getArgName(0)) {
auto name = resultRoot->getArgName(0)->getAsUnquotedString();
auto [ord, isVec] = nameToOrdinal.lookup(name, pattern, resultTree);
auto [ord, isVec, ext] =
nameToOrdinal.lookup(name, pattern, resultTree);
assert(!isVec);
exto = std::move(ext);
if (exto.size())
os << "gutils->extractMeta(";
os << ord;
} else
PrintFatalError(pattern->getLoc(),
Twine("unknown named operand in constantfp") +
resultTree->getAsString());
os << "->getType())";
os << "->getType()";
if (exto.size()) {
os << ", ArrayRef<unsigned>({";
for (unsigned i = 0; i < exto.size(); i++) {
if (i != 0)
os << ", ";
os << std::to_string(exto[i]);
}
os << "}))";
}
os << ")";
return false;
} else if (opName == "ConstantCFP" || Def->isSubClassOf("ConstantCFP")) {
if (resultRoot->getNumArgs() != 1)
Expand All @@ -528,8 +588,10 @@ bool handle(const Twine &curIndent, const Twine &argPattern, raw_ostream &os,
os << curIndent << INDENT << "auto ty = ";
if (resultRoot->getArgName(0)) {
auto name = resultRoot->getArgName(0)->getAsUnquotedString();
auto [ord, isVec] = nameToOrdinal.lookup(name, pattern, resultTree);
auto [ord, isVec, ext] =
nameToOrdinal.lookup(name, pattern, resultTree);
assert(!isVec);
assert(!ext.size());
os << ord;
} else
PrintFatalError(pattern->getLoc(),
Expand Down Expand Up @@ -582,8 +644,10 @@ bool handle(const Twine &curIndent, const Twine &argPattern, raw_ostream &os,

if (resultRoot->getArgName(0)) {
auto name = resultRoot->getArgName(0)->getAsUnquotedString();
auto [ord, isVec] = nameToOrdinal.lookup(name, pattern, resultTree);
auto [ord, isVec, ext] =
nameToOrdinal.lookup(name, pattern, resultTree);
assert(!isVec);
assert(!ext.size());
os << ord;
} else
PrintFatalError(pattern->getLoc(),
Expand Down Expand Up @@ -622,8 +686,10 @@ bool handle(const Twine &curIndent, const Twine &argPattern, raw_ostream &os,
os << "UndefValue::get(";
if (resultRoot->getArgName(0)) {
auto name = resultRoot->getArgName(0)->getAsUnquotedString();
auto [ord, isVec] = nameToOrdinal.lookup(name, pattern, resultTree);
auto [ord, isVec, ext] =
nameToOrdinal.lookup(name, pattern, resultTree);
assert(!isVec);
assert(!ext.size());
os << ord;
} else
PrintFatalError(pattern->getLoc(),
Expand All @@ -641,9 +707,24 @@ bool handle(const Twine &curIndent, const Twine &argPattern, raw_ostream &os,

if (resultRoot->getArgName(0)) {
auto name = resultRoot->getArgName(0)->getAsUnquotedString();
auto [ord, isVec] = nameToOrdinal.lookup(name, pattern, resultTree);
auto [ord, isVec, ext] =
nameToOrdinal.lookup(name, pattern, resultTree);
assert(!isVec);

if (ext.size())
os << "gutils->extractMeta(" << builder << ",";
os << ord;

if (ext.size()) {
os << ", ArrayRef<unsigned>({";
for (unsigned i = 0; i < ext.size(); i++) {
if (i != 0)
os << ", ";
os << std::to_string(ext[i]);
}
os << "}))";
}

} else
PrintFatalError(pattern->getLoc(),
Twine("unknown named operand in shadow") +
Expand Down Expand Up @@ -821,7 +902,7 @@ bool handle(const Twine &curIndent, const Twine &argPattern, raw_ostream &os,
op = ("local_" + ptree->getArgNameStr(i)).str();
}
nnameToOrdinal.insert(ptree->getArgNameStr(i), op,
vectorValued[next[0]]);
vectorValued[next[0]], {});
}
i++;
}
Expand Down Expand Up @@ -1385,18 +1466,14 @@ static VariableSetting parseVariables(DagInit *tree, ActionType intrinsic,
op = (origName + ".getOperand(" + Twine(next[0]) + ")").str();
else
op = (origName + "->getOperand(" + Twine(next[0]) + ")").str();
std::vector<int> extractions;
if (prev.size() > 0) {
op = "gutils->extractMeta(Builder2, " + op +
", ArrayRef<unsigned>({";
bool first = true;
for (unsigned i = 1; i < next.size(); i++) {
if (!first)
op += ", ";
op += std::to_string(next[i]);
extractions.push_back(next[i]);
}
op += "}))";
}
nameToOrdinal.insert(ptree->getArgNameStr(i), op, false);
nameToOrdinal.insert(ptree->getArgNameStr(i), op, false,
extractions);
}
i++;
}
Expand All @@ -1406,7 +1483,7 @@ static VariableSetting parseVariables(DagInit *tree, ActionType intrinsic,

if (tree->getNameStr().size())
nameToOrdinal.insert(tree->getNameStr(),
(Twine("(&") + origName + ")").str(), false);
(Twine("(&") + origName + ")").str(), false, {});
return nameToOrdinal;
}

Expand Down

0 comments on commit 611ca37

Please sign in to comment.