Skip to content

Commit

Permalink
Browse files Browse the repository at this point in the history
* Integrate LLVM at llvm/llvm-project@ede40da1f8c1

Updates LLVM usage to match
[ede40da1f8c1](llvm/llvm-project@ede40da1f8c1)

PiperOrigin-RevId: 671944195

* Update BUILD

* Update BUILD

---------

Co-authored-by: Jorge Gorbe Moya <jgorbe@google.com>
  • Loading branch information
wsmoses and slackito authored Sep 10, 2024
1 parent 4b451bd commit eddc309
Show file tree
Hide file tree
Showing 7 changed files with 33 additions and 36 deletions.
15 changes: 7 additions & 8 deletions enzyme/BUILD
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
load("@bazel_skylib//rules:expand_template.bzl", "expand_template")
load("@llvm-project//llvm:tblgen.bzl", "gentbl")
load("@llvm-project//mlir:tblgen.bzl", "gentbl_cc_library", "td_library")
load("@bazel_skylib//rules:expand_template.bzl", "expand_template")

licenses(["notice"])

Expand All @@ -22,13 +22,13 @@ cc_library(
cc_binary(
name = "enzyme-tblgen",
srcs = glob(["tools/enzyme-tblgen/*.cpp"]),
visibility = ["//visibility:public"],
deps = [
":enzyme-tblgen-hdrs",
"@llvm-project//llvm:Support",
"@llvm-project//llvm:TableGen",
"@llvm-project//llvm:config",
],
visibility = ["//visibility:public"],
)

gentbl(
Expand Down Expand Up @@ -144,12 +144,12 @@ cc_library(
data = ["@llvm-project//clang:builtin_headers_gen"],
visibility = ["//visibility:public"],
deps = [
":bundled-includes",
":binop-derivatives",
":blas-attributor",
":blas-derivatives",
":blas-diffuseanalysis",
":blas-typeanalysis",
":bundled-includes",
":call-derivatives",
":inst-derivatives",
":intr-derivatives",
Expand Down Expand Up @@ -187,6 +187,7 @@ expand_template(
substitutions = {"@TOOL_NAME@": "clang"},
template = "@llvm-project//llvm:cmake/modules/llvm-driver-template.cpp.in",
)

cc_binary(
name = "enzyme-clang",
srcs = ["enzyme-clang-driver.cpp"],
Expand All @@ -206,10 +207,9 @@ genrule(
name = "bundled-includes",
srcs = glob(["include/**"]) + ["scripts/bundle-includes.sh"],
outs = ["bundled_includes.h"],
cmd = "$(location :scripts/bundle-includes.sh) $(location :include/enzyme/enzyme) $@"
cmd = "$(location :scripts/bundle-includes.sh) $(location :include/enzyme/enzyme) $@",
)


genrule(
name = "gen_enzyme-clang++",
srcs = [":enzyme-clang"],
Expand Down Expand Up @@ -567,8 +567,8 @@ cc_library(
":arith-derivatives",
":cf-derivatives",
":complex-derivatives",
":llvm-derivatives",
":func-derivatives",
":llvm-derivatives",
":math-derivatives",
":memref-derivatives",
":nvvm-derivatives",
Expand Down Expand Up @@ -612,8 +612,8 @@ cc_library(
"@llvm-project//mlir:ShapeDialect",
"@llvm-project//mlir:SideEffectInterfaces",
"@llvm-project//mlir:Support",
"@llvm-project//mlir:Transforms",
"@llvm-project//mlir:TransformUtils",
"@llvm-project//mlir:Transforms",
"@llvm-project//mlir:ViewLikeInterface",
],
)
Expand Down Expand Up @@ -649,4 +649,3 @@ cc_binary(
)

exports_files(["run_lit.sh"])

1 change: 0 additions & 1 deletion enzyme/test/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -29,4 +29,3 @@ exports_files(
["lit.cfg.py"],
visibility = [":__subpackages__"],
)

4 changes: 2 additions & 2 deletions enzyme/tools/enzyme-tblgen/blas-tblgen.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ bool hasAdjoint(const TGPattern &pattern, Init *resultTree, StringRef argName) {
}

static void checkBlasCallsInDag(const RecordKeeper &RK,
ArrayRef<Record *> blasPatterns,
ArrayRef<const Record *> blasPatterns,
StringRef blasName, const DagInit *toSearch) {

// For nested FAdd, ... rules which don't directly call a blass fnc
Expand All @@ -95,7 +95,7 @@ static void checkBlasCallsInDag(const RecordKeeper &RK,
/// blas function will use the correct amount of args
/// Later we might check for "types" too.
static void checkBlasCalls(const RecordKeeper &RK,
ArrayRef<Record *> blasPatterns) {
ArrayRef<const Record *> blasPatterns) {
for (auto &&pattern : blasPatterns) {
ListInit *argOps = pattern->getValueAsListInit("ArgDerivatives");
// for each possibly active parameter
Expand Down
2 changes: 1 addition & 1 deletion enzyme/tools/enzyme-tblgen/blasDeclUpdater.h
Original file line number Diff line number Diff line change
Expand Up @@ -225,7 +225,7 @@ void emitBlasDeclUpdater(const RecordKeeper &RK, raw_ostream &os) {
os << " }\n";
{
const auto &patterns = RK.getAllDerivedDefinitions("CallPattern");
for (Record *pattern : patterns) {
for (const Record *pattern : patterns) {
DagInit *tree = pattern->getValueAsDag("PatternToMatch");
os << " if ((";
bool prev = false;
Expand Down
2 changes: 1 addition & 1 deletion enzyme/tools/enzyme-tblgen/datastructures.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -429,7 +429,7 @@ void fillArgUserMap(ArrayRef<Rule> rules, ArrayRef<std::string> nameVec,

ArrayRef<SMLoc> TGPattern::getLoc() const { return record->getLoc(); }

TGPattern::TGPattern(Record *r)
TGPattern::TGPattern(const Record *r)
: record(r), blasName(r->getNameInitAsString()) {
fillArgs(r, args, argNameToPos);
fillArgTypes(r, argTypes);
Expand Down
4 changes: 2 additions & 2 deletions enzyme/tools/enzyme-tblgen/datastructures.h
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ void fillArgUserMap(ArrayRef<Rule> rules, ArrayRef<std::string> nameVec,
/// A single Blas function, including replacement rules. E.g. scal, axpy, ...
class TGPattern {
private:
Record *record;
const Record *record;
std::string blasName;
bool BLASLevel2or3;

Expand Down Expand Up @@ -123,7 +123,7 @@ class TGPattern {
DenseMap<size_t, SmallVector<size_t, 3>> relatedLengths;

public:
TGPattern(Record *r);
TGPattern(const Record *r);
SmallVector<size_t, 3> getRelatedLengthArgs(size_t arg,
bool hideuplo = false) const;
bool isBLASLevel2or3() const;
Expand Down
41 changes: 20 additions & 21 deletions enzyme/tools/enzyme-tblgen/enzyme-tblgen.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -196,7 +196,7 @@ struct VariableSetting {
StringMap<std::vector<int>> extractions;

std::tuple<std::string, bool, std::vector<int>>
lookup(StringRef name, Record *pattern, Init *resultRoot) {
lookup(StringRef name, const Record *pattern, Init *resultRoot) {
auto ord = nameToOrdinal.find(name);
if (ord == nameToOrdinal.end())
PrintFatalError(pattern->getLoc(), Twine("unknown named operand '") +
Expand All @@ -219,13 +219,13 @@ struct VariableSetting {

#define INDENT " "
bool handle(const Twine &curIndent, const Twine &argPattern, raw_ostream &os,
Record *pattern, Init *resultTree, StringRef builder,
const Record *pattern, Init *resultTree, StringRef builder,
VariableSetting &nameToOrdinal, bool lookup,
ArrayRef<unsigned> retidx, StringRef origName, bool newFromOriginal,
ActionType intrinsic);

SmallVector<bool, 1> prepareArgs(const Twine &curIndent, raw_ostream &os,
const Twine &argName, Record *pattern,
const Twine &argName, const Record *pattern,
DagInit *resultRoot, StringRef builder,
VariableSetting &nameToOrdinal, bool lookup,
ArrayRef<unsigned> retidx, StringRef origName,
Expand Down Expand Up @@ -300,7 +300,7 @@ SmallVector<bool, 1> prepareArgs(const Twine &curIndent, raw_ostream &os,

// Returns whether value generated is a vector value or not.
bool handle(const Twine &curIndent, const Twine &argPattern, raw_ostream &os,
Record *pattern, Init *resultTree, StringRef builder,
const Record *pattern, Init *resultTree, StringRef builder,
VariableSetting &nameToOrdinal, bool lookup,
ArrayRef<unsigned> retidx, StringRef origName, bool newFromOriginal,
ActionType intrinsic) {
Expand Down Expand Up @@ -1405,10 +1405,9 @@ void printDiffUse(
}
}

static void emitMLIRReverse(raw_ostream &os, Record *pattern, DagInit *tree,
ActionType intrinsic, StringRef origName,
ListInit *argOps) {

static void emitMLIRReverse(raw_ostream &os, const Record *pattern,
DagInit *tree, ActionType intrinsic,
StringRef origName, ListInit *argOps) {
auto opName = pattern->getValueAsString("opName");
auto dialect = pattern->getValueAsString("dialect");
os << "struct " << opName << "RevDerivative : \n";
Expand Down Expand Up @@ -1522,9 +1521,9 @@ static VariableSetting parseVariables(DagInit *tree, ActionType intrinsic,
return nameToOrdinal;
}

static void emitReverseCommon(raw_ostream &os, Record *pattern, DagInit *tree,
ActionType intrinsic, StringRef origName,
ListInit *argOps) {
static void emitReverseCommon(raw_ostream &os, const Record *pattern,
DagInit *tree, ActionType intrinsic,
StringRef origName, ListInit *argOps) {
auto nameToOrdinal = parseVariables(tree, intrinsic, origName);

bool seen = false;
Expand Down Expand Up @@ -1746,7 +1745,7 @@ static void emitDerivatives(const RecordKeeper &recordKeeper, raw_ostream &os,
}
const auto &patterns = recordKeeper.getAllDerivedDefinitions(patternNames);

for (Record *pattern : patterns) {
for (const Record *pattern : patterns) {
DagInit *tree = pattern->getValueAsDag("PatternToMatch");

DagInit *duals = pattern->getValueAsDag("ArgDuals");
Expand Down Expand Up @@ -2497,29 +2496,29 @@ static void emitDerivatives(const RecordKeeper &recordKeeper, raw_ostream &os,
recordKeeper.getAllDerivedDefinitions("AllocationOp");

os << "void registerInterfaces(MLIRContext* context) {\n";
for (Record *pattern : patterns) {
for (const Record *pattern : patterns) {
auto opName = pattern->getValueAsString("opName");
auto dialect = pattern->getValueAsString("dialect");
os << " " << dialect << "::" << opName << "::attachInterface<" << opName
<< "FwdDerivative>(*context);\n";
os << " " << dialect << "::" << opName << "::attachInterface<" << opName
<< "RevDerivative>(*context);\n";
}
for (Record *pattern : actpatterns) {
for (const Record *pattern : actpatterns) {
auto opName = pattern->getValueAsString("opName");
auto dialect = pattern->getValueAsString("dialect");
os << " " << dialect << "::" << opName << "::attachInterface<" << opName
<< "Activity>(*context);\n";
}
for (Record *pattern : cfpatterns) {
for (const Record *pattern : cfpatterns) {
auto opName = pattern->getValueAsString("opName");
auto dialect = pattern->getValueAsString("dialect");
os << " " << dialect << "::" << opName << "::attachInterface<" << opName
<< "CF>(*context);\n";
os << " registerAutoDiffUsingControlFlowInterface<" << dialect
<< "::" << opName << ">(*context);\n";
}
for (Record *pattern : mempatterns) {
for (const Record *pattern : mempatterns) {
auto opName = pattern->getValueAsString("opName");
auto dialect = pattern->getValueAsString("dialect");
os << " " << dialect << "::" << opName << "::attachInterface<" << opName
Expand All @@ -2535,25 +2534,25 @@ static void emitDerivatives(const RecordKeeper &recordKeeper, raw_ostream &os,
<< opName << "RevDerivative>(*context);\n";
}
}
for (Record *pattern : brpatterns) {
for (const Record *pattern : brpatterns) {
auto opName = pattern->getValueAsString("opName");
auto dialect = pattern->getValueAsString("dialect");
os << " registerAutoDiffUsingBranchInterface<" << dialect
<< "::" << opName << ">(*context);\n";
}
for (Record *pattern : regtpatterns) {
for (const Record *pattern : regtpatterns) {
auto opName = pattern->getValueAsString("opName");
auto dialect = pattern->getValueAsString("dialect");
os << " registerAutoDiffUsingRegionTerminatorInterface<" << dialect
<< "::" << opName << ">(*context);\n";
}
for (Record *pattern : retpatterns) {
for (const Record *pattern : retpatterns) {
auto opName = pattern->getValueAsString("opName");
auto dialect = pattern->getValueAsString("dialect");
os << " registerAutoDiffUsingReturnInterface<" << dialect
<< "::" << opName << ">(*context);\n";
}
for (Record *pattern : allocpatterns) {
for (const Record *pattern : allocpatterns) {
auto opName = pattern->getValueAsString("opName");
auto dialect = pattern->getValueAsString("dialect");
os << " registerAutoDiffUsingAllocationInterface<" << dialect
Expand Down Expand Up @@ -2589,7 +2588,7 @@ void emitDiffUse(const RecordKeeper &recordKeeper, raw_ostream &os,
}
const auto &patterns = recordKeeper.getAllDerivedDefinitions(patternNames);

for (Record *pattern : patterns) {
for (const Record *pattern : patterns) {
DagInit *tree = pattern->getValueAsDag("PatternToMatch");

// Emit RewritePattern for Pattern.
Expand Down

0 comments on commit eddc309

Please sign in to comment.