Skip to content

Commit

Permalink
[mlir][MLProgram] Add MLProgram to MemRef bufferization pass (llvm#75103
Browse files Browse the repository at this point in the history
)

There is currently no lowering out of `ml_program` in the LLVM
repository. This change adds a lowering to `memref` so that it can be
lowered all the way to LLVM. This lowering was taken from the [reference
backend in
torch-mlir](llvm/torch-mlir@f416953
).

I had tried implementing the `BufferizableOpInterface` for `ml_program`
instead of adding a new pass but that did not work because
`OneShotBufferize` does not visit module-level ops like
`ml_program.global`.
  • Loading branch information
ryan-holt-1 authored Jan 30, 2024
1 parent fc7c79b commit fa10121
Show file tree
Hide file tree
Showing 7 changed files with 239 additions and 1 deletion.
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
//===- BufferizableOpInterfaceImpl.h - Impl. of BufferizableOpInterface ---===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#ifndef MLIR_DIALECT_MLPROGRAM_BUFFERIZABLEOPINTERFACEIMPL_H
#define MLIR_DIALECT_MLPROGRAM_BUFFERIZABLEOPINTERFACEIMPL_H

namespace mlir {
class DialectRegistry;

namespace ml_program {
void registerBufferizableOpInterfaceExternalModels(DialectRegistry &registry);
} // namespace ml_program
} // namespace mlir

#endif // MLIR_DIALECT_MLPROGRAM_BUFFERIZABLEOPINTERFACEIMPL_H
2 changes: 2 additions & 0 deletions mlir/include/mlir/InitAllDialects.h
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@
#include "mlir/Dialect/Linalg/Transforms/SubsetInsertionOpInterfaceImpl.h"
#include "mlir/Dialect/Linalg/Transforms/TilingInterfaceImpl.h"
#include "mlir/Dialect/MLProgram/IR/MLProgram.h"
#include "mlir/Dialect/MLProgram/Transforms/BufferizableOpInterfaceImpl.h"
#include "mlir/Dialect/Math/IR/Math.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/MemRef/IR/MemRefMemorySlot.h"
Expand Down Expand Up @@ -160,6 +161,7 @@ inline void registerAllDialects(DialectRegistry &registry) {
memref::registerRuntimeVerifiableOpInterfaceExternalModels(registry);
memref::registerValueBoundsOpInterfaceExternalModels(registry);
memref::registerMemorySlotExternalModels(registry);
ml_program::registerBufferizableOpInterfaceExternalModels(registry);
scf::registerBufferDeallocationOpInterfaceExternalModels(registry);
scf::registerBufferizableOpInterfaceExternalModels(registry);
scf::registerValueBoundsOpInterfaceExternalModels(registry);
Expand Down
4 changes: 4 additions & 0 deletions mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -498,6 +498,10 @@ LogicalResult bufferization::bufferizeOp(Operation *op,
<< "\n//===-------------------------------------------===//\n");
}

// Return early if the top-level op is entirely gone.
if (erasedOps.contains(op))
return success();

// Fold all to_memref(to_tensor(x)) pairs.
for (Operation *op : toMemrefOps) {
rewriter.setInsertionPoint(op);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -459,7 +459,7 @@ LogicalResult mlir::bufferization::bufferizeModuleOp(
}

// Bufferize all other ops.
for (Operation &op : moduleOp.getOps()) {
for (Operation &op : llvm::make_early_inc_range(moduleOp.getOps())) {
// Functions were already bufferized.
if (isa<func::FuncOp>(&op))
continue;
Expand Down
159 changes: 159 additions & 0 deletions mlir/lib/Dialect/MLProgram/Transforms/BufferizableOpInterfaceImpl.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,159 @@
//===- BufferizableOpInterfaceImpl.cpp - Impl. of BufferizableOpInterface -===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#include "mlir/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.h"

#include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
#include "mlir/Dialect/MLProgram/IR/MLProgram.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"

using namespace mlir;
using namespace mlir::bufferization;
using namespace mlir::ml_program;

namespace mlir {
namespace ml_program {
namespace {

template <typename Interface, typename Op>
struct ExternalModelBase
: public BufferizableOpInterface::ExternalModel<Interface, Op> {

AliasingValueList getAliasingValues(Operation *, OpOperand &,
const AnalysisState &) const {
return {};
}

BufferRelation bufferRelation(Operation *, OpResult,
const AnalysisState &) const {
return BufferRelation::Unknown;
}
};

/// Bufferization of ml_program.global into a memref.global
struct GlobalOpInterface
: public ExternalModelBase<GlobalOpInterface, GlobalOp> {

bool bufferizesToMemoryRead(Operation *, OpOperand &,
const AnalysisState &) const {
return false;
}

bool bufferizesToMemoryWrite(Operation *, OpOperand &,
const AnalysisState &) const {
return false;
}

bool hasTensorSemantics(Operation *) const { return true; }

LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
const BufferizationOptions &) const {
auto globalOp = cast<GlobalOp>(op);
if (!globalOp.getValue().has_value())
return globalOp.emitError("global op must have a value");

auto tensorType = cast<TensorType>(globalOp.getType());
auto memrefType = getMemRefTypeWithStaticIdentityLayout(tensorType);

replaceOpWithNewBufferizedOp<memref::GlobalOp>(
rewriter, globalOp, globalOp.getSymName(),
/*sym_visibility=*/globalOp.getSymVisibilityAttr(),
/*type=*/cast<MemRefType>(memrefType),
/*initial_value=*/globalOp.getValue().value(),
/*constant=*/!globalOp.getIsMutable(),
/*alignment=*/nullptr);

return success();
}
};

/// Bufferization of ml_program.global_load into a memref.get_global
struct GlobalLoadOpInterface
: public ExternalModelBase<GlobalLoadOpInterface, GlobalLoadOp> {

bool bufferizesToMemoryRead(Operation *, OpOperand &,
const AnalysisState &) const {
return false;
}

bool bufferizesToMemoryWrite(Operation *, OpOperand &,
const AnalysisState &) const {
return false;
}

bool isWritable(Operation *, Value, const AnalysisState &) const {
return false;
}

LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
const BufferizationOptions &) const {
auto globalLoadOp = cast<GlobalLoadOp>(op);

auto tensorType = cast<TensorType>(globalLoadOp.getType());
auto memrefType = getMemRefTypeWithStaticIdentityLayout(tensorType);

replaceOpWithNewBufferizedOp<memref::GetGlobalOp>(
rewriter, globalLoadOp, memrefType,
globalLoadOp.getGlobalAttr().getLeafReference());

return success();
}
};

/// Bufferization of ml_program.global_store into a memref.get_global and
/// memcpy
struct GlobalStoreOpInterface
: public ExternalModelBase<GlobalStoreOpInterface, GlobalStoreOp> {

bool bufferizesToMemoryRead(Operation *, OpOperand &,
const AnalysisState &) const {
return false;
}

bool bufferizesToMemoryWrite(Operation *, OpOperand &,
const AnalysisState &) const {
return true;
}

LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
const BufferizationOptions &options) const {
auto globalStoreOp = cast<GlobalStoreOp>(op);

auto tensorType = cast<TensorType>(globalStoreOp.getValue().getType());
auto memrefType = getMemRefTypeWithStaticIdentityLayout(tensorType);

auto loc = globalStoreOp.getLoc();
auto targetMemref = rewriter.create<memref::GetGlobalOp>(
loc, memrefType, globalStoreOp.getGlobalAttr().getLeafReference());

auto sourceMemref = getBuffer(rewriter, globalStoreOp.getValue(), options);
if (failed(sourceMemref)) {
return failure();
}

auto memcpy =
options.createMemCpy(rewriter, loc, sourceMemref.value(), targetMemref);
if (failed(memcpy)) {
return failure();
}
rewriter.eraseOp(globalStoreOp);

return success();
}
};
} // namespace

void registerBufferizableOpInterfaceExternalModels(DialectRegistry &registry) {
registry.addExtension(+[](MLIRContext *ctx, MLProgramDialect *) {
GlobalOp::attachInterface<GlobalOpInterface>(*ctx);
GlobalLoadOp::attachInterface<GlobalLoadOpInterface>(*ctx);
GlobalStoreOp::attachInterface<GlobalStoreOpInterface>(*ctx);
});
}
} // namespace ml_program
} // namespace mlir
1 change: 1 addition & 0 deletions mlir/lib/Dialect/MLProgram/Transforms/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
add_mlir_dialect_library(MLIRMLProgramTransforms
BufferizableOpInterfaceImpl.cpp
PipelineGlobalOps.cpp

ADDITIONAL_HEADER_DIRS
Expand Down
52 changes: 52 additions & 0 deletions mlir/test/Dialect/MLProgram/one-shot-bufferize.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
// RUN: mlir-opt %s -one-shot-bufferize -split-input-file | FileCheck %s

// CHECK-LABEL: memref.global "private" @global
ml_program.global private mutable @global(dense<0> : tensor<i64>) : tensor<i64>

// CHECK-LABEL: func.func @global_load_store
func.func @global_load_store() -> i64 {
// CHECK-DAG: %[[CST127:.*]] = arith.constant 127
// CHECK-DAG: %[[GLOBAL_1:.*]] = memref.get_global @global
// CHECK: %[[VALUE:.*]] = memref.load %[[GLOBAL_1]][]
// CHECK: %[[NEW_VALUE:.*]] = arith.muli %[[VALUE]], %[[CST127]]
// CHECK: %[[ALLOC:.*]] = memref.alloc()
// CHECK: memref.copy %[[GLOBAL_1]], %[[ALLOC]]
// CHECK: memref.store %[[NEW_VALUE]], %[[ALLOC]][]
// CHECK: %[[GLOBAL_2:.*]] = memref.get_global @global
// CHECK: memref.copy %[[ALLOC]], %[[GLOBAL_2]]
// CHECK: return %[[NEW_VALUE]]
%c127 = arith.constant 127 : i64
%0 = ml_program.global_load @global : tensor<i64>
%extracted = tensor.extract %0[] : tensor<i64>
%1 = arith.muli %extracted, %c127 : i64
%inserted = tensor.insert %1 into %0[] : tensor<i64>
ml_program.global_store @global = %inserted : tensor<i64>
return %1 : i64
}

// -----

// CHECK-LABEL: memref.global "private" @global
ml_program.global private mutable @global(dense<0> : tensor<i64>) : tensor<i64>

// CHECK-LABEL: func.func @raw_hazard
func.func @raw_hazard() -> i64 {
// CHECK-DAG: %[[CST127:.*]] = arith.constant 127
// CHECK-DAG: %[[GLOBAL_1:.*]] = memref.get_global @global
// CHECK-DAG: %[[GLOBAL_2:.*]] = memref.get_global @global
// CHECK-DAG: %[[ALLOC:.*]] = memref.alloc()
// CHECK: memref.copy %[[GLOBAL_1]], %[[ALLOC]]
// CHECK: memref.store %[[CST127]], %[[ALLOC]][]
// CHECK: %[[VAL:.*]] = memref.load %[[GLOBAL_2]][]
// CHECK: %[[GLOBAL_3:.*]] = memref.get_global @global
// CHECK: memref.copy %[[ALLOC]], %[[GLOBAL_3]]
// CHECK: return %[[VAL]]
%c127 = arith.constant 127 : i64
%0 = ml_program.global_load @global : tensor<i64>
%1 = ml_program.global_load @global : tensor<i64>
%inserted = tensor.insert %c127 into %0[] : tensor<i64>
%extracted = tensor.extract %1[] : tensor<i64>
ml_program.global_store @global = %inserted : tensor<i64>
return %extracted : i64
}

0 comments on commit fa10121

Please sign in to comment.