From baced911d005f0c55def5f5cff55fcd587732a03 Mon Sep 17 00:00:00 2001 From: jumerckx <31353884+jumerckx@users.noreply.github.com> Date: Fri, 5 Jan 2024 17:26:06 +0100 Subject: [PATCH] tblgen (#19) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: mofeing Co-authored-by: jumerckx <31353884+jumerckx@users.noreply.github.com> Co-authored-by: Sergio Sánchez Ramírez <15837247+mofeing@users.noreply.github.com> --- deps/Project.toml | 5 + deps/build.jl | 86 ++++++++ deps/tblgen/LICENSE | 202 +++++++++++++++++ deps/tblgen/jl-generators.cc | 399 ++++++++++++++++++++++++++++++++++ deps/tblgen/mlir-jl-tblgen.cc | 64 ++++++ src/Dialects.jl | 198 +++-------------- src/MLIR.jl | 6 +- src/dialects/.gitignore | 1 + 8 files changed, 796 insertions(+), 165 deletions(-) create mode 100644 deps/Project.toml create mode 100644 deps/build.jl create mode 100644 deps/tblgen/LICENSE create mode 100644 deps/tblgen/jl-generators.cc create mode 100644 deps/tblgen/mlir-jl-tblgen.cc create mode 100644 src/dialects/.gitignore diff --git a/deps/Project.toml b/deps/Project.toml new file mode 100644 index 00000000..02907581 --- /dev/null +++ b/deps/Project.toml @@ -0,0 +1,5 @@ +[deps] +LLVM_full_jll = "a3ccf953-465e-511d-b87f-60a6490c289d" + +[compat] +LLVM_full_jll = "15" diff --git a/deps/build.jl b/deps/build.jl new file mode 100644 index 00000000..6c0f742c --- /dev/null +++ b/deps/build.jl @@ -0,0 +1,86 @@ +using LLVM_full_jll + +println("Environment") +println("- llvm-config = $(LLVM_full_jll.get_llvm_config_path())") +println("- clang = $(LLVM_full_jll.get_clang_path())") + +CXXFLAGS = `$(llvm_config()) --cxxflags` |> readchomp |> split +LDFLAGS = `$(llvm_config()) --ldflags` |> readchomp |> split +println("- CXXFLAGS = $CXXFLAGS") +println("- LDFLAGS = $LDFLAGS") + +INCLUDE_PATH = joinpath(LLVM_full_jll.artifact_dir, "include") +DIALECTS_PATH = joinpath(INCLUDE_PATH, "mlir", "Dialect") +println("- INCLUDE_PATH = $INCLUDE_PATH") +println("- DIALECTS_PATH = $DIALECTS_PATH") + +# compile TableGen generator +println("Compiling TableGen generator...") +files = [joinpath(@__DIR__, "tblgen", "mlir-jl-tblgen.cc"), joinpath(@__DIR__, "tblgen", "jl-generators.cc")] +output = ["-o", "mlir-jl-tblgen"] +libs = ["-lLLVM", "-lMLIR", "-lMLIRTableGen", "-lLLVMTableGen"] + +extra = ["-rpath", joinpath(LLVM_full_jll.artifact_dir, "lib")] +if Base.Sys.isapple() + isysroot = strip(read(`xcrun --show-sdk-path`, String)) + append!(extra, [ + "-isysroot", + isysroot, + "-lc++", + ]) +elseif Base.Sys.islinux() + append!(extra, [ + "-lstdc++", + ]) +end +println("- extra flags = $extra") + +run(`$(clang()) $files $CXXFLAGS $LDFLAGS $extra $libs $output`) + +# generate bindings +println("Generating bindings...") + +target_dialects = [ + ("Builtin.jl", "../IR/BuiltinOps.td"), + ("AMDGPU.jl", "AMDGPU/AMDGPU.td"), + ("AMX.jl", "AMX/AMX.td"), + ("Affine.jl", "Affine/IR/AffineOps.td"), + ("Arithmetic.jl", "Arithmetic/IR/ArithmeticOps.td"), + # ("ArmNeon.jl", "ArmNeon/ArmNeon.td"), + ("ArmSVE.jl", "ArmSVE/ArmSVE.td"), + ("Async.jl", "Async/IR/AsyncOps.td"), + ("Bufferization.jl", "Bufferization/IR/BufferizationOps.td"), + ("Complex.jl", "Complex/IR/ComplexOps.td"), + ("ControlFlow.jl", "ControlFlow/IR/ControlFlowOps.td"), + # ("DLTI.jl", "DLTI/DLTI.td"), + ("EmitC.jl", "EmitC/IR/EmitC.td"), + ("Func.jl", "Func/IR/FuncOps.td"), + # ("GPU.jl", "GPU/IR/GPUOps.td"), + ("Linalg.jl", "Linalg/IR/LinalgOps.td"), + # ("LinalgStructured.jl", "Linalg/IR/LinalgStructuredOps.td"), + ("LLVMIR.jl", "LLVMIR/LLVMOps.td"), + # ("MLProgram.jl", "MLProgram/IR/MLProgramOps.td"), + ("Math.jl", "Math/IR/MathOps.td"), + ("MemRef.jl", "MemRef/IR/MemRefOps.td"), + ("NVGPU.jl", "NVGPU/IR/NVGPU.td"), + # ("OpenACC.jl", "OpenACC/OpenACCOps.td"), + # ("OpenMP.jl", "OpenMP/OpenMPOps.td"), + # ("PDL.jl", "PDL/IR/PDLOps.td"), + # ("PDLInterp.jl", "PDLInterp/IR/PDLInterpOps.td"), + ("Quant.jl", "Quant/QuantOps.td"), + # ("SCF.jl", "SCF/IR/SCFOps.td"), + # ("SPIRV.jl", "SPIRV/IR/SPIRVOps.td"), + ("Shape.jl", "Shape/IR/ShapeOps.td"), + ("SparseTensor.jl", "SparseTensor/IR/SparseTensorOps.td"), + ("Tensor.jl", "Tensor/IR/TensorOps.td"), + # ("Tosa.jl", "Tosa/IR/TosaOps.td"), + ("Transform.jl", "Transform/IR/TransformOps.td"), + ("Vector.jl", "Vector/IR/VectorOps.td"), + # ("X86Vector.jl", "X86Vector/X86Vector.td"), +] + +for (file, path) in target_dialects + output = joinpath(@__DIR__, "..", "src", "dialects", file) + run(`./mlir-jl-tblgen --generator=jl-op-defs $(joinpath(DIALECTS_PATH, path)) -I$INCLUDE_PATH -o $output`) + println("- Generated \"$output\" from \"$path\"") +end diff --git a/deps/tblgen/LICENSE b/deps/tblgen/LICENSE new file mode 100644 index 00000000..d6456956 --- /dev/null +++ b/deps/tblgen/LICENSE @@ -0,0 +1,202 @@ + + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "[]" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright [yyyy] [name of copyright owner] + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. diff --git a/deps/tblgen/jl-generators.cc b/deps/tblgen/jl-generators.cc new file mode 100644 index 00000000..54b3e422 --- /dev/null +++ b/deps/tblgen/jl-generators.cc @@ -0,0 +1,399 @@ +// Copyright 2021 Google LLC +// Copyright 2023 Valentin Churavy +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include +#include + +#include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/Sequence.h" +#include "llvm/ADT/StringExtras.h" +#include "llvm/ADT/StringMap.h" +#include "llvm/ADT/StringRef.h" +#include "llvm/ADT/StringSet.h" +#include "llvm/ADT/iterator_range.h" +#include "llvm/Support/CommandLine.h" +#include "llvm/Support/FormatAdapters.h" +#include "llvm/Support/FormatCommon.h" +#include "llvm/Support/FormatVariadic.h" +#include "llvm/Support/Path.h" +#include "llvm/Support/Signals.h" +#include "llvm/Support/raw_ostream.h" +#include "llvm/TableGen/Error.h" +#include "llvm/TableGen/Record.h" +#include "llvm/TableGen/TableGenBackend.h" +#include "mlir/TableGen/Argument.h" +#include "mlir/TableGen/Class.h" +#include "mlir/TableGen/CodeGenHelpers.h" +#include "mlir/TableGen/Format.h" +#include "mlir/TableGen/Interfaces.h" +#include "mlir/TableGen/Operator.h" +#include "mlir/TableGen/Region.h" +#include "mlir/TableGen/SideEffects.h" +#include "mlir/TableGen/Trait.h" + +namespace +{ + + llvm::cl::opt ExplainMissing( + "explain-missing", + llvm::cl::desc("Print the reason for skipping operations from output")); + llvm::cl::opt DialectName( + "dialect-name", llvm::cl::desc("Override the inferred dialect name, used as the name for the generated Julia module."), + llvm::cl::value_desc("dialect")); + + using namespace mlir; + using namespace mlir::tblgen; + + /// Returns true if the SameArgumentAndResultTypes trait can be used to infer + /// result types of the given operation. + static bool hasSameArgumentAndResultTypes(const Operator &op) + { + return op.getTrait("::mlir::OpTrait::SameOperandsAndResultType") && + op.getNumVariableLengthResults() == 0; + } + + /// Returns true if the FirstAttrDerivedResultType trait can be used to infer + /// result types of the given operation. + static bool hasFirstAttrDerivedResultTypes(const Operator &op) + { + return op.getTrait("::mlir::OpTrait::FirstAttrDerivedResultType") && + op.getNumVariableLengthResults() == 0; + } + + /// Returns true if the InferTypeOpInterface can be used to infer result types + /// of the given operation. + static bool hasInferTypeInterface(const Operator &op) + { + return op.getTrait("::mlir::InferTypeOpInterface::Trait") && + op.getNumRegions() == 0; + } + + /// Returns true if there is a trait or interface that can be used to infer + /// result types of the given operation. + static bool canInferType(const Operator &op) + { + return hasSameArgumentAndResultTypes(op) || + hasFirstAttrDerivedResultTypes(op) || hasInferTypeInterface(op); + } + + std::string formatDescription(mlir::tblgen::Operator op) + { + std::string description; + description = op.getDescription().str(); + size_t pos = 0; + while (description[pos] == '\n') + ++pos; + size_t leading_spaces = 0; + while (description[pos++] == ' ') + ++leading_spaces; + if (leading_spaces) + { + std::string leading_spaces_str; + for (size_t i = 0; i < leading_spaces; ++i) + leading_spaces_str += "[ ]"; + description = std::regex_replace(description, std::regex("\n" + leading_spaces_str), "\n"); + } + description = std::regex_replace(description, std::regex("(['\"$])"), "\\$1"); + description = std::regex_replace(description, std::regex("(^|\n)(Example|Syntax):"), "$1# $2"); + + // remove trailing whitespaces and newlines + while (std::isspace(description.back())) { + description.pop_back(); + } + return description; + } + + std::string getDialectName(llvm::ArrayRef op_defs) { + mlir::tblgen::Operator any_op(op_defs.front()); + assert( + std::all_of(op_defs.begin(), op_defs.end(), [&any_op](llvm::Record* op) { + return mlir::tblgen::Operator(op).getDialectName() == + any_op.getDialectName(); + })); + std::string dialect_name; + if (DialectName.empty()) { + dialect_name = any_op.getDialectName().str(); + } else { + dialect_name = DialectName; + } + return dialect_name; + } + +} // namespace + +bool emitOpTableDefs(const llvm::RecordKeeper &recordKeeper, + llvm::raw_ostream &os) +{ + std::vector opdefs = recordKeeper.getAllDerivedDefinitionsIfDefined("Op"); + + const char *moduleTemplate = R"(module {0} + +import ...IR: NamedAttribute, MLIRType, Value, Location, Block, Region, Attribute, create_operation, context, IndexType +import ..Dialects: namedattribute, operandsegmentsizes +import ...API + +{1} +end # {0} +)"; + + const char *functiontemplate = R"( +{3} +function {0}({1}location=Location()) + {2} +end +)"; // 0: functionname, 1: functionarguments, 2: functionbody + const char *functionbodytemplate = R"(results = MLIRType[{0}] + operands = Value[{1}] + owned_regions = Region[{2}] + successors = Block[{3}] + attributes = NamedAttribute[{4}] + {5} + create_operation( + "{6}", location; + operands, owned_regions, successors, attributes, + results={7}, + result_inference={8} + ))"; // 0: results, 1: operands, 2: owned_regions, 3: successors, 4: attributes, 5: optionals, 6: opname, 7: results expression, 8: result_inference + + std::string modulecontents = ""; + + std::string modulename; + if (!DialectName.empty()) + { + modulename = DialectName; + } else { + modulename = getDialectName(opdefs); + } + + for (const auto *def : opdefs) + { + mlir::tblgen::Operator op(*def); + + std::string operandarguments = ""; + std::string operandcontainer = ""; + std::string optionals = ""; + + auto opname = op.getOperationName(); + auto functionname = opname.substr(op.getDialectName().str().length() + 1); // get rid of "dialect." prefix. + // check if functionname colides with Julia keywords (for, return, if, continue, break, ...): + std::vector reservedKeywords = {modulename, "for", "return", "if", "continue", "break", "module", "global"}; + if (std::find(reservedKeywords.begin(), reservedKeywords.end(), functionname) != reservedKeywords.end()) + { + functionname = functionname + "_"; + } + auto i = functionname.find('.'); + if (i!=std::string::npos) + { + functionname.replace(i, 1, "_"); // replace . with _ + } + + std::string description = ""; + if (op.hasDescription()) + { + description = "\"\"\"\n`"+functionname+"`\n"+formatDescription(op)+"\n\"\"\""; + } + bool inferrable = canInferType(op); + + for (size_t i = 0; i < op.getNumOperands(); i++) + { + const auto &named_operand = op.getOperand(i); + std::string defaultvalue = ""; + std::string operandname = named_operand.name.str(); + if (operandname.empty()) + { + operandname = "operand_" + std::to_string(i); + } + + // auto type = named_operand.constraint.getPredicate().getCondition(); + std::string type = "Value"; + + bool optional = named_operand.isOptional(); + bool variadic = named_operand.isVariadic(); + + if (variadic) + { + type = "Vector{" + type + "}"; + } + + if (optional) + { + optionals += llvm::formatv(R"(({0} != nothing) && push!(operands, {0}{1}) + )", + operandname, (variadic ? "..." : "")); + type = "Union{Nothing, " + type + "}"; + defaultvalue = "=nothing"; + } + else + { + operandcontainer += operandname + (variadic ? "..." : "") + ", "; + } + operandarguments += operandname + defaultvalue + "::" + type + (i == op.getNumOperands() - 1 ? "" : ", "); + } + + if (op.getTrait("::mlir::OpTrait::AttrSizedOperandSegments")) + { + std::string operandsegmentsizes = ""; + for (size_t i = 0; i < op.getNumOperands(); i++) + { + const auto &named_operand = op.getOperand(i); + std::string operandname = named_operand.name.str(); + if (operandname.empty()) + { + operandname = "operand_" + std::to_string(i); + } + if (named_operand.isOptional()) + { + operandsegmentsizes += "(" + operandname + "==nothing) ? 0 : 1"; + continue; + } + operandsegmentsizes += named_operand.isVariadic() ? "length(" + operandname + "), " : "1, "; + } + optionals += llvm::formatv(R"(push!(attributes, operandsegmentsizes([{0}])) + )", + operandsegmentsizes); + } + + std::string resultarguments = ""; + std::string resultcontainer = ""; + for (size_t i = 0; i < op.getNumResults(); i++) + { + const auto &named_result = op.getResult(i); + std::string defaultvalue = ""; + std::string resultname = named_result.name.str(); + if (resultname.empty()) + { + resultname = "result_" + std::to_string(i); + } + std::string type = "MLIRType"; + + bool optional = named_result.isOptional() || inferrable; + bool variadic = named_result.isVariadic(); + + if (variadic) + { + type = "Vector{" + type + "}"; + } + + if (optional) + { + optionals += llvm::formatv(R"(({0} != nothing) && push!(results, {0}{1}) + )", + resultname, (variadic ? "..." : "")); + type = "Union{Nothing, " + type + "}"; + defaultvalue = "=nothing"; + } + else + { + resultcontainer += resultname + (variadic ? "..." : "") + ", "; + } + resultarguments += resultname + defaultvalue + "::" + type + ", "; + } + + std::string resultsexpression = (inferrable ? "(length(results) == 0 ? nothing : results)" : "results"); + std::string resultinference = (inferrable ? "(length(results) == 0 ? true : false)" : "false"); + + std::string attributearguments = ""; + std::string attributecontainer = ""; + for (size_t i = 0; i < op.getNumAttributes(); i++) + { + const auto &named_attr = op.getAttribute(i); + + // Derived attributes are never materialized and don't have to be + // specified. + if (named_attr.attr.isDerivedAttr()) + continue; + + std::string defaultvalue = ""; + std::string attributename = named_attr.name.str(); + if (attributename.empty()) + { + attributename = "attribute_" + std::to_string(i); + } + + bool optional = named_attr.attr.isOptional() || named_attr.attr.hasDefaultValue(); + + if (optional) + { + optionals += llvm::formatv(R"(({0} != nothing) && push!(attributes, namedattribute("{0}", {0})) + )", + attributename); + defaultvalue = "=nothing"; + } + else + { + attributecontainer += "namedattribute(\"" + attributename + "\", " + attributename + "), "; + } + attributearguments += attributename + defaultvalue + ", "; + } + + std::string regionarguments = ""; + std::string regioncontainer = ""; + for (size_t i = 0; i < op.getNumRegions(); i++) + { + const auto &named_region = op.getRegion(i); + std::string defaultvalue = ""; + std::string regionname = named_region.name.str(); + if (regionname.empty()) + { + regionname = "region_" + std::to_string(i); + } + std::string type = "Region"; + + bool variadic = named_region.isVariadic(); + + if (variadic) + { + type = "Vector{" + type + "}"; + } + + regioncontainer += regionname + (variadic ? "..." : "") + ", "; + regionarguments += regionname + defaultvalue + "::" + type + ", "; + } + + std::string successorarguments = ""; + std::string successorcontainer = ""; + for (size_t i = 0; i < op.getNumSuccessors(); i++) + { + const auto &named_successor = op.getSuccessor(i); + std::string defaultvalue = ""; + std::string successorname = named_successor.name.str(); + if (successorname.empty()) + { + successorname = "successor_" + std::to_string(i); + } + std::string type = "Block"; + + bool variadic = named_successor.isVariadic(); + if (variadic) + { + type = "Vector{" + type + "}"; + } + + successorcontainer += successorname + (variadic ? "..." : "") + ", "; + successorarguments += successorname + defaultvalue + "::" + type + ", "; + } + + std::string arguments = operandarguments + "; " + resultarguments + attributearguments + regionarguments + successorarguments; + std::string functionbody = llvm::formatv(functionbodytemplate, resultcontainer, operandcontainer, regioncontainer, successorcontainer, attributecontainer, optionals, opname, resultsexpression, resultinference); + + modulecontents += llvm::formatv(functiontemplate, functionname, arguments, functionbody, description); + } + + os << llvm::formatv(moduleTemplate, modulename, modulecontents); + + return false; +} diff --git a/deps/tblgen/mlir-jl-tblgen.cc b/deps/tblgen/mlir-jl-tblgen.cc new file mode 100644 index 00000000..3f1632e5 --- /dev/null +++ b/deps/tblgen/mlir-jl-tblgen.cc @@ -0,0 +1,64 @@ +// Copyright 2021 Google LLC +// Copyright 2023 Valentin Churavy +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "llvm/ADT/StringExtras.h" +#include "llvm/Support/CommandLine.h" +#include "llvm/Support/FormatVariadic.h" +#include "llvm/Support/InitLLVM.h" +#include "llvm/Support/ManagedStatic.h" +#include "llvm/Support/Signals.h" +#include "llvm/TableGen/Error.h" +#include "llvm/TableGen/Main.h" +#include "llvm/TableGen/Record.h" +#include "llvm/TableGen/TableGenBackend.h" + +using namespace llvm; + +using generator_function = bool(const llvm::RecordKeeper& recordKeeper, + llvm::raw_ostream& os); + +struct GeneratorInfo { + const char* name; + generator_function* generator; +}; + +extern generator_function emitOpTableDefs; +extern generator_function emitTestTableDefs; + +static std::array generators {{ + {"jl-op-defs", emitOpTableDefs}, +}}; + +generator_function* generator; + +int main(int argc, char **argv) { + llvm::InitLLVM y(argc, argv); + llvm::cl::opt generatorOpt("generator", llvm::cl::desc("Generator to run"), cl::Required); + cl::ParseCommandLineOptions(argc, argv); + for (const auto& spec : generators) { + if (generatorOpt == spec.name) { + generator = spec.generator; + break; + } + } + if (!generator) { + llvm::errs() << "Invalid generator type\n"; + abort(); + } + + return TableGenMain(argv[0], [](raw_ostream& os, RecordKeeper &records) { + return generator(records, os); + }); +} \ No newline at end of file diff --git a/src/Dialects.jl b/src/Dialects.jl index cd6f4244..046fffec 100644 --- a/src/Dialects.jl +++ b/src/Dialects.jl @@ -1,166 +1,40 @@ module Dialects -module arith - -using ...IR - -for (f, t) in Iterators.product( - (:add, :sub, :mul), - (:i, :f), -) - fname = Symbol(f, t) - @eval function $fname(operands, type=IR.get_type(first(operands)); loc=Location()) - IR.create_operation($(string("arith.", fname)), loc; operands, results=[type]) - end -end - -for fname in (:xori, :andi, :ori) - @eval function $fname(operands, type=IR.get_type(first(operands)); loc=Location()) - IR.create_operation($(string("arith.", fname)), loc; operands, results=[type]) - end -end - -for (f, t) in Iterators.product( - (:div, :max, :min), - (:si, :ui, :f), -) - fname = Symbol(f, t) - @eval function $fname(operands, type=IR.get_type(first(operands)); loc=Location()) - IR.create_operation($(string("arith.", fname)), loc; operands, results=[type]) - end -end - -# https://mlir.llvm.org/docs/Dialects/ArithOps/#arithindex_cast-mlirarithindexcastop -for f in (:index_cast, :index_castui) - @eval function $f(operand; loc=Location()) - IR.create_operation( - $(string("arith.", f)), - loc; - operands=[operand], - results=[IR.IndexType()], - ) - end -end - -# https://mlir.llvm.org/docs/Dialects/ArithOps/#arithextf-mlirarithextfop -function extf(operand, type; loc=Location()) - IR.create_operation("arith.exf", loc; operands=[operand], results=[type]) -end - -# https://mlir.llvm.org/docs/Dialects/ArithOps/#arithconstant-mlirarithconstantop -function constant(value, type=MLIRType(typeof(value)); loc=Location()) - IR.create_operation( - "arith.constant", - loc; - results=[type], - attributes=[ - IR.NamedAttribute("value", - Attribute(value, type)), - ], - ) -end - -module Predicates - const eq = 0 - const ne = 1 - const slt = 2 - const sle = 3 - const sgt = 4 - const sge = 5 - const ult = 6 - const ule = 7 - const ugt = 8 - const uge = 9 -end - -function cmpi(predicate, operands; loc=Location()) - IR.create_operation( - "arith.cmpi", - loc; - operands, - results=[MLIRType(Bool)], - attributes=[ - IR.NamedAttribute("predicate", - Attribute(predicate)) - ], - ) -end - -end # module arith - -module std -# for llvm 14 - -using ...IR - -function return_(operands; loc=Location()) - IR.create_operation("std.return", loc; operands, result_inference=false) -end - -function br(dest, operands; loc=Location()) - IR.create_operation("std.br", loc; operands, successors=[dest], result_inference=false) -end - -function cond_br( - cond, - true_dest, false_dest, - true_dest_operands, - false_dest_operands; - loc=Location(), -) - IR.create_operation( - "std.cond_br", - loc; - successors=[true_dest, false_dest], - operands=[cond, true_dest_operands..., false_dest_operands...], - attributes=[ - IR.NamedAttribute("operand_segment_sizes", - IR.Attribute(Int32[1, length(true_dest_operands), length(false_dest_operands)])) - ], - result_inference=false, - ) -end - -end # module std - -module func -# https://mlir.llvm.org/docs/Dialects/Func/ - -using ...IR - -function return_(operands; loc=Location()) - IR.create_operation("func.return", loc; operands, result_inference=false) -end - -end # module func - -module cf - -using ...IR - -function br(dest, operands; loc=Location()) - IR.create_operation("cf.br", loc; operands, successors=[dest], result_inference=false) -end - -function cond_br( - cond, - true_dest, false_dest, - true_dest_operands, - false_dest_operands; - loc=Location(), -) - IR.create_operation( - "cf.cond_br", loc; - operands=[cond, true_dest_operands..., false_dest_operands...], - successors=[true_dest, false_dest], - attributes=[ - IR.NamedAttribute("operand_segment_sizes", - IR.Attribute(Int32[1, length(true_dest_operands), length(false_dest_operands)])) - ], - result_inference=false, - ) -end - -end # module cf +import ..IR: Attribute, NamedAttribute, context +import ..API + +namedattribute(name, val) = namedattribute(name, Attribute(val)) +namedattribute(name, val::Attribute) = NamedAttribute(name, val) +function namedattribute(name, val::NamedAttribute) + @assert true # TODO(jm): check whether name of attribute is correct, getting the name might need to be added to IR.jl? + return val +end + +operandsegmentsizes(segments) = namedattribute( + "operand_segment_sizes", + Attribute(API.mlirDenseI32ArrayGet( + context().context, + length(segments), + Int32.(segments) + ))) + +include.(filter(contains(r".jl$"), readdir(joinpath(@__DIR__, "dialects"); join=true))) + +# module arith + +# module Predicates +# const eq = 0 +# const ne = 1 +# const slt = 2 +# const sle = 3 +# const sgt = 4 +# const sge = 5 +# const ult = 6 +# const ule = 7 +# const ugt = 8 +# const uge = 9 +# end + +# end # module arith end # module Dialects diff --git a/src/MLIR.jl b/src/MLIR.jl index 4d200805..49d275c8 100644 --- a/src/MLIR.jl +++ b/src/MLIR.jl @@ -38,11 +38,11 @@ end module IR import ..API: API - include("./IR/IR.jl") - include("./IR/state.jl") + include("IR/IR.jl") + include("IR/state.jl") end # module IR -include("./Dialects.jl") +include("Dialects.jl") end # module MLIR diff --git a/src/dialects/.gitignore b/src/dialects/.gitignore new file mode 100644 index 00000000..978d85b3 --- /dev/null +++ b/src/dialects/.gitignore @@ -0,0 +1 @@ +*.jl \ No newline at end of file