From eddc3092acfcd5e9fa85cdb4d7327fb7ba116804 Mon Sep 17 00:00:00 2001 From: William Moses Date: Mon, 9 Sep 2024 20:02:25 -0400 Subject: [PATCH] Integrate LLVM at llvm/llvm-project@ede40da1f8c1 (#2073) * Integrate LLVM at llvm/llvm-project@ede40da1f8c1 Updates LLVM usage to match [ede40da1f8c1](https://github.com/llvm/llvm-project/commit/ede40da1f8c1) PiperOrigin-RevId: 671944195 * Update BUILD * Update BUILD --------- Co-authored-by: Jorge Gorbe Moya --- enzyme/BUILD | 15 ++++--- enzyme/test/BUILD | 1 - enzyme/tools/enzyme-tblgen/blas-tblgen.cpp | 4 +- enzyme/tools/enzyme-tblgen/blasDeclUpdater.h | 2 +- enzyme/tools/enzyme-tblgen/datastructures.cpp | 2 +- enzyme/tools/enzyme-tblgen/datastructures.h | 4 +- enzyme/tools/enzyme-tblgen/enzyme-tblgen.cpp | 41 +++++++++---------- 7 files changed, 33 insertions(+), 36 deletions(-) diff --git a/enzyme/BUILD b/enzyme/BUILD index 03503c0e01f5..9c70d7f54762 100644 --- a/enzyme/BUILD +++ b/enzyme/BUILD @@ -1,6 +1,6 @@ -load("@bazel_skylib//rules:expand_template.bzl", "expand_template") load("@llvm-project//llvm:tblgen.bzl", "gentbl") load("@llvm-project//mlir:tblgen.bzl", "gentbl_cc_library", "td_library") +load("@bazel_skylib//rules:expand_template.bzl", "expand_template") licenses(["notice"]) @@ -22,13 +22,13 @@ cc_library( cc_binary( name = "enzyme-tblgen", srcs = glob(["tools/enzyme-tblgen/*.cpp"]), + visibility = ["//visibility:public"], deps = [ ":enzyme-tblgen-hdrs", "@llvm-project//llvm:Support", "@llvm-project//llvm:TableGen", "@llvm-project//llvm:config", ], - visibility = ["//visibility:public"], ) gentbl( @@ -144,12 +144,12 @@ cc_library( data = ["@llvm-project//clang:builtin_headers_gen"], visibility = ["//visibility:public"], deps = [ - ":bundled-includes", ":binop-derivatives", ":blas-attributor", ":blas-derivatives", ":blas-diffuseanalysis", ":blas-typeanalysis", + ":bundled-includes", ":call-derivatives", ":inst-derivatives", ":intr-derivatives", @@ -187,6 +187,7 @@ expand_template( substitutions = {"@TOOL_NAME@": "clang"}, template = "@llvm-project//llvm:cmake/modules/llvm-driver-template.cpp.in", ) + cc_binary( name = "enzyme-clang", srcs = ["enzyme-clang-driver.cpp"], @@ -206,10 +207,9 @@ genrule( name = "bundled-includes", srcs = glob(["include/**"]) + ["scripts/bundle-includes.sh"], outs = ["bundled_includes.h"], - cmd = "$(location :scripts/bundle-includes.sh) $(location :include/enzyme/enzyme) $@" + cmd = "$(location :scripts/bundle-includes.sh) $(location :include/enzyme/enzyme) $@", ) - genrule( name = "gen_enzyme-clang++", srcs = [":enzyme-clang"], @@ -567,8 +567,8 @@ cc_library( ":arith-derivatives", ":cf-derivatives", ":complex-derivatives", - ":llvm-derivatives", ":func-derivatives", + ":llvm-derivatives", ":math-derivatives", ":memref-derivatives", ":nvvm-derivatives", @@ -612,8 +612,8 @@ cc_library( "@llvm-project//mlir:ShapeDialect", "@llvm-project//mlir:SideEffectInterfaces", "@llvm-project//mlir:Support", - "@llvm-project//mlir:Transforms", "@llvm-project//mlir:TransformUtils", + "@llvm-project//mlir:Transforms", "@llvm-project//mlir:ViewLikeInterface", ], ) @@ -649,4 +649,3 @@ cc_binary( ) exports_files(["run_lit.sh"]) - diff --git a/enzyme/test/BUILD b/enzyme/test/BUILD index 47143966810d..5b1d2998508d 100644 --- a/enzyme/test/BUILD +++ b/enzyme/test/BUILD @@ -29,4 +29,3 @@ exports_files( ["lit.cfg.py"], visibility = [":__subpackages__"], ) - diff --git a/enzyme/tools/enzyme-tblgen/blas-tblgen.cpp b/enzyme/tools/enzyme-tblgen/blas-tblgen.cpp index 043fe845c94e..ab8f5c5668e1 100644 --- a/enzyme/tools/enzyme-tblgen/blas-tblgen.cpp +++ b/enzyme/tools/enzyme-tblgen/blas-tblgen.cpp @@ -80,7 +80,7 @@ bool hasAdjoint(const TGPattern &pattern, Init *resultTree, StringRef argName) { } static void checkBlasCallsInDag(const RecordKeeper &RK, - ArrayRef blasPatterns, + ArrayRef blasPatterns, StringRef blasName, const DagInit *toSearch) { // For nested FAdd, ... rules which don't directly call a blass fnc @@ -95,7 +95,7 @@ static void checkBlasCallsInDag(const RecordKeeper &RK, /// blas function will use the correct amount of args /// Later we might check for "types" too. static void checkBlasCalls(const RecordKeeper &RK, - ArrayRef blasPatterns) { + ArrayRef blasPatterns) { for (auto &&pattern : blasPatterns) { ListInit *argOps = pattern->getValueAsListInit("ArgDerivatives"); // for each possibly active parameter diff --git a/enzyme/tools/enzyme-tblgen/blasDeclUpdater.h b/enzyme/tools/enzyme-tblgen/blasDeclUpdater.h index ca71afa228b1..ae41b2b95ea7 100644 --- a/enzyme/tools/enzyme-tblgen/blasDeclUpdater.h +++ b/enzyme/tools/enzyme-tblgen/blasDeclUpdater.h @@ -225,7 +225,7 @@ void emitBlasDeclUpdater(const RecordKeeper &RK, raw_ostream &os) { os << " }\n"; { const auto &patterns = RK.getAllDerivedDefinitions("CallPattern"); - for (Record *pattern : patterns) { + for (const Record *pattern : patterns) { DagInit *tree = pattern->getValueAsDag("PatternToMatch"); os << " if (("; bool prev = false; diff --git a/enzyme/tools/enzyme-tblgen/datastructures.cpp b/enzyme/tools/enzyme-tblgen/datastructures.cpp index 1ef33abeea4d..6666e8a89324 100644 --- a/enzyme/tools/enzyme-tblgen/datastructures.cpp +++ b/enzyme/tools/enzyme-tblgen/datastructures.cpp @@ -429,7 +429,7 @@ void fillArgUserMap(ArrayRef rules, ArrayRef nameVec, ArrayRef TGPattern::getLoc() const { return record->getLoc(); } -TGPattern::TGPattern(Record *r) +TGPattern::TGPattern(const Record *r) : record(r), blasName(r->getNameInitAsString()) { fillArgs(r, args, argNameToPos); fillArgTypes(r, argTypes); diff --git a/enzyme/tools/enzyme-tblgen/datastructures.h b/enzyme/tools/enzyme-tblgen/datastructures.h index eb5c73d2d6dc..f1be13fd2fe0 100644 --- a/enzyme/tools/enzyme-tblgen/datastructures.h +++ b/enzyme/tools/enzyme-tblgen/datastructures.h @@ -94,7 +94,7 @@ void fillArgUserMap(ArrayRef rules, ArrayRef nameVec, /// A single Blas function, including replacement rules. E.g. scal, axpy, ... class TGPattern { private: - Record *record; + const Record *record; std::string blasName; bool BLASLevel2or3; @@ -123,7 +123,7 @@ class TGPattern { DenseMap> relatedLengths; public: - TGPattern(Record *r); + TGPattern(const Record *r); SmallVector getRelatedLengthArgs(size_t arg, bool hideuplo = false) const; bool isBLASLevel2or3() const; diff --git a/enzyme/tools/enzyme-tblgen/enzyme-tblgen.cpp b/enzyme/tools/enzyme-tblgen/enzyme-tblgen.cpp index 6bed9401cf7c..2e5ee993f988 100644 --- a/enzyme/tools/enzyme-tblgen/enzyme-tblgen.cpp +++ b/enzyme/tools/enzyme-tblgen/enzyme-tblgen.cpp @@ -196,7 +196,7 @@ struct VariableSetting { StringMap> extractions; std::tuple> - lookup(StringRef name, Record *pattern, Init *resultRoot) { + lookup(StringRef name, const Record *pattern, Init *resultRoot) { auto ord = nameToOrdinal.find(name); if (ord == nameToOrdinal.end()) PrintFatalError(pattern->getLoc(), Twine("unknown named operand '") + @@ -219,13 +219,13 @@ struct VariableSetting { #define INDENT " " bool handle(const Twine &curIndent, const Twine &argPattern, raw_ostream &os, - Record *pattern, Init *resultTree, StringRef builder, + const Record *pattern, Init *resultTree, StringRef builder, VariableSetting &nameToOrdinal, bool lookup, ArrayRef retidx, StringRef origName, bool newFromOriginal, ActionType intrinsic); SmallVector prepareArgs(const Twine &curIndent, raw_ostream &os, - const Twine &argName, Record *pattern, + const Twine &argName, const Record *pattern, DagInit *resultRoot, StringRef builder, VariableSetting &nameToOrdinal, bool lookup, ArrayRef retidx, StringRef origName, @@ -300,7 +300,7 @@ SmallVector prepareArgs(const Twine &curIndent, raw_ostream &os, // Returns whether value generated is a vector value or not. bool handle(const Twine &curIndent, const Twine &argPattern, raw_ostream &os, - Record *pattern, Init *resultTree, StringRef builder, + const Record *pattern, Init *resultTree, StringRef builder, VariableSetting &nameToOrdinal, bool lookup, ArrayRef retidx, StringRef origName, bool newFromOriginal, ActionType intrinsic) { @@ -1405,10 +1405,9 @@ void printDiffUse( } } -static void emitMLIRReverse(raw_ostream &os, Record *pattern, DagInit *tree, - ActionType intrinsic, StringRef origName, - ListInit *argOps) { - +static void emitMLIRReverse(raw_ostream &os, const Record *pattern, + DagInit *tree, ActionType intrinsic, + StringRef origName, ListInit *argOps) { auto opName = pattern->getValueAsString("opName"); auto dialect = pattern->getValueAsString("dialect"); os << "struct " << opName << "RevDerivative : \n"; @@ -1522,9 +1521,9 @@ static VariableSetting parseVariables(DagInit *tree, ActionType intrinsic, return nameToOrdinal; } -static void emitReverseCommon(raw_ostream &os, Record *pattern, DagInit *tree, - ActionType intrinsic, StringRef origName, - ListInit *argOps) { +static void emitReverseCommon(raw_ostream &os, const Record *pattern, + DagInit *tree, ActionType intrinsic, + StringRef origName, ListInit *argOps) { auto nameToOrdinal = parseVariables(tree, intrinsic, origName); bool seen = false; @@ -1746,7 +1745,7 @@ static void emitDerivatives(const RecordKeeper &recordKeeper, raw_ostream &os, } const auto &patterns = recordKeeper.getAllDerivedDefinitions(patternNames); - for (Record *pattern : patterns) { + for (const Record *pattern : patterns) { DagInit *tree = pattern->getValueAsDag("PatternToMatch"); DagInit *duals = pattern->getValueAsDag("ArgDuals"); @@ -2497,7 +2496,7 @@ static void emitDerivatives(const RecordKeeper &recordKeeper, raw_ostream &os, recordKeeper.getAllDerivedDefinitions("AllocationOp"); os << "void registerInterfaces(MLIRContext* context) {\n"; - for (Record *pattern : patterns) { + for (const Record *pattern : patterns) { auto opName = pattern->getValueAsString("opName"); auto dialect = pattern->getValueAsString("dialect"); os << " " << dialect << "::" << opName << "::attachInterface<" << opName @@ -2505,13 +2504,13 @@ static void emitDerivatives(const RecordKeeper &recordKeeper, raw_ostream &os, os << " " << dialect << "::" << opName << "::attachInterface<" << opName << "RevDerivative>(*context);\n"; } - for (Record *pattern : actpatterns) { + for (const Record *pattern : actpatterns) { auto opName = pattern->getValueAsString("opName"); auto dialect = pattern->getValueAsString("dialect"); os << " " << dialect << "::" << opName << "::attachInterface<" << opName << "Activity>(*context);\n"; } - for (Record *pattern : cfpatterns) { + for (const Record *pattern : cfpatterns) { auto opName = pattern->getValueAsString("opName"); auto dialect = pattern->getValueAsString("dialect"); os << " " << dialect << "::" << opName << "::attachInterface<" << opName @@ -2519,7 +2518,7 @@ static void emitDerivatives(const RecordKeeper &recordKeeper, raw_ostream &os, os << " registerAutoDiffUsingControlFlowInterface<" << dialect << "::" << opName << ">(*context);\n"; } - for (Record *pattern : mempatterns) { + for (const Record *pattern : mempatterns) { auto opName = pattern->getValueAsString("opName"); auto dialect = pattern->getValueAsString("dialect"); os << " " << dialect << "::" << opName << "::attachInterface<" << opName @@ -2535,25 +2534,25 @@ static void emitDerivatives(const RecordKeeper &recordKeeper, raw_ostream &os, << opName << "RevDerivative>(*context);\n"; } } - for (Record *pattern : brpatterns) { + for (const Record *pattern : brpatterns) { auto opName = pattern->getValueAsString("opName"); auto dialect = pattern->getValueAsString("dialect"); os << " registerAutoDiffUsingBranchInterface<" << dialect << "::" << opName << ">(*context);\n"; } - for (Record *pattern : regtpatterns) { + for (const Record *pattern : regtpatterns) { auto opName = pattern->getValueAsString("opName"); auto dialect = pattern->getValueAsString("dialect"); os << " registerAutoDiffUsingRegionTerminatorInterface<" << dialect << "::" << opName << ">(*context);\n"; } - for (Record *pattern : retpatterns) { + for (const Record *pattern : retpatterns) { auto opName = pattern->getValueAsString("opName"); auto dialect = pattern->getValueAsString("dialect"); os << " registerAutoDiffUsingReturnInterface<" << dialect << "::" << opName << ">(*context);\n"; } - for (Record *pattern : allocpatterns) { + for (const Record *pattern : allocpatterns) { auto opName = pattern->getValueAsString("opName"); auto dialect = pattern->getValueAsString("dialect"); os << " registerAutoDiffUsingAllocationInterface<" << dialect @@ -2589,7 +2588,7 @@ void emitDiffUse(const RecordKeeper &recordKeeper, raw_ostream &os, } const auto &patterns = recordKeeper.getAllDerivedDefinitions(patternNames); - for (Record *pattern : patterns) { + for (const Record *pattern : patterns) { DagInit *tree = pattern->getValueAsDag("PatternToMatch"); // Emit RewritePattern for Pattern.