Skip to content

Commit

Permalink
[mlir][tosa] Fix lowering of tosa.conv2d (llvm#73240)
Browse files Browse the repository at this point in the history
The lowering of tosa.conv2d produces an illegal tensor.empty operation
where the number of inputs do not match the number of dynamic dimensions
in the output type.

The fix is to base the generation of tensor.dim operations off the
result type of the conv2d operation, rather than the input type. The
problem and fix are very similar to this fix

llvm#72724

but for convolution.
  • Loading branch information
sabauma committed Dec 1, 2023
1 parent 0d87e25 commit f58fb8c
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 2 deletions.
4 changes: 2 additions & 2 deletions mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -179,7 +179,7 @@ static SmallVector<Value> inferDynamicDimsForConv(
for (uint32_t i = 0, s = inputSizeDims.size(); i < s; ++i) {
int64_t inputDim = inputSizeDims[i];
int64_t kernelDim = kernelSizeDims[i];
if (inputTy.isDynamicDim(inputDim)) {
if (resultTy.isDynamicDim(inputDim)) {
auto padTop = padAttr[i * 2];
auto padBottom = padAttr[i * 2 + 1];
auto stride = strideAttr[i];
Expand All @@ -196,7 +196,7 @@ static SmallVector<Value> inferDynamicDimsForConv(

// Get the batch/channels dimensions.
for (int i = 0; i < inputRank; i++) {
if (inputTy.isDynamicDim(i) && !dynDims[i])
if (resultTy.isDynamicDim(i) && !dynDims[i])
dynDims[i] = rewriter.create<tensor::DimOp>(loc, input, i);
}

Expand Down
23 changes: 23 additions & 0 deletions mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-named.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -495,6 +495,29 @@ func.func @conv2d_dyn_w_h(%input: tensor<1x?x?x27xf32>, %weights: tensor<28x3x3x

// -----

// CHECK: [[$MAP1:.+]] = affine_map<(d0, d1, d2, d3) -> (d3)>
// CHECK: [[$MAP2:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>

func.func @conv2d_dyn_output(%input: tensor<2x6x5x4xf32>, %weights: tensor<4x3x3x4xf32>, %bias: tensor<4xf32>) {
// %[[C0:.+]] = arith.constant 0 : index
// %[[DIM0:.+]] = tensor.dim %input, %[[C0]] : tensor<2x6x5x4xf32>
// %[[INIT_CONV:.+]] = tensor.empty(%[[DIM0]]) : tensor<?x4x3x4xf32>
// %[[ZERO:.+]] = arith.constant 0.000000e+00 : f32
// %[[FILL:.+]] = linalg.fill
// %[[INIT_GENERIC:.+]] = tensor.empty([[DIM0]]) : tensor<?x4x3x4xf32>

// %[[CONV:.+]] = linalg.conv_2d_nhwc_fhwc {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%arg0, %arg1 : tensor<2x6x5x4xf32>, tensor<4x3x3x4xf32>) outs(%[[INIT_CONV]] : tensor<?x4x3x4xf32>) -> tensor<?x4x3x4xf32>
// linalg.generic {indexing_maps = [#[[MAP1]], #[[MAP2]], #[[MAP2]]], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%arg2, %[[CONV]] : tensor<4xf32>, tensor<?x4x3x4xf32>) outs(%[[INIT_GENERIC]] : tensor<?x4x3x4xf32>) {
// %[[ADD:.+]] = arith.addf
// linalg.yield %[[ADD]] : f32
// } -> tensor<?x4x3x4xf32>

%0 = tosa.conv2d %input, %weights, %bias {dilation = array<i64: 1, 1>, pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>} : (tensor<2x6x5x4xf32 >, tensor<4x3x3x4xf32>, tensor<4xf32>) -> tensor<?x4x3x4xf32>
return
}

// -----

// CHECK-LABEL: @conv2d_padded_f32
func.func @conv2d_padded_f32(%input: tensor<1x47x40x28xf32>, %weights: tensor<28x3x3x28xf32>, %bias: tensor<28xf32>) -> () {
// CHECK: %[[C0:.+]] = arith.constant 0
Expand Down

0 comments on commit f58fb8c

Please sign in to comment.