Skip to content

Commit

Permalink
[mlir] Adopt ConvertToLLVMPatternInterface GpuToLLVMConversionPass …
Browse files Browse the repository at this point in the history
…to align with `convert-to-llvm` (llvm#73761)

This is a follow-up to the introduction of `convert-to-llvm`: it is
supposed to be a unifying pass through the
`ConvertToLLVMPatternInterface`, but some specific conversion (like the
GPU target) aren't vanilla LLVM target. Instead they need extra
customizations that are specific to LLVM-on-GPUs and our custom runtime
wrappers.
This change make the GpuToLLVMConversionPass just as pluggable as the
`convert-to-llvm` by using the same mechanism.
  • Loading branch information
joker-eph committed Nov 29, 2023
1 parent 14028ec commit 9e7b6f4
Show file tree
Hide file tree
Showing 3 changed files with 34 additions and 11 deletions.
4 changes: 4 additions & 0 deletions mlir/include/mlir/Conversion/ConvertToLLVM/ToLLVMPass.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,10 @@ namespace mlir {
/// implementing `ConvertToLLVMPatternInterface`.
std::unique_ptr<Pass> createConvertToLLVMPass();

/// Register the extension that will load dependent dialects for LLVM
/// conversion. This is useful to implement a pass similar to "convert-to-llvm".
void registerConvertToLLVMDependentDialectLoading(DialectRegistry &registry);

} // namespace mlir

#endif // MLIR_CONVERSION_CONVERTTOLLVM_TOLLVM_PASS_H
5 changes: 5 additions & 0 deletions mlir/lib/Conversion/ConvertToLLVM/ConvertToLLVMPass.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,11 @@ class ConvertToLLVMPass

} // namespace

void mlir::registerConvertToLLVMDependentDialectLoading(
DialectRegistry &registry) {
registry.addExtensions<LoadDependentDialectExtension>();
}

std::unique_ptr<Pass> mlir::createConvertToLLVMPass() {
return std::make_unique<ConvertToLLVMPass>();
}
36 changes: 25 additions & 11 deletions mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@
#include "mlir/Conversion/ArithToLLVM/ArithToLLVM.h"
#include "mlir/Conversion/AsyncToLLVM/AsyncToLLVM.h"
#include "mlir/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.h"
#include "mlir/Conversion/ConvertToLLVM/ToLLVMInterface.h"
#include "mlir/Conversion/ConvertToLLVM/ToLLVMPass.h"
#include "mlir/Conversion/FuncToLLVM/ConvertFuncToLLVM.h"
#include "mlir/Conversion/FuncToLLVM/ConvertFuncToLLVMPass.h"
#include "mlir/Conversion/LLVMCommon/ConversionTarget.h"
Expand All @@ -38,6 +40,8 @@
#include "llvm/Support/Error.h"
#include "llvm/Support/FormatVariadic.h"

#define DEBUG_TYPE "gpu-to-llvm"

namespace mlir {
#define GEN_PASS_DEF_GPUTOLLVMCONVERSIONPASS
#include "mlir/Conversion/Passes.h.inc"
Expand All @@ -48,12 +52,14 @@ using namespace mlir;
static constexpr const char *kGpuBinaryStorageSuffix = "_gpubin_cst";

namespace {

class GpuToLLVMConversionPass
: public impl::GpuToLLVMConversionPassBase<GpuToLLVMConversionPass> {
public:
using Base::Base;

void getDependentDialects(DialectRegistry &registry) const final {
Base::getDependentDialects(registry);
registerConvertToLLVMDependentDialectLoading(registry);
}
// Run the dialect converter on the module.
void runOnOperation() override;
};
Expand Down Expand Up @@ -580,14 +586,24 @@ DECLARE_CONVERT_OP_TO_GPU_RUNTIME_CALL_PATTERN(SetCsrPointersOp)
} // namespace

void GpuToLLVMConversionPass::runOnOperation() {
LowerToLLVMOptions options(&getContext());
MLIRContext *context = &getContext();
SymbolTable symbolTable = SymbolTable(getOperation());
LowerToLLVMOptions options(context);
options.useBarePtrCallConv = hostBarePtrCallConv;
RewritePatternSet patterns(context);
ConversionTarget target(*context);
target.addLegalDialect<LLVM::LLVMDialect>();
LLVMTypeConverter converter(context, options);

// Populate all patterns from all dialects that implement the
// `ConvertToLLVMPatternInterface` interface.
for (Dialect *dialect : context->getLoadedDialects()) {
auto iface = dyn_cast<ConvertToLLVMPatternInterface>(dialect);
if (!iface)
continue;
iface->populateConvertToLLVMConversionPatterns(target, converter, patterns);
}

LLVMTypeConverter converter(&getContext(), options);
RewritePatternSet patterns(&getContext());
LLVMConversionTarget target(getContext());

SymbolTable symbolTable = SymbolTable(getOperation());
// Preserve GPU modules if they have target attributes.
target.addDynamicallyLegalOp<gpu::GPUModuleOp>(
[](gpu::GPUModuleOp module) -> bool {
Expand All @@ -605,11 +621,9 @@ void GpuToLLVMConversionPass::runOnOperation() {
!module.getTargetsAttr().empty());
});

mlir::arith::populateArithToLLVMConversionPatterns(converter, patterns);
mlir::cf::populateControlFlowToLLVMConversionPatterns(converter, patterns);
// These aren't covered by the ConvertToLLVMPatternInterface right now.
populateVectorToLLVMConversionPatterns(converter, patterns);
populateFinalizeMemRefToLLVMConversionPatterns(converter, patterns);
populateFuncToLLVMConversionPatterns(converter, patterns);
populateAsyncStructuralTypeConversionsAndLegality(converter, patterns,
target);
populateGpuToLLVMConversionPatterns(converter, patterns, gpuBinaryAnnotation,
Expand Down

0 comments on commit 9e7b6f4

Please sign in to comment.