Skip to content

Commit

Permalink
[SYCL] Fix edge cases in ctanh and cexp (#14329)
Browse files Browse the repository at this point in the history
Due to formatting, the test case diff is a bit hard to see. Here is an
overview of the new tests cases added and their old behavior:

`ctanh(-inf + nan*i) == -1 + 0*i` (previously was `1 + 0*i`)
`ctanh(-inf + nan*i) == -1 + 0*i` (previously was `1 + 0*i`)
`ctanh(-inf + nan*i) == -1 + 0*i` (previously was `1 + 0*i`)
`cexp(1e6 + 0*i) == inf + 0*i` (previously was `inf + nan*i`)
`cexp(1e6 + 0.1*i) == inf + inf*i` (old behavior, just adding more
coverage)
`cexp(1e6 + -0.*i) == inf + -inf*i` (old behavior, just adding more
coverage)
  • Loading branch information
jzc committed Jul 11, 2024
1 parent 0f0b699 commit 6c030d2
Show file tree
Hide file tree
Showing 4 changed files with 185 additions and 115 deletions.
15 changes: 11 additions & 4 deletions libdevice/fallback-complex-fp64.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -153,8 +153,14 @@ double __complex__ __devicelib_cexp(double __complex__ z) {
return z;
}
double __e = __spirv_ocl_exp(z_real);
return CMPLX((__e * __spirv_ocl_cos(z_imag)),
(__e * __spirv_ocl_sin(z_imag)));
double ret_real = __e * __spirv_ocl_cos(z_imag);
double ret_imag = __e * __spirv_ocl_sin(z_imag);

if (__spirv_IsNan(ret_real))
ret_real = 0.;
if (__spirv_IsNan(ret_imag))
ret_imag = 0.;
return CMPLX(ret_real, ret_imag);
}

DEVICE_EXTERN_C_INLINE
Expand Down Expand Up @@ -249,8 +255,9 @@ double __complex__ __devicelib_ctanh(double __complex__ z) {
double z_imag = __devicelib_cimag(z);
if (__spirv_IsInf(z_real)) {
if (!__spirv_IsFinite(z_imag))
return CMPLX(1.0, 0.0);
return CMPLX(1.0, __spirv_ocl_copysign(0.0, __spirv_ocl_sin(2.0 * z_imag)));
return CMPLX(__spirv_ocl_copysign(1.0, z_real), 0.0);
return CMPLX(__spirv_ocl_copysign(1.0, z_real),
__spirv_ocl_copysign(0.0, __spirv_ocl_sin(2.0 * z_imag)));
}
if (__spirv_IsNan(z_real) && z_imag == 0)
return z;
Expand Down
14 changes: 10 additions & 4 deletions libdevice/fallback-complex.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -154,8 +154,14 @@ float __complex__ __devicelib_cexpf(float __complex__ z) {
return z;
}
float __e = __spirv_ocl_exp(z_real);
return CMPLXF((__e * __spirv_ocl_cos(z_imag)),
(__e * __spirv_ocl_sin(z_imag)));
float ret_real = __e * __spirv_ocl_cos(z_imag);
float ret_imag = __e * __spirv_ocl_sin(z_imag);

if (__spirv_IsNan(ret_real))
ret_real = 0.f;
if (__spirv_IsNan(ret_imag))
ret_imag = 0.f;
return CMPLXF(ret_real, ret_imag);
}

DEVICE_EXTERN_C_INLINE
Expand Down Expand Up @@ -249,8 +255,8 @@ float __complex__ __devicelib_ctanhf(float __complex__ z) {
float z_imag = __devicelib_cimagf(z);
if (__spirv_IsInf(z_real)) {
if (!__spirv_IsFinite(z_imag))
return CMPLXF(1.0f, 0.0f);
return CMPLXF(1.0f,
return CMPLXF(__spirv_ocl_copysign(1.0f, z_real), 0.0f);
return CMPLXF(__spirv_ocl_copysign(1.0f, z_real),
__spirv_ocl_copysign(0.0f, __spirv_ocl_sin(2.0f * z_imag)));
}
if (__spirv_IsNan(z_real) && z_imag == 0)
Expand Down
160 changes: 93 additions & 67 deletions sycl/test-e2e/DeviceLib/std_complex_math_fp64_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,73 +24,78 @@ template <typename T> bool approx_equal_cmplx(complex<T> x, complex<T> y) {
approx_equal_fp(x.imag(), y.imag());
}

static constexpr auto TestArraySize1 = 57;
static constexpr auto TestArraySize2 = 10;
complex<double> ref1_results[] = {complex<double>(-1., 1.),
complex<double>(1., 3.),
complex<double>(-2., 10.),
complex<double>(-8., 31.),
complex<double>(1., 1.),
complex<double>(2., 1.),
complex<double>(2., 2.),
complex<double>(3., 4.),
complex<double>(2., 1.),
complex<double>(0., 1.),
complex<double>(2., 0.),
complex<double>(0., 0.),
complex<double>(0., 1.),
complex<double>(1., 1.),
complex<double>(2., 0.),
complex<double>(2., 3.),
complex<double>(1., 0.),
complex<double>(0., 1.),
complex<double>(-1., 0.),
complex<double>(0., M_E),
complex<double>(0., 0.),
complex<double>(0., M_PI_2),
complex<double>(0., M_PI),
complex<double>(1., M_PI_2),
complex<double>(0., 0.),
complex<double>(1., 0.),
complex<double>(1., 0.),
complex<double>(-1., 0.),
complex<double>(-INFINITY, 0.),
complex<double>(1., 0.),
complex<double>(10., 0.),
complex<double>(100., 0.),
complex<double>(200., 0.),
complex<double>(1., 2.),
complex<double>(INFINITY, 0.),
complex<double>(INFINITY, 0.),
complex<double>(0., 1.),
complex<double>(M_PI_2, 0.),
complex<double>(0., 0.),
complex<double>(1., 0.),
complex<double>(INFINITY, 0.),
complex<double>(0., 0.),
complex<double>(1., 0.),
complex<double>(0., 0.),
complex<double>(INFINITY, M_PI_2),
complex<double>(INFINITY, 0.),
complex<double>(0., M_PI_2),
complex<double>(INFINITY, M_PI_2),
complex<double>(INFINITY, 0.),
complex<double>(0., 0.),
complex<double>(0., M_PI_2),

std::array<complex<double>, TestArraySize1> ref1_results = {
complex<double>(-1., 1.),
complex<double>(1., 3.),
complex<double>(-2., 10.),
complex<double>(-8., 31.),
complex<double>(1., 1.),
complex<double>(2., 1.),
complex<double>(2., 2.),
complex<double>(3., 4.),
complex<double>(2., 1.),
complex<double>(0., 1.),
complex<double>(2., 0.),
complex<double>(0., 0.),
complex<double>(0., 1.),
complex<double>(1., 1.),
complex<double>(2., 0.),
complex<double>(2., 3.),
complex<double>(1., 0.),
complex<double>(0., 1.),
complex<double>(-1., 0.),
complex<double>(0., M_E),
complex<double>(0., 0.),
complex<double>(0., M_PI_2),
complex<double>(0., M_PI),
complex<double>(1., M_PI_2),
complex<double>(0., 0.),
complex<double>(1., 0.),
complex<double>(1., 0.),
complex<double>(-1., 0.),
complex<double>(-INFINITY, 0.),
complex<double>(1., 0.),
complex<double>(10., 0.),
complex<double>(100., 0.),
complex<double>(200., 0.),
complex<double>(1., 2.),
complex<double>(INFINITY, 0.),
complex<double>(INFINITY, 0.),
complex<double>(0., 1.),
complex<double>(M_PI_2, 0.),
complex<double>(0., 0.),
complex<double>(1., 0.),
complex<double>(INFINITY, 0.),
complex<double>(0., 0.),
complex<double>(1., 0.),
complex<double>(0., 0.),
complex<double>(INFINITY, M_PI_2),
complex<double>(INFINITY, 0.),
complex<double>(0., M_PI_2),
complex<double>(INFINITY, M_PI_2),
complex<double>(INFINITY, 0.),
complex<double>(0., 0.),
complex<double>(0., M_PI_2),
complex<double>(1., -4.),
complex<double>(18., -7.),
complex<double>(1.557407724654902, 0.),
complex<double>(0, 0.761594155955765),
complex<double>(M_PI_2, 0.),
complex<double>(M_PI_2, 0.549306144334055),
complex<double>(-1., 0.),
complex<double>(-1., 0.),
complex<double>(-1., 0.),
complex<double>(INFINITY, 0.),
complex<double>(INFINITY, INFINITY),
complex<double>(INFINITY, -INFINITY)};

complex<double>(1., -4.),
complex<double>(18., -7.),
complex<double>(1.557407724654902, 0.),
complex<double>(0, 0.761594155955765),
complex<double>(M_PI_2, 0.),
complex<double>(M_PI_2, 0.549306144334055)};
double ref2_results[] = {0., 25., 169., INFINITY, 0.,
5., 13., INFINITY, 0., M_PI_2};

std::array<double, TestArraySize2> ref2_results = {
0., 25., 169., INFINITY, 0., 5., 13., INFINITY, 0., M_PI_2};
static constexpr auto TestArraySize1 = std::size(ref1_results);
static constexpr auto TestArraySize2 = 10;

void device_complex_test(s::queue &deviceQueue) {
int device_complex_test(s::queue &deviceQueue) {
s::range<1> numOfItems1{TestArraySize1};
s::range<1> numOfItems2{TestArraySize2};
std::array<complex<double>, TestArraySize1> result1;
Expand Down Expand Up @@ -172,6 +177,13 @@ void device_complex_test(s::queue &deviceQueue) {
buf_out1_access[index++] = std::tan(complex<double>(0., 1.));
buf_out1_access[index++] = std::asin(complex<double>(1., 0.));
buf_out1_access[index++] = std::atan(complex<double>(0., 2.));
buf_out1_access[index++] = std::tanh(complex<double>(-INFINITY, NAN));
buf_out1_access[index++] =
std::tanh(complex<double>(-INFINITY, -INFINITY));
buf_out1_access[index++] = std::tanh(complex<double>(-INFINITY, -2.));
buf_out1_access[index++] = std::exp(complex<double>(1e6, 0.));
buf_out1_access[index++] = std::exp(complex<double>(1e6, 0.1));
buf_out1_access[index++] = std::exp(complex<double>(1e6, -0.1));

index = 0;
buf_out2_access[index++] = std::norm(complex<double>(0., 0.));
Expand All @@ -188,16 +200,30 @@ void device_complex_test(s::queue &deviceQueue) {
});
}

int n_fails = 0;
for (size_t idx = 0; idx < TestArraySize1; ++idx) {
assert(approx_equal_cmplx(result1[idx], ref1_results[idx]));
if (!approx_equal_cmplx(result1[idx], ref1_results[idx])) {
++n_fails;
std::cout << "test array 1 fail at index " << idx << "\n";
std::cout << "expected: " << ref1_results[idx] << "\n";
std::cout << "actual: " << result1[idx] << "\n";
}
}
for (size_t idx = 0; idx < TestArraySize2; ++idx) {
assert(approx_equal_fp(result2[idx], ref2_results[idx]));
if (!approx_equal_fp(result2[idx], ref2_results[idx])) {
++n_fails;
std::cout << "test array 2 fail at index " << idx << "\n";
std::cout << "expected: " << ref2_results[idx] << "\n";
std::cout << "actual: " << result2[idx] << "\n";
}
}
return n_fails;
}

int main() {
s::queue deviceQueue;
device_complex_test(deviceQueue);
std::cout << "Pass" << std::endl;
auto n_fails = device_complex_test(deviceQueue);
if (n_fails == 0)
std::cout << "Pass" << std::endl;
return n_fails;
}
Loading

0 comments on commit 6c030d2

Please sign in to comment.