forked from llvm/llvm-project
-
Notifications
You must be signed in to change notification settings - Fork 55
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[mlir][MLProgram] Add MLProgram to MemRef bufferization pass (llvm#75103
) There is currently no lowering out of `ml_program` in the LLVM repository. This change adds a lowering to `memref` so that it can be lowered all the way to LLVM. This lowering was taken from the [reference backend in torch-mlir](llvm/torch-mlir@f416953 ). I had tried implementing the `BufferizableOpInterface` for `ml_program` instead of adding a new pass but that did not work because `OneShotBufferize` does not visit module-level ops like `ml_program.global`.
- Loading branch information
1 parent
fc7c79b
commit fa10121
Showing
7 changed files
with
239 additions
and
1 deletion.
There are no files selected for viewing
20 changes: 20 additions & 0 deletions
20
mlir/include/mlir/Dialect/MLProgram/Transforms/BufferizableOpInterfaceImpl.h
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,20 @@ | ||
//===- BufferizableOpInterfaceImpl.h - Impl. of BufferizableOpInterface ---===// | ||
// | ||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. | ||
// See https://llvm.org/LICENSE.txt for license information. | ||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception | ||
// | ||
//===----------------------------------------------------------------------===// | ||
|
||
#ifndef MLIR_DIALECT_MLPROGRAM_BUFFERIZABLEOPINTERFACEIMPL_H | ||
#define MLIR_DIALECT_MLPROGRAM_BUFFERIZABLEOPINTERFACEIMPL_H | ||
|
||
namespace mlir { | ||
class DialectRegistry; | ||
|
||
namespace ml_program { | ||
void registerBufferizableOpInterfaceExternalModels(DialectRegistry ®istry); | ||
} // namespace ml_program | ||
} // namespace mlir | ||
|
||
#endif // MLIR_DIALECT_MLPROGRAM_BUFFERIZABLEOPINTERFACEIMPL_H |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
159 changes: 159 additions & 0 deletions
159
mlir/lib/Dialect/MLProgram/Transforms/BufferizableOpInterfaceImpl.cpp
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,159 @@ | ||
//===- BufferizableOpInterfaceImpl.cpp - Impl. of BufferizableOpInterface -===// | ||
// | ||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. | ||
// See https://llvm.org/LICENSE.txt for license information. | ||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception | ||
// | ||
//===----------------------------------------------------------------------===// | ||
|
||
#include "mlir/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.h" | ||
|
||
#include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h" | ||
#include "mlir/Dialect/MLProgram/IR/MLProgram.h" | ||
#include "mlir/Dialect/MemRef/IR/MemRef.h" | ||
|
||
using namespace mlir; | ||
using namespace mlir::bufferization; | ||
using namespace mlir::ml_program; | ||
|
||
namespace mlir { | ||
namespace ml_program { | ||
namespace { | ||
|
||
template <typename Interface, typename Op> | ||
struct ExternalModelBase | ||
: public BufferizableOpInterface::ExternalModel<Interface, Op> { | ||
|
||
AliasingValueList getAliasingValues(Operation *, OpOperand &, | ||
const AnalysisState &) const { | ||
return {}; | ||
} | ||
|
||
BufferRelation bufferRelation(Operation *, OpResult, | ||
const AnalysisState &) const { | ||
return BufferRelation::Unknown; | ||
} | ||
}; | ||
|
||
/// Bufferization of ml_program.global into a memref.global | ||
struct GlobalOpInterface | ||
: public ExternalModelBase<GlobalOpInterface, GlobalOp> { | ||
|
||
bool bufferizesToMemoryRead(Operation *, OpOperand &, | ||
const AnalysisState &) const { | ||
return false; | ||
} | ||
|
||
bool bufferizesToMemoryWrite(Operation *, OpOperand &, | ||
const AnalysisState &) const { | ||
return false; | ||
} | ||
|
||
bool hasTensorSemantics(Operation *) const { return true; } | ||
|
||
LogicalResult bufferize(Operation *op, RewriterBase &rewriter, | ||
const BufferizationOptions &) const { | ||
auto globalOp = cast<GlobalOp>(op); | ||
if (!globalOp.getValue().has_value()) | ||
return globalOp.emitError("global op must have a value"); | ||
|
||
auto tensorType = cast<TensorType>(globalOp.getType()); | ||
auto memrefType = getMemRefTypeWithStaticIdentityLayout(tensorType); | ||
|
||
replaceOpWithNewBufferizedOp<memref::GlobalOp>( | ||
rewriter, globalOp, globalOp.getSymName(), | ||
/*sym_visibility=*/globalOp.getSymVisibilityAttr(), | ||
/*type=*/cast<MemRefType>(memrefType), | ||
/*initial_value=*/globalOp.getValue().value(), | ||
/*constant=*/!globalOp.getIsMutable(), | ||
/*alignment=*/nullptr); | ||
|
||
return success(); | ||
} | ||
}; | ||
|
||
/// Bufferization of ml_program.global_load into a memref.get_global | ||
struct GlobalLoadOpInterface | ||
: public ExternalModelBase<GlobalLoadOpInterface, GlobalLoadOp> { | ||
|
||
bool bufferizesToMemoryRead(Operation *, OpOperand &, | ||
const AnalysisState &) const { | ||
return false; | ||
} | ||
|
||
bool bufferizesToMemoryWrite(Operation *, OpOperand &, | ||
const AnalysisState &) const { | ||
return false; | ||
} | ||
|
||
bool isWritable(Operation *, Value, const AnalysisState &) const { | ||
return false; | ||
} | ||
|
||
LogicalResult bufferize(Operation *op, RewriterBase &rewriter, | ||
const BufferizationOptions &) const { | ||
auto globalLoadOp = cast<GlobalLoadOp>(op); | ||
|
||
auto tensorType = cast<TensorType>(globalLoadOp.getType()); | ||
auto memrefType = getMemRefTypeWithStaticIdentityLayout(tensorType); | ||
|
||
replaceOpWithNewBufferizedOp<memref::GetGlobalOp>( | ||
rewriter, globalLoadOp, memrefType, | ||
globalLoadOp.getGlobalAttr().getLeafReference()); | ||
|
||
return success(); | ||
} | ||
}; | ||
|
||
/// Bufferization of ml_program.global_store into a memref.get_global and | ||
/// memcpy | ||
struct GlobalStoreOpInterface | ||
: public ExternalModelBase<GlobalStoreOpInterface, GlobalStoreOp> { | ||
|
||
bool bufferizesToMemoryRead(Operation *, OpOperand &, | ||
const AnalysisState &) const { | ||
return false; | ||
} | ||
|
||
bool bufferizesToMemoryWrite(Operation *, OpOperand &, | ||
const AnalysisState &) const { | ||
return true; | ||
} | ||
|
||
LogicalResult bufferize(Operation *op, RewriterBase &rewriter, | ||
const BufferizationOptions &options) const { | ||
auto globalStoreOp = cast<GlobalStoreOp>(op); | ||
|
||
auto tensorType = cast<TensorType>(globalStoreOp.getValue().getType()); | ||
auto memrefType = getMemRefTypeWithStaticIdentityLayout(tensorType); | ||
|
||
auto loc = globalStoreOp.getLoc(); | ||
auto targetMemref = rewriter.create<memref::GetGlobalOp>( | ||
loc, memrefType, globalStoreOp.getGlobalAttr().getLeafReference()); | ||
|
||
auto sourceMemref = getBuffer(rewriter, globalStoreOp.getValue(), options); | ||
if (failed(sourceMemref)) { | ||
return failure(); | ||
} | ||
|
||
auto memcpy = | ||
options.createMemCpy(rewriter, loc, sourceMemref.value(), targetMemref); | ||
if (failed(memcpy)) { | ||
return failure(); | ||
} | ||
rewriter.eraseOp(globalStoreOp); | ||
|
||
return success(); | ||
} | ||
}; | ||
} // namespace | ||
|
||
void registerBufferizableOpInterfaceExternalModels(DialectRegistry ®istry) { | ||
registry.addExtension(+[](MLIRContext *ctx, MLProgramDialect *) { | ||
GlobalOp::attachInterface<GlobalOpInterface>(*ctx); | ||
GlobalLoadOp::attachInterface<GlobalLoadOpInterface>(*ctx); | ||
GlobalStoreOp::attachInterface<GlobalStoreOpInterface>(*ctx); | ||
}); | ||
} | ||
} // namespace ml_program | ||
} // namespace mlir |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,52 @@ | ||
// RUN: mlir-opt %s -one-shot-bufferize -split-input-file | FileCheck %s | ||
|
||
// CHECK-LABEL: memref.global "private" @global | ||
ml_program.global private mutable @global(dense<0> : tensor<i64>) : tensor<i64> | ||
|
||
// CHECK-LABEL: func.func @global_load_store | ||
func.func @global_load_store() -> i64 { | ||
// CHECK-DAG: %[[CST127:.*]] = arith.constant 127 | ||
// CHECK-DAG: %[[GLOBAL_1:.*]] = memref.get_global @global | ||
// CHECK: %[[VALUE:.*]] = memref.load %[[GLOBAL_1]][] | ||
// CHECK: %[[NEW_VALUE:.*]] = arith.muli %[[VALUE]], %[[CST127]] | ||
// CHECK: %[[ALLOC:.*]] = memref.alloc() | ||
// CHECK: memref.copy %[[GLOBAL_1]], %[[ALLOC]] | ||
// CHECK: memref.store %[[NEW_VALUE]], %[[ALLOC]][] | ||
// CHECK: %[[GLOBAL_2:.*]] = memref.get_global @global | ||
// CHECK: memref.copy %[[ALLOC]], %[[GLOBAL_2]] | ||
// CHECK: return %[[NEW_VALUE]] | ||
%c127 = arith.constant 127 : i64 | ||
%0 = ml_program.global_load @global : tensor<i64> | ||
%extracted = tensor.extract %0[] : tensor<i64> | ||
%1 = arith.muli %extracted, %c127 : i64 | ||
%inserted = tensor.insert %1 into %0[] : tensor<i64> | ||
ml_program.global_store @global = %inserted : tensor<i64> | ||
return %1 : i64 | ||
} | ||
|
||
// ----- | ||
|
||
// CHECK-LABEL: memref.global "private" @global | ||
ml_program.global private mutable @global(dense<0> : tensor<i64>) : tensor<i64> | ||
|
||
// CHECK-LABEL: func.func @raw_hazard | ||
func.func @raw_hazard() -> i64 { | ||
// CHECK-DAG: %[[CST127:.*]] = arith.constant 127 | ||
// CHECK-DAG: %[[GLOBAL_1:.*]] = memref.get_global @global | ||
// CHECK-DAG: %[[GLOBAL_2:.*]] = memref.get_global @global | ||
// CHECK-DAG: %[[ALLOC:.*]] = memref.alloc() | ||
// CHECK: memref.copy %[[GLOBAL_1]], %[[ALLOC]] | ||
// CHECK: memref.store %[[CST127]], %[[ALLOC]][] | ||
// CHECK: %[[VAL:.*]] = memref.load %[[GLOBAL_2]][] | ||
// CHECK: %[[GLOBAL_3:.*]] = memref.get_global @global | ||
// CHECK: memref.copy %[[ALLOC]], %[[GLOBAL_3]] | ||
// CHECK: return %[[VAL]] | ||
%c127 = arith.constant 127 : i64 | ||
%0 = ml_program.global_load @global : tensor<i64> | ||
%1 = ml_program.global_load @global : tensor<i64> | ||
%inserted = tensor.insert %c127 into %0[] : tensor<i64> | ||
%extracted = tensor.extract %1[] : tensor<i64> | ||
ml_program.global_store @global = %inserted : tensor<i64> | ||
return %extracted : i64 | ||
} | ||
|