diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp index 783698952b36f8..633b599e1c0769 100644 --- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp +++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp @@ -768,7 +768,7 @@ class MaxPool2dConverter : public OpRewritePattern { // Determine what the initial value needs to be for the max pool op. TypedAttr initialAttr; - if (resultETy.isF32()) + if (resultETy.isF32() || resultETy.isBF16() || resultETy.isF16()) initialAttr = rewriter.getFloatAttr( resultETy, APFloat::getLargest( cast(resultETy).getFloatSemantics(), true)); diff --git a/mlir/test/Dialect/Tosa/ops.mlir b/mlir/test/Dialect/Tosa/ops.mlir index 7d7f2d31a4244c..80bd1640b6e10e 100644 --- a/mlir/test/Dialect/Tosa/ops.mlir +++ b/mlir/test/Dialect/Tosa/ops.mlir @@ -83,12 +83,26 @@ func.func @test_matmul(%arg0: tensor<1x14x19xf32>, %arg1: tensor<1x19x28xf32>) - } // ----- -// CHECK-LABEL: max_pool2d -func.func @test_max_pool2d(%arg0: tensor<1x32x32x8xf32>) -> tensor<1x32x32x8xf32> { +// CHECK-LABEL: max_pool2d_f32 +func.func @test_max_pool2d_f32(%arg0: tensor<1x32x32x8xf32>) -> tensor<1x32x32x8xf32> { %0 = tosa.max_pool2d %arg0 {kernel = array, pad = array, stride = array} : (tensor<1x32x32x8xf32>) -> tensor<1x32x32x8xf32> return %0 : tensor<1x32x32x8xf32> } +// ----- +// CHECK-LABEL: max_pool2d_bf16 +func.func @test_max_pool2d_bf16(%arg0: tensor<1x32x32x8xbf16>) -> tensor<1x32x32x8xbf16> { + %0 = tosa.max_pool2d %arg0 {kernel = array, pad = array, stride = array} : (tensor<1x32x32x8xbf16>) -> tensor<1x32x32x8xbf16> + return %0 : tensor<1x32x32x8xbf16> +} + +// ----- +// CHECK-LABEL: max_pool2d_f16 +func.func @test_max_pool2d_f16(%arg0: tensor<1x32x32x8xf16>) -> tensor<1x32x32x8xf16> { + %0 = tosa.max_pool2d %arg0 {kernel = array, pad = array, stride = array} : (tensor<1x32x32x8xf16>) -> tensor<1x32x32x8xf16> + return %0 : tensor<1x32x32x8xf16> +} + // ----- // CHECK-LABEL: rfft2d func.func @test_rfft2d(%arg0: tensor<13x8x16xf32>) -> (tensor<13x8x9xf32>, tensor<13x8x9xf32>) {