Skip to content

Commit

Permalink
Substitute zdnn calls for stick/unstick late, after most ZLow optimiz…
Browse files Browse the repository at this point in the history
…ations are performed (onnx#2812)

Signed-off-by: Alexandre Eichenberger <alexe@us.ibm.com>
  • Loading branch information
AlexandreEichenberger authored May 9, 2024
1 parent 15e59bd commit 893cf89
Show file tree
Hide file tree
Showing 31 changed files with 1,032 additions and 963 deletions.
10 changes: 10 additions & 0 deletions src/Accelerators/NNPA/Compiler/NNPACompilerUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
//===----------------------------------------------------------------------===//

#include "mlir/Conversion/AffineToStandard/AffineToStandard.h"
#include "mlir/Conversion/Passes.h"
#include "mlir/Conversion/ReconcileUnrealizedCasts/ReconcileUnrealizedCasts.h"
#include "mlir/Conversion/SCFToControlFlow/SCFToControlFlow.h"
#include "mlir/Conversion/VectorToSCF/VectorToSCF.h"
Expand Down Expand Up @@ -215,6 +216,10 @@ void addPassesNNPA(mlir::OwningOpRef<mlir::ModuleOp> &module,
addKrnlToAffinePasses(pm);
// Optimizations at ZLow that needs affine map in MemRef.
pm.addPass(zlow::createZLowRewritePass());
// Late generation of code for stick/unstick, needed to be after a
// ZLowRewrite pass.
if (nnpaEnableCompilerStickUnstick)
pm.addPass(zlow::createZLowStickExpansionPass(enableParallel));
pm.addPass(mlir::createCanonicalizerPass());
// Normalize MemRefs.
normalizeMemRefsPasses(pm);
Expand All @@ -223,6 +228,11 @@ void addPassesNNPA(mlir::OwningOpRef<mlir::ModuleOp> &module,
addKrnlToAffinePasses(pm);
// Optimizations at ZLow after normalizing MemRefs.
pm.addPass(zlow::createZLowRewritePass());
// The createZLowStickExpansion pass may create parallel constructs,
// they need to be handled here.
if (nnpaEnableCompilerStickUnstick && enableParallel)
pm.addPass(mlir::createConvertSCFToOpenMPPass());

pm.addPass(mlir::createCanonicalizerPass());
// Constant folding for std.alloc.
pm.addNestedPass<func::FuncOp>(onnx_mlir::createFoldStdAllocPass());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -981,3 +981,11 @@ bool isSuitableForZDNN<ONNXBatchNormalizationInferenceModeOp>(

return true;
}

/// Check legality for ONNXReshapeOp.
template <>
bool isSuitableForZDNN<ONNXReshapeOp>(
ONNXReshapeOp op, const DimAnalysis *dimAnalysis) {
// Noop Reshape is suitable for zAIU as this pass removes such reshape ops.
return isIdentityReshape(op, dimAnalysis);
}
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
#include "src/Dialect/ONNX/ElementsAttr/WideNum.hpp"
#include "src/Dialect/ONNX/ONNXDimAnalysis.hpp"
#include "src/Dialect/ONNX/ONNXOps.hpp"
#include "src/Dialect/ONNX/ONNXOps/OpHelper.hpp"
#include "src/Dialect/ONNX/ONNXOps/ShapeHelper.hpp"
#include "src/Dialect/ONNX/OnnxElementsAttrBuilder.hpp"
#include "src/Support/TypeUtilities.hpp"
Expand Down Expand Up @@ -467,6 +468,31 @@ class AddSubWithRHSZeroExpandPattern : public OpRewritePattern<OP_TYPE> {
}
};

class RemoveReshapeWithIdentityPattern
: public OpRewritePattern<ONNXReshapeOp> {
public:
using OpRewritePattern<ONNXReshapeOp>::OpRewritePattern;

DimAnalysis *dimAnalysis;

RemoveReshapeWithIdentityPattern(
MLIRContext *context, DimAnalysis *dimAnalysis)
: OpRewritePattern<ONNXReshapeOp>(context, 1001),
dimAnalysis(dimAnalysis) {}

LogicalResult matchAndRewrite(
ONNXReshapeOp reshapeOp, PatternRewriter &rewriter) const override {
if (!isIdentityReshape(reshapeOp, dimAnalysis))
return failure();

// Rewrite
Operation *op = reshapeOp.getOperation();
Value data = reshapeOp.getData();
rewriter.replaceOp(op, data);
return success();
}
};

//===----------------------------------------------------------------------===//
// Rewrite ONNX ops to ZHigh ops and ONNX ops for ZHigh.
//===----------------------------------------------------------------------===//
Expand All @@ -482,6 +508,8 @@ void getRewriteONNXForZHighPatterns(
patterns.getContext(), dimAnalysis);
patterns.insert<AddSubWithRHSZeroExpandPattern<ONNXSubOp>>(
patterns.getContext(), dimAnalysis);
patterns.insert<RemoveReshapeWithIdentityPattern>(
patterns.getContext(), dimAnalysis);
}

void getRewriteONNXForZHighDynamicallyLegal(
Expand Down Expand Up @@ -643,6 +671,13 @@ void getRewriteONNXForZHighDynamicallyLegal(
return isSuitableForZDNN<ONNXConvOp>(op) ||
!canInferencePadsForNNPAConv(op);
});
addDynamicallyLegalOpFor<ONNXReshapeOp>(target, dimAnalysis,
[](ONNXReshapeOp op, const DimAnalysis *dimAnalysis) {
// Get rid of identity reshape here, as it impacts stick/unstick.
// So all reshape are legal, unless it is an identity reshape, in which
// case there is a rule here to remove it.
return !isIdentityReshape(op, dimAnalysis);
});
}

struct RewriteONNXForZHighPass
Expand Down
Loading

0 comments on commit 893cf89

Please sign in to comment.