Skip to content

Commit

Permalink
[OpenMP][flang][do-conc] Map values depending on values defined outsi…
Browse files Browse the repository at this point in the history
…de target region.

Adds support for `do concurrent` mapping when mapped value(s) depend on
values defined outside the target region; e.g. the size of the array is
dynamic. This needs to be handled by localizing these region outsiders
by either cloning them in the region or in case we cannot do that, map
them and use the mapped values.
  • Loading branch information
ergawy committed Aug 20, 2024
1 parent 5c0cc18 commit dde321c
Show file tree
Hide file tree
Showing 3 changed files with 126 additions and 8 deletions.
88 changes: 82 additions & 6 deletions flang/lib/Optimizer/Transforms/DoConcurrentConversion.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,67 @@ mlir::Value calculateTripCount(fir::FirOpBuilder &builder, mlir::Location loc,

return tripCount;
}

/// Check if cloning the bounds introduced any dependency on the outer region.
/// If so, then either clone them as well if they are MemoryEffectFree, or else
/// copy them to a new temporary and add them to the map and block_argument
/// lists and replace their uses with the new temporary.
///
/// TODO: similar to the above functions, this is copied from OpenMP lowering
/// (in this case, from `genBodyOfTargetOp`). Once we move to a common lib for
/// these utils this will move as well.
void cloneOrMapRegionOutsiders(fir::FirOpBuilder &builder,
mlir::omp::TargetOp targetOp) {
mlir::Region &region = targetOp.getRegion();
mlir::Block *regionBlock = &region.getBlocks().front();
llvm::SetVector<mlir::Value> valuesDefinedAbove;
mlir::getUsedValuesDefinedAbove(region, valuesDefinedAbove);

while (!valuesDefinedAbove.empty()) {
for (mlir::Value val : valuesDefinedAbove) {
mlir::Operation *valOp = val.getDefiningOp();
if (mlir::isMemoryEffectFree(valOp)) {
mlir::Operation *clonedOp = valOp->clone();
regionBlock->push_front(clonedOp);
val.replaceUsesWithIf(
clonedOp->getResult(0), [regionBlock](mlir::OpOperand &use) {
return use.getOwner()->getBlock() == regionBlock;
});
} else {
auto savedIP = builder.getInsertionPoint();
builder.setInsertionPointAfter(valOp);
auto copyVal = builder.createTemporary(val.getLoc(), val.getType());
builder.createStoreWithConvert(copyVal.getLoc(), val, copyVal);

llvm::SmallVector<mlir::Value> bounds;
std::stringstream name;
builder.setInsertionPoint(targetOp);
mlir::Value mapOp = createMapInfoOp(
builder, copyVal.getLoc(), copyVal,
/*varPtrPtr=*/mlir::Value{}, name.str(), bounds,
/*members=*/llvm::SmallVector<mlir::Value>{},
/*membersIndex=*/mlir::DenseIntElementsAttr{},
static_cast<
std::underlying_type_t<llvm::omp::OpenMPOffloadMappingFlags>>(
llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_IMPLICIT),
mlir::omp::VariableCaptureKind::ByCopy, copyVal.getType());
targetOp.getMapVarsMutable().append(mapOp);
mlir::Value clonedValArg =
region.addArgument(copyVal.getType(), copyVal.getLoc());
builder.setInsertionPointToStart(regionBlock);
auto loadOp =
builder.create<fir::LoadOp>(clonedValArg.getLoc(), clonedValArg);
val.replaceUsesWithIf(
loadOp->getResult(0), [regionBlock](mlir::OpOperand &use) {
return use.getOwner()->getBlock() == regionBlock;
});
builder.setInsertionPoint(regionBlock, savedIP);
}
}
valuesDefinedAbove.clear();
mlir::getUsedValuesDefinedAbove(region, valuesDefinedAbove);
}
}
} // namespace internal
} // namespace omp
} // namespace lower
Expand Down Expand Up @@ -717,14 +778,19 @@ class DoConcurrentConversion : public mlir::OpConversionPattern<fir::DoLoopOp> {
llvm::SmallVector<mlir::Value> boundsOps;
genBoundsOps(rewriter, liveIn.getLoc(), declareOp, boundsOps);

// Use the raw address to avoid unboxing `fir.box` values whenever possible.
// Put differently, we have access to the direct value memory
// reference/address, we use it.
mlir::Value rawAddr = declareOp.getOriginalBase();
return Fortran::lower::omp ::internal::createMapInfoOp(
rewriter, liveIn.getLoc(), declareOp.getBase(), /*varPtrPtr=*/{},
declareOp.getUniqName().str(), boundsOps, /*members=*/{},
rewriter, liveIn.getLoc(), rawAddr,
/*varPtrPtr=*/{}, declareOp.getUniqName().str(), boundsOps,
/*members=*/{},
/*membersIndex=*/mlir::DenseIntElementsAttr{},
static_cast<
std::underlying_type_t<llvm::omp::OpenMPOffloadMappingFlags>>(
mapFlag),
captureKind, liveInType);
captureKind, rawAddr.getType());
}

mlir::omp::TargetOp genTargetOp(mlir::Location loc,
Expand All @@ -751,14 +817,24 @@ class DoConcurrentConversion : public mlir::OpConversionPattern<fir::DoLoopOp> {
auto miOp = mlir::cast<mlir::omp::MapInfoOp>(mapInfoOp.getDefiningOp());
hlfir::DeclareOp liveInDeclare = genLiveInDeclare(rewriter, arg, miOp);
mlir::Value miOperand = miOp.getVariableOperand(0);
mapper.map(miOperand, liveInDeclare.getBase());

// TODO If `miOperand.getDefiningOp()` is a `fir::BoxAddrOp`, we probably
// need to "unpack" the box by getting the defining op of it's value.
// However, we did not hit this case in reality yet so leaving it as a
// todo for now.

mapper.map(miOperand, liveInDeclare.getOriginalBase());

if (auto origDeclareOp = mlir::dyn_cast_if_present<hlfir::DeclareOp>(
miOperand.getDefiningOp()))
mapper.map(origDeclareOp.getOriginalBase(),
liveInDeclare.getOriginalBase());
mapper.map(origDeclareOp.getBase(), liveInDeclare.getBase());
}

fir::FirOpBuilder firBuilder(
rewriter,
fir::getKindMapping(targetOp->getParentOfType<mlir::ModuleOp>()));
Fortran::lower::omp::internal::cloneOrMapRegionOutsiders(firBuilder,
targetOp);
rewriter.setInsertionPoint(
rewriter.create<mlir::omp::TerminatorOp>(targetOp.getLoc()));

Expand Down
4 changes: 2 additions & 2 deletions flang/test/Transforms/DoConcurrent/basic_device.f90
Original file line number Diff line number Diff line change
Expand Up @@ -21,15 +21,15 @@ program do_concurrent_basic

! CHECK-NOT: fir.do_loop

! CHECK-DAG: %[[I_MAP_INFO:.*]] = omp.map.info var_ptr(%[[I_ORIG_DECL]]#0
! CHECK-DAG: %[[I_MAP_INFO:.*]] = omp.map.info var_ptr(%[[I_ORIG_DECL]]#1
! CHECK: %[[C0:.*]] = arith.constant 0 : index
! CHECK: %[[UPPER_BOUND:.*]] = arith.subi %[[A_EXTENT]], %[[C0]] : index

! CHECK: %[[A_BOUNDS:.*]] = omp.map.bounds lower_bound(%[[C0]] : index)
! CHECK-SAME: upper_bound(%[[UPPER_BOUND]] : index)
! CHECK-SAME: extent(%[[A_EXTENT]] : index)

! CHECK-DAG: %[[A_MAP_INFO:.*]] = omp.map.info var_ptr(%[[A_ORIG_DECL]]#0 : {{[^(]+}})
! CHECK-DAG: %[[A_MAP_INFO:.*]] = omp.map.info var_ptr(%[[A_ORIG_DECL]]#1 : {{[^(]+}})
! CHECK-SAME: map_clauses(implicit, tofrom) capture(ByRef) bounds(%[[A_BOUNDS]])

! CHECK: %[[TRIP_COUNT:.*]] = arith.muli %{{.*}}, %{{.*}} : i64
Expand Down
42 changes: 42 additions & 0 deletions flang/test/Transforms/DoConcurrent/runtime_sized_array.f90
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
! Tests `do concurrent` mapping when mapped value(s) depend on values defined
! outside the target region; e.g. the size of the array is dynamic. This needs
! to be handled by localizing these region outsiders by either cloning them in
! the region or in case we cannot do that, map them and use the mapped values.

! RUN: %flang_fc1 -emit-hlfir -fopenmp -fdo-concurrent-parallel=device %s -o - \
! RUN: | FileCheck %s

subroutine foo(n)
implicit none
integer :: n
integer :: i
integer, dimension(n) :: a

do concurrent(i=1:10)
a(i) = i
end do
end subroutine

! CHECK-DAG: %[[I_DECL:.*]]:2 = hlfir.declare %{{.*}} {uniq_name = "_QFfooEi"}
! CHECK-DAG: %[[A_DECL:.*]]:2 = hlfir.declare %{{.*}}(%{{.*}}) {uniq_name = "_QFfooEa"}
! CHECK-DAG: %[[N_ALLOC:.*]] = fir.alloca i32

! CHECK-DAG: %[[I_MAP:.*]] = omp.map.info var_ptr(%[[I_DECL]]#1 : {{.*}})
! CHECK-DAG: %[[A_MAP:.*]] = omp.map.info var_ptr(%[[A_DECL]]#1 : {{.*}})
! CHECK-DAG: %[[N_MAP:.*]] = omp.map.info var_ptr(%[[N_ALLOC]] : {{.*}})

! CHECK: omp.target
! CHECK-SAME: map_entries(%[[I_MAP]] -> %[[I_ARG:arg[0-9]*]],
! CHECK-SAME: %[[A_MAP]] -> %[[A_ARG:arg[0-9]*]],
! CHECK-SAME: %[[N_MAP]] -> %[[N_ARG:arg[0-9]*]] : {{.*}})
! CHECK-SAME: {{.*}} {

! CHECK-DAG: %{{.*}} = hlfir.declare %[[I_ARG]]
! CHECK-DAG: %{{.*}} = hlfir.declare %[[A_ARG]]
! CHECK-DAG: %{{.*}} = fir.load %[[N_ARG]]

! CHECK: omp.terminator
! CHECK: }



0 comments on commit dde321c

Please sign in to comment.