Skip to content

Commit

Permalink
[BLAS][portBLAS] Add try/catch for portblas runtime exception & minor…
Browse files Browse the repository at this point in the history
… fix (#525)

Signed-off-by: nscipione <nicolo.scipione@codeplay.com>

* Catch PortBLAS's unsupported exceptions and rethrow as mkl::unimplemented.
* Add missing checks for the device having double support in tests.
  • Loading branch information
s-Nick authored Jul 17, 2024
1 parent b4955ca commit 60bd6d1
Show file tree
Hide file tree
Showing 3 changed files with 20 additions and 2 deletions.
14 changes: 12 additions & 2 deletions src/blas/backends/portblas/portblas_common.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -199,7 +199,12 @@ struct throw_if_unsupported_by_device {
auto fn = [](auto&&... targs) { \
portBLASFunc(std::forward<decltype(targs)>(targs)...); \
}; \
std::apply(fn, args); \
try { \
std::apply(fn, args); \
} \
catch (const ::blas::unsupported_exception& e) { \
throw unimplemented("blas", e.what()); \
} \
} \
else { \
throw unimplemented("blas", "portBLAS function"); \
Expand All @@ -215,7 +220,12 @@ struct throw_if_unsupported_by_device {
auto fn = [](auto&&... targs) { \
return portblasFunc(std::forward<decltype(targs)>(targs)...).back(); \
}; \
return std::apply(fn, args); \
try { \
return std::apply(fn, args); \
} \
catch (const ::blas::unsupported_exception& e) { \
throw unimplemented("blas", e.what()); \
} \
} \
else { \
throw unimplemented("blas", "portBLAS function"); \
Expand Down
4 changes: 4 additions & 0 deletions tests/unit_tests/blas/extensions/omatcopy2.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -177,6 +177,8 @@ TEST_P(Omatcopy2Tests, RealSinglePrecision) {
}

TEST_P(Omatcopy2Tests, RealDoublePrecision) {
CHECK_DOUBLE_ON_DEVICE(std::get<0>(GetParam()));

EXPECT_TRUEORSKIP(test<double>(std::get<0>(GetParam()), std::get<1>(GetParam())));
}

Expand All @@ -185,6 +187,8 @@ TEST_P(Omatcopy2Tests, ComplexSinglePrecision) {
}

TEST_P(Omatcopy2Tests, ComplexDoublePrecision) {
CHECK_DOUBLE_ON_DEVICE(std::get<0>(GetParam()));

EXPECT_TRUEORSKIP(test<std::complex<double>>(std::get<0>(GetParam()), std::get<1>(GetParam())));
}

Expand Down
4 changes: 4 additions & 0 deletions tests/unit_tests/blas/extensions/omatcopy2_usm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -186,6 +186,8 @@ TEST_P(Omatcopy2UsmTests, RealSinglePrecision) {
}

TEST_P(Omatcopy2UsmTests, RealDoublePrecision) {
CHECK_DOUBLE_ON_DEVICE(std::get<0>(GetParam()));

EXPECT_TRUEORSKIP(test<double>(std::get<0>(GetParam()), std::get<1>(GetParam())));
}

Expand All @@ -194,6 +196,8 @@ TEST_P(Omatcopy2UsmTests, ComplexSinglePrecision) {
}

TEST_P(Omatcopy2UsmTests, ComplexDoublePrecision) {
CHECK_DOUBLE_ON_DEVICE(std::get<0>(GetParam()));

EXPECT_TRUEORSKIP(test<std::complex<double>>(std::get<0>(GetParam()), std::get<1>(GetParam())));
}

Expand Down

0 comments on commit 60bd6d1

Please sign in to comment.