-
Notifications
You must be signed in to change notification settings - Fork 643
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Move GPU ukernel selection to KernelConfig (#19440)
This moves the logic deciding whether an op should be a ukernel out of the GPULowerToUKernels pass, into KernelConfig. So KernelConfig decides whether the op should be a ukernel, and encodes that into the resulting `lowering_config`, in a new parameter, that is a new attribute, UKernelSpecAttr. That attribute is directly modeled after the equivalent C++ data structure that we have had in LowerToUKernels passes, `FnNameAndDefAttrs`, which it replaces. If the attribute is present, it means that the op was selected for ukernel lowering, with the fields telling the ukernel name and some function definition attributes (to import any dependencies, such as the `rocm` module for runtime support symbols). All the details about supplying the ukernel bitcode in a `hal.executable.object` are also moved there, becoming a side effect of `KernelConfig`. The GPULowerToUKernels becomes much simpler, since all the decision-making was already done for it. It just looks at the `LoweringConfigAttr` and if it's there, it performs the requested lowering. The motivation for this split is that we need to know in KernelConfig whether it's going to be a ukernel, because ops that will get lowered to a ukernel require a different configuration. The important example for us is `multi_mma`, which in the ukernel case needs to avoid reduction-dimension tiling to 1 so that the ukernel gets to see the reduction loop. A few simplifications arise already in the current argmax ukernel logic, confirming that this was the right design choice: the old ukernel's matching logic was checking that the distribution tile sizes matched what the ukernel could handle; now that is turned upside down: the ukernel matching happens as a helper within KernelConfig where we know we are setting the appropriate tile sizes on purpose. Another nice improvement is that this puts just enough distance between ukernel selection (which creates the `hal.executable.object`) and ukernel lowering, that we are able to insert `HoistExecutableObjectsPass` in between, simplifying the ukernel lowering as it doesn't need to worry anymore about preserving the `hal.executable.object`. --------- Signed-off-by: Benoit Jacob <jacob.benoit.1@gmail.com>
- Loading branch information
Showing
21 changed files
with
392 additions
and
304 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
30 changes: 30 additions & 0 deletions
30
compiler/plugins/target/ROCM/test/config_ukernel_argmax_gfx908.mlir
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,30 @@ | ||
// RUN: iree-opt --split-input-file --iree-gpu-test-target=gfx908 --pass-pipeline='builtin.module(iree-llvmgpu-select-lowering-strategy)' %s | FileCheck %s | ||
|
||
// gfx908 a.k.a. CDNA1 is used here as an example of a GPU target that we don't have ukernels for. | ||
// No need to add many ukernels here, just a quick check that we correctly do not select a ukernel. | ||
|
||
func.func @argmax_2d_f32i64(%arg0 : tensor<1x?xf32>) -> tensor<1xi64> attributes { | ||
hal.executable.target = #hal.executable.target<"rocm", "rocm-hsaco-fb", {ukernels = "all"}> | ||
} { | ||
%c0_i64 = arith.constant 0 : i64 | ||
%cst = arith.constant 0xFF800000 : f32 | ||
%0 = tensor.empty() : tensor<1xi64> | ||
%1 = linalg.fill ins(%c0_i64 : i64) outs(%0 : tensor<1xi64>) -> tensor<1xi64> | ||
%2 = tensor.empty() : tensor<1xf32> | ||
%3 = linalg.fill ins(%cst : f32) outs(%2 : tensor<1xf32>) -> tensor<1xf32> | ||
%4:2 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0)>, affine_map<(d0, d1) -> (d0)>], iterator_types = ["parallel", "reduction"]} ins(%arg0 : tensor<1x?xf32>) outs(%3, %1 : tensor<1xf32>, tensor<1xi64>) { | ||
^bb0(%in: f32, %out: f32, %out_0: i64): | ||
%5 = linalg.index 1 : index | ||
%6 = arith.index_cast %5 : index to i64 | ||
%7 = arith.maximumf %in, %out : f32 | ||
%8 = arith.cmpf ogt, %in, %out : f32 | ||
%9 = arith.select %8, %6, %out_0 : i64 | ||
linalg.yield %7, %9 : f32, i64 | ||
} -> (tensor<1xf32>, tensor<1xi64>) | ||
return %4#1 : tensor<1xi64> | ||
} | ||
|
||
// CHECK-NOT: lowering_config<{{.*}}ukernel | ||
// CHECK-LABEL: func @argmax_2d_f32i64( | ||
// CHECK: linalg.generic | ||
// CHECK-NOT: hal.executable.objects |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.