From 4d857ec127466cc06d93d4f6ccf427c4e0666f4b Mon Sep 17 00:00:00 2001 From: Sergio Afonso Date: Tue, 13 Aug 2024 11:26:15 +0100 Subject: [PATCH] [MLIR][OpenMP] Remove omp.parallel from loop wrapper operations (#134) This patch removes the `LoopWrapperInterface` from `omp.parallel` and updates the semantics of the interface to make loop wrapper restrictions mandatory to operations that have it, rather than a role they might optionally take. MLIR operation verifiers are updated to expect the "hoisted omp.parallel" representation for `distribute parallel do`, to be later implemented in place of a loop wrapper `omp.parallel`. --- .../lib/Lower/OpenMP/DataSharingProcessor.cpp | 4 +- mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td | 1 - .../Dialect/OpenMP/OpenMPOpsInterfaces.td | 26 +++--- mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp | 56 +++++-------- .../OpenMP/OpenMPToLLVMIRTranslation.cpp | 3 +- mlir/test/Dialect/OpenMP/invalid.mlir | 79 +++++-------------- mlir/test/Dialect/OpenMP/ops.mlir | 9 ++- 7 files changed, 62 insertions(+), 116 deletions(-) diff --git a/flang/lib/Lower/OpenMP/DataSharingProcessor.cpp b/flang/lib/Lower/OpenMP/DataSharingProcessor.cpp index ddb5801f3cbc0e..0ae78dc5da07af 100644 --- a/flang/lib/Lower/OpenMP/DataSharingProcessor.cpp +++ b/flang/lib/Lower/OpenMP/DataSharingProcessor.cpp @@ -235,9 +235,7 @@ void DataSharingProcessor::insertBarrier() { void DataSharingProcessor::insertLastPrivateCompare(mlir::Operation *op) { mlir::omp::LoopNestOp loopOp; if (auto wrapper = mlir::dyn_cast(op)) - loopOp = wrapper.isWrapper() - ? mlir::cast(wrapper.getWrappedLoop()) - : nullptr; + loopOp = mlir::cast(wrapper.getWrappedLoop()); bool cmpCreated = false; mlir::OpBuilder::InsertionGuard guard(firOpBuilder); diff --git a/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td b/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td index d05543f8fbd929..9fbb122170e798 100644 --- a/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td +++ b/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td @@ -128,7 +128,6 @@ def PrivateClauseOp : OpenMP_Op<"private", [IsolatedFromAbove, RecipeInterface]> def ParallelOp : OpenMP_Op<"parallel", traits = [ AttrSizedOperandSegments, AutomaticAllocationScope, - DeclareOpInterfaceMethods, DeclareOpInterfaceMethods, RecursiveMemoryEffects ], clauses = [ diff --git a/mlir/include/mlir/Dialect/OpenMP/OpenMPOpsInterfaces.td b/mlir/include/mlir/Dialect/OpenMP/OpenMPOpsInterfaces.td index 2d1de37239c82a..ace671e04d7876 100644 --- a/mlir/include/mlir/Dialect/OpenMP/OpenMPOpsInterfaces.td +++ b/mlir/include/mlir/Dialect/OpenMP/OpenMPOpsInterfaces.td @@ -71,10 +71,10 @@ def ReductionClauseInterface : OpInterface<"ReductionClauseInterface"> { def LoopWrapperInterface : OpInterface<"LoopWrapperInterface"> { let description = [{ - OpenMP operations that can wrap a single loop nest. When taking a wrapper - role, these operations must only contain a single region with a single block - in which there's a single operation and a terminator. That nested operation - must be another loop wrapper or an `omp.loop_nest`. + OpenMP operations that wrap a single loop nest. They must only contain a + single region with a single block in which there's a single operation and a + terminator. That nested operation must be another loop wrapper or an + `omp.loop_nest`. }]; let cppNamespace = "::mlir::omp"; @@ -82,13 +82,13 @@ def LoopWrapperInterface : OpInterface<"LoopWrapperInterface"> { let methods = [ InterfaceMethod< /*description=*/[{ - Tell whether the operation could be taking the role of a loop wrapper. - That is, it has a single region with a single block in which there are - two operations: another wrapper (also taking a loop wrapper role) or + Check whether the operation is a valid loop wrapper. That is, it has a + single region with a single block in which there are two operations: + another loop wrapper (also taking a loop wrapper role) or `omp.loop_nest` operation and a terminator. }], /*retTy=*/"bool", - /*methodName=*/"isWrapper", + /*methodName=*/"isValidWrapper", (ins ), [{}], [{ if ($_op->getNumRegions() != 1) return false; @@ -107,7 +107,7 @@ def LoopWrapperInterface : OpInterface<"LoopWrapperInterface"> { return false; if (auto wrapper = ::llvm::dyn_cast(firstOp)) - return wrapper.isWrapper(); + return wrapper.isValidWrapper(); return ::llvm::isa(firstOp); }] @@ -115,12 +115,12 @@ def LoopWrapperInterface : OpInterface<"LoopWrapperInterface"> { InterfaceMethod< /*description=*/[{ If there is another loop wrapper immediately nested inside, return that - operation. Assumes this operation is taking a loop wrapper role. + operation. Assumes this operation is a valid loop wrapper. }], /*retTy=*/"::mlir::omp::LoopWrapperInterface", /*methodName=*/"getNestedWrapper", (ins), [{}], [{ - assert($_op.isWrapper() && "Unexpected non-wrapper op"); + assert($_op.isValidWrapper() && "Unexpected non-wrapper op"); Operation *nested = &*$_op->getRegion(0).op_begin(); return ::llvm::dyn_cast(nested); }] @@ -128,12 +128,12 @@ def LoopWrapperInterface : OpInterface<"LoopWrapperInterface"> { InterfaceMethod< /*description=*/[{ Return the loop nest nested directly or indirectly inside of this loop - wrapper. Assumes this operation is taking a loop wrapper role. + wrapper. Assumes this operation is a valid loop wrapper. }], /*retTy=*/"::mlir::Operation *", /*methodName=*/"getWrappedLoop", (ins), [{}], [{ - assert($_op.isWrapper() && "Unexpected non-wrapper op"); + assert($_op.isValidWrapper() && "Unexpected non-wrapper op"); if (LoopWrapperInterface nested = $_op.getNestedWrapper()) return nested.getWrappedLoop(); return &*$_op->getRegion(0).op_begin(); diff --git a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp index 5f6007561967e4..82efe62525ce12 100644 --- a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp +++ b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp @@ -1604,15 +1604,15 @@ bool TargetOp::isTargetSPMDLoop() { if (!isa_and_present(workshareOp)) return false; - Operation *parallelOp = workshareOp->getParentOp(); - if (!isa_and_present(parallelOp)) + Operation *distributeOp = workshareOp->getParentOp(); + if (!isa_and_present(distributeOp)) return false; - Operation *distributeOp = parallelOp->getParentOp(); - if (!isa_and_present(distributeOp)) + Operation *parallelOp = distributeOp->getParentOp(); + if (!isa_and_present(parallelOp)) return false; - Operation *teamsOp = distributeOp->getParentOp(); + Operation *teamsOp = parallelOp->getParentOp(); if (!isa_and_present(teamsOp)) return false; @@ -1690,22 +1690,6 @@ static LogicalResult verifyPrivateVarList(OpType &op) { } LogicalResult ParallelOp::verify() { - // Check that it is a valid loop wrapper if it's taking that role. - if (isa((*this)->getParentOp())) { - if (!isWrapper()) - return emitOpError() << "must take a loop wrapper role if nested inside " - "of 'omp.distribute'"; - - if (LoopWrapperInterface nested = getNestedWrapper()) { - // Check for the allowed leaf constructs that may appear in a composite - // construct directly after PARALLEL. - if (!isa(nested)) - return emitError() << "only supported nested wrapper is 'omp.wsloop'"; - } else { - return emitOpError() << "must not wrap an 'omp.loop_nest' directly"; - } - } - if (getAllocateVars().size() != getAllocatorVars().size()) return emitError( "expected equal sizes for allocate and allocator variables"); @@ -1894,8 +1878,8 @@ void WsloopOp::build(OpBuilder &builder, OperationState &state, } LogicalResult WsloopOp::verify() { - if (!isWrapper()) - return emitOpError() << "must be a loop wrapper"; + if (!isValidWrapper()) + return emitOpError() << "must be a valid loop wrapper"; if (LoopWrapperInterface nested = getNestedWrapper()) { // Check for the allowed leaf constructs that may appear in a composite @@ -1939,8 +1923,8 @@ LogicalResult SimdOp::verify() { if (verifyNontemporalClause(*this, getNontemporalVars()).failed()) return failure(); - if (!isWrapper()) - return emitOpError() << "must be a loop wrapper"; + if (!isValidWrapper()) + return emitOpError() << "must be a valid loop wrapper"; if (getNestedWrapper()) return emitOpError() << "must wrap an 'omp.loop_nest' directly"; @@ -1970,15 +1954,19 @@ LogicalResult DistributeOp::verify() { return emitError( "expected equal sizes for allocate and allocator variables"); - if (!isWrapper()) - return emitOpError() << "must be a loop wrapper"; + if (!isValidWrapper()) + return emitOpError() << "must be a valid loop wrapper"; if (LoopWrapperInterface nested = getNestedWrapper()) { // Check for the allowed leaf constructs that may appear in a composite // construct directly after DISTRIBUTE. - if (!isa(nested)) - return emitError() << "only supported nested wrappers are 'omp.parallel' " - "and 'omp.simd'"; + if (isa(nested)) { + if (!llvm::dyn_cast_if_present((*this)->getParentOp())) + return emitError() << "an 'omp.wsloop' nested wrapper is only allowed " + "when 'omp.parallel' is the direct parent"; + } else if (!isa(nested)) + return emitError() << "only supported nested wrappers are 'omp.simd' and " + "'omp.wsloop'"; } return success(); @@ -2176,8 +2164,8 @@ LogicalResult TaskloopOp::verify() { "may not appear on the same taskloop directive"); } - if (!isWrapper()) - return emitOpError() << "must be a loop wrapper"; + if (!isValidWrapper()) + return emitOpError() << "must be a valid loop wrapper"; if (LoopWrapperInterface nested = getNestedWrapper()) { // Check for the allowed leaf constructs that may appear in a composite @@ -2269,7 +2257,7 @@ LogicalResult LoopNestOp::verify() { auto wrapper = llvm::dyn_cast_if_present((*this)->getParentOp()); - if (!wrapper || !wrapper.isWrapper()) + if (!wrapper || !wrapper.isValidWrapper()) return emitOpError() << "expects parent op to be a valid loop wrapper"; return success(); @@ -2280,8 +2268,6 @@ void LoopNestOp::gatherWrappers( Operation *parent = (*this)->getParentOp(); while (auto wrapper = llvm::dyn_cast_if_present(parent)) { - if (!wrapper.isWrapper()) - break; wrappers.push_back(wrapper); parent = parent->getParentOp(); } diff --git a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp index 8c72411b7b2e17..a5b289bee84b19 100644 --- a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp +++ b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp @@ -3190,8 +3190,7 @@ static LogicalResult convertOmpDistribute( llvm::OpenMPIRBuilder::InsertPointTy *redAllocaIP, SmallVector &reductionInfos) { llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder(); - // FIXME: This ignores any other nested wrappers (e.g. omp.parallel + - // omp.wsloop, omp.simd). + // FIXME: This ignores any other nested wrappers (e.g. omp.wsloop, omp.simd). auto distributeOp = cast(opInst); auto loopOp = cast(distributeOp.getWrappedLoop()); diff --git a/mlir/test/Dialect/OpenMP/invalid.mlir b/mlir/test/Dialect/OpenMP/invalid.mlir index 3bed809bb6dc0d..856415655490f8 100644 --- a/mlir/test/Dialect/OpenMP/invalid.mlir +++ b/mlir/test/Dialect/OpenMP/invalid.mlir @@ -10,58 +10,6 @@ func.func @unknown_clause() { // ----- -func.func @not_wrapper() { - // expected-error@+1 {{op must be a loop wrapper}} - omp.distribute { - omp.parallel { - %0 = arith.constant 0 : i32 - omp.terminator - } - omp.terminator - } - - return -} - -// ----- - -func.func @invalid_nested_wrapper(%lb : index, %ub : index, %step : index) { - omp.distribute { - // expected-error@+1 {{only supported nested wrapper is 'omp.wsloop'}} - omp.parallel { - omp.simd { - omp.loop_nest (%iv) : index = (%lb) to (%ub) step (%step) { - omp.yield - } - omp.terminator - } - omp.terminator - } - omp.terminator - } - - return -} - -// ----- - -func.func @no_nested_wrapper(%lb : index, %ub : index, %step : index) { - omp.distribute { - // expected-error@+1 {{op must not wrap an 'omp.loop_nest' directly}} - omp.parallel { - omp.loop_nest (%iv) : index = (%lb) to (%ub) step (%step) { - omp.yield - } - omp.terminator - } - omp.terminator - } - - return -} - -// ----- - func.func @if_once(%n : i1) { // expected-error@+1 {{`if` clause can appear at most once in the expansion of the oilist directive}} omp.parallel if(%n) if(%n) { @@ -188,7 +136,7 @@ func.func @iv_number_mismatch(%lb : index, %ub : index, %step : index) { // ----- func.func @no_wrapper(%lb : index, %ub : index, %step : index) { - // expected-error @below {{op must be a loop wrapper}} + // expected-error @below {{op must be a valid loop wrapper}} omp.wsloop { %0 = arith.constant 0 : i32 omp.loop_nest (%iv) : index = (%lb) to (%ub) step (%step) { @@ -374,7 +322,7 @@ llvm.func @test_omp_wsloop_dynamic_wrong_modifier3(%lb : i64, %ub : i64, %step : // ----- func.func @omp_simd() -> () { - // expected-error @below {{op must be a loop wrapper}} + // expected-error @below {{op must be a valid loop wrapper}} omp.simd { omp.terminator } @@ -1939,7 +1887,7 @@ func.func @taskloop(%lb: i32, %ub: i32, %step: i32) { // ----- func.func @taskloop(%lb: i32, %ub: i32, %step: i32) { - // expected-error @below {{op must be a loop wrapper}} + // expected-error @below {{op must be a valid loop wrapper}} omp.taskloop { %0 = arith.constant 0 : i32 omp.terminator @@ -2148,7 +2096,7 @@ func.func @omp_distribute_allocate(%data_var : memref) -> () { // ----- func.func @omp_distribute_wrapper() -> () { - // expected-error @below {{op must be a loop wrapper}} + // expected-error @below {{op must be a valid loop wrapper}} omp.distribute { %0 = arith.constant 0 : i32 "omp.terminator"() : () -> () @@ -2157,8 +2105,8 @@ func.func @omp_distribute_wrapper() -> () { // ----- -func.func @omp_distribute_nested_wrapper(%lb: index, %ub: index, %step: index) -> () { - // expected-error @below {{only supported nested wrappers are 'omp.parallel' and 'omp.simd'}} +func.func @omp_distribute_nested_wrapper1(%lb: index, %ub: index, %step: index) -> () { + // expected-error @below {{an 'omp.wsloop' nested wrapper is only allowed when 'omp.parallel' is the direct parent}} omp.distribute { "omp.wsloop"() ({ omp.loop_nest (%iv) : index = (%lb) to (%ub) step (%step) { @@ -2172,6 +2120,21 @@ func.func @omp_distribute_nested_wrapper(%lb: index, %ub: index, %step: index) - // ----- +func.func @omp_distribute_nested_wrapper2(%lb: index, %ub: index, %step: index) -> () { + // expected-error @below {{only supported nested wrappers are 'omp.simd' and 'omp.wsloop'}} + omp.distribute { + "omp.taskloop"() ({ + omp.loop_nest (%iv) : index = (%lb) to (%ub) step (%step) { + "omp.yield"() : () -> () + } + "omp.terminator"() : () -> () + }) : () -> () + "omp.terminator"() : () -> () + } +} + +// ----- + func.func @omp_distribute_order() -> () { // expected-error @below {{invalid clause value: 'default'}} omp.distribute order(default) { diff --git a/mlir/test/Dialect/OpenMP/ops.mlir b/mlir/test/Dialect/OpenMP/ops.mlir index e49731c0d68301..92e42d31e90333 100644 --- a/mlir/test/Dialect/OpenMP/ops.mlir +++ b/mlir/test/Dialect/OpenMP/ops.mlir @@ -99,10 +99,11 @@ func.func @omp_parallel(%data_var : memref, %if_cond : i1, %num_threads : i omp.terminator }) {operandSegmentSizes = array} : (memref, memref) -> () - // CHECK: omp.distribute - omp.distribute { - // CHECK-NEXT: omp.parallel - omp.parallel { + // CHECK: omp.parallel + omp.parallel { + // CHECK-NOT: omp.terminator + // CHECK: omp.distribute + omp.distribute { // CHECK-NEXT: omp.wsloop omp.wsloop { // CHECK-NEXT: omp.loop_nest