From 6c030d2167092df6bd85d2f389741891b052c5e6 Mon Sep 17 00:00:00 2001 From: Justin Cai Date: Thu, 11 Jul 2024 05:28:59 -0700 Subject: [PATCH] [SYCL] Fix edge cases in ctanh and cexp (#14329) Due to formatting, the test case diff is a bit hard to see. Here is an overview of the new tests cases added and their old behavior: `ctanh(-inf + nan*i) == -1 + 0*i` (previously was `1 + 0*i`) `ctanh(-inf + nan*i) == -1 + 0*i` (previously was `1 + 0*i`) `ctanh(-inf + nan*i) == -1 + 0*i` (previously was `1 + 0*i`) `cexp(1e6 + 0*i) == inf + 0*i` (previously was `inf + nan*i`) `cexp(1e6 + 0.1*i) == inf + inf*i` (old behavior, just adding more coverage) `cexp(1e6 + -0.*i) == inf + -inf*i` (old behavior, just adding more coverage) --- libdevice/fallback-complex-fp64.cpp | 15 +- libdevice/fallback-complex.cpp | 14 +- .../DeviceLib/std_complex_math_fp64_test.cpp | 160 ++++++++++-------- .../DeviceLib/std_complex_math_test.cpp | 111 +++++++----- 4 files changed, 185 insertions(+), 115 deletions(-) diff --git a/libdevice/fallback-complex-fp64.cpp b/libdevice/fallback-complex-fp64.cpp index 5c4bc67924998..273ece4358067 100644 --- a/libdevice/fallback-complex-fp64.cpp +++ b/libdevice/fallback-complex-fp64.cpp @@ -153,8 +153,14 @@ double __complex__ __devicelib_cexp(double __complex__ z) { return z; } double __e = __spirv_ocl_exp(z_real); - return CMPLX((__e * __spirv_ocl_cos(z_imag)), - (__e * __spirv_ocl_sin(z_imag))); + double ret_real = __e * __spirv_ocl_cos(z_imag); + double ret_imag = __e * __spirv_ocl_sin(z_imag); + + if (__spirv_IsNan(ret_real)) + ret_real = 0.; + if (__spirv_IsNan(ret_imag)) + ret_imag = 0.; + return CMPLX(ret_real, ret_imag); } DEVICE_EXTERN_C_INLINE @@ -249,8 +255,9 @@ double __complex__ __devicelib_ctanh(double __complex__ z) { double z_imag = __devicelib_cimag(z); if (__spirv_IsInf(z_real)) { if (!__spirv_IsFinite(z_imag)) - return CMPLX(1.0, 0.0); - return CMPLX(1.0, __spirv_ocl_copysign(0.0, __spirv_ocl_sin(2.0 * z_imag))); + return CMPLX(__spirv_ocl_copysign(1.0, z_real), 0.0); + return CMPLX(__spirv_ocl_copysign(1.0, z_real), + __spirv_ocl_copysign(0.0, __spirv_ocl_sin(2.0 * z_imag))); } if (__spirv_IsNan(z_real) && z_imag == 0) return z; diff --git a/libdevice/fallback-complex.cpp b/libdevice/fallback-complex.cpp index daa8c234fbc88..e3f58b9eeb019 100644 --- a/libdevice/fallback-complex.cpp +++ b/libdevice/fallback-complex.cpp @@ -154,8 +154,14 @@ float __complex__ __devicelib_cexpf(float __complex__ z) { return z; } float __e = __spirv_ocl_exp(z_real); - return CMPLXF((__e * __spirv_ocl_cos(z_imag)), - (__e * __spirv_ocl_sin(z_imag))); + float ret_real = __e * __spirv_ocl_cos(z_imag); + float ret_imag = __e * __spirv_ocl_sin(z_imag); + + if (__spirv_IsNan(ret_real)) + ret_real = 0.f; + if (__spirv_IsNan(ret_imag)) + ret_imag = 0.f; + return CMPLXF(ret_real, ret_imag); } DEVICE_EXTERN_C_INLINE @@ -249,8 +255,8 @@ float __complex__ __devicelib_ctanhf(float __complex__ z) { float z_imag = __devicelib_cimagf(z); if (__spirv_IsInf(z_real)) { if (!__spirv_IsFinite(z_imag)) - return CMPLXF(1.0f, 0.0f); - return CMPLXF(1.0f, + return CMPLXF(__spirv_ocl_copysign(1.0f, z_real), 0.0f); + return CMPLXF(__spirv_ocl_copysign(1.0f, z_real), __spirv_ocl_copysign(0.0f, __spirv_ocl_sin(2.0f * z_imag))); } if (__spirv_IsNan(z_real) && z_imag == 0) diff --git a/sycl/test-e2e/DeviceLib/std_complex_math_fp64_test.cpp b/sycl/test-e2e/DeviceLib/std_complex_math_fp64_test.cpp index 9b903f23a6592..ca3d093416555 100644 --- a/sycl/test-e2e/DeviceLib/std_complex_math_fp64_test.cpp +++ b/sycl/test-e2e/DeviceLib/std_complex_math_fp64_test.cpp @@ -24,73 +24,78 @@ template bool approx_equal_cmplx(complex x, complex y) { approx_equal_fp(x.imag(), y.imag()); } -static constexpr auto TestArraySize1 = 57; -static constexpr auto TestArraySize2 = 10; +complex ref1_results[] = {complex(-1., 1.), + complex(1., 3.), + complex(-2., 10.), + complex(-8., 31.), + complex(1., 1.), + complex(2., 1.), + complex(2., 2.), + complex(3., 4.), + complex(2., 1.), + complex(0., 1.), + complex(2., 0.), + complex(0., 0.), + complex(0., 1.), + complex(1., 1.), + complex(2., 0.), + complex(2., 3.), + complex(1., 0.), + complex(0., 1.), + complex(-1., 0.), + complex(0., M_E), + complex(0., 0.), + complex(0., M_PI_2), + complex(0., M_PI), + complex(1., M_PI_2), + complex(0., 0.), + complex(1., 0.), + complex(1., 0.), + complex(-1., 0.), + complex(-INFINITY, 0.), + complex(1., 0.), + complex(10., 0.), + complex(100., 0.), + complex(200., 0.), + complex(1., 2.), + complex(INFINITY, 0.), + complex(INFINITY, 0.), + complex(0., 1.), + complex(M_PI_2, 0.), + complex(0., 0.), + complex(1., 0.), + complex(INFINITY, 0.), + complex(0., 0.), + complex(1., 0.), + complex(0., 0.), + complex(INFINITY, M_PI_2), + complex(INFINITY, 0.), + complex(0., M_PI_2), + complex(INFINITY, M_PI_2), + complex(INFINITY, 0.), + complex(0., 0.), + complex(0., M_PI_2), -std::array, TestArraySize1> ref1_results = { - complex(-1., 1.), - complex(1., 3.), - complex(-2., 10.), - complex(-8., 31.), - complex(1., 1.), - complex(2., 1.), - complex(2., 2.), - complex(3., 4.), - complex(2., 1.), - complex(0., 1.), - complex(2., 0.), - complex(0., 0.), - complex(0., 1.), - complex(1., 1.), - complex(2., 0.), - complex(2., 3.), - complex(1., 0.), - complex(0., 1.), - complex(-1., 0.), - complex(0., M_E), - complex(0., 0.), - complex(0., M_PI_2), - complex(0., M_PI), - complex(1., M_PI_2), - complex(0., 0.), - complex(1., 0.), - complex(1., 0.), - complex(-1., 0.), - complex(-INFINITY, 0.), - complex(1., 0.), - complex(10., 0.), - complex(100., 0.), - complex(200., 0.), - complex(1., 2.), - complex(INFINITY, 0.), - complex(INFINITY, 0.), - complex(0., 1.), - complex(M_PI_2, 0.), - complex(0., 0.), - complex(1., 0.), - complex(INFINITY, 0.), - complex(0., 0.), - complex(1., 0.), - complex(0., 0.), - complex(INFINITY, M_PI_2), - complex(INFINITY, 0.), - complex(0., M_PI_2), - complex(INFINITY, M_PI_2), - complex(INFINITY, 0.), - complex(0., 0.), - complex(0., M_PI_2), + complex(1., -4.), + complex(18., -7.), + complex(1.557407724654902, 0.), + complex(0, 0.761594155955765), + complex(M_PI_2, 0.), + complex(M_PI_2, 0.549306144334055), + complex(-1., 0.), + complex(-1., 0.), + complex(-1., 0.), + complex(INFINITY, 0.), + complex(INFINITY, INFINITY), + complex(INFINITY, -INFINITY)}; - complex(1., -4.), - complex(18., -7.), - complex(1.557407724654902, 0.), - complex(0, 0.761594155955765), - complex(M_PI_2, 0.), - complex(M_PI_2, 0.549306144334055)}; +double ref2_results[] = {0., 25., 169., INFINITY, 0., + 5., 13., INFINITY, 0., M_PI_2}; -std::array ref2_results = { - 0., 25., 169., INFINITY, 0., 5., 13., INFINITY, 0., M_PI_2}; +static constexpr auto TestArraySize1 = std::size(ref1_results); +static constexpr auto TestArraySize2 = 10; -void device_complex_test(s::queue &deviceQueue) { +int device_complex_test(s::queue &deviceQueue) { s::range<1> numOfItems1{TestArraySize1}; s::range<1> numOfItems2{TestArraySize2}; std::array, TestArraySize1> result1; @@ -172,6 +177,13 @@ void device_complex_test(s::queue &deviceQueue) { buf_out1_access[index++] = std::tan(complex(0., 1.)); buf_out1_access[index++] = std::asin(complex(1., 0.)); buf_out1_access[index++] = std::atan(complex(0., 2.)); + buf_out1_access[index++] = std::tanh(complex(-INFINITY, NAN)); + buf_out1_access[index++] = + std::tanh(complex(-INFINITY, -INFINITY)); + buf_out1_access[index++] = std::tanh(complex(-INFINITY, -2.)); + buf_out1_access[index++] = std::exp(complex(1e6, 0.)); + buf_out1_access[index++] = std::exp(complex(1e6, 0.1)); + buf_out1_access[index++] = std::exp(complex(1e6, -0.1)); index = 0; buf_out2_access[index++] = std::norm(complex(0., 0.)); @@ -188,16 +200,30 @@ void device_complex_test(s::queue &deviceQueue) { }); } + int n_fails = 0; for (size_t idx = 0; idx < TestArraySize1; ++idx) { - assert(approx_equal_cmplx(result1[idx], ref1_results[idx])); + if (!approx_equal_cmplx(result1[idx], ref1_results[idx])) { + ++n_fails; + std::cout << "test array 1 fail at index " << idx << "\n"; + std::cout << "expected: " << ref1_results[idx] << "\n"; + std::cout << "actual: " << result1[idx] << "\n"; + } } for (size_t idx = 0; idx < TestArraySize2; ++idx) { - assert(approx_equal_fp(result2[idx], ref2_results[idx])); + if (!approx_equal_fp(result2[idx], ref2_results[idx])) { + ++n_fails; + std::cout << "test array 2 fail at index " << idx << "\n"; + std::cout << "expected: " << ref2_results[idx] << "\n"; + std::cout << "actual: " << result2[idx] << "\n"; + } } + return n_fails; } int main() { s::queue deviceQueue; - device_complex_test(deviceQueue); - std::cout << "Pass" << std::endl; + auto n_fails = device_complex_test(deviceQueue); + if (n_fails == 0) + std::cout << "Pass" << std::endl; + return n_fails; } diff --git a/sycl/test-e2e/DeviceLib/std_complex_math_test.cpp b/sycl/test-e2e/DeviceLib/std_complex_math_test.cpp index d85171c4cf64f..3a75f78243a23 100644 --- a/sycl/test-e2e/DeviceLib/std_complex_math_test.cpp +++ b/sycl/test-e2e/DeviceLib/std_complex_math_test.cpp @@ -24,37 +24,34 @@ template bool approx_equal_cmplx(complex x, complex y) { approx_equal_fp(x.imag(), y.imag()); } -static constexpr auto TestArraySize1 = 41; -static constexpr auto TestArraySize2 = 10; -static constexpr auto TestArraySize3 = 16; - -std::array, TestArraySize1> ref1_results = { - complex(-1.f, 1.f), complex(1.f, 3.f), - complex(-2.f, 10.f), complex(-8.f, 31.f), - complex(1.f, 1.f), complex(2.f, 1.f), - complex(2.f, 2.f), complex(3.f, 4.f), - complex(2.f, 1.f), complex(0.f, 1.f), - complex(2.f, 0.f), complex(0.f, 0.f), - complex(1.f, 0.f), complex(0.f, 1.f), - complex(-1.f, 0.f), complex(0.f, M_E), - complex(0.f, 0.f), complex(0.f, M_PI_2), - complex(0.f, M_PI), complex(1.f, M_PI_2), - complex(0.f, 0.f), complex(1.f, 0.f), - complex(1.f, 0.f), complex(-1.f, 0.f), - complex(-INFINITY, 0.f), complex(1.f, 0.f), - complex(10.f, 0.f), complex(100.f, 0.f), - complex(200.f, 0.f), complex(1.f, 2.f), - complex(INFINITY, 0.f), complex(INFINITY, 0.f), - complex(0.f, 1.f), complex(0.f, 0.f), - complex(1.f, 0.f), complex(INFINITY, 0.f), - complex(0.f, 0.f), complex(0.f, M_PI_2), - complex(1.f, -4.f), complex(18.f, -7.f), - complex(M_PI_2, 0.549306f)}; - -std::array ref2_results = { - 0.f, 25.f, 169.f, INFINITY, 0.f, 5.f, 13.f, INFINITY, 0.f, M_PI_2}; - -std::array, TestArraySize3> ref3_results = { +complex ref1_results[] = { + complex(-1.f, 1.f), complex(1.f, 3.f), + complex(-2.f, 10.f), complex(-8.f, 31.f), + complex(1.f, 1.f), complex(2.f, 1.f), + complex(2.f, 2.f), complex(3.f, 4.f), + complex(2.f, 1.f), complex(0.f, 1.f), + complex(2.f, 0.f), complex(0.f, 0.f), + complex(1.f, 0.f), complex(0.f, 1.f), + complex(-1.f, 0.f), complex(0.f, M_E), + complex(0.f, 0.f), complex(0.f, M_PI_2), + complex(0.f, M_PI), complex(1.f, M_PI_2), + complex(0.f, 0.f), complex(1.f, 0.f), + complex(1.f, 0.f), complex(-1.f, 0.f), + complex(-INFINITY, 0.f), complex(1.f, 0.f), + complex(10.f, 0.f), complex(100.f, 0.f), + complex(200.f, 0.f), complex(1.f, 2.f), + complex(INFINITY, 0.f), complex(INFINITY, 0.f), + complex(0.f, 1.f), complex(0.f, 0.f), + complex(1.f, 0.f), complex(INFINITY, 0.f), + complex(0.f, 0.f), complex(0.f, M_PI_2), + complex(1.f, -4.f), complex(18.f, -7.f), + complex(M_PI_2, 0.549306f), complex(INFINITY, 0.f), + complex(INFINITY, INFINITY), complex(INFINITY, -INFINITY)}; + +float ref2_results[] = {0.f, 25.f, 169.f, INFINITY, 0.f, + 5.f, 13.f, INFINITY, 0.f, M_PI_2}; + +complex ref3_results[] = { complex(0.f, 1.f), complex(1.f, 1.f), complex(2.f, 0.f), complex(2.f, 3.f), complex(M_PI_2, 0.f), complex(0.f, 0.f), @@ -63,9 +60,14 @@ std::array, TestArraySize3> ref3_results = { complex(0.f, M_PI_2), complex(INFINITY, M_PI_2), complex(INFINITY, 0.f), complex(1.557408f, 0.f), complex(0.f, 0.761594f), complex(M_PI_2, 0.f), + complex(-1.f, 0.f), complex(-1.f, 0.f), + complex(-1.f, 0.f)}; -}; -void device_complex_test_1(s::queue &deviceQueue) { +static constexpr auto TestArraySize1 = std::size(ref1_results); +static constexpr auto TestArraySize2 = std::size(ref2_results); +static constexpr auto TestArraySize3 = std::size(ref3_results); + +int device_complex_test_1(s::queue &deviceQueue) { s::range<1> numOfItems1{TestArraySize1}; s::range<1> numOfItems2{TestArraySize2}; std::array, TestArraySize1> result1; @@ -131,6 +133,9 @@ void device_complex_test_1(s::queue &deviceQueue) { buf_out1_access[index++] = std::conj(complex(1.f, 4.f)); buf_out1_access[index++] = std::conj(complex(18.f, 7.f)); buf_out1_access[index++] = std::atan(complex(0.f, 2.f)); + buf_out1_access[index++] = std::exp(complex(1e6f, 0.f)); + buf_out1_access[index++] = std::exp(complex(1e6f, 0.1f)); + buf_out1_access[index++] = std::exp(complex(1e6f, -0.1f)); index = 0; buf_out2_access[index++] = std::norm(complex(0.f, 0.f)); @@ -147,12 +152,24 @@ void device_complex_test_1(s::queue &deviceQueue) { }); } + int n_fails = 0; for (size_t idx = 0; idx < TestArraySize1; ++idx) { - assert(approx_equal_cmplx(result1[idx], ref1_results[idx])); + if (!approx_equal_cmplx(result1[idx], ref1_results[idx])) { + ++n_fails; + std::cout << "test array 1 fail at index " << idx << "\n"; + std::cout << "expected: " << ref1_results[idx] << "\n"; + std::cout << "actual: " << result1[idx] << "\n"; + } } for (size_t idx = 0; idx < TestArraySize2; ++idx) { - assert(approx_equal_fp(result2[idx], ref2_results[idx])); + if (!approx_equal_fp(result2[idx], ref2_results[idx])) { + ++n_fails; + std::cout << "test array 2 fail at index " << idx << "\n"; + std::cout << "expected: " << ref2_results[idx] << "\n"; + std::cout << "actual: " << result2[idx] << "\n"; + } } + return n_fails; } // The MSVC implementation of some complex math functions depends on @@ -160,7 +177,7 @@ void device_complex_test_1(s::queue &deviceQueue) { // functions can only work on Windows with fp64 extension support from // underlying device. #ifndef _WIN32 -void device_complex_test_2(s::queue &deviceQueue) { +int device_complex_test_2(s::queue &deviceQueue) { s::range<1> numOfItems1{TestArraySize3}; std::array, TestArraySize3> result3; { @@ -185,13 +202,24 @@ void device_complex_test_2(s::queue &deviceQueue) { buf_out1_access[index++] = std::tan(complex(1.f, 0.f)); buf_out1_access[index++] = std::tan(complex(0.f, 1.f)); buf_out1_access[index++] = std::asin(complex(1.f, 0.f)); + buf_out1_access[index++] = std::tanh(complex(-INFINITY, NAN)); + buf_out1_access[index++] = + std::tanh(complex(-INFINITY, -INFINITY)); + buf_out1_access[index++] = std::tanh(complex(-INFINITY, -2.f)); }); }); } + int n_fails = 0; for (size_t idx = 0; idx < TestArraySize3; ++idx) { - assert(approx_equal_cmplx(result3[idx], ref3_results[idx])); + if (!approx_equal_cmplx(result3[idx], ref3_results[idx])) { + ++n_fails; + std::cout << "test array 3 fail at index " << idx << "\n"; + std::cout << "expected: " << ref3_results[idx] << "\n"; + std::cout << "actual: " << result3[idx] << "\n"; + } } + return n_fails; } #endif int main() { @@ -208,9 +236,12 @@ int main() { } #endif - device_complex_test_1(deviceQueue); + int n_fails = 0; + n_fails += device_complex_test_1(deviceQueue); #ifndef _WIN32 - device_complex_test_2(deviceQueue); + n_fails += device_complex_test_2(deviceQueue); #endif - std::cout << "Pass" << std::endl; + if (n_fails == 0) + std::cout << "Pass" << std::endl; + return n_fails; }