From 9a95922fbdb98440390a342f95b44bb619394c08 Mon Sep 17 00:00:00 2001 From: Sergio Afonso Date: Wed, 7 Aug 2024 11:47:42 +0100 Subject: [PATCH] [Flang][OpenMP] Update DO CONCURRENT conversion for the device This patch makes changes to the `DoConcurrentConversion` pass to follow the "hoisted omp.parallel" representation when converting `do concurrent` constructs into `target teams distribute parallel do`. --- .../Transforms/DoConcurrentConversion.cpp | 23 +++++++------------ .../Transforms/DoConcurrent/basic_device.f90 | 3 +-- .../multiple_iteration_ranges.f90 | 5 ++-- .../DoConcurrent/not_perfectly_nested.f90 | 4 +++- .../DoConcurrent/skip_all_nested_loops.f90 | 4 +++- 5 files changed, 17 insertions(+), 22 deletions(-) diff --git a/flang/lib/Optimizer/Transforms/DoConcurrentConversion.cpp b/flang/lib/Optimizer/Transforms/DoConcurrentConversion.cpp index a5de824c6cc0f7..912d33e0e38e9d 100644 --- a/flang/lib/Optimizer/Transforms/DoConcurrentConversion.cpp +++ b/flang/lib/Optimizer/Transforms/DoConcurrentConversion.cpp @@ -599,9 +599,7 @@ class DoConcurrentConversion : public mlir::OpConversionPattern { targetOp = genTargetOp(doLoop.getLoc(), rewriter, mapper, outermostLoopLives, targetClauseOps); - genTeamsOp(doLoop.getLoc(), rewriter, loopNest, mapper, - loopNestClauseOps); - genDistributeOp(doLoop.getLoc(), rewriter); + genTeamsOp(doLoop.getLoc(), rewriter); } mlir::omp::ParallelOp parallelOp = genParallelOp( @@ -611,6 +609,9 @@ class DoConcurrentConversion : public mlir::OpConversionPattern { looputils::localizeLoopLocalValue(local, parallelOp.getRegion(), rewriter); + if (mapToDevice) + genDistributeOp(doLoop.getLoc(), rewriter); + mlir::omp::LoopNestOp ompLoopNest = genWsLoopOp(rewriter, loopNest.back().first, mapper, loopNestClauseOps); @@ -800,18 +801,14 @@ class DoConcurrentConversion : public mlir::OpConversionPattern { } mlir::omp::TeamsOp - genTeamsOp(mlir::Location loc, mlir::ConversionPatternRewriter &rewriter, - looputils::LoopNestToIndVarMap &loopNest, mlir::IRMapping &mapper, - mlir::omp::LoopNestOperands &loopNestClauseOps) const { + genTeamsOp(mlir::Location loc, + mlir::ConversionPatternRewriter &rewriter) const { auto teamsOp = rewriter.create( loc, /*clauses=*/mlir::omp::TeamsOperands{}); rewriter.createBlock(&teamsOp.getRegion()); rewriter.setInsertionPoint(rewriter.create(loc)); - genLoopNestIndVarAllocs(rewriter, loopNest, mapper); - genLoopNestClauseOps(loc, rewriter, loopNest, mapper, loopNestClauseOps); - return teamsOp; } @@ -905,12 +902,8 @@ class DoConcurrentConversion : public mlir::OpConversionPattern { rewriter.createBlock(¶llelOp.getRegion()); rewriter.setInsertionPoint(rewriter.create(loc)); - // If mapping to host, the local induction variable and loop bounds need to - // be emitted as part of the `omp.parallel` op. - if (!mapToDevice) { - genLoopNestIndVarAllocs(rewriter, loopNest, mapper); - genLoopNestClauseOps(loc, rewriter, loopNest, mapper, loopNestClauseOps); - } + genLoopNestIndVarAllocs(rewriter, loopNest, mapper); + genLoopNestClauseOps(loc, rewriter, loopNest, mapper, loopNestClauseOps); return parallelOp; } diff --git a/flang/test/Transforms/DoConcurrent/basic_device.f90 b/flang/test/Transforms/DoConcurrent/basic_device.f90 index b3d0f91ddd3e18..7873fa4f88db6d 100644 --- a/flang/test/Transforms/DoConcurrent/basic_device.f90 +++ b/flang/test/Transforms/DoConcurrent/basic_device.f90 @@ -43,6 +43,7 @@ program do_concurrent_basic ! CHECK: %[[A_DEV_DECL:.*]]:2 = hlfir.declare %[[A_ARG]] ! CHECK: omp.teams { + ! CHECK-NEXT: omp.parallel { ! CHECK-NEXT: %[[ITER_VAR:.*]] = fir.alloca i32 {bindc_name = "i"} ! CHECK-NEXT: %[[BINDING:.*]]:2 = hlfir.declare %[[ITER_VAR]] {uniq_name = "_QFEi"} : (!fir.ref) -> (!fir.ref, !fir.ref) @@ -54,8 +55,6 @@ program do_concurrent_basic ! CHECK: %[[STEP:.*]] = arith.constant 1 : index ! CHECK-NEXT: omp.distribute { - ! CHECK-NEXT: omp.parallel { - ! CHECK-NEXT: omp.wsloop { ! CHECK-NEXT: omp.loop_nest (%[[ARG0:.*]]) : index = (%[[LB]]) to (%[[UB]]) inclusive step (%[[STEP]]) { diff --git a/flang/test/Transforms/DoConcurrent/multiple_iteration_ranges.f90 b/flang/test/Transforms/DoConcurrent/multiple_iteration_ranges.f90 index a0364612976bcb..17cf27a9b70b27 100644 --- a/flang/test/Transforms/DoConcurrent/multiple_iteration_ranges.f90 +++ b/flang/test/Transforms/DoConcurrent/multiple_iteration_ranges.f90 @@ -65,7 +65,7 @@ program main ! DEVICE: omp.target ! DEVICE: omp.teams -! HOST: omp.parallel { +! COMMON: omp.parallel { ! COMMON-NEXT: %[[ITER_VAR_I:.*]] = fir.alloca i32 {bindc_name = "i"} ! COMMON-NEXT: %[[BINDING_I:.*]]:2 = hlfir.declare %[[ITER_VAR_I]] {uniq_name = "_QFEi"} @@ -94,8 +94,7 @@ program main ! COMMON: %[[UB_K:.*]] = fir.convert %[[C30]] : (i32) -> index ! COMMON: %[[STEP_K:.*]] = arith.constant 1 : index -! DEVICE: omp.distribute -! DEVICE-NEXT: omp.parallel +! DEVICE: omp.distribute ! COMMON: omp.wsloop { ! COMMON-NEXT: omp.loop_nest diff --git a/flang/test/Transforms/DoConcurrent/not_perfectly_nested.f90 b/flang/test/Transforms/DoConcurrent/not_perfectly_nested.f90 index 559d26c39cba55..f3f2e78f5b3183 100644 --- a/flang/test/Transforms/DoConcurrent/not_perfectly_nested.f90 +++ b/flang/test/Transforms/DoConcurrent/not_perfectly_nested.f90 @@ -39,9 +39,11 @@ program main ! DEVICE: %[[TARGET_K_DECL:.*]]:2 = hlfir.declare %[[K_ARG]] {uniq_name = "_QFEk"} ! DEVICE: omp.teams -! DEVICE: omp.distribute ! COMMON: omp.parallel { + +! DEVICE: omp.distribute + ! COMMON: omp.wsloop { ! COMMON: omp.loop_nest ({{[^[:space:]]+}}) {{.*}} { ! COMMON: fir.do_loop %[[J_IV:.*]] = {{.*}} { diff --git a/flang/test/Transforms/DoConcurrent/skip_all_nested_loops.f90 b/flang/test/Transforms/DoConcurrent/skip_all_nested_loops.f90 index 362b8685cfd15b..429500cead1073 100644 --- a/flang/test/Transforms/DoConcurrent/skip_all_nested_loops.f90 +++ b/flang/test/Transforms/DoConcurrent/skip_all_nested_loops.f90 @@ -39,9 +39,11 @@ program main ! DEVICE: %[[TARGET_K_DECL:.*]]:2 = hlfir.declare %[[K_ARG]] {uniq_name = "_QFEk"} ! DEVICE: omp.teams -! DEVICE: omp.distribute ! COMMON: omp.parallel { + +! DEVICE: omp.distribute + ! COMMON: omp.wsloop { ! COMMON: omp.loop_nest ({{[^[:space:]]+}}) {{.*}} { ! COMMON: fir.do_loop {{.*}} iter_args(%[[J_IV:.*]] = {{.*}}) -> {{.*}} {