Skip to content

Commit

Permalink
[MLIR][TOSA] Fix f16/bf16 support for MaxPool2D (llvm#69332)
Browse files Browse the repository at this point in the history
Currently, the MaxPool2D operation in the TOSA MLIR dialect does not
accept half-precision Fp16 and Bf16 tensors, converse to what is stated
in the [TOSA
Specification](https://www.mlplatform.org/tosa/tosa_spec.html#_max_pool2d).

This patch fixes the verifier to accept the two datatypes for
input/output tensors, and adds related LIT test cases in Tosa/ops.mlir
  • Loading branch information
dchauhan-arm authored and ljfitz committed Feb 22, 2024
1 parent 3a5f724 commit a498541
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 3 deletions.
2 changes: 1 addition & 1 deletion mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -768,7 +768,7 @@ class MaxPool2dConverter : public OpRewritePattern<tosa::MaxPool2dOp> {

// Determine what the initial value needs to be for the max pool op.
TypedAttr initialAttr;
if (resultETy.isF32())
if (resultETy.isF32() || resultETy.isBF16() || resultETy.isF16())
initialAttr = rewriter.getFloatAttr(
resultETy, APFloat::getLargest(
cast<FloatType>(resultETy).getFloatSemantics(), true));
Expand Down
18 changes: 16 additions & 2 deletions mlir/test/Dialect/Tosa/ops.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -83,12 +83,26 @@ func.func @test_matmul(%arg0: tensor<1x14x19xf32>, %arg1: tensor<1x19x28xf32>) -
}

// -----
// CHECK-LABEL: max_pool2d
func.func @test_max_pool2d(%arg0: tensor<1x32x32x8xf32>) -> tensor<1x32x32x8xf32> {
// CHECK-LABEL: max_pool2d_f32
func.func @test_max_pool2d_f32(%arg0: tensor<1x32x32x8xf32>) -> tensor<1x32x32x8xf32> {
%0 = tosa.max_pool2d %arg0 {kernel = array<i64: 1, 1>, pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>} : (tensor<1x32x32x8xf32>) -> tensor<1x32x32x8xf32>
return %0 : tensor<1x32x32x8xf32>
}

// -----
// CHECK-LABEL: max_pool2d_bf16
func.func @test_max_pool2d_bf16(%arg0: tensor<1x32x32x8xbf16>) -> tensor<1x32x32x8xbf16> {
%0 = tosa.max_pool2d %arg0 {kernel = array<i64: 1, 1>, pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>} : (tensor<1x32x32x8xbf16>) -> tensor<1x32x32x8xbf16>
return %0 : tensor<1x32x32x8xbf16>
}

// -----
// CHECK-LABEL: max_pool2d_f16
func.func @test_max_pool2d_f16(%arg0: tensor<1x32x32x8xf16>) -> tensor<1x32x32x8xf16> {
%0 = tosa.max_pool2d %arg0 {kernel = array<i64: 1, 1>, pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>} : (tensor<1x32x32x8xf16>) -> tensor<1x32x32x8xf16>
return %0 : tensor<1x32x32x8xf16>
}

// -----
// CHECK-LABEL: rfft2d
func.func @test_rfft2d(%arg0: tensor<13x8x16xf32>) -> (tensor<13x8x9xf32>, tensor<13x8x9xf32>) {
Expand Down

0 comments on commit a498541

Please sign in to comment.