Skip to content

Commit

Permalink
Merge remote-tracking branch 'xlnx/feature/fused-ops' into bump_to_f8…
Browse files Browse the repository at this point in the history
…eceb45
  • Loading branch information
mgehre-amd committed Dec 18, 2024
2 parents d1726f4 + 5518042 commit e4cc751
Show file tree
Hide file tree
Showing 11 changed files with 156 additions and 50 deletions.
5 changes: 5 additions & 0 deletions mlir/include/mlir/Dialect/Linalg/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,11 @@ def LinalgElementwiseOpFusionPass : Pass<"linalg-fuse-elementwise-ops"> {
let dependentDialects = [
"affine::AffineDialect", "linalg::LinalgDialect", "memref::MemRefDialect"
];
let options = [
Option<"removeOutsDependency", "remove-outs-dependency", "bool",
/*default=*/"true",
"Replace out by tensor.empty">,
];
}

def LinalgNamedOpConversionPass: Pass<"linalg-named-op-conversion"> {
Expand Down
3 changes: 2 additions & 1 deletion mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
Original file line number Diff line number Diff line change
Expand Up @@ -1701,7 +1701,8 @@ using ControlFusionFn = std::function<bool(OpOperand *fusedOperand)>;
/// when both operations are fusable elementwise operations.
void populateElementwiseOpsFusionPatterns(
RewritePatternSet &patterns,
const ControlFusionFn &controlElementwiseOpFusion);
const ControlFusionFn &controlElementwiseOpFusion,
bool replaceOutsDependency = true);

/// Function type which is used to control propagation of tensor.pack/unpack
/// ops.
Expand Down
4 changes: 2 additions & 2 deletions mlir/lib/Conversion/ArithToEmitC/ArithToEmitC.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -757,7 +757,7 @@ class TruncFConversion : public OpConversionPattern<arith::TruncFOp> {
return rewriter.notifyMatchFailure(castOp,
"unsupported cast destination type");

if (!castOp.areCastCompatible(operandType, dstType))
if (!emitc::CastOp::areCastCompatible(operandType, dstType))
return rewriter.notifyMatchFailure(castOp, "cast-incompatible types");

rewriter.replaceOpWithNewOp<emitc::CastOp>(castOp, dstType,
Expand Down Expand Up @@ -787,7 +787,7 @@ class ExtFConversion : public OpConversionPattern<arith::ExtFOp> {
return rewriter.notifyMatchFailure(castOp,
"unsupported cast destination type");

if (!castOp.areCastCompatible(operandType, dstType))
if (!emitc::CastOp::areCastCompatible(operandType, dstType))
return rewriter.notifyMatchFailure(castOp, "cast-incompatible types");

rewriter.replaceOpWithNewOp<emitc::CastOp>(castOp, dstType,
Expand Down
4 changes: 4 additions & 0 deletions mlir/lib/Dialect/EmitC/IR/EmitC.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -313,6 +313,10 @@ LogicalResult emitc::AssignOp::verify() {
bool CastOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
Type input = inputs.front(), output = outputs.front();

// Opaque types are always allowed
if (isa<emitc::OpaqueType>(input) || isa<emitc::OpaqueType>(output))
return true;

// Cast to array is only possible from an array
if (isa<emitc::ArrayType>(input) != isa<emitc::ArrayType>(output))
return false;
Expand Down
11 changes: 7 additions & 4 deletions mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2134,11 +2134,13 @@ void mlir::linalg::populateFoldReshapeOpsByCollapsingPatterns(

void mlir::linalg::populateElementwiseOpsFusionPatterns(
RewritePatternSet &patterns,
const ControlFusionFn &controlElementwiseOpsFusion) {
const ControlFusionFn &controlElementwiseOpsFusion,
bool removeOutsDependency) {
auto *context = patterns.getContext();
patterns.add<FuseElementwiseOps>(context, controlElementwiseOpsFusion);
patterns.add<FoldFillWithGenericOp, FoldScalarOrSplatConstant,
RemoveOutsDependency>(context);
patterns.add<FoldFillWithGenericOp, FoldScalarOrSplatConstant>(context);
if (removeOutsDependency)
patterns.add<RemoveOutsDependency>(context);
// Add the patterns that clean up dead operands and results.
populateEraseUnusedOperandsAndResultsPatterns(patterns);
}
Expand Down Expand Up @@ -2180,7 +2182,8 @@ struct LinalgElementwiseOpFusionPass
};

// Add elementwise op fusion patterns.
populateElementwiseOpsFusionPatterns(patterns, defaultControlFn);
populateElementwiseOpsFusionPatterns(patterns, defaultControlFn,
removeOutsDependency);
populateFoldReshapeOpsByExpansionPatterns(patterns, defaultControlFn);
tensor::populateBubbleUpExpandShapePatterns(patterns);

Expand Down
12 changes: 3 additions & 9 deletions mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1349,19 +1349,13 @@ OpFoldResult tosa::AbsOp::fold(FoldAdaptor adaptor) {
return {};
}

static bool hasZeroSize(Type ty) {
auto ranked = dyn_cast<RankedTensorType>(ty);
if (!ranked)
return false;
return any_of(ranked.getShape(), [](auto d) { return d == 0; });
}

OpFoldResult ConcatOp::fold(FoldAdaptor adaptor) {
/// Remove operands that have zero elements.
bool changed = false;
for (size_t i = 0; i < getInput1().size(); ) {
auto input = getInput1()[i];
if (hasZeroSize(input.getType())) {
auto input = cast<RankedTensorType>(getInput1()[i].getType());
// Ensure that we have at least one operand left.
if (input.getDimSize(getAxis()) == 0 && getInput1().size() > 1) {
getInput1Mutable().erase(i);
changed = true;
} else {
Expand Down
74 changes: 42 additions & 32 deletions mlir/lib/Target/Cpp/TranslateToCpp.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -254,6 +254,16 @@ struct CppEmitter {
return operandExpression == emittedExpression;
};

/// Determine whether expression \p expressionOp should be emitted inline,
/// i.e. as part of its user. This function recommends inlining of any
/// expressions that can be inlined unless it is used by another expression,
/// under the assumption that any expression fusion/re-materialization was
/// taken care of by transformations run by the backend.
bool shouldBeInlined(ExpressionOp expressionOp);

/// This emitter will only emit translation units whos id matches this value.
StringRef willOnlyEmitTu() { return onlyTu; }

private:
using ValueMapper = llvm::ScopedHashTable<Value, std::string>;
using BlockMapper = llvm::ScopedHashTable<Block *, std::string>;
Expand Down Expand Up @@ -297,21 +307,22 @@ struct CppEmitter {
return lowestPrecedence();
return emittedExpressionPrecedence.back();
}

/// Determine whether expression \p op should be emitted in a deferred way.
bool hasDeferredEmission(Operation *op);
};
} // namespace

/// Determine whether expression \p op should be emitted in a deferred way.
static bool hasDeferredEmission(Operation *op) {
bool CppEmitter::hasDeferredEmission(Operation *op) {
if (llvm::isa_and_nonnull<emitc::ConstantOp>(op)) {
return !shouldUseConstantsAsVariables();
}

return isa_and_nonnull<emitc::GetGlobalOp, emitc::LiteralOp, emitc::MemberOp,
emitc::MemberOfPtrOp, emitc::SubscriptOp>(op);
}

/// Determine whether expression \p expressionOp should be emitted inline, i.e.
/// as part of its user. This function recommends inlining of any expressions
/// that can be inlined unless it is used by another expression, under the
/// assumption that any expression fusion/re-materialization was taken care of
/// by transformations run by the backend.
static bool shouldBeInlined(ExpressionOp expressionOp) {
bool CppEmitter::shouldBeInlined(ExpressionOp expressionOp) {
// Do not inline if expression is marked as such.
if (expressionOp.getDoNotInline())
return false;
Expand Down Expand Up @@ -373,6 +384,25 @@ static LogicalResult printConstantOp(CppEmitter &emitter, Operation *operation,
static LogicalResult printOperation(CppEmitter &emitter,
emitc::ConstantOp constantOp) {
if (!emitter.shouldUseConstantsAsVariables()) {
std::string out;
llvm::raw_string_ostream ss(out);

/// Temporary emitter object that writes to our stream instead of the output
/// allowing for the capture and caching of the produced string.
CppEmitter sniffer = CppEmitter(ss, emitter.shouldDeclareVariablesAtTop(),
emitter.willOnlyEmitTu(),
emitter.shouldUseConstantsAsVariables());

ss << "(";
if (failed(sniffer.emitType(constantOp.getLoc(), constantOp.getType())))
return failure();
ss << ") ";

if (failed(
sniffer.emitAttribute(constantOp.getLoc(), constantOp.getValue())))
return failure();

emitter.cacheDeferredOpResult(constantOp.getResult(), out);
return success();
}

Expand Down Expand Up @@ -838,7 +868,7 @@ static LogicalResult printOperation(CppEmitter &emitter, emitc::CastOp castOp) {

static LogicalResult printOperation(CppEmitter &emitter,
emitc::ExpressionOp expressionOp) {
if (shouldBeInlined(expressionOp))
if (emitter.shouldBeInlined(expressionOp))
return success();

Operation &op = *expressionOp.getOperation();
Expand Down Expand Up @@ -892,7 +922,7 @@ static LogicalResult printOperation(CppEmitter &emitter, emitc::ForOp forOp) {
dyn_cast_if_present<ExpressionOp>(value.getDefiningOp());
if (!expressionOp)
return false;
return shouldBeInlined(expressionOp);
return emitter.shouldBeInlined(expressionOp);
};

os << "for (";
Expand Down Expand Up @@ -1114,7 +1144,7 @@ static LogicalResult printFunctionBody(CppEmitter &emitter,
functionOp->walk<WalkOrder::PreOrder>([&](Operation *op) -> WalkResult {
if (isa<emitc::ExpressionOp>(op->getParentOp()) ||
(isa<emitc::ExpressionOp>(op) &&
shouldBeInlined(cast<emitc::ExpressionOp>(op))))
emitter.shouldBeInlined(cast<emitc::ExpressionOp>(op))))
return WalkResult::skip();
for (OpResult result : op->getResults()) {
if (failed(emitter.emitVariableDeclaration(
Expand Down Expand Up @@ -1494,22 +1524,6 @@ LogicalResult CppEmitter::emitExpression(ExpressionOp expressionOp) {

LogicalResult CppEmitter::emitOperand(Value value) {
Operation *def = value.getDefiningOp();
if (!shouldUseConstantsAsVariables()) {
if (auto constant = dyn_cast_if_present<ConstantOp>(def)) {
os << "((";

if (failed(emitType(constant.getLoc(), constant.getType()))) {
return failure();
}
os << ") ";

if (failed(emitAttribute(constant.getLoc(), constant.getValue()))) {
return failure();
}
os << ")";
return success();
}
}

if (isPartOfCurrentExpression(value)) {
assert(def && "Expected operand to be defined by an operation");
Expand Down Expand Up @@ -1721,11 +1735,7 @@ LogicalResult CppEmitter::emitOperation(Operation &op, bool trailingSemicolon) {
cacheDeferredOpResult(op.getResult(), op.getValue());
return success();
})
.Case<emitc::MemberOp>([&](auto op) {
cacheDeferredOpResult(op.getResult(), createMemberAccess(op));
return success();
})
.Case<emitc::MemberOfPtrOp>([&](auto op) {
.Case<emitc::MemberOp, emitc::MemberOfPtrOp>([&](auto op) {
cacheDeferredOpResult(op.getResult(), createMemberAccess(op));
return success();
})
Expand Down
4 changes: 4 additions & 0 deletions mlir/test/Dialect/EmitC/ops.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -36,11 +36,15 @@ emitc.func private @extern(i32) attributes {specifiers = ["extern"]}

func.func @cast(%arg0: i32) {
%1 = emitc.cast %arg0: i32 to f32
%2 = emitc.cast %1: f32 to !emitc.opaque<"some type">
%3 = emitc.cast %2: !emitc.opaque<"some type"> to !emitc.size_t
return
}

func.func @cast_array(%arg : !emitc.array<4xf32>) {
%1 = emitc.cast %arg: !emitc.array<4xf32> to !emitc.array<4xf32> ref
%2 = emitc.cast %arg: !emitc.array<4xf32> to !emitc.opaque<"some type">
%3 = emitc.cast %2: !emitc.opaque<"some type"> to !emitc.array<4xf32> ref
return
}

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
// RUN: mlir-opt %s -p 'builtin.module(func.func(linalg-fuse-elementwise-ops{remove-outs-dependency=0}))' -split-input-file | FileCheck %s

#identity = affine_map<(d0) -> (d0)>

func.func @keep_outs_dependency(%arg: tensor<4xf32>) -> tensor<4xf32> {
// CHECK-NOT: tensor.empty
%1 = linalg.generic {indexing_maps = [#identity, #identity], iterator_types = ["parallel"] } ins(%arg: tensor<4xf32>) outs(%arg: tensor<4xf32>) {
^bb0(%in: f32, %out: f32):
%exp = arith.negf %in: f32
linalg.yield %exp : f32
} -> tensor<4xf32>
return %1 : tensor<4xf32>
}
26 changes: 26 additions & 0 deletions mlir/test/Dialect/Tosa/canonicalize.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -204,6 +204,32 @@ func.func @concat_fold_zero(%arg0: tensor<?x0xf32>, %arg1: tensor<?x1xf32>, %arg
%0 = tosa.concat %arg0, %arg1, %arg2 {axis = 1 : i32}: (tensor<?x0xf32>, tensor<?x1xf32>, tensor<?x2xf32>) -> tensor<?x3xf32>
return %0 : tensor<?x3xf32>
}
// -----

// CHECK-LABEL: @concat_fold_zero
func.func @concat_fold_zero_all(%arg0: tensor<?x0xf32>, %arg1: tensor<?x0xf32>) -> tensor<?x0xf32> {
// CHECK: return %arg1
%0 = tosa.concat %arg0, %arg1 {axis = 1 : i32}: (tensor<?x0xf32>, tensor<?x0xf32>) -> tensor<?x0xf32>
return %0 : tensor<?x0xf32>
}

// -----

// CHECK-LABEL: @concat_fold_zero
func.func @concat_fold_zero_different_axis(%arg0: tensor<0x2xf32>, %arg1: tensor<0x4xf32>) -> tensor<0x6xf32> {
// CHECK: tosa.concat %arg0, %arg1
%0 = tosa.concat %arg0, %arg1 {axis = 1 : i32}: (tensor<0x2xf32>, tensor<0x4xf32>) -> tensor<0x6xf32>
return %0 : tensor<0x6xf32>
}

// -----

// CHECK-LABEL: @concat_fold_zero_size
func.func @concat_fold_zero_size(%arg0: tensor<?x0xf32>, %arg1: tensor<?x1xf32>, %arg2: tensor<?x2xf32>) -> tensor<?x3xf32> {
// CHECK: tosa.concat %arg1, %arg2 {axis = 1 : i32}
%0 = tosa.concat %arg0, %arg1, %arg2 {axis = 1 : i32}: (tensor<?x0xf32>, tensor<?x1xf32>, tensor<?x2xf32>) -> tensor<?x3xf32>
return %0 : tensor<?x3xf32>
}

// -----

Expand Down
50 changes: 48 additions & 2 deletions mlir/test/Target/Cpp/emitc-constants-as-variables.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,55 @@ func.func @test() {

return
}
// CPP-DEFAULT-LABEL: void test() {
// CPP-DEFAULT-NEXT: for (size_t v1 = (size_t) 0; v1 < (size_t) 10; v1 += (size_t) 1) {
// CPP-DEFAULT-NEXT: }
// CPP-DEFAULT-NEXT: return;
// CPP-DEFAULT-NEXT: }

// -----

func.func @test_subscript(%arg0: !emitc.array<4xf32>) -> (!emitc.lvalue<f32>) {
%c0 = "emitc.constant"() <{value = 0 : index}> : () -> !emitc.size_t
%0 = emitc.subscript %arg0[%c0] : (!emitc.array<4xf32>, !emitc.size_t) -> !emitc.lvalue<f32>
return %0 : !emitc.lvalue<f32>
}
// CPP-DEFAULT-LABEL: float test_subscript(float v1[4]) {
// CPP-DEFAULT-NEXT: return v1[(size_t) 0];
// CPP-DEFAULT-NEXT: }

// -----

// CPP-DEFAULT: void test() {
// CPP-DEFAULT-NEXT: for (size_t v1 = ((size_t) 0); v1 < ((size_t) 10); v1 += ((size_t) 1)) {
func.func @emitc_switch_ui64() {
%0 = "emitc.constant"(){value = 1 : ui64} : () -> ui64

emitc.switch %0 : ui64
default {
emitc.call_opaque "func2" (%0) : (ui64) -> ()
emitc.yield
}
return
}
// CPP-DEFAULT-LABEL: void emitc_switch_ui64() {
// CPP-DEFAULT: switch ((uint64_t) 1) {
// CPP-DEFAULT-NEXT: default: {
// CPP-DEFAULT-NEXT: func2((uint64_t) 1);
// CPP-DEFAULT-NEXT: break;
// CPP-DEFAULT-NEXT: }
// CPP-DEFAULT-NEXT: return;
// CPP-DEFAULT-NEXT: }

// -----

func.func @negative_values() {
%1 = "emitc.constant"() <{value = 10 : index}> : () -> !emitc.size_t
%2 = "emitc.constant"() <{value = -3000000000 : index}> : () -> !emitc.ssize_t

%3 = emitc.add %1, %2 : (!emitc.size_t, !emitc.ssize_t) -> !emitc.ssize_t

return
}
// CPP-DEFAULT-LABEL: void negative_values() {
// CPP-DEFAULT-NEXT: ssize_t v1 = (size_t) 10 + (ssize_t) -3000000000;
// CPP-DEFAULT-NEXT: return;
// CPP-DEFAULT-NEXT: }

0 comments on commit e4cc751

Please sign in to comment.