Skip to content

Commit

Permalink
[MLIR][OpenMP] Remove omp.parallel from loop wrapper operations
Browse files Browse the repository at this point in the history
This patch removes the `LoopWrapperInterface` from `omp.parallel` and updates
the semantics of the interface to make loop wrapper restrictions mandatory to
operations that have it, rather than a role they might optionally take.

MLIR operation verifiers are updated to expect the "hoisted omp.parallel"
representation for `distribute parallel do`, to be later implemented in place
of a loop wrapper `omp.parallel`.
  • Loading branch information
skatrak committed Aug 13, 2024
1 parent 4f5f002 commit 3abba41
Show file tree
Hide file tree
Showing 7 changed files with 62 additions and 116 deletions.
4 changes: 1 addition & 3 deletions flang/lib/Lower/OpenMP/DataSharingProcessor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -235,9 +235,7 @@ void DataSharingProcessor::insertBarrier() {
void DataSharingProcessor::insertLastPrivateCompare(mlir::Operation *op) {
mlir::omp::LoopNestOp loopOp;
if (auto wrapper = mlir::dyn_cast<mlir::omp::LoopWrapperInterface>(op))
loopOp = wrapper.isWrapper()
? mlir::cast<mlir::omp::LoopNestOp>(wrapper.getWrappedLoop())
: nullptr;
loopOp = mlir::cast<mlir::omp::LoopNestOp>(wrapper.getWrappedLoop());

bool cmpCreated = false;
mlir::OpBuilder::InsertionGuard guard(firOpBuilder);
Expand Down
1 change: 0 additions & 1 deletion mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,6 @@ def PrivateClauseOp : OpenMP_Op<"private", [IsolatedFromAbove, RecipeInterface]>

def ParallelOp : OpenMP_Op<"parallel", traits = [
AttrSizedOperandSegments, AutomaticAllocationScope,
DeclareOpInterfaceMethods<LoopWrapperInterface>,
DeclareOpInterfaceMethods<OutlineableOpenMPOpInterface>,
RecursiveMemoryEffects
], clauses = [
Expand Down
26 changes: 13 additions & 13 deletions mlir/include/mlir/Dialect/OpenMP/OpenMPOpsInterfaces.td
Original file line number Diff line number Diff line change
Expand Up @@ -71,24 +71,24 @@ def ReductionClauseInterface : OpInterface<"ReductionClauseInterface"> {

def LoopWrapperInterface : OpInterface<"LoopWrapperInterface"> {
let description = [{
OpenMP operations that can wrap a single loop nest. When taking a wrapper
role, these operations must only contain a single region with a single block
in which there's a single operation and a terminator. That nested operation
must be another loop wrapper or an `omp.loop_nest`.
OpenMP operations that wrap a single loop nest. They must only contain a
single region with a single block in which there's a single operation and a
terminator. That nested operation must be another loop wrapper or an
`omp.loop_nest`.
}];

let cppNamespace = "::mlir::omp";

let methods = [
InterfaceMethod<
/*description=*/[{
Tell whether the operation could be taking the role of a loop wrapper.
That is, it has a single region with a single block in which there are
two operations: another wrapper (also taking a loop wrapper role) or
Check whether the operation is a valid loop wrapper. That is, it has a
single region with a single block in which there are two operations:
another loop wrapper (also taking a loop wrapper role) or
`omp.loop_nest` operation and a terminator.
}],
/*retTy=*/"bool",
/*methodName=*/"isWrapper",
/*methodName=*/"isValidWrapper",
(ins ), [{}], [{
if ($_op->getNumRegions() != 1)
return false;
Expand All @@ -107,33 +107,33 @@ def LoopWrapperInterface : OpInterface<"LoopWrapperInterface"> {
return false;

if (auto wrapper = ::llvm::dyn_cast<LoopWrapperInterface>(firstOp))
return wrapper.isWrapper();
return wrapper.isValidWrapper();

return ::llvm::isa<LoopNestOp>(firstOp);
}]
>,
InterfaceMethod<
/*description=*/[{
If there is another loop wrapper immediately nested inside, return that
operation. Assumes this operation is taking a loop wrapper role.
operation. Assumes this operation is a valid loop wrapper.
}],
/*retTy=*/"::mlir::omp::LoopWrapperInterface",
/*methodName=*/"getNestedWrapper",
(ins), [{}], [{
assert($_op.isWrapper() && "Unexpected non-wrapper op");
assert($_op.isValidWrapper() && "Unexpected non-wrapper op");
Operation *nested = &*$_op->getRegion(0).op_begin();
return ::llvm::dyn_cast<LoopWrapperInterface>(nested);
}]
>,
InterfaceMethod<
/*description=*/[{
Return the loop nest nested directly or indirectly inside of this loop
wrapper. Assumes this operation is taking a loop wrapper role.
wrapper. Assumes this operation is a valid loop wrapper.
}],
/*retTy=*/"::mlir::Operation *",
/*methodName=*/"getWrappedLoop",
(ins), [{}], [{
assert($_op.isWrapper() && "Unexpected non-wrapper op");
assert($_op.isValidWrapper() && "Unexpected non-wrapper op");
if (LoopWrapperInterface nested = $_op.getNestedWrapper())
return nested.getWrappedLoop();
return &*$_op->getRegion(0).op_begin();
Expand Down
56 changes: 21 additions & 35 deletions mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1604,15 +1604,15 @@ bool TargetOp::isTargetSPMDLoop() {
if (!isa_and_present<WsloopOp>(workshareOp))
return false;

Operation *parallelOp = workshareOp->getParentOp();
if (!isa_and_present<ParallelOp>(parallelOp))
Operation *distributeOp = workshareOp->getParentOp();
if (!isa_and_present<DistributeOp>(distributeOp))
return false;

Operation *distributeOp = parallelOp->getParentOp();
if (!isa_and_present<DistributeOp>(distributeOp))
Operation *parallelOp = distributeOp->getParentOp();
if (!isa_and_present<ParallelOp>(parallelOp))
return false;

Operation *teamsOp = distributeOp->getParentOp();
Operation *teamsOp = parallelOp->getParentOp();
if (!isa_and_present<TeamsOp>(teamsOp))
return false;

Expand Down Expand Up @@ -1690,22 +1690,6 @@ static LogicalResult verifyPrivateVarList(OpType &op) {
}

LogicalResult ParallelOp::verify() {
// Check that it is a valid loop wrapper if it's taking that role.
if (isa<DistributeOp>((*this)->getParentOp())) {
if (!isWrapper())
return emitOpError() << "must take a loop wrapper role if nested inside "
"of 'omp.distribute'";

if (LoopWrapperInterface nested = getNestedWrapper()) {
// Check for the allowed leaf constructs that may appear in a composite
// construct directly after PARALLEL.
if (!isa<WsloopOp>(nested))
return emitError() << "only supported nested wrapper is 'omp.wsloop'";
} else {
return emitOpError() << "must not wrap an 'omp.loop_nest' directly";
}
}

if (getAllocateVars().size() != getAllocatorVars().size())
return emitError(
"expected equal sizes for allocate and allocator variables");
Expand Down Expand Up @@ -1894,8 +1878,8 @@ void WsloopOp::build(OpBuilder &builder, OperationState &state,
}

LogicalResult WsloopOp::verify() {
if (!isWrapper())
return emitOpError() << "must be a loop wrapper";
if (!isValidWrapper())
return emitOpError() << "must be a valid loop wrapper";

if (LoopWrapperInterface nested = getNestedWrapper()) {
// Check for the allowed leaf constructs that may appear in a composite
Expand Down Expand Up @@ -1939,8 +1923,8 @@ LogicalResult SimdOp::verify() {
if (verifyNontemporalClause(*this, getNontemporalVars()).failed())
return failure();

if (!isWrapper())
return emitOpError() << "must be a loop wrapper";
if (!isValidWrapper())
return emitOpError() << "must be a valid loop wrapper";

if (getNestedWrapper())
return emitOpError() << "must wrap an 'omp.loop_nest' directly";
Expand Down Expand Up @@ -1970,15 +1954,19 @@ LogicalResult DistributeOp::verify() {
return emitError(
"expected equal sizes for allocate and allocator variables");

if (!isWrapper())
return emitOpError() << "must be a loop wrapper";
if (!isValidWrapper())
return emitOpError() << "must be a valid loop wrapper";

if (LoopWrapperInterface nested = getNestedWrapper()) {
// Check for the allowed leaf constructs that may appear in a composite
// construct directly after DISTRIBUTE.
if (!isa<ParallelOp, SimdOp>(nested))
return emitError() << "only supported nested wrappers are 'omp.parallel' "
"and 'omp.simd'";
if (isa<WsloopOp>(nested)) {
if (!llvm::dyn_cast_if_present<ParallelOp>((*this)->getParentOp()))
return emitError() << "an 'omp.wsloop' nested wrapper is only allowed "
"when 'omp.parallel' is the direct parent";
} else if (!isa<SimdOp>(nested))
return emitError() << "only supported nested wrappers are 'omp.simd' and "
"'omp.wsloop'";
}

return success();
Expand Down Expand Up @@ -2176,8 +2164,8 @@ LogicalResult TaskloopOp::verify() {
"may not appear on the same taskloop directive");
}

if (!isWrapper())
return emitOpError() << "must be a loop wrapper";
if (!isValidWrapper())
return emitOpError() << "must be a valid loop wrapper";

if (LoopWrapperInterface nested = getNestedWrapper()) {
// Check for the allowed leaf constructs that may appear in a composite
Expand Down Expand Up @@ -2269,7 +2257,7 @@ LogicalResult LoopNestOp::verify() {
auto wrapper =
llvm::dyn_cast_if_present<LoopWrapperInterface>((*this)->getParentOp());

if (!wrapper || !wrapper.isWrapper())
if (!wrapper || !wrapper.isValidWrapper())
return emitOpError() << "expects parent op to be a valid loop wrapper";

return success();
Expand All @@ -2280,8 +2268,6 @@ void LoopNestOp::gatherWrappers(
Operation *parent = (*this)->getParentOp();
while (auto wrapper =
llvm::dyn_cast_if_present<LoopWrapperInterface>(parent)) {
if (!wrapper.isWrapper())
break;
wrappers.push_back(wrapper);
parent = parent->getParentOp();
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3190,8 +3190,7 @@ static LogicalResult convertOmpDistribute(
llvm::OpenMPIRBuilder::InsertPointTy *redAllocaIP,
SmallVector<llvm::OpenMPIRBuilder::ReductionInfo> &reductionInfos) {
llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
// FIXME: This ignores any other nested wrappers (e.g. omp.parallel +
// omp.wsloop, omp.simd).
// FIXME: This ignores any other nested wrappers (e.g. omp.wsloop, omp.simd).
auto distributeOp = cast<omp::DistributeOp>(opInst);
auto loopOp = cast<omp::LoopNestOp>(distributeOp.getWrappedLoop());

Expand Down
79 changes: 21 additions & 58 deletions mlir/test/Dialect/OpenMP/invalid.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -10,58 +10,6 @@ func.func @unknown_clause() {

// -----

func.func @not_wrapper() {
// expected-error@+1 {{op must be a loop wrapper}}
omp.distribute {
omp.parallel {
%0 = arith.constant 0 : i32
omp.terminator
}
omp.terminator
}

return
}

// -----

func.func @invalid_nested_wrapper(%lb : index, %ub : index, %step : index) {
omp.distribute {
// expected-error@+1 {{only supported nested wrapper is 'omp.wsloop'}}
omp.parallel {
omp.simd {
omp.loop_nest (%iv) : index = (%lb) to (%ub) step (%step) {
omp.yield
}
omp.terminator
}
omp.terminator
}
omp.terminator
}

return
}

// -----

func.func @no_nested_wrapper(%lb : index, %ub : index, %step : index) {
omp.distribute {
// expected-error@+1 {{op must not wrap an 'omp.loop_nest' directly}}
omp.parallel {
omp.loop_nest (%iv) : index = (%lb) to (%ub) step (%step) {
omp.yield
}
omp.terminator
}
omp.terminator
}

return
}

// -----

func.func @if_once(%n : i1) {
// expected-error@+1 {{`if` clause can appear at most once in the expansion of the oilist directive}}
omp.parallel if(%n) if(%n) {
Expand Down Expand Up @@ -188,7 +136,7 @@ func.func @iv_number_mismatch(%lb : index, %ub : index, %step : index) {
// -----

func.func @no_wrapper(%lb : index, %ub : index, %step : index) {
// expected-error @below {{op must be a loop wrapper}}
// expected-error @below {{op must be a valid loop wrapper}}
omp.wsloop {
%0 = arith.constant 0 : i32
omp.loop_nest (%iv) : index = (%lb) to (%ub) step (%step) {
Expand Down Expand Up @@ -374,7 +322,7 @@ llvm.func @test_omp_wsloop_dynamic_wrong_modifier3(%lb : i64, %ub : i64, %step :
// -----

func.func @omp_simd() -> () {
// expected-error @below {{op must be a loop wrapper}}
// expected-error @below {{op must be a valid loop wrapper}}
omp.simd {
omp.terminator
}
Expand Down Expand Up @@ -1939,7 +1887,7 @@ func.func @taskloop(%lb: i32, %ub: i32, %step: i32) {
// -----

func.func @taskloop(%lb: i32, %ub: i32, %step: i32) {
// expected-error @below {{op must be a loop wrapper}}
// expected-error @below {{op must be a valid loop wrapper}}
omp.taskloop {
%0 = arith.constant 0 : i32
omp.terminator
Expand Down Expand Up @@ -2148,7 +2096,7 @@ func.func @omp_distribute_allocate(%data_var : memref<i32>) -> () {
// -----

func.func @omp_distribute_wrapper() -> () {
// expected-error @below {{op must be a loop wrapper}}
// expected-error @below {{op must be a valid loop wrapper}}
omp.distribute {
%0 = arith.constant 0 : i32
"omp.terminator"() : () -> ()
Expand All @@ -2157,8 +2105,8 @@ func.func @omp_distribute_wrapper() -> () {

// -----

func.func @omp_distribute_nested_wrapper(%lb: index, %ub: index, %step: index) -> () {
// expected-error @below {{only supported nested wrappers are 'omp.parallel' and 'omp.simd'}}
func.func @omp_distribute_nested_wrapper1(%lb: index, %ub: index, %step: index) -> () {
// expected-error @below {{an 'omp.wsloop' nested wrapper is only allowed when 'omp.parallel' is the direct parent}}
omp.distribute {
"omp.wsloop"() ({
omp.loop_nest (%iv) : index = (%lb) to (%ub) step (%step) {
Expand All @@ -2172,6 +2120,21 @@ func.func @omp_distribute_nested_wrapper(%lb: index, %ub: index, %step: index) -

// -----

func.func @omp_distribute_nested_wrapper2(%lb: index, %ub: index, %step: index) -> () {
// expected-error @below {{only supported nested wrappers are 'omp.simd' and 'omp.wsloop'}}
omp.distribute {
"omp.taskloop"() ({
omp.loop_nest (%iv) : index = (%lb) to (%ub) step (%step) {
"omp.yield"() : () -> ()
}
"omp.terminator"() : () -> ()
}) : () -> ()
"omp.terminator"() : () -> ()
}
}

// -----

func.func @omp_distribute_order() -> () {
// expected-error @below {{invalid clause value: 'default'}}
omp.distribute order(default) {
Expand Down
9 changes: 5 additions & 4 deletions mlir/test/Dialect/OpenMP/ops.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -99,10 +99,11 @@ func.func @omp_parallel(%data_var : memref<i32>, %if_cond : i1, %num_threads : i
omp.terminator
}) {operandSegmentSizes = array<i32: 1,1,0,0,0,0>} : (memref<i32>, memref<i32>) -> ()

// CHECK: omp.distribute
omp.distribute {
// CHECK-NEXT: omp.parallel
omp.parallel {
// CHECK: omp.parallel
omp.parallel {
// CHECK-NOT: omp.terminator
// CHECK: omp.distribute
omp.distribute {
// CHECK-NEXT: omp.wsloop
omp.wsloop {
// CHECK-NEXT: omp.loop_nest
Expand Down

0 comments on commit 3abba41

Please sign in to comment.