diff --git a/mlir/include/mlir/Conversion/ConvertToLLVM/ToLLVMPass.h b/mlir/include/mlir/Conversion/ConvertToLLVM/ToLLVMPass.h index 2eddf52d7abc52..73deef49c4175d 100644 --- a/mlir/include/mlir/Conversion/ConvertToLLVM/ToLLVMPass.h +++ b/mlir/include/mlir/Conversion/ConvertToLLVM/ToLLVMPass.h @@ -22,6 +22,10 @@ namespace mlir { /// implementing `ConvertToLLVMPatternInterface`. std::unique_ptr createConvertToLLVMPass(); +/// Register the extension that will load dependent dialects for LLVM +/// conversion. This is useful to implement a pass similar to "convert-to-llvm". +void registerConvertToLLVMDependentDialectLoading(DialectRegistry ®istry); + } // namespace mlir #endif // MLIR_CONVERSION_CONVERTTOLLVM_TOLLVM_PASS_H diff --git a/mlir/lib/Conversion/ConvertToLLVM/ConvertToLLVMPass.cpp b/mlir/lib/Conversion/ConvertToLLVM/ConvertToLLVMPass.cpp index a90e557b1fdbd9..6135117348a5b8 100644 --- a/mlir/lib/Conversion/ConvertToLLVM/ConvertToLLVMPass.cpp +++ b/mlir/lib/Conversion/ConvertToLLVM/ConvertToLLVMPass.cpp @@ -124,6 +124,11 @@ class ConvertToLLVMPass } // namespace +void mlir::registerConvertToLLVMDependentDialectLoading( + DialectRegistry ®istry) { + registry.addExtensions(); +} + std::unique_ptr mlir::createConvertToLLVMPass() { return std::make_unique(); } diff --git a/mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp b/mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp index 2da97c20e9c984..75dee09d2f64fd 100644 --- a/mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp +++ b/mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp @@ -18,6 +18,8 @@ #include "mlir/Conversion/ArithToLLVM/ArithToLLVM.h" #include "mlir/Conversion/AsyncToLLVM/AsyncToLLVM.h" #include "mlir/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.h" +#include "mlir/Conversion/ConvertToLLVM/ToLLVMInterface.h" +#include "mlir/Conversion/ConvertToLLVM/ToLLVMPass.h" #include "mlir/Conversion/FuncToLLVM/ConvertFuncToLLVM.h" #include "mlir/Conversion/FuncToLLVM/ConvertFuncToLLVMPass.h" #include "mlir/Conversion/LLVMCommon/ConversionTarget.h" @@ -38,6 +40,8 @@ #include "llvm/Support/Error.h" #include "llvm/Support/FormatVariadic.h" +#define DEBUG_TYPE "gpu-to-llvm" + namespace mlir { #define GEN_PASS_DEF_GPUTOLLVMCONVERSIONPASS #include "mlir/Conversion/Passes.h.inc" @@ -48,12 +52,14 @@ using namespace mlir; static constexpr const char *kGpuBinaryStorageSuffix = "_gpubin_cst"; namespace { - class GpuToLLVMConversionPass : public impl::GpuToLLVMConversionPassBase { public: using Base::Base; - + void getDependentDialects(DialectRegistry ®istry) const final { + Base::getDependentDialects(registry); + registerConvertToLLVMDependentDialectLoading(registry); + } // Run the dialect converter on the module. void runOnOperation() override; }; @@ -580,14 +586,24 @@ DECLARE_CONVERT_OP_TO_GPU_RUNTIME_CALL_PATTERN(SetCsrPointersOp) } // namespace void GpuToLLVMConversionPass::runOnOperation() { - LowerToLLVMOptions options(&getContext()); + MLIRContext *context = &getContext(); + SymbolTable symbolTable = SymbolTable(getOperation()); + LowerToLLVMOptions options(context); options.useBarePtrCallConv = hostBarePtrCallConv; + RewritePatternSet patterns(context); + ConversionTarget target(*context); + target.addLegalDialect(); + LLVMTypeConverter converter(context, options); + + // Populate all patterns from all dialects that implement the + // `ConvertToLLVMPatternInterface` interface. + for (Dialect *dialect : context->getLoadedDialects()) { + auto iface = dyn_cast(dialect); + if (!iface) + continue; + iface->populateConvertToLLVMConversionPatterns(target, converter, patterns); + } - LLVMTypeConverter converter(&getContext(), options); - RewritePatternSet patterns(&getContext()); - LLVMConversionTarget target(getContext()); - - SymbolTable symbolTable = SymbolTable(getOperation()); // Preserve GPU modules if they have target attributes. target.addDynamicallyLegalOp( [](gpu::GPUModuleOp module) -> bool { @@ -605,11 +621,9 @@ void GpuToLLVMConversionPass::runOnOperation() { !module.getTargetsAttr().empty()); }); - mlir::arith::populateArithToLLVMConversionPatterns(converter, patterns); - mlir::cf::populateControlFlowToLLVMConversionPatterns(converter, patterns); + // These aren't covered by the ConvertToLLVMPatternInterface right now. populateVectorToLLVMConversionPatterns(converter, patterns); populateFinalizeMemRefToLLVMConversionPatterns(converter, patterns); - populateFuncToLLVMConversionPatterns(converter, patterns); populateAsyncStructuralTypeConversionsAndLegality(converter, patterns, target); populateGpuToLLVMConversionPatterns(converter, patterns, gpuBinaryAnnotation,