From 60bd6d19b284bb4f891a3322d50c7e0b10f69c2d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Nicol=C3=B2=20Scipione?= Date: Wed, 17 Jul 2024 03:10:43 -0700 Subject: [PATCH] [BLAS][portBLAS] Add try/catch for portblas runtime exception & minor fix (#525) Signed-off-by: nscipione * Catch PortBLAS's unsupported exceptions and rethrow as mkl::unimplemented. * Add missing checks for the device having double support in tests. --- src/blas/backends/portblas/portblas_common.hpp | 14 ++++++++++++-- tests/unit_tests/blas/extensions/omatcopy2.cpp | 4 ++++ tests/unit_tests/blas/extensions/omatcopy2_usm.cpp | 4 ++++ 3 files changed, 20 insertions(+), 2 deletions(-) diff --git a/src/blas/backends/portblas/portblas_common.hpp b/src/blas/backends/portblas/portblas_common.hpp index 52ec2c6ea..1624749e8 100644 --- a/src/blas/backends/portblas/portblas_common.hpp +++ b/src/blas/backends/portblas/portblas_common.hpp @@ -199,7 +199,12 @@ struct throw_if_unsupported_by_device { auto fn = [](auto&&... targs) { \ portBLASFunc(std::forward(targs)...); \ }; \ - std::apply(fn, args); \ + try { \ + std::apply(fn, args); \ + } \ + catch (const ::blas::unsupported_exception& e) { \ + throw unimplemented("blas", e.what()); \ + } \ } \ else { \ throw unimplemented("blas", "portBLAS function"); \ @@ -215,7 +220,12 @@ struct throw_if_unsupported_by_device { auto fn = [](auto&&... targs) { \ return portblasFunc(std::forward(targs)...).back(); \ }; \ - return std::apply(fn, args); \ + try { \ + return std::apply(fn, args); \ + } \ + catch (const ::blas::unsupported_exception& e) { \ + throw unimplemented("blas", e.what()); \ + } \ } \ else { \ throw unimplemented("blas", "portBLAS function"); \ diff --git a/tests/unit_tests/blas/extensions/omatcopy2.cpp b/tests/unit_tests/blas/extensions/omatcopy2.cpp index 50dcd0f6b..d0407c324 100644 --- a/tests/unit_tests/blas/extensions/omatcopy2.cpp +++ b/tests/unit_tests/blas/extensions/omatcopy2.cpp @@ -177,6 +177,8 @@ TEST_P(Omatcopy2Tests, RealSinglePrecision) { } TEST_P(Omatcopy2Tests, RealDoublePrecision) { + CHECK_DOUBLE_ON_DEVICE(std::get<0>(GetParam())); + EXPECT_TRUEORSKIP(test(std::get<0>(GetParam()), std::get<1>(GetParam()))); } @@ -185,6 +187,8 @@ TEST_P(Omatcopy2Tests, ComplexSinglePrecision) { } TEST_P(Omatcopy2Tests, ComplexDoublePrecision) { + CHECK_DOUBLE_ON_DEVICE(std::get<0>(GetParam())); + EXPECT_TRUEORSKIP(test>(std::get<0>(GetParam()), std::get<1>(GetParam()))); } diff --git a/tests/unit_tests/blas/extensions/omatcopy2_usm.cpp b/tests/unit_tests/blas/extensions/omatcopy2_usm.cpp index d12331c8e..d2103d243 100644 --- a/tests/unit_tests/blas/extensions/omatcopy2_usm.cpp +++ b/tests/unit_tests/blas/extensions/omatcopy2_usm.cpp @@ -186,6 +186,8 @@ TEST_P(Omatcopy2UsmTests, RealSinglePrecision) { } TEST_P(Omatcopy2UsmTests, RealDoublePrecision) { + CHECK_DOUBLE_ON_DEVICE(std::get<0>(GetParam())); + EXPECT_TRUEORSKIP(test(std::get<0>(GetParam()), std::get<1>(GetParam()))); } @@ -194,6 +196,8 @@ TEST_P(Omatcopy2UsmTests, ComplexSinglePrecision) { } TEST_P(Omatcopy2UsmTests, ComplexDoublePrecision) { + CHECK_DOUBLE_ON_DEVICE(std::get<0>(GetParam())); + EXPECT_TRUEORSKIP(test>(std::get<0>(GetParam()), std::get<1>(GetParam()))); }