Skip to content

Commit

Permalink
[MLIR][OpenMP] Introduce host_eval clause to omp.target
Browse files Browse the repository at this point in the history
This patch defines a map-like clause named `host_eval` used to capture host
values for use inside of target regions on restricted cases:
  - As `num_teams` or `thread_limit` of a nested `omp.target` operation.
  - As `num_threads` of a nested `omp.parallel` operation or as bounds or steps
of a nested `omp.loop_nest`, if it is a target SPMD kernel.

This replaces the following `omp.target` arguments: `trip_count`,
`num_threads`, `num_teams_lower`, `num_teams_upper` and `teams_thread_limit`.
  • Loading branch information
skatrak committed Oct 23, 2024
1 parent b29b413 commit ad391a1
Show file tree
Hide file tree
Showing 7 changed files with 296 additions and 60 deletions.
58 changes: 57 additions & 1 deletion mlir/docs/Dialects/OpenMPDialect/_index.md
Original file line number Diff line number Diff line change
Expand Up @@ -297,7 +297,8 @@ arguments for the region of that MLIR operation. This enables, for example, the
introduction of private copies of the same underlying variable defined outside
the MLIR operation the clause is attached to. Currently, clauses with this
property can be classified into three main categories:
- Map-like clauses: `map`, `use_device_addr` and `use_device_ptr`.
- Map-like clauses: `host_eval`, `map`, `use_device_addr` and
`use_device_ptr`.
- Reduction-like clauses: `in_reduction`, `reduction` and `task_reduction`.
- Privatization clauses: `private`.

Expand Down Expand Up @@ -522,3 +523,58 @@ omp.parallel ... {
omp.terminator
} {omp.composite}
```

## Host-Evaluated Clauses in Target Regions

The `omp.target` operation, which represents the OpenMP `target` construct, is
marked with the `IsolatedFromAbove` trait. This means that, inside of its
region, no MLIR values defined outside of the op itself can be used. This is
consistent with the OpenMP specification of the `target` construct, which
mandates that all host device values used inside of the `target` region must
either be privatized (data-sharing) or mapped (data-mapping).

Normally, clauses applied to a construct are evaluated before entering that
construct. Further, in some cases, the OpenMP specification stipulates that
clauses be evaluated _on the host device_ on entry to a parent `target`
construct. In particular, the `num_teams` and `thread_limit` clauses of the
`teams` construct must be evaluated on the host device if it's nested inside or
combined with a `target` construct.

Additionally, the runtime library targeted by the MLIR to LLVM IR translation of
the OpenMP dialect supports the optimized launch of SPMD kernels (i.e.
`target teams distribute parallel {do,for}` in OpenMP), which requires
specifying in advance what the total trip count of the loop is. Consequently, it
is also beneficial to evaluate the trip count on the host device prior to the
kernel launch.

These host-evaluated values in MLIR would need to be placed outside of the
`omp.target` region and also attached to the corresponding nested operations,
which is not possible because of the `IsolatedFromAbove` trait. The solution
implemented to address this problem has been to introduce the `host_eval`
argument to the `omp.target` operation. It works similarly to a `map` clause,
but its only intended use is to forward host-evaluated values to their
corresponding operation inside of the region. Any uses outside of the previously
described result in a verifier error.

```mlir
// Initialize %0, %1, %2, %3...
omp.target host_eval(%0 -> %nt, %1 -> %lb, %2 -> %ub, %3 -> %step : i32, i32, i32, i32) {
omp.teams num_teams(to %nt : i32) {
omp.parallel {
omp.distribute {
omp.wsloop {
omp.loop_nest (%iv) : i32 = (%lb) to (%ub) step (%step) {
// ...
omp.yield
}
omp.terminator
} {omp.composite}
omp.terminator
} {omp.composite}
omp.terminator
} {omp.composite}
omp.terminator
}
omp.terminator
}
```
38 changes: 38 additions & 0 deletions mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td
Original file line number Diff line number Diff line change
Expand Up @@ -419,6 +419,44 @@ class OpenMP_HintClauseSkip<

def OpenMP_HintClause : OpenMP_HintClauseSkip<>;

//===----------------------------------------------------------------------===//
// Not in the spec: Clause-like structure to hold host-evaluated values.
//===----------------------------------------------------------------------===//

class OpenMP_HostEvalClauseSkip<
bit traits = false, bit arguments = false, bit assemblyFormat = false,
bit description = false, bit extraClassDeclaration = false
> : OpenMP_Clause<traits, arguments, assemblyFormat, description,
extraClassDeclaration> {
let traits = [
BlockArgOpenMPOpInterface
];

let arguments = (ins
Variadic<AnyType>:$host_eval_vars
);

let extraClassDeclaration = [{
unsigned numHostEvalBlockArgs() {
return getHostEvalVars().size();
}
}];

let description = [{
The optional `host_eval_vars` holds values defined outside of the region of
the `IsolatedFromAbove` operation for which a corresponding entry block
argument is defined. The only legal uses for these captured values are the
following:
- `num_teams` or `thread_limit` clause of an immediately nested
`omp.teams` operation.
- If the operation is the top-level `omp.target` of a target SPMD kernel:
- `num_threads` clause of the nested `omp.parallel` operation.
- Bounds and steps of the nested `omp.loop_nest` operation.
}];
}

def OpenMP_HostEvalClause : OpenMP_HostEvalClauseSkip<>;

//===----------------------------------------------------------------------===//
// V5.2: [3.4] `if` clause
//===----------------------------------------------------------------------===//
Expand Down
31 changes: 10 additions & 21 deletions mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -1116,20 +1116,16 @@ def TargetUpdateOp: OpenMP_Op<"target_update", traits = [
// 2.14.5 target construct
//===----------------------------------------------------------------------===//

// TODO: Remove num_threads, teams_thread_limit and trip_count and implement the
// passthrough approach described here:
// https://discourse.llvm.org/t/rfc-openmp-dialect-representation-of-num-teams-thread-limit-and-target-spmd/81106.
def TargetOp : OpenMP_Op<"target", traits = [
AttrSizedOperandSegments, BlockArgOpenMPOpInterface, IsolatedFromAbove,
OutlineableOpenMPOpInterface
], clauses = [
// TODO: Complete clause list (defaultmap, uses_allocators).
OpenMP_AllocateClause, OpenMP_DependClause, OpenMP_DeviceClause,
OpenMP_HasDeviceAddrClause, OpenMP_IfClause, OpenMP_InReductionClause,
OpenMP_IsDevicePtrClause, OpenMP_MapClauseSkip<assemblyFormat = true>,
OpenMP_NowaitClause, OpenMP_NumTeamsClauseSkip<description = true>,
OpenMP_NumThreadsClauseSkip<description = true>, OpenMP_PrivateClause,
OpenMP_ThreadLimitClause
OpenMP_HasDeviceAddrClause, OpenMP_HostEvalClause, OpenMP_IfClause,
OpenMP_InReductionClause, OpenMP_IsDevicePtrClause,
OpenMP_MapClauseSkip<assemblyFormat = true>, OpenMP_NowaitClause,
OpenMP_PrivateClause, OpenMP_ThreadLimitClause
], singleRegion = true> {
let summary = "target construct";
let description = [{
Expand All @@ -1156,10 +1152,6 @@ def TargetOp : OpenMP_Op<"target", traits = [
an `omp.parallel`.
}] # clausesDescription;

let arguments = !con(clausesArgs,
(ins Optional<AnyInteger>:$trip_count,
Optional<AnyInteger>:$teams_thread_limit));

let builders = [
OpBuilder<(ins CArg<"const TargetOperands &">:$clauses)>
];
Expand All @@ -1184,15 +1176,12 @@ def TargetOp : OpenMP_Op<"target", traits = [
bool isTargetSPMDLoop();
}] # clausesExtraClassDeclaration;

let assemblyFormat = clausesReqAssemblyFormat #
" oilist(" # clausesOptAssemblyFormat # [{
| `trip_count` `(` $trip_count `:` type($trip_count) `)`
| `teams_thread_limit` `(` $teams_thread_limit `:` type($teams_thread_limit) `)`
}] # ")" # [{
custom<InReductionMapPrivateRegion>(
$region, $in_reduction_vars, type($in_reduction_vars),
$in_reduction_byref, $in_reduction_syms, $map_vars, type($map_vars),
$private_vars, type($private_vars), $private_syms) attr-dict
let assemblyFormat = clausesAssemblyFormat # [{
custom<HostEvalInReductionMapPrivateRegion>(
$region, $host_eval_vars, type($host_eval_vars), $in_reduction_vars,
type($in_reduction_vars), $in_reduction_byref, $in_reduction_syms,
$map_vars, type($map_vars), $private_vars, type($private_vars),
$private_syms) attr-dict
}];

let hasVerifier = 1;
Expand Down
27 changes: 22 additions & 5 deletions mlir/include/mlir/Dialect/OpenMP/OpenMPOpsInterfaces.td
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,10 @@ def BlockArgOpenMPOpInterface : OpInterface<"BlockArgOpenMPOpInterface"> {

let methods = [
// Default-implemented methods to be overriden by the corresponding clauses.
InterfaceMethod<"Get number of block arguments defined by `host_eval`.",
"unsigned", "numHostEvalBlockArgs", (ins), [{}], [{
return 0;
}]>,
InterfaceMethod<"Get number of block arguments defined by `in_reduction`.",
"unsigned", "numInReductionBlockArgs", (ins), [{}], [{
return 0;
Expand Down Expand Up @@ -55,9 +59,14 @@ def BlockArgOpenMPOpInterface : OpInterface<"BlockArgOpenMPOpInterface"> {
}]>,

// Unified access methods for clause-associated entry block arguments.
InterfaceMethod<"Get start index of block arguments defined by `host_eval`.",
"unsigned", "getHostEvalBlockArgsStart", (ins), [{
return 0;
}]>,
InterfaceMethod<"Get start index of block arguments defined by `in_reduction`.",
"unsigned", "getInReductionBlockArgsStart", (ins), [{
return 0;
auto iface = ::llvm::cast<BlockArgOpenMPOpInterface>(*$_op);
return iface.getHostEvalBlockArgsStart() + $_op.numHostEvalBlockArgs();
}]>,
InterfaceMethod<"Get start index of block arguments defined by `map`.",
"unsigned", "getMapBlockArgsStart", (ins), [{
Expand Down Expand Up @@ -91,6 +100,13 @@ def BlockArgOpenMPOpInterface : OpInterface<"BlockArgOpenMPOpInterface"> {
return iface.getUseDeviceAddrBlockArgsStart() + $_op.numUseDeviceAddrBlockArgs();
}]>,

InterfaceMethod<"Get block arguments defined by `host_eval`.",
"::llvm::MutableArrayRef<::mlir::BlockArgument>",
"getHostEvalBlockArgs", (ins), [{
auto iface = ::llvm::cast<BlockArgOpenMPOpInterface>(*$_op);
return $_op->getRegion(0).getArguments().slice(
iface.getHostEvalBlockArgsStart(), $_op.numHostEvalBlockArgs());
}]>,
InterfaceMethod<"Get block arguments defined by `in_reduction`.",
"::llvm::MutableArrayRef<::mlir::BlockArgument>",
"getInReductionBlockArgs", (ins), [{
Expand Down Expand Up @@ -147,10 +163,11 @@ def BlockArgOpenMPOpInterface : OpInterface<"BlockArgOpenMPOpInterface"> {

let verify = [{
auto iface = ::llvm::cast<BlockArgOpenMPOpInterface>($_op);
unsigned expectedArgs = iface.numInReductionBlockArgs() +
iface.numMapBlockArgs() + iface.numPrivateBlockArgs() +
iface.numReductionBlockArgs() + iface.numTaskReductionBlockArgs() +
iface.numUseDeviceAddrBlockArgs() + iface.numUseDevicePtrBlockArgs();
unsigned expectedArgs = iface.numHostEvalBlockArgs() +
iface.numInReductionBlockArgs() + iface.numMapBlockArgs() +
iface.numPrivateBlockArgs() + iface.numReductionBlockArgs() +
iface.numTaskReductionBlockArgs() + iface.numUseDeviceAddrBlockArgs() +
iface.numUseDevicePtrBlockArgs();
if ($_op->getRegion(0).getNumArguments() < expectedArgs)
return $_op->emitOpError() << "expected at least " << expectedArgs
<< " entry block argument(s)";
Expand Down
Loading

0 comments on commit ad391a1

Please sign in to comment.