Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SYCL][DOC] Expand complex extension with complex support for sycl::marray #11792

Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
259 changes: 247 additions & 12 deletions sycl/doc/extensions/experimental/sycl_ext_oneapi_complex.asciidoc
Original file line number Diff line number Diff line change
Expand Up @@ -54,10 +54,20 @@ specification.*
While {dpcpp} has support for `std::complex` in device code, it limits the
complex interface and operations to the existing C++ standard. This proposal
defines a SYCL complex extension based on but independent of the `std::complex`
interface. This framework would allow for further development of complex math
within oneAPI. Possible areas for deviation with `std::complex` include adding
complex support for `marray` and `vec` and overloading mathematical
functions to handle the element-wise operations.
interface.

The proposed framework not only encompasses complex support for traditional use
cases but also accommodate for advanced mathematical features and data
jle-quel marked this conversation as resolved.
Show resolved Hide resolved
structures.

Specifically, we propose to incorporate complex support for `sycl::marray`.
This addition will empower developers to store complex numbers seamlessly
within arrays, opening up new possibilities for data manipulation and
jle-quel marked this conversation as resolved.
Show resolved Hide resolved
computation.

Furthermore, this extension involves overloading existing mathematical
functions to facilitate scalar operation on complex numbers as well as
element-wise operations on complex marrays.

== Specification

Expand Down Expand Up @@ -211,17 +221,125 @@ namespace sycl::ext::oneapi::experimental {
} // namespace sycl::ext::oneapi::experimental
```

=== Mathematical operations
=== Marray Complex Class Specialization

This proposal also introduces the specialization of the marray class to
jle-quel marked this conversation as resolved.
Show resolved Hide resolved
support SYCL complex. The marray class undergoes slight modification for this
specialization, primarily involving the removal of operators that are
inapplicable. No new functions or operators are introduced to the marray class.

The marray complex specialization maintains the principles of trivial
copyability (as seen in the Complex class description), with the
jle-quel marked this conversation as resolved.
Show resolved Hide resolved
`is_device_copyable` type trait resolving to `std::true_type`.

The marray definition used within this proposal assumes that any operator the
`sycl::marray` class defines is only implemented if the marray's value type
also implements the operator.

For instance,
`sycl::marray<sycl::ext::oneapi::experimental::complex<T>, NumElements>` does
not implement the modulus operator since
`sycl::ext::oneapi::experimental::complex<T>` does not support it.
jle-quel marked this conversation as resolved.
Show resolved Hide resolved

```C++
namespace sycl {

// Specialization of exiting `marray` class for `sycl::ext::oneapi::experimental::complex`
template <typename T, std::size_t NumElements>
class marray<sycl::ext::oneapi::experimental::complex<T>, NumElements> {
public:

/* ... */

friend marray operator %(const marray &lhs, const marray &rhs) = delete;
friend marray operator %(const marray &lhs, const value_type &rhs) = delete;
friend marray operator %(const value_type &lhs, const marray &rhs) = delete;

friend marray &operator %=(marray &lhs, const marray &rhs) = delete;
friend marray &operator %=(marray &lhs, const value_type &rhs) = delete;
friend marray &operator %=(value_type &lhs, const marray &rhs) = delete;

friend marray operator ++(marray &lhs, int) = delete;
friend marray &operator ++(marray & rhs) = delete;

friend marray operator --(marray &lhs, int) = delete;
friend marray &operator --(marray & rhs) = delete;

friend marray operator &(const marray &lhs, const marray &rhs) = delete;
friend marray operator &(const marray &lhs, const value_type &rhs) = delete;

friend marray operator |(const marray &lhs, const marray &rhs) = delete;
friend marray operator |(const marray &lhs, const value_type &rhs) = delete;

friend marray operator ^(const marray &lhs, const marray &rhs) = delete;
friend marray operator ^(const marray &lhs, const value_type &rhs) = delete;

friend marray &operator &=(marray & lhs, const marray & rhs) = delete;
friend marray &operator &=(marray & lhs, const value_type & rhs) = delete;
friend marray &operator &=(value_type & lhs, const marray & rhs) = delete;

friend marray &operator |=(marray & lhs, const marray & rhs) = delete;
friend marray &operator |=(marray & lhs, const value_type & rhs) = delete;
friend marray &operator |=(value_type & lhs, const marray & rhs) = delete;

friend marray &operator ^=(marray & lhs, const marray & rhs) = delete;
friend marray &operator ^=(marray & lhs, const value_type & rhs) = delete;
friend marray &operator ^=(value_type & lhs, const marray & rhs) = delete;

friend marray<bool, NumElements> operator <<(const marray & lhs, const marray & rhs) = delete;
friend marray<bool, NumElements> operator <<(const marray & lhs, const value_type & rhs) = delete;
friend marray<bool, NumElements> operator <<(const value_type & lhs, const marray & rhs) = delete;

friend marray<bool, NumElements> operator >>(const marray & lhs, const marray & rhs) = delete;
friend marray<bool, NumElements> operator >>(const marray & lhs, const value_type & rhs) = delete;
friend marray<bool, NumElements> operator >>(const value_type & lhs, const marray & rhs) = delete;

friend marray &operator <<=(marray & lhs, const marray & rhs) = delete;
friend marray &operator <<=(marray & lhs, const value_type & rhs) = delete;

friend marray &operator >>=(marray & lhs, const marray & rhs) = delete;
friend marray &operator >>=(marray & lhs, const value_type & rhs) = delete;

This proposal adds to the `sycl::ext::oneapi::experimental` namespace, math
functions accepting the complex types `complex<sycl::half>`, `complex<float>`,
`complex<double>` as well as the scalar types `sycl::half`, `float` and `double`
for the SYCL math functions, `abs`, `acos`, `asin`, `atan`, `acosh`, `asinh`,
`atanh`, `arg`, `conj`, `cos`, `cosh`, `exp`, `log`, `log10`, `norm`, `polar`,
`pow`, `proj`, `sin`, `sinh`, `sqrt`, `tan`, and `tanh`.
friend marray<bool, NumElements> operator <(const marray & lhs, const marray & rhs) = delete;
friend marray<bool, NumElements> operator <(const marray & lhs, const value_type & rhs) = delete;
friend marray<bool, NumElements> operator <(const value_type & lhs, const marray & rhs) = delete;

friend marray<bool, NumElements> operator >(const marray & lhs, const marray & rhs) = delete;
friend marray<bool, NumElements> operator >(const marray & lhs, const value_type & rhs) = delete;
friend marray<bool, NumElements> operator >(const value_type & lhs, const marray & rhs) = delete;

friend marray<bool, NumElements> operator <=(const marray & lhs, const marray & rhs) = delete;
friend marray<bool, NumElements> operator <=(const marray & lhs, const value_type & rhs) = delete;
friend marray<bool, NumElements> operator <=(const value_type & lhs, const marray & rhs) = delete;

friend marray<bool, NumElements> operator >=(const marray & lhs, const marray & rhs) = delete;
friend marray<bool, NumElements> operator >=(const marray & lhs, const value_type & rhs) = delete;
friend marray<bool, NumElements> operator >=(const value_type & lhs, const marray & rhs) = delete;

friend marray operator ~(const marray &v) = delete;

friend marray<bool, NumElements> operator !(const marray &v) = delete;
};

} // namespace sycl
```

=== Scalar Mathematical operations

This proposal extends the `sycl::ext::oneapi::experimental` namespace math
functions to accept `complex<sycl::half>`, `complex<float>`, `complex<double>`
as well as the scalar types `sycl::half`, `float` and `double` for a range of
SYCL math functions.

Specifically, it adds support for `abs`, `acos`, `asin`, `atan`, `acosh`,
`asinh`, `atanh`, `arg`, `conj`, `cos`, `cosh`, `exp`, `log`, `log10`, `norm`,
`polar`, `pow`, `proj`, `sin`, `sinh`, `sqrt`, `tan`, and `tanh`.

Additionally, this extension introduces support for the `real` and `imag` free
functions, which the real and imaginary component, respectively.
jle-quel marked this conversation as resolved.
Show resolved Hide resolved

These functions are available in both host and device code, and each math
function should follow the C++ standard for handling NaN's and Inf values.
function should follow the C++ standard for handling `NaN` and `Inf` values.

Note: In the case of the `pow` function, additional overloads have been added
to ensure that for their first argument `base` and second argument `exponent`:
Expand Down Expand Up @@ -319,6 +437,123 @@ namespace sycl::ext::oneapi::experimental {
} // namespace sycl::ext::oneapi::experimental
```

=== Element-Wise Mathematical operations

In harmony with the complex scalar operations, this proposal extends
furthermore the `sycl::ext::oneapi::experimental`` namespace math functions
jle-quel marked this conversation as resolved.
Show resolved Hide resolved
to accept `sycl::marray<complex<T>>` for a range of SYCL math functions.

Specifically, it adds support for `abs`, `acos`, `asin`, `atan`, `acosh`,
`asinh`, `atanh`, `arg`, `conj`, `cos`, `cosh`, `exp`, `log`, `log10`, `norm`,
`polar`, `pow`, `proj`, `sin`, `sinh`, `sqrt`, `tan`, and `tanh`.

Additionally, this extension introduces support for the `real` and `imag` free
functions, which return marrays of scalar values representing the real and
imaginary components, respectively.
Pennycook marked this conversation as resolved.
Show resolved Hide resolved

In scenarios where mathematical functions involve both marray and scalar
parameters, two sets of overloads are introduced marray-scalar and
scalar-marray.

These mathematical operations are designed to execute element-wise across the
marray, ensuring that each operation is applied to every element within the
marray.

Moreover, this proposal includes overloads for mathematical functions between
marrays and scalar inputs. In these cases, the operations are executed across
the entire marray, with the scalar value held constant.

For consistency, these functions are available in both host and device code,
and each math function should follow the C++ standard for handling `NaN` and
`Inf` values.

```C++
namespace sycl/ext/oneapi/experimental {

/// VALUES:
/// Returns an marray of real components from the marray x.
template <typename T, std::size_t NumElements>
sycl::marray<T, NumElements> real(const marray<complex<T>, NumElements> &x);
/// Returns an marray of imaginary components from the marray x.
template <typename T, std::size_t NumElements>
sycl::marray<T, NumElements> imag(const marray<complex<T>, NumElements> &x);

/// Compute the magnitude for each complex number in marray x.
template <typename T, std::size_t NumElements> marray<T, NumElements> abs(const marray<complex<T>, NumElements> &x);
/// Compute phase angle in radians for each complex number in marray x.
template <typename T, std::size_t NumElements> marray<T, NumElements> arg(const marray<complex<T>, NumElements> &x);
/// Compute the squared magnitude for each complex number in marray x.
template <typename T, std::size_t NumElements> marray<T, NumElements> norm(const marray<complex<T>, NumElements> &x);
/// Compute the conjugate for each complex number in marray x.
template <typename T, std::size_t NumElements> marray<complex<T>, NumElements> conj(const marray<complex<T>, NumElements> &x);
/// Compute the projection for each complex number in marray x.
template <typename T, std::size_t NumElements> marray<complex<T>, NumElements> proj(const marray<complex<T>, NumElements> &x);
/// Compute the projection for each real number in marray x.
template <typename T, std::size_t NumElements> marray<complex<T>, NumElements> proj(const marray<T, NumElements> &x);
/// Construct an marray, elementwise, of complex numbers from each polar coordinate in marray rho and scalar theta.
template <typename T, std::size_t NumElements> marray<complex<T>, NumElements> polar(const marray<T, NumElements> &rho, T theta = 0);
/// Construct an marray, elementwise, of complex numbers from each polar coordinate in marray rho and marray theta.
template <typename T, std::size_t NumElements> marray<complex<T>, NumElements> polar(const marray<T, NumElements> &rho, const marray<T, NumElements> &theta);
/// Construct an marray, elementwise, of complex numbers from each polar coordinate in scalar rho and marray theta.
template <typename T, std::size_t NumElements> marray<complex<T>, NumElements> polar(T rho, const marray<T, NumElements> &theta);

/// TRANSCENDENTALS:
/// Compute the natural log for each complex number in marray x.
template <typename T, std::size_t NumElements> marray<complex<T>, NumElements> log(const marray<complex<T>, NumElements> &x);
/// Compute the base-10 log for each complex number in marray x.
template <typename T, std::size_t NumElements> marray<complex<T>, NumElements> log10(const marray<complex<T>, NumElements> &x);
/// Compute the square root for each complex number in marray x.
template <typename T, std::size_t NumElements> marray<complex<T>, NumElements> sqrt(const marray<complex<T>, NumElements> &x);
/// Compute the base-e exponent for each complex number in marray x.
template <typename T, std::size_t NumElements> marray<complex<T>, NumElements> exp(const marray<complex<T>, NumElements> &x);

/// Raise each complex element in x to the power of the corresponding decimal element in y.
template <typename T, std::size_t NumElements> marray<complex<T>, NumElements> pow(const marray<complex<T>, NumElements> &x, const marray<T, NumElements> &y);
/// Raise each complex element in x to the power of the decimal number y.
template <typename T, std::size_t NumElements> marray<complex<T>, NumElements> pow(const marray<complex<T>, NumElements> &x, T y);
/// Raise complex number x to the power of each decimal element in y.
template <typename T, std::size_t NumElements> marray<complex<T>, NumElements> pow(const marray<complex<T>, NumElements> &x, const marray<T, NumElements> &y);
/// Raise each complex element in x to the power of the corresponding complex element in y.
template <typename T, std::size_t NumElements> marray<complex<T>, NumElements> pow(const marray<complex<T>, NumElements> &x, const marray<complex<T>, NumElements> &y);
/// Raise each complex element in x to the power of the complex number y.
template <typename T, std::size_t NumElements> marray<complex<T>, NumElements> pow(const marray<complex<T>, NumElements> &x, const marray<complex<T>, NumElements> &y);
/// Raise complex number x to the power of each complex element in y.
template <typename T, std::size_t NumElements> marray<complex<T>, NumElements> pow(const marray<complex<T>, NumElements> &x, const marray<complex<T>, NumElements> &y);
/// Raise each decimal element in x to the power of the corresponding complex element in y.
template <typename T, std::size_t NumElements> marray<complex<T>, NumElements> pow(const marray<T, NumElements> &x, const marray<complex<T>, NumElements> &y);
/// Raise each decimal element in x to the power of the complex number y.
template <typename T, std::size_t NumElements> marray<complex<T>, NumElements> pow(const marray<T, NumElements> &x, const marray<complex<T>, NumElements> &y);
/// Raise decimal number x to the power of each complex element in y.
template <typename T, std::size_t NumElements> marray<complex<T>, NumElements> pow(T x, const marray<complex<T>, NumElements> &y);

/// Compute the inverse hyperbolic sine for each complex number in marray x.
template <typename T, std::size_t NumElements> marray<complex<T>, NumElements> asinh(const marray<complex<T>, NumElements> &x);
/// Compute the inverse hyperbolic cosine for each complex number in marray x.
template <typename T, std::size_t NumElements> marray<complex<T>, NumElements> acosh(const marray<complex<T>, NumElements> &x);
/// Compute the inverse hyperbolic tangent for each complex number in marray x.
template <typename T, std::size_t NumElements> marray<complex<T>, NumElements> atanh(const marray<complex<T>, NumElements> &x);
/// Compute the hyperbolic sine for each complex number in marray x.
template <typename T, std::size_t NumElements> marray<complex<T>, NumElements> sinh(const marray<complex<T>, NumElements> &x);
/// Compute the hyperbolic cosine for each complex number in marray x.
template <typename T, std::size_t NumElements> marray<complex<T>, NumElements> cosh(const marray<complex<T>, NumElements> &x);
/// Compute the hyperbolic tangent for each complex number in marray x.
template <typename T, std::size_t NumElements> marray<complex<T>, NumElements> tanh(const marray<complex<T>, NumElements> &x);
/// Compute the inverse sine for each complex number in marray x.
template <typename T, std::size_t NumElements> marray<complex<T>, NumElements> asin(const marray<complex<T>, NumElements> &x);
/// Compute the inverse cosine for each complex number in marray x.
template <typename T, std::size_t NumElements> marray<complex<T>, NumElements> acos(const marray<complex<T>, NumElements> &x);
/// Compute the inverse tangent for each complex number in marray x.
template <typename T, std::size_t NumElements> marray<complex<T>, NumElements> atan(const marray<complex<T>, NumElements> &x);
/// Compute the sine for each complex number in marray x.
template <typename T, std::size_t NumElements> marray<complex<T>, NumElements> sin(const marray<complex<T>, NumElements> &x);
/// Compute the cosine for each complex number in marray x.
template <typename T, std::size_t NumElements> marray<complex<T>, NumElements> cos(const marray<complex<T>, NumElements> &x);
/// Compute the tangent for each complex number in marray x.
template <typename T, std::size_t NumElements> marray<complex<T>, NumElements> tan(const marray<complex<T>, NumElements> &x);

} // namespace sycl::ext::oneapi::experimental
```

== Implementation notes

The complex mathematical operations can all be defined using SYCL built-ins.
Expand Down