From f58fb8c209a5179f8f2e02e2a0816c9b1f1edb1b Mon Sep 17 00:00:00 2001 From: Spenser Bauman Date: Fri, 1 Dec 2023 10:33:14 -0500 Subject: [PATCH] [mlir][tosa] Fix lowering of tosa.conv2d (#73240) The lowering of tosa.conv2d produces an illegal tensor.empty operation where the number of inputs do not match the number of dynamic dimensions in the output type. The fix is to base the generation of tensor.dim operations off the result type of the conv2d operation, rather than the input type. The problem and fix are very similar to this fix https://github.com/llvm/llvm-project/pull/72724 but for convolution. --- .../TosaToLinalg/TosaToLinalgNamed.cpp | 4 ++-- .../TosaToLinalg/tosa-to-linalg-named.mlir | 23 +++++++++++++++++++ 2 files changed, 25 insertions(+), 2 deletions(-) diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp index b30651976eeb93..0accd9d1986a1e 100644 --- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp +++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp @@ -179,7 +179,7 @@ static SmallVector inferDynamicDimsForConv( for (uint32_t i = 0, s = inputSizeDims.size(); i < s; ++i) { int64_t inputDim = inputSizeDims[i]; int64_t kernelDim = kernelSizeDims[i]; - if (inputTy.isDynamicDim(inputDim)) { + if (resultTy.isDynamicDim(inputDim)) { auto padTop = padAttr[i * 2]; auto padBottom = padAttr[i * 2 + 1]; auto stride = strideAttr[i]; @@ -196,7 +196,7 @@ static SmallVector inferDynamicDimsForConv( // Get the batch/channels dimensions. for (int i = 0; i < inputRank; i++) { - if (inputTy.isDynamicDim(i) && !dynDims[i]) + if (resultTy.isDynamicDim(i) && !dynDims[i]) dynDims[i] = rewriter.create(loc, input, i); } diff --git a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-named.mlir b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-named.mlir index bbdd1bad799865..230001f7633b57 100644 --- a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-named.mlir +++ b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-named.mlir @@ -495,6 +495,29 @@ func.func @conv2d_dyn_w_h(%input: tensor<1x?x?x27xf32>, %weights: tensor<28x3x3x // ----- +// CHECK: [[$MAP1:.+]] = affine_map<(d0, d1, d2, d3) -> (d3)> +// CHECK: [[$MAP2:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)> + +func.func @conv2d_dyn_output(%input: tensor<2x6x5x4xf32>, %weights: tensor<4x3x3x4xf32>, %bias: tensor<4xf32>) { + // %[[C0:.+]] = arith.constant 0 : index + // %[[DIM0:.+]] = tensor.dim %input, %[[C0]] : tensor<2x6x5x4xf32> + // %[[INIT_CONV:.+]] = tensor.empty(%[[DIM0]]) : tensor + // %[[ZERO:.+]] = arith.constant 0.000000e+00 : f32 + // %[[FILL:.+]] = linalg.fill + // %[[INIT_GENERIC:.+]] = tensor.empty([[DIM0]]) : tensor + + // %[[CONV:.+]] = linalg.conv_2d_nhwc_fhwc {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%arg0, %arg1 : tensor<2x6x5x4xf32>, tensor<4x3x3x4xf32>) outs(%[[INIT_CONV]] : tensor) -> tensor + // linalg.generic {indexing_maps = [#[[MAP1]], #[[MAP2]], #[[MAP2]]], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%arg2, %[[CONV]] : tensor<4xf32>, tensor) outs(%[[INIT_GENERIC]] : tensor) { + // %[[ADD:.+]] = arith.addf + // linalg.yield %[[ADD]] : f32 + // } -> tensor + + %0 = tosa.conv2d %input, %weights, %bias {dilation = array, pad = array, stride = array} : (tensor<2x6x5x4xf32 >, tensor<4x3x3x4xf32>, tensor<4xf32>) -> tensor + return +} + +// ----- + // CHECK-LABEL: @conv2d_padded_f32 func.func @conv2d_padded_f32(%input: tensor<1x47x40x28xf32>, %weights: tensor<28x3x3x28xf32>, %bias: tensor<28xf32>) -> () { // CHECK: %[[C0:.+]] = arith.constant 0