Skip to content

Commit

Permalink
[MLIR][Bufferization] BufferResultsToOutParams: Add option to add att…
Browse files Browse the repository at this point in the history
…ribute to output arguments

Adds a new pass option `add-result-attr` that will make the pass add the attribute
`{bufferize.result}` to each argument that was converted from a result.

To be able to test this, the pass option was added to the tablegen.
And then the existing manual option struct `BufferResultsToOutParamsOptions`
was renamed to `BufferResultsToOutParamsOpts` to not conflict.

Reviewers: TinaAMD

Reviewed By: TinaAMD

Pull Request: #116
  • Loading branch information
mgehre-amd authored Feb 29, 2024
1 parent ef0da4b commit ee4d46c
Show file tree
Hide file tree
Showing 4 changed files with 48 additions and 11 deletions.
10 changes: 7 additions & 3 deletions mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.h
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,7 @@ std::unique_ptr<Pass> createBufferLoopHoistingPass();

// Options struct for BufferResultsToOutParams pass.
// Note: defined only here, not in tablegen.
struct BufferResultsToOutParamsOptions {
struct BufferResultsToOutParamsOpts {
/// Memcpy function: Generate a memcpy between two memrefs.
using MemCpyFn =
std::function<LogicalResult(OpBuilder &, Location, Value, Value)>;
Expand All @@ -159,17 +159,21 @@ struct BufferResultsToOutParamsOptions {
};

std::optional<MemCpyFn> memCpyFn;

/// If true, the pass adds a "bufferize.result" attribute to each output
/// parameter.
bool addResultAttribute = false;
};

/// Creates a pass that converts memref function results to out-params.
std::unique_ptr<Pass> createBufferResultsToOutParamsPass(
const BufferResultsToOutParamsOptions &options = {});
const BufferResultsToOutParamsOpts &options = {});

/// Replace buffers that are returned from a function with an out parameter.
/// Also update all call sites.
LogicalResult
promoteBufferResultsToOutParams(ModuleOp module,
const BufferResultsToOutParamsOptions &options);
const BufferResultsToOutParamsOpts &options);

/// Creates a pass that drops memref function results that are equivalent to a
/// function argument.
Expand Down
5 changes: 5 additions & 0 deletions mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -316,6 +316,11 @@ def BufferResultsToOutParams : Pass<"buffer-results-to-out-params", "ModuleOp">
buffers for results need to be allocated in the caller. This currently only
works for static shaped memrefs.
}];
let options = [
Option<"addResultAttribute", "add-result-attr", "bool",
/*default=*/"false",
"Add the attribute 'bufferize.result' to all output parameters.">,
];
let constructor = "mlir::bufferization::createBufferResultsToOutParamsPass()";
let dependentDialects = ["memref::MemRefDialect"];
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ namespace bufferization {
} // namespace mlir

using namespace mlir;
using MemCpyFn = bufferization::BufferResultsToOutParamsOptions::MemCpyFn;
using MemCpyFn = bufferization::BufferResultsToOutParamsOpts::MemCpyFn;

/// Return `true` if the given MemRef type has a fully dynamic layout.
static bool hasFullyDynamicLayoutMap(MemRefType type) {
Expand All @@ -47,7 +47,8 @@ static bool hasStaticIdentityLayout(MemRefType type) {
// Any args appended to the entry block are added to `appendedEntryArgs`.
static LogicalResult
updateFuncOp(func::FuncOp func,
SmallVectorImpl<BlockArgument> &appendedEntryArgs) {
SmallVectorImpl<BlockArgument> &appendedEntryArgs,
bool addResultAttribute) {
auto functionType = func.getFunctionType();

// Collect information about the results will become appended arguments.
Expand Down Expand Up @@ -80,6 +81,10 @@ updateFuncOp(func::FuncOp func,
for (int i = 0, e = erasedResultTypes.size(); i < e; ++i, ++erasedIndicesIt) {
func.setArgAttrs(functionType.getNumInputs() + i,
func.getResultAttrs(*erasedIndicesIt));
if (addResultAttribute)
func.setArgAttr(functionType.getNumInputs() + i,
StringAttr::get(func.getContext(), "bufferize.result"),
UnitAttr::get(func.getContext()));
}

// Erase the results.
Expand Down Expand Up @@ -127,7 +132,7 @@ static LogicalResult updateReturnOps(func::FuncOp func,
// temporary buffers for newly introduced out params.
static LogicalResult
updateCalls(ModuleOp module,
const bufferization::BufferResultsToOutParamsOptions &options) {
const bufferization::BufferResultsToOutParamsOpts &options) {
bool didFail = false;
SymbolTable symtab(module);
module.walk([&](func::CallOp op) {
Expand Down Expand Up @@ -189,12 +194,13 @@ updateCalls(ModuleOp module,

LogicalResult mlir::bufferization::promoteBufferResultsToOutParams(
ModuleOp module,
const bufferization::BufferResultsToOutParamsOptions &options) {
const bufferization::BufferResultsToOutParamsOpts &options) {
for (auto func : module.getOps<func::FuncOp>()) {
if (!options.filterFn(&func))
continue;
SmallVector<BlockArgument, 6> appendedEntryArgs;
if (failed(updateFuncOp(func, appendedEntryArgs)))
if (failed(
updateFuncOp(func, appendedEntryArgs, options.addResultAttribute)))
return failure();
if (func.isExternal())
continue;
Expand All @@ -218,21 +224,25 @@ struct BufferResultsToOutParamsPass
: bufferization::impl::BufferResultsToOutParamsBase<
BufferResultsToOutParamsPass> {
explicit BufferResultsToOutParamsPass(
const bufferization::BufferResultsToOutParamsOptions &options)
const bufferization::BufferResultsToOutParamsOpts &options)
: options(options) {}

void runOnOperation() override {
// Convert from pass options in tablegen to BufferResultsToOutParamsOpts.
if (addResultAttribute)
options.addResultAttribute = true;

if (failed(bufferization::promoteBufferResultsToOutParams(getOperation(),
options)))
return signalPassFailure();
}

private:
bufferization::BufferResultsToOutParamsOptions options;
bufferization::BufferResultsToOutParamsOpts options;
};
} // namespace

std::unique_ptr<Pass> mlir::bufferization::createBufferResultsToOutParamsPass(
const bufferization::BufferResultsToOutParamsOptions &options) {
const bufferization::BufferResultsToOutParamsOpts &options) {
return std::make_unique<BufferResultsToOutParamsPass>(options);
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
// RUN: mlir-opt -p 'builtin.module(buffer-results-to-out-params{add-result-attr})' -split-input-file -verify-diagnostics %s | FileCheck %s

// CHECK-LABEL: basic
// CHECK-SAME: memref<f32> {bufferize.result})
func.func @basic() -> (memref<f32>) {
%0 = "test.source"() : () -> (memref<f32>)
return %0 : memref<f32>
}

// -----

// CHECK-LABEL: multiple_results
// CHECK-SAME: memref<1xf32> {bufferize.result},
// CHECK-SAME: memref<2xf32> {bufferize.result})
func.func @multiple_results() -> (memref<1xf32>, memref<2xf32>) {
%0, %1 = "test.source"() : () -> (memref<1xf32>, memref<2xf32>)
return %0, %1 : memref<1xf32>, memref<2xf32>
}

0 comments on commit ee4d46c

Please sign in to comment.