diff --git a/flang/include/flang/Optimizer/Transforms/Passes.h b/flang/include/flang/Optimizer/Transforms/Passes.h index b987f481702f3c..38f053caa93b3f 100644 --- a/flang/include/flang/Optimizer/Transforms/Passes.h +++ b/flang/include/flang/Optimizer/Transforms/Passes.h @@ -75,6 +75,7 @@ std::unique_ptr createAlgebraicSimplificationPass(); std::unique_ptr createAlgebraicSimplificationPass(const mlir::GreedyRewriteConfig &config); +std::unique_ptr createOMPGlobalFilteringPass(); std::unique_ptr createVScaleAttrPass(); std::unique_ptr createVScaleAttrPass(std::pair vscaleAttr); diff --git a/flang/include/flang/Optimizer/Transforms/Passes.td b/flang/include/flang/Optimizer/Transforms/Passes.td index 4df560b8a15812..d9521bd34cf7f5 100644 --- a/flang/include/flang/Optimizer/Transforms/Passes.td +++ b/flang/include/flang/Optimizer/Transforms/Passes.td @@ -349,6 +349,16 @@ def OMPFunctionFiltering : Pass<"omp-function-filtering"> { ]; } +def OMPGlobalFiltering : Pass<"omp-global-filtering"> { + let summary = "Filters out globals intended for the host when compiling " + "for the target device."; + let constructor = "::fir::createOMPGlobalFilteringPass()"; + let dependentDialects = [ + "mlir::func::FuncDialect", + "fir::FIROpsDialect" + ]; +} + def VScaleAttr : Pass<"vscale-attr", "mlir::func::FuncOp"> { let summary = "Add vscale_range attribute to functions"; let description = [{ diff --git a/flang/include/flang/Tools/CLOptions.inc b/flang/include/flang/Tools/CLOptions.inc index 9ceef3ee679657..24de8ba8191c2c 100644 --- a/flang/include/flang/Tools/CLOptions.inc +++ b/flang/include/flang/Tools/CLOptions.inc @@ -354,8 +354,10 @@ inline void createOpenMPFIRPassPipeline(mlir::PassManager &pm, pm, fir::createOMPMapInfoFinalizationPass); pm.addPass(fir::createOMPMarkDeclareTargetPass()); - if (isTargetDevice) + if (isTargetDevice) { pm.addPass(fir::createOMPFunctionFiltering()); + pm.addPass(fir::createOMPGlobalFilteringPass()); + } } #if !defined(FLANG_EXCLUDE_CODEGEN) diff --git a/flang/lib/Optimizer/Transforms/CMakeLists.txt b/flang/lib/Optimizer/Transforms/CMakeLists.txt index 217037d9600919..1921df7a2b7540 100644 --- a/flang/lib/Optimizer/Transforms/CMakeLists.txt +++ b/flang/lib/Optimizer/Transforms/CMakeLists.txt @@ -19,6 +19,7 @@ add_flang_library(FIRTransforms PolymorphicOpConversion.cpp LoopVersioning.cpp OMPFunctionFiltering.cpp + OMPGlobalFiltering.cpp OMPMapInfoFinalization.cpp OMPMarkDeclareTarget.cpp VScaleAttr.cpp diff --git a/flang/lib/Optimizer/Transforms/OMPGlobalFiltering.cpp b/flang/lib/Optimizer/Transforms/OMPGlobalFiltering.cpp new file mode 100644 index 00000000000000..b2e475c629cdd9 --- /dev/null +++ b/flang/lib/Optimizer/Transforms/OMPGlobalFiltering.cpp @@ -0,0 +1,64 @@ +//===- OMPFunctionFiltering.cpp -------------------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements transforms to filter out functions intended for the host +// when compiling for the device and vice versa. +// +//===----------------------------------------------------------------------===// + +#include "flang/Optimizer/Dialect/FIRDialect.h" +#include "flang/Optimizer/Dialect/FIROpsSupport.h" +#include "flang/Optimizer/Transforms/Passes.h" + +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/OpenMP/OpenMPDialect.h" +#include "mlir/Dialect/OpenMP/OpenMPInterfaces.h" +#include "mlir/IR/BuiltinOps.h" +#include "llvm/ADT/SmallVector.h" + +namespace fir { +#define GEN_PASS_DEF_OMPGLOBALFILTERING +#include "flang/Optimizer/Transforms/Passes.h.inc" +} // namespace fir + +using namespace mlir; + +namespace { +class OMPGlobalFilteringPass + : public fir::impl::OMPGlobalFilteringBase { +public: + OMPGlobalFilteringPass() = default; + + void runOnOperation() override { + auto op = dyn_cast(getOperation()); + if (!op || !op.getIsTargetDevice()) + return; + + op->walk([&](fir::GlobalOp globalOp) { + bool symbolUnused = true; + SymbolTable::UseRange globalUses = *globalOp.getSymbolUses(op); + for (SymbolTable::SymbolUse use : globalUses) { + if (use.getUser() == globalOp) + continue; + symbolUnused = false; + break; + } + + // Remove unused host symbols with external linkage + // TODO: Add support for declare target global variables + if (symbolUnused && !globalOp.getLinkName()) + globalOp.erase(); + return WalkResult::advance(); + }); + } +}; +} // namespace + +std::unique_ptr fir::createOMPGlobalFilteringPass() { + return std::make_unique(); +}