diff --git a/include/oneapi/mkl/rng/device/detail/beta_impl.hpp b/include/oneapi/mkl/rng/device/detail/beta_impl.hpp new file mode 100644 index 000000000..e412ee157 --- /dev/null +++ b/include/oneapi/mkl/rng/device/detail/beta_impl.hpp @@ -0,0 +1,468 @@ +/******************************************************************************* +* Copyright 2024 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, +* software distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions +* and limitations under the License. +* +* +* SPDX-License-Identifier: Apache-2.0 +*******************************************************************************/ + +#ifndef _MKL_RNG_DEVICE_BETA_IMPL_HPP_ +#define _MKL_RNG_DEVICE_BETA_IMPL_HPP_ + +#include "vm_wrappers.hpp" + +namespace oneapi::mkl::rng::device::detail { + +enum class beta_algorithm { Johnk = 0, Atkinson1, Atkinson2, Atkinson3, Cheng, p1, q1, p1q1 }; + +// log(4)=1.3862944.. +template +inline DataType log4() { + if constexpr (std::is_same_v) + return 0x1.62e42fefa39efp+0; + else + return 0x1.62e43p+0f; +} + +// K=0.85225521765372429631847 +template +inline DataType beta_k() { + if constexpr (std::is_same_v) + return 0x1.b45acbbf56123p-1; + else + return 0x1.b45accp-1f; +} + +// C=-0.956240971340815081432202 +template +inline DataType beta_c() { + if constexpr (std::is_same_v) + return -0x1.e9986aa60216p-1; + else + return -0x1.e9986ap-1f; +} + +template +class distribution_base> { +public: + struct param_type { + param_type(RealType p, RealType q, RealType a, RealType b) : p_(p), q_(q), a_(a), b_(b) {} + RealType p_; + RealType q_; + RealType a_; + RealType b_; + }; + + distribution_base(RealType p, RealType q, RealType a, RealType b) + : p_(p), + q_(q), + a_(a), + b_(b), + count_(0) { + set_algorithm(); +#ifndef __SYCL_DEVICE_ONLY__ + if (p <= RealType(0.0)) { + throw oneapi::mkl::invalid_argument("rng", "beta", "p <= 0"); + } + else if (q <= RealType(0.0)) { + throw oneapi::mkl::invalid_argument("rng", "beta", "q <= 0"); + } + else if (b <= RealType(0.0)) { + throw oneapi::mkl::invalid_argument("rng", "beta", "b <= 0"); + } +#endif + } + + RealType p() const { + return p_; + } + + RealType q() const { + return q_; + } + + RealType a() const { + return a_; + } + + RealType b() const { + return b_; + } + + std::size_t count_rejected_numbers() const { + return count_; + } + + param_type param() const { + return param_type(p_, q_, a_, b_); + } + + void param(const param_type& pt) { +#ifndef __SYCL_DEVICE_ONLY__ + if (pt.p_ <= RealType(0.0)) { + throw oneapi::mkl::invalid_argument("rng", "beta", "p <= 0"); + } + else if (pt.q_ <= RealType(0.0)) { + throw oneapi::mkl::invalid_argument("rng", "beta", "q <= 0"); + } + else if (pt.b_ <= RealType(0.0)) { + throw oneapi::mkl::invalid_argument("rng", "beta", "b <= 0"); + } +#endif + p_ = pt.p_; + q_ = pt.q_; + a_ = pt.a_; + b_ = pt.b_; + set_algorithm(); + } + +protected: + template + T pq_kernel(T& z) { + for (std::int32_t i = 0; i < n; i++) { + if (p_ == RealType(1.0)) { + z[i] = pow_wrapper(z[i], RealType(1) / q_); + z[i] = RealType(1.0) - z[i]; + } + if (q_ == RealType(1.0)) { + z[i] = pow_wrapper(z[i], RealType(1) / p_); + } + } + count_ = 0; + + // p1q1 + return a_ + b_ * z; + } + + template + T acc_rej_kernel(T& z, EngineType& engine) { + RealType s, t; + + RealType flKoef1, flKoef2, flKoef3, flKoef4, flKoef5, flKoef6; + RealType flDeg[2]; + + if (algorithm_ == beta_algorithm::Atkinson1) { + RealType flInv_s[2], flTmp[2]; + flTmp[0] = p_ * (RealType(1.0) - p_); + flTmp[1] = q_ * (RealType(1.0) - q_); + + flTmp[0] = sqrt_wrapper(flTmp[0]); + flTmp[1] = sqrt_wrapper(flTmp[1]); + + t = flTmp[0] / (flTmp[0] + flTmp[1]); + + s = q_ * t; + s = s / (s + p_ * (RealType(1.0) - t)); + + flInv_s[0] = RealType(1.0) / s; + flInv_s[1] = RealType(1.0) / (RealType(1.0) - s); + flDeg[0] = RealType(1.0) / p_; + flDeg[1] = RealType(1.0) / q_; + + flInv_s[0] = pow_wrapper(flInv_s[0], flDeg[0]); + flInv_s[1] = pow_wrapper(flInv_s[1], flDeg[1]); + + flKoef1 = t * flInv_s[0]; + flKoef2 = (RealType(1.0) - t) * flInv_s[1]; + flKoef3 = RealType(1.0) - q_; + flKoef4 = RealType(1.0) - p_; + flKoef5 = RealType(1.0) / (RealType(1.0) - t); + flKoef6 = RealType(1.0) / t; + } + else if (algorithm_ == beta_algorithm::Atkinson2) { + RealType flInv_s[2], flTmp; + + t = RealType(1.0) - p_; + t /= (t + q_); + + flTmp = RealType(1.0) - t; + flTmp = pow_wrapper(flTmp, q_); + s = q_ * t; + s /= (s + p_ * flTmp); + + flInv_s[0] = RealType(1.0) / s; + flInv_s[1] = RealType(1.0) / (RealType(1.0) - s); + flDeg[0] = RealType(1.0) / p_; + flDeg[1] = RealType(1.0) / q_; + + flInv_s[0] = pow_wrapper(flInv_s[0], flDeg[0]); + flInv_s[1] = pow_wrapper(flInv_s[1], flDeg[1]); + + flKoef1 = t * flInv_s[0]; + flKoef2 = (RealType(1.0) - t) * flInv_s[1]; + flKoef3 = RealType(1.0) - q_; + flKoef4 = RealType(1.0) - p_; + } + else if (algorithm_ == beta_algorithm::Atkinson3) { + RealType flInv_s[2], flTmp; + + t = RealType(1.0) - q_; + t /= (t + p_); + + flTmp = RealType(1.0) - t; + flTmp = pow_wrapper(flTmp, p_); + s = p_ * t; + s /= (s + q_ * flTmp); + + flInv_s[0] = RealType(1.0) / s; + flInv_s[1] = RealType(1.0) / (RealType(1.0) - s); + flDeg[0] = RealType(1.0) / q_; + flDeg[1] = RealType(1.0) / p_; + + flInv_s[0] = pow_wrapper(flInv_s[0], flDeg[0]); + flInv_s[1] = pow_wrapper(flInv_s[1], flDeg[1]); + + flKoef1 = t * flInv_s[0]; + flKoef2 = (RealType(1.0) - t) * flInv_s[1]; + flKoef3 = RealType(1.0) - p_; + flKoef4 = RealType(1.0) - q_; + } + else if (algorithm_ == beta_algorithm::Cheng) { + flKoef1 = p_ + q_; + flKoef2 = (flKoef1 - RealType(2.0)) / (RealType(2.0) * p_ * q_ - flKoef1); + flKoef2 = sqrt_wrapper(flKoef2); + flKoef3 = p_ + RealType(1.0) / flKoef2; + } + + RealType z1, z2; + + count_ = 0; + for (int i = 0; i < n; i++) { + while (1) { // looping until satisfied + z1 = engine.generate_single(RealType(0), RealType(1)); + z2 = engine.generate_single(RealType(0), RealType(1)); + + if (algorithm_ == beta_algorithm::Johnk) { + RealType flU1, flU2, flSum; + z1 = ln_wrapper(z1) / p_; + z2 = ln_wrapper(z2) / q_; + + z1 = exp_wrapper(z1); + z2 = exp_wrapper(z2); + + flU1 = z1; + flU2 = z2; + flSum = flU1 + flU2; + if (flSum > RealType(0.0) && flSum <= RealType(1.0)) { + z[i] = flU1 / flSum; + break; + } + } + if (algorithm_ == beta_algorithm::Atkinson1) { + RealType flU, flExp, flX, flLn; + z2 = ln_wrapper(z2); + + flU = z1; + flExp = z2; + if (flU <= s) { + flU = pow_wrapper(flU, flDeg[0]); + flX = flKoef1 * flU; + flLn = (RealType(1.0) - flX) * flKoef5; + flLn = ln_wrapper(flLn); + if (flKoef3 * flLn + flExp <= RealType(0.0)) { + z[i] = flX; + break; + } + } + else { + flU = RealType(1.0) - flU; + flU = pow_wrapper(flU, flDeg[1]); + flX = RealType(1.0) - flKoef2 * flU; + + flLn = flX * flKoef6; + flLn = ln_wrapper(flLn); + if (flKoef4 * flLn + flExp <= RealType(0.0)) { + z[i] = flX; + break; + } + } + } + if (algorithm_ == beta_algorithm::Atkinson2) { + RealType flU, flExp, flX, flLn; + z2 = ln_wrapper(z2); + + flU = z1; + flExp = z2; + if (flU <= s) { + flU = pow_wrapper(flU, flDeg[0]); + flX = flKoef1 * flU; + flLn = (RealType(1.0) - flX); + flLn = ln_wrapper(flLn); + if (flKoef3 * flLn + flExp <= RealType(0.0)) { + z[i] = flX; + break; + } + } + else { + flU = RealType(1.0) - flU; + flU = pow_wrapper(flU, flDeg[1]); + flX = RealType(1.0) - flKoef2 * flU; + + flLn = flX / t; + flLn = ln_wrapper(flLn); + if (flKoef4 * flLn + flExp <= RealType(0.0)) { + z[i] = flX; + break; + } + } + } + if (algorithm_ == beta_algorithm::Atkinson3) { + RealType flU, flExp, flX, flLn; + z2 = ln_wrapper(z2); + + flU = z1; + flExp = z2; + if (flU <= s) { + flU = pow_wrapper(flU, flDeg[0]); + flX = flKoef1 * flU; + flLn = (RealType(1.0) - flX); + flLn = ln_wrapper(flLn); + if (flKoef3 * flLn + flExp <= RealType(0.0)) { + z[i] = RealType(1.0) - flX; + break; + } + } + else { + flU = RealType(1.0) - flU; + flU = pow_wrapper(flU, flDeg[1]); + flX = RealType(1.0) - flKoef2 * flU; + + flLn = flX / t; + flLn = ln_wrapper(flLn); + if (flKoef4 * flLn + flExp <= RealType(0.0)) { + z[i] = RealType(1.0) - flX; + break; + } + } + } + if (algorithm_ == beta_algorithm::Cheng) { + RealType flU1, flU2, flV, flW, flInv; + RealType flTmp[2]; + flU1 = z1; + flU2 = z2; + + flV = flU1 / (RealType(1.0) - flU1); + + flV = ln_wrapper(flV); + + flV = flKoef2 * flV; + + flW = flV; + flW = exp_wrapper(flW); + flW = p_ * flW; + flInv = RealType(1.0) / (q_ + flW); + flTmp[0] = flKoef1 * flInv; + flTmp[1] = flU1 * flU1 * flU2; + for (int i = 0; i < 2; i++) { + flTmp[i] = ln_wrapper(flTmp[i]); + } + + if (flKoef1 * flTmp[0] + flKoef3 * flV - log4() >= flTmp[1]) { + z[i] = flW * flInv; + break; + } + } + ++count_; + } + } + return a_ + b_ * z; + } + + template + auto generate(EngineType& engine) -> + typename std::conditional>::type { + sycl::vec res{}; + if (algorithm_ == beta_algorithm::p1 || algorithm_ == beta_algorithm::q1 || + algorithm_ == beta_algorithm::p1q1) { + res = engine.generate(RealType(0), RealType(1)); + res = pq_kernel(res); + } + else { + res = acc_rej_kernel(res, engine); + } + if constexpr (std::is_same_v) { + for (std::int32_t i = 0; i < EngineType::vec_size; i++) { + if (res[i] < a_) + res[i] = a_; + if (res[i] > a_ + b_) + res[i] = a_ + b_; + } + } + return res; + } + + template + RealType generate_single(EngineType& engine) { + RealType res{}; + sycl::vec z{ res }; + if (algorithm_ == beta_algorithm::p1 || algorithm_ == beta_algorithm::q1 || + algorithm_ == beta_algorithm::p1q1) { + z[0] = engine.generate_single(RealType(0), RealType(1)); + res = pq_kernel<1>(z); + } + else { + res = acc_rej_kernel<1>(z, engine); + } + if constexpr (std::is_same_v) { + for (std::int32_t i = 0; i < EngineType::vec_size; i++) { + if (res[i] < a_) + res[i] = a_; + if (res[i] > a_ + b_) + res[i] = a_ + b_; + } + } + return res; + } + + void set_algorithm() { + if (p_ < RealType(1.0) && q_ < RealType(1.0)) { + if (q_ + beta_k() * p_ * p_ + beta_c() <= RealType(0.0)) { + algorithm_ = beta_algorithm::Johnk; + } + else { + algorithm_ = beta_algorithm::Atkinson1; + } + } + else if (p_ < RealType(1.0) && q_ > RealType(1.0)) { + algorithm_ = beta_algorithm::Atkinson2; + } + else if (p_ > RealType(1.0) && q_ < RealType(1.0)) { + algorithm_ = beta_algorithm::Atkinson3; + } + else if (p_ > RealType(1.0) && q_ > RealType(1.0)) { + algorithm_ = beta_algorithm::Cheng; + } + else if (p_ == RealType(1.0) && q_ != RealType(1.0)) { + algorithm_ = beta_algorithm::p1; + } + else if (q_ == RealType(1.0) && p_ != RealType(1.0)) { + algorithm_ = beta_algorithm::q1; + } + else { + algorithm_ = beta_algorithm::p1q1; + } + } + + RealType p_; + RealType q_; + RealType a_; + RealType b_; + std::size_t count_; + beta_algorithm algorithm_; +}; + +} // namespace oneapi::mkl::rng::device::detail + +#endif // _MKL_RNG_DEVICE_BETA_IMPL_HPP_ diff --git a/include/oneapi/mkl/rng/device/detail/distribution_base.hpp b/include/oneapi/mkl/rng/device/detail/distribution_base.hpp index e728a564c..575ea27f7 100644 --- a/include/oneapi/mkl/rng/device/detail/distribution_base.hpp +++ b/include/oneapi/mkl/rng/device/detail/distribution_base.hpp @@ -53,6 +53,12 @@ class bits; template class exponential; +template +class beta; + +template +class gamma; + template class poisson; @@ -69,5 +75,7 @@ class bernoulli; #include "oneapi/mkl/rng/device/detail/exponential_impl.hpp" #include "oneapi/mkl/rng/device/detail/poisson_impl.hpp" #include "oneapi/mkl/rng/device/detail/bernoulli_impl.hpp" +#include "oneapi/mkl/rng/device/detail/beta_impl.hpp" +#include "oneapi/mkl/rng/device/detail/gamma_impl.hpp" #endif // _MKL_RNG_DISTRIBUTION_BASE_HPP_ diff --git a/include/oneapi/mkl/rng/device/detail/exponential_impl.hpp b/include/oneapi/mkl/rng/device/detail/exponential_impl.hpp index cf712f0e5..9419fc154 100644 --- a/include/oneapi/mkl/rng/device/detail/exponential_impl.hpp +++ b/include/oneapi/mkl/rng/device/detail/exponential_impl.hpp @@ -68,10 +68,7 @@ class distribution_base> auto generate(EngineType& engine) -> typename std::conditional>::type { - using OutType = typename std::conditional>::type; - - OutType res = engine.generate(RealType(0), RealType(1)); + auto res = engine.generate(RealType(0), RealType(1)); if constexpr (EngineType::vec_size == 1) { res = ln_wrapper(res); } @@ -82,7 +79,19 @@ class distribution_base> } res = a_ - res * beta_; if constexpr (std::is_same::value) { - res = sycl::fmax(res, OutType{ a_ }); +#ifndef __HIPSYCL__ + res = sycl::fmax(res, a_); +#else + // a workaround for hipSYCL (AdaptiveCpp) + if constexpr (EngineType::vec_size == 1) { + res = std::fmax(res, a_); + } + else { + for (int i = 0; i < EngineType::vec_size; i++) { + res[i] = std::fmax(res[i], a_); + } + } +#endif } return res; } @@ -93,7 +102,19 @@ class distribution_base> res = ln_wrapper(res); res = a_ - res * beta_; if constexpr (std::is_same::value) { +#ifndef __HIPSYCL__ res = sycl::fmax(res, a_); +#else + // a workaround for hipSYCL (AdaptiveCpp) + if constexpr (EngineType::vec_size == 1) { + res = std::fmax(res, a_); + } + else { + for (int i = 0; i < EngineType::vec_size; i++) { + res[i] = std::fmax(res[i], a_); + } + } +#endif } return res; } @@ -105,6 +126,13 @@ class distribution_base> oneapi::mkl::rng::device::poisson>; friend class distribution_base< oneapi::mkl::rng::device::poisson>; + friend class distribution_base>; + friend class distribution_base< + oneapi::mkl::rng::device::gamma>; + friend class distribution_base< + oneapi::mkl::rng::device::gamma>; + friend class distribution_base< + oneapi::mkl::rng::device::gamma>; }; } // namespace oneapi::mkl::rng::device::detail diff --git a/include/oneapi/mkl/rng/device/detail/gamma_impl.hpp b/include/oneapi/mkl/rng/device/detail/gamma_impl.hpp new file mode 100644 index 000000000..11397a69d --- /dev/null +++ b/include/oneapi/mkl/rng/device/detail/gamma_impl.hpp @@ -0,0 +1,287 @@ +/******************************************************************************* +* Copyright 2024 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, +* software distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions +* and limitations under the License. +* +* +* SPDX-License-Identifier: Apache-2.0 +*******************************************************************************/ + +#ifndef _MKL_RNG_DEVICE_GAMMA_IMPL_HPP_ +#define _MKL_RNG_DEVICE_GAMMA_IMPL_HPP_ + +#include "vm_wrappers.hpp" + +namespace oneapi::mkl::rng::device::detail { + +enum class gamma_algorithm { Exponential = 0, Vaduva, EPD_Transform, Marsaglia }; + +// 1/3 +template +inline DataType gamma_c1() { + if constexpr (std::is_same_v) + return 0x1.5555555555555p-2; + else + return 0x1.555556p-2f; +} + +// 0.0331 +template +inline DataType gamma_c2() { + if constexpr (std::is_same_v) + return 0x1.0f27bb2fec56dp-5; + else + return 0x1.0f27bcp-5f; +} + +// 0.6 +template +inline DataType gamma_c06() { + if constexpr (std::is_same_v) + return 0x1.3333333333333p-1; + else + return 0x1.333334p-1f; +} + +template +class distribution_base> { +public: + struct param_type { + param_type(RealType alpha, RealType a, RealType beta) : alpha_(alpha), a_(a), beta_(beta) {} + RealType alpha_; + RealType a_; + RealType beta_; + }; + + distribution_base(RealType alpha, RealType a, RealType beta) + : alpha_(alpha), + a_(a), + beta_(beta), + count_(0) { + set_algorithm(); +#ifndef __SYCL_DEVICE_ONLY__ + if (alpha <= RealType(0.0)) { + throw oneapi::mkl::invalid_argument("rng", "gamma", "alpha <= 0"); + } + else if (beta <= RealType(0.0)) { + throw oneapi::mkl::invalid_argument("rng", "gamma", "beta <= 0"); + } +#endif + } + + RealType alpha() const { + return alpha_; + } + + RealType a() const { + return a_; + } + + RealType beta() const { + return beta_; + } + + std::size_t count_rejected_numbers() const { + return count_; + } + + param_type param() const { + return param_type(alpha_, a_, beta_); + } + + void param(const param_type& pt) { +#ifndef __SYCL_DEVICE_ONLY__ + if (pt.alpha_ <= RealType(0.0)) { + throw oneapi::mkl::invalid_argument("rng", "gamma", "alpha <= 0"); + } + else if (pt.beta_ <= RealType(0.0)) { + throw oneapi::mkl::invalid_argument("rng", "gamma", "beta <= 0"); + } +#endif + alpha_ = pt.alpha_; + a_ = pt.a_; + beta_ = pt.beta_; + set_algorithm(); + } + +protected: + void set_algorithm() { + if (alpha_ <= RealType(1.0)) { + if (alpha_ == RealType(1.0)) { + algorithm_ = gamma_algorithm::Exponential; + } + else if (alpha_ > gamma_c06()) { + algorithm_ = gamma_algorithm::Vaduva; + } + else { + algorithm_ = gamma_algorithm::EPD_Transform; + } + } + else { + algorithm_ = gamma_algorithm::Marsaglia; + } + } + + template + inline std::pair gauss_BM2_for_Marsaglia(const sycl::vec& vec) { + T tmp, sin, cos, gauss_1, gauss_2; + tmp = ln_wrapper(vec[0]); + tmp = sqrt_wrapper(T(-2.0) * tmp); + sin = sincospi_wrapper(T(2) * vec[2], cos); + gauss_1 = (tmp * sin); + gauss_2 = (tmp * cos); + return { gauss_1, gauss_2 }; + } + + template + T acc_rej_kernel(T& z, EngineType& engine) { + RealType flC, flD; + if (algorithm_ == gamma_algorithm::Vaduva) { + flC = RealType(1.0) / alpha_; + flD = (RealType(1.0) - alpha_) * + exp_wrapper(ln_wrapper(alpha_) * alpha_ / (RealType(1.0) - alpha_)); + } + else if (algorithm_ == gamma_algorithm::EPD_Transform) { + flC = RealType(1.0) / alpha_; + flD = (RealType(1.0) - alpha_); + } + else if (algorithm_ == gamma_algorithm::Marsaglia) { + flD = alpha_ - gamma_c1(); + flC = sqrt_wrapper(RealType(1.0) / (RealType(9.0) * alpha_ - RealType(3.0))); + } + + count_ = 0; + RealType z1, z2, z3, z4; + for (int i = 0; i < n; i++) { + while (1) { // looping until satisfied + if (!flag_) { + z1 = engine.generate_single(RealType(0), RealType(1)); + z2 = engine.generate_single(RealType(0), RealType(1)); + } + + if (algorithm_ == gamma_algorithm::Vaduva) { + z1 = -ln_wrapper(z1); + z2 = -ln_wrapper(z2); + z[i] = powr_wrapper(z1, flC); + if (z1 + z2 >= z[i] + flD) { + break; + } + } + if (algorithm_ == gamma_algorithm::EPD_Transform) { + z2 = -ln_wrapper(z2); + if (z1 <= flD) { + z[i] = powr_wrapper(z1, flC); + if (z[i] <= z2) { + break; + } + } + else { + z1 = -ln_wrapper((RealType(1.0) - z1) * flC); + z[i] = powr_wrapper(flD + alpha_ * z1, flC); + if (z[i] <= z2 + z1) { + break; + } + } + } + if (algorithm_ == gamma_algorithm::Marsaglia) { + RealType local_uniform_2, local_gauss; + if (!flag_) { + z3 = engine.generate_single(RealType(0), RealType(1)); + z4 = engine.generate_single(RealType(0), RealType(1)); + auto gauss = + gauss_BM2_for_Marsaglia(sycl::vec{ z1, z2, z3, z4 }); + local_uniform_2 = z2; + local_gauss = gauss.first; + + saved_uniform_2_ = z4; + saved_gauss_ = gauss.second; + } + else { + local_uniform_2 = saved_uniform_2_; + local_gauss = saved_gauss_; + } + flag_ = !flag_; + z[i] = RealType(1.0) + flC * local_gauss; + if (z[i] > RealType(0.0)) { + z[i] = z[i] * z[i] * z[i]; + local_gauss = local_gauss * local_gauss; + if (local_uniform_2 < + RealType(1.0) - gamma_c2() * local_gauss * local_gauss) { + z[i] = flD * z[i]; + break; + } + else { + RealType local_uniform_1 = ln_wrapper(z[i]); + local_uniform_2 = ln_wrapper(local_uniform_2); + if (local_uniform_2 < + RealType(0.5) * local_gauss + + flD * (RealType(1.0) - z[i] + local_uniform_1)) { + z[i] = flD * z[i]; + break; + } + } + } + } + ++count_; + } + } + auto res = a_ + beta_ * z; + if constexpr (std::is_same_v) { + for (std::int32_t i = 0; i < EngineType::vec_size; i++) { + if (res[i] < a_) + res[i] = a_; + } + } + return res; + } + + template + auto generate(EngineType& engine) -> + typename std::conditional>::type { + if (algorithm_ == gamma_algorithm::Exponential) { + distribution_base> distr_exp(a_, beta_); + return distr_exp.generate(engine); + } + sycl::vec res{}; + res = acc_rej_kernel(res, engine); + + return res; + } + + template + RealType generate_single(EngineType& engine) { + if (algorithm_ == gamma_algorithm::Exponential) { + distribution_base> distr_exp(a_, beta_); + RealType z = distr_exp.generate_single(engine); + return z; + } + sycl::vec res{}; + res = acc_rej_kernel<1>(res, engine); + + return res[0]; + } + + RealType alpha_; + RealType a_; + RealType beta_; + RealType saved_gauss_; + RealType saved_uniform_2_; + bool flag_ = false; + std::size_t count_; + gamma_algorithm algorithm_; +}; + +} // namespace oneapi::mkl::rng::device::detail + +#endif // _MKL_RNG_DEVICE_GAMMA_IMPL_HPP_ diff --git a/include/oneapi/mkl/rng/device/detail/mcg31m1_impl.hpp b/include/oneapi/mkl/rng/device/detail/mcg31m1_impl.hpp index 8f1294ac2..72447bc5d 100644 --- a/include/oneapi/mkl/rng/device/detail/mcg31m1_impl.hpp +++ b/include/oneapi/mkl/rng/device/detail/mcg31m1_impl.hpp @@ -27,7 +27,7 @@ class mcg31m1; namespace detail { -template +template constexpr sycl::vec select_vector_a_mcg31m1() { if constexpr (VecSize == 1) return sycl::vec(UINT64_C(1)); @@ -56,7 +56,7 @@ constexpr sycl::vec select_vector_a_mcg31m1() { // hipSYCL (AdaptiveCpp) doesn't support constexpr sycl::vec constructor // that's why in case of hipSYCL backend sycl::vec is created as a local variable #ifndef __HIPSYCL__ -template +template struct mcg31m1_vector_a { static constexpr sycl::vec vector_a = select_vector_a_mcg31m1(); // powers of a diff --git a/include/oneapi/mkl/rng/device/detail/mcg59_impl.hpp b/include/oneapi/mkl/rng/device/detail/mcg59_impl.hpp index bc21eb607..a70bb323d 100644 --- a/include/oneapi/mkl/rng/device/detail/mcg59_impl.hpp +++ b/include/oneapi/mkl/rng/device/detail/mcg59_impl.hpp @@ -27,7 +27,7 @@ class mcg59; namespace detail { -template +template constexpr sycl::vec select_vector_a_mcg59() { if constexpr (VecSize == 1) return sycl::vec(UINT64_C(1)); @@ -57,7 +57,7 @@ constexpr sycl::vec select_vector_a_mcg59() { // hipSYCL (AdaptiveCpp) doesn't support constexpr sycl::vec constructor // that's why in case of hipSYCL backend sycl::vec is created as a local variable #ifndef __HIPSYCL__ -template +template struct mcg59_vector_a { static constexpr sycl::vec vector_a = select_vector_a_mcg59(); // powers of a @@ -165,7 +165,7 @@ class engine_base> { auto generate() -> typename std::conditional>::type { - return mcg59_impl::generate(this->state_); + return mcg59_impl::generate(this->state_).template convert(); } auto generate_bits() -> typename std::conditional +#include +#include "engine_base.hpp" + namespace oneapi::mkl::rng::device::detail { +static inline std::uint64_t umul_hi_64(const std::uint64_t a, const std::uint64_t b) { + const std::uint64_t a_lo = a & 0xFFFFFFFFULL; + const std::uint64_t a_hi = a >> 32; + const std::uint64_t b_lo = b & 0xFFFFFFFFULL; + const std::uint64_t b_hi = b >> 32; + + const std::uint64_t ab_hi = a_hi * b_hi; + const std::uint64_t ab_lo = a_lo * b_lo; + const std::uint64_t ab_md = a_hi * b_lo; + const std::uint64_t ba_md = b_hi * a_lo; + + const std::uint64_t bias = + ((ab_md & 0xFFFFFFFFULL) + (ba_md & 0xFFFFFFFFULL) + (ab_lo >> 32)) >> 32; + + return ab_hi + (ab_md >> 32) + (ba_md >> 32) + bias; +} + +template +static inline void generate_leftover(std::uint64_t range, Generator generate, std::uint64_t& res_64, + std::uint64_t& leftover) { + if constexpr (std::is_same_v>) { + std::uint32_t res_1 = generate(); + std::uint32_t res_2 = generate(); + std::uint32_t res_3 = generate(); + res_64 = (static_cast(res_3) << 62) + + (static_cast(res_2) << 31) + res_1; + } + else { + std::uint32_t res_1 = generate(); + std::uint32_t res_2 = generate(); + res_64 = (static_cast(res_2) << 32) + res_1; + } + + leftover = res_64 * range; +} + template class distribution_base> { public: @@ -62,6 +102,15 @@ class distribution_base> { } protected: + template + OutType generate_single_int(EngineType& engine) { + sycl::vec res_fp; + res_fp = engine.generate(static_cast(a_), static_cast(b_)); + res_fp = sycl::floor(res_fp); + OutType res = res_fp.template convert(); + return res; + } + template auto generate(EngineType& engine) -> typename std::conditional> { float>::type; OutType res; if constexpr (std::is_integral::value) { - if constexpr (EngineType::vec_size == 1) { - FpType res_fp = engine.generate(static_cast(a_), static_cast(b_)); - res_fp = sycl::floor(res_fp); - res = static_cast(res_fp); - return res; + if constexpr (std::is_same_v || + std::is_same_v) { + return generate_single_int(engine); } else { - sycl::vec res_fp; - res_fp = engine.generate(static_cast(a_), static_cast(b_)); - res_fp = sycl::floor(res_fp); - res = res_fp.template convert(); - return res; + // Lemire's sample rejection method to exclude bias for uniform numbers + // https://arxiv.org/abs/1805.10941 + + constexpr std::uint64_t uint_max64 = std::numeric_limits::max(); + constexpr std::uint64_t uint_max32 = std::numeric_limits::max(); + + std::uint64_t range = b_ - a_; + std::uint64_t threshold = (uint_max64 - range) % range; + + if (range <= uint_max32) + return generate_single_int(engine); + + if constexpr (EngineType::vec_size == 1) { + std::uint32_t res_1, res_2; + std::uint64_t res_64, leftover; + + generate_leftover( + range, [&engine]() { return engine.generate(); }, res_64, leftover); + + if (range == uint_max64) + return res_64; + + while (leftover < threshold) { + generate_leftover( + range, [&engine]() { return engine.generate(); }, res_64, leftover); + } + + res = a_ + umul_hi_64(res_64, range); + + return res; + } + else { + std::uint64_t leftover; + + sycl::vec res_1 = engine.generate(); + sycl::vec res_2 = engine.generate(); + sycl::vec res_64; + + if constexpr (std::is_same_v>) { + sycl::vec res_3 = engine.generate(); + + for (int i = 0; i < EngineType::vec_size; i++) { + res_64[i] = (static_cast(res_3[i]) << 62) + + (static_cast(res_2[i]) << 31) + res_1[i]; + } + } + else { + if constexpr (EngineType::vec_size == 3) { + res_64[0] = (static_cast(res_1[1]) << 32) + + static_cast(res_1[0]); + res_64[1] = (static_cast(res_2[0]) << 32) + + static_cast(res_1[2]); + res_64[2] = (static_cast(res_2[2]) << 32) + + static_cast(res_2[1]); + } + else { + for (int i = 0; i < EngineType::vec_size / 2; i++) { + res_64[i] = (static_cast(res_1[2 * i + 1]) << 32) + + static_cast(res_1[2 * i]); + res_64[i + EngineType::vec_size / 2] = + (static_cast(res_2[2 * i + 1]) << 32) + + static_cast(res_2[2 * i]); + } + } + } + + if (range == uint_max64) + return res_64.template convert(); + + for (int i = 0; i < EngineType::vec_size; i++) { + leftover = res_64[i] * range; + + while (leftover < threshold) { + generate_leftover( + range, [&engine]() { return engine.generate_single(); }, res_64[i], + leftover); + } + + res[i] = a_ + umul_hi_64(res_64[i], range); + } + + return res; + } } } else { res = engine.generate(a_, b_); if constexpr (std::is_same::value) { - res = sycl::fmax(res, OutType{ a_ }); - res = sycl::fmin(res, OutType{ b_ }); +#ifndef __HIPSYCL__ + res = sycl::fmax(res, a_); + res = sycl::fmin(res, b_); +#else + // a workaround for hipSYCL (AdaptiveCpp) + if constexpr (EngineType::vec_size == 1) { + res = sycl::fmax(res, a_); + res = sycl::fmin(res, b_); + } + else { + for (int i = 0; i < EngineType::vec_size; i++) { + res[i] = sycl::fmax(res[i], a_); + res[i] = sycl::fmin(res[i], b_); + } + } +#endif } } @@ -105,17 +244,70 @@ class distribution_base> { float>::type; Type res; if constexpr (std::is_integral::value) { - FpType res_fp = - engine.generate_single(static_cast(a_), static_cast(b_)); - res_fp = sycl::floor(res_fp); - res = static_cast(res_fp); - return res; + if constexpr (std::is_same_v || + std::is_same_v) { + FpType res_fp = + engine.generate_single(static_cast(a_), static_cast(b_)); + res_fp = sycl::floor(res_fp); + res = static_cast(res_fp); + return res; + } + else { + // Lemire's sample rejection method to exclude bias for uniform numbers + // https://arxiv.org/abs/1805.10941 + + constexpr std::uint64_t uint_max64 = std::numeric_limits::max(); + constexpr std::uint64_t uint_max32 = std::numeric_limits::max(); + + std::uint64_t range = b_ - a_; + std::uint64_t threshold = (uint_max64 - range) % range; + + if (range <= uint_max32) { + FpType res_fp = + engine.generate_single(static_cast(a_), static_cast(b_)); + res_fp = sycl::floor(res_fp); + res = static_cast(res_fp); + return res; + } + + std::uint32_t res_1, res_2; + std::uint64_t res_64, leftover; + + generate_leftover( + range, [&engine]() { return engine.generate_single(); }, res_64, leftover); + + if (range == uint_max64) + return res_64; + + while (leftover < threshold) { + generate_leftover( + range, [&engine]() { return engine.generate_single(); }, res_64, leftover); + } + + res = a_ + umul_hi_64(res_64, range); + + return res; + } } else { res = engine.generate_single(a_, b_); if constexpr (std::is_same::value) { +#ifndef __HIPSYCL__ res = sycl::fmax(res, a_); res = sycl::fmin(res, b_); +#else + // a workaround for hipSYCL (AdaptiveCpp) + if constexpr (EngineType::vec_size == 1) { + res = sycl::fmax(res, a_); + res = sycl::fmin(res, b_); + } + else { + for (int i = 0; i < EngineType::vec_size; i++) { + res[i] = sycl::fmax(res[i], a_); + res[i] = sycl::fmin(res[i], b_); + } + } +#endif } } diff --git a/include/oneapi/mkl/rng/device/detail/vm_wrappers.hpp b/include/oneapi/mkl/rng/device/detail/vm_wrappers.hpp index ec070c92c..850945a4c 100644 --- a/include/oneapi/mkl/rng/device/detail/vm_wrappers.hpp +++ b/include/oneapi/mkl/rng/device/detail/vm_wrappers.hpp @@ -56,6 +56,20 @@ static inline DataType ln_wrapper(DataType a) { return sycl::log(a); } +template +static inline DataType pow_wrapper(DataType a, DataType b) { + return sycl::pow(a, b); +} + +template +static inline DataType powr_wrapper(DataType a, DataType b) { + return sycl::powr(a, b); +} + +template +static inline DataType exp_wrapper(DataType a) { + return sycl::exp(a); +} } // namespace oneapi::mkl::rng::device::detail #endif // _MKL_RNG_DEVICE_VM_WRAPPERS_HPP_ diff --git a/include/oneapi/mkl/rng/device/distributions.hpp b/include/oneapi/mkl/rng/device/distributions.hpp index 21739f7f2..121e81aa3 100644 --- a/include/oneapi/mkl/rng/device/distributions.hpp +++ b/include/oneapi/mkl/rng/device/distributions.hpp @@ -62,7 +62,9 @@ class uniform : detail::distribution_base> { static_assert(std::is_same::value || std::is_same::value || std::is_same::value || - std::is_same::value, + std::is_same::value || + std::is_same::value || + std::is_same::value, "oneMKL: rng/uniform: type is not supported"); using method_type = Method; @@ -71,12 +73,14 @@ class uniform : detail::distribution_base> { uniform() : detail::distribution_base>( - static_cast(0.0), + Type(0.0), std::is_integral::value - ? (std::is_same::value - ? (1 << 23) - : (std::numeric_limits::max)()) - : static_cast(1.0)) {} + ? ((std::is_same_v || std::is_same_v) + ? (std::numeric_limits::max)() + : (std::is_same::value + ? (1 << 23) + : (std::numeric_limits::max)())) + : Type(1.0)) {} explicit uniform(Type a, Type b) : detail::distribution_base>(a, b) {} explicit uniform(const param_type& pt) @@ -144,8 +148,7 @@ class gaussian : detail::distribution_base> { using param_type = typename detail::distribution_base>::param_type; gaussian() - : detail::distribution_base>(static_cast(0.0), - static_cast(1.0)) {} + : detail::distribution_base>(RealType(0.0), RealType(1.0)) {} explicit gaussian(RealType mean, RealType stddev) : detail::distribution_base>(mean, stddev) {} @@ -208,11 +211,10 @@ class lognormal : detail::distribution_base> { lognormal() : detail::distribution_base>( - static_cast(0.0), static_cast(1.0), - static_cast(0.0), static_cast(1.0)) {} + RealType(0.0), RealType(1.0), RealType(0.0), RealType(1.0)) {} - explicit lognormal(RealType m, RealType s, RealType displ = static_cast(0.0), - RealType scale = static_cast(1.0)) + explicit lognormal(RealType m, RealType s, RealType displ = RealType(0.0), + RealType scale = RealType(1.0)) : detail::distribution_base>(m, s, displ, scale) {} explicit lognormal(const param_type& pt) : detail::distribution_base>(pt.m_, pt.s_, pt.displ_, @@ -250,6 +252,157 @@ class lognormal : detail::distribution_base> { friend typename Distr::result_type generate_single(Distr& distr, Engine& engine); }; +// Class template oneapi::mkl::rng::device::beta +// +// Represents continuous beta random number distribution +// +// Supported types: +// float +// double +// +// Supported methods: +// oneapi::mkl::rng::device::beta_method::cja +// oneapi::mkl::rng::device::beta_method::cja_accurate +// +// Input arguments: +// p - shape. 1.0 by default +// q - shape. 0.0 by default +// a - displacement. 1.0 by default +// b - scalefactor. 1.0 by default +// +template +class beta : detail::distribution_base> { +public: + static_assert(std::is_same::value || + std::is_same::value, + "oneMKL: rng/beta: method is incorrect"); + + static_assert(std::is_same::value || std::is_same::value, + "oneMKL: rng/beta: type is not supported"); + + using method_type = Method; + using result_type = RealType; + using param_type = typename detail::distribution_base>::param_type; + + beta() + : detail::distribution_base>(RealType(1.0), RealType(1.0), + RealType(0.0), RealType(1.0)) {} + + explicit beta(RealType p, RealType q, RealType a, RealType b) + : detail::distribution_base>(p, q, a, b) {} + + explicit beta(const param_type& pt) + : detail::distribution_base>(pt.p_, pt.q_, pt.a_, pt.b_) {} + + RealType p() const { + return detail::distribution_base>::p(); + } + + RealType q() const { + return detail::distribution_base>::q(); + } + + RealType a() const { + return detail::distribution_base>::a(); + } + + RealType b() const { + return detail::distribution_base>::b(); + } + + param_type param() const { + return detail::distribution_base>::param(); + } + + std::size_t count_rejected_numbers() const { + return detail::distribution_base>::count_rejected_numbers(); + } + + void param(const param_type& pt) { + detail::distribution_base>::param(pt); + } + + template + friend auto generate(Distr& distr, Engine& engine) -> + typename std::conditional>::type; + template + friend typename Distr::result_type generate_single(Distr& distr, Engine& engine); +}; + +// Class template oneapi::mkl::rng::device::gamma +// +// Represents continuous gamma random number distribution +// +// Supported types: +// float +// double +// +// Supported methods: +// oneapi::mkl::rng::device::gamma_method::marsaglia +// oneapi::mkl::rng::device::gamma_method::marsaglia_accurate +// +// Input arguments: +// alpha - shape. 1.0 by default +// a - displacement. 0.0 by default +// beta - scalefactor. 1.0 by default +// +template +class gamma : detail::distribution_base> { +public: + static_assert(std::is_same::value || + std::is_same::value, + "oneMKL: rng/gamma: method is incorrect"); + + static_assert(std::is_same::value || std::is_same::value, + "oneMKL: rng/gamma: type is not supported"); + + using method_type = Method; + using result_type = RealType; + using param_type = typename detail::distribution_base>::param_type; + + gamma() + : detail::distribution_base>(RealType(1.0), RealType(0.0), + RealType(1.0)) {} + + explicit gamma(RealType alpha, RealType a, RealType beta) + : detail::distribution_base>(alpha, a, beta) {} + + explicit gamma(const param_type& pt) + : detail::distribution_base>(pt.alpha_, pt.a_, pt.beta_) {} + + RealType alpha() const { + return detail::distribution_base>::alpha(); + } + + RealType a() const { + return detail::distribution_base>::a(); + } + + RealType beta() const { + return detail::distribution_base>::beta(); + } + + std::size_t count_rejected_numbers() const { + return detail::distribution_base>::count_rejected_numbers(); + } + + param_type param() const { + return detail::distribution_base>::param(); + } + + void param(const param_type& pt) { + detail::distribution_base>::param(pt); + } + + template + friend auto generate(Distr& distr, Engine& engine) -> + typename std::conditional>::type; + template + friend typename Distr::result_type generate_single(Distr& distr, Engine& engine); +}; + // Class template oneapi::mkl::rng::device::uniform_bits // // Represents discrete uniform bits random number distribution @@ -334,8 +487,8 @@ class exponential : detail::distribution_base> { typename detail::distribution_base>::param_type; exponential() - : detail::distribution_base>( - static_cast(0.0), static_cast(1.0)) {} + : detail::distribution_base>(RealType(0.0), + RealType(1.0)) {} explicit exponential(RealType a, RealType beta) : detail::distribution_base>(a, beta) {} @@ -442,7 +595,11 @@ class bernoulli : detail::distribution_base> { "oneMKL: rng/bernoulli: method is incorrect"); static_assert(std::is_same::value || - std::is_same::value, + std::is_same::value || + std::is_same::value || + std::is_same::value || + std::is_same::value || + std::is_same::value, "oneMKL: rng/bernoulli: type is not supported"); using method_type = Method; diff --git a/include/oneapi/mkl/rng/device/types.hpp b/include/oneapi/mkl/rng/device/types.hpp index e5f74e25b..6f87917f8 100644 --- a/include/oneapi/mkl/rng/device/types.hpp +++ b/include/oneapi/mkl/rng/device/types.hpp @@ -57,6 +57,18 @@ struct icdf {}; using by_default = icdf; } // namespace bernoulli_method +namespace beta_method { +struct cja {}; +struct cja_accurate {}; +using by_default = cja; +} // namespace beta_method + +namespace gamma_method { +struct marsaglia {}; +struct marsaglia_accurate {}; +using by_default = marsaglia; +} // namespace gamma_method + } // namespace oneapi::mkl::rng::device #endif // _MKL_RNG_DEVICE_TYPES_HPP_ diff --git a/tests/unit_tests/rng/device/include/rng_device_test_common.hpp b/tests/unit_tests/rng/device/include/rng_device_test_common.hpp index 6b014f0ec..33533255e 100644 --- a/tests/unit_tests/rng/device/include/rng_device_test_common.hpp +++ b/tests/unit_tests/rng/device/include/rng_device_test_common.hpp @@ -191,8 +191,8 @@ struct statistics_device> { template struct statistics_device> { template - bool check(const std::vector& r, - const oneapi::mkl::rng::device::uniform& distr) { + bool check(const std::vector& r, + const oneapi::mkl::rng::device::uniform& distr) { double tM, tD, tQ; double a = distr.a(); double b = distr.b(); @@ -210,8 +210,46 @@ struct statistics_device template struct statistics_device> { template - bool check(const std::vector& r, - const oneapi::mkl::rng::device::uniform& distr) { + bool check(const std::vector& r, + const oneapi::mkl::rng::device::uniform& distr) { + double tM, tD, tQ; + double a = distr.a(); + double b = distr.b(); + + // Theoretical moments + tM = (a + b - 1.0) / 2.0; + tD = ((b - a) * (b - a) - 1.0) / 12.0; + tQ = (((b - a) * (b - a)) * ((1.0 / 80.0) * (b - a) * (b - a) - (1.0 / 24.0))) + + (7.0 / 240.0); + + return compare_moments(r, tM, tD, tQ); + } +}; + +template +struct statistics_device> { + template + bool check(const std::vector& r, + const oneapi::mkl::rng::device::uniform& distr) { + double tM, tD, tQ; + double a = distr.a(); + double b = distr.b(); + + // Theoretical moments + tM = (a + b - 1.0) / 2.0; + tD = ((b - a) * (b - a) - 1.0) / 12.0; + tQ = (((b - a) * (b - a)) * ((1.0 / 80.0) * (b - a) * (b - a) - (1.0 / 24.0))) + + (7.0 / 240.0); + + return compare_moments(r, tM, tD, tQ); + } +}; + +template +struct statistics_device> { + template + bool check(const std::vector& r, + const oneapi::mkl::rng::device::uniform& distr) { double tM, tD, tQ; double a = distr.a(); double b = distr.b(); @@ -315,6 +353,52 @@ struct statistics_device> { } }; +template +struct statistics_device> { + template + bool check(const std::vector& r, + const oneapi::mkl::rng::device::beta& distr) { + double tM, tD, tQ; + double b, c, d, e, e2, b2, sum_pq; + Fp p = distr.p(); + Fp q = distr.q(); + Fp a = distr.a(); + Fp beta = distr.b(); + + b2 = beta * beta; + sum_pq = p + q; + b = (p + 1.0) / (sum_pq + 1.0); + c = (p + 2.0) / (sum_pq + 2.0); + d = (p + 3.0) / (sum_pq + 3.0); + e = p / sum_pq; + e2 = e * e; + + tM = a + e * beta; + tD = b2 * p * q / (sum_pq * sum_pq * (sum_pq + 1.0)); + tQ = b2 * b2 * (e * b * c * d - 4.0 * e2 * b * c + 6.0 * e2 * e * b - 3.0 * e2 * e2); + + return compare_moments(r, tM, tD, tQ); + } +}; + +template +struct statistics_device> { + template + bool check(const std::vector& r, + const oneapi::mkl::rng::device::gamma& distr) { + double tM, tD, tQ; + Fp a = distr.a(); + Fp alpha = distr.alpha(); + Fp beta = distr.beta(); + + tM = a + beta * alpha; + tD = beta * beta * alpha; + tQ = beta * beta * beta * beta * 3 * alpha * (alpha + 2); + + return compare_moments(r, tM, tD, tQ); + } +}; + template struct statistics_device> { template diff --git a/tests/unit_tests/rng/device/moments/moments.cpp b/tests/unit_tests/rng/device/moments/moments.cpp index 36ce38ee8..3ae45d657 100644 --- a/tests/unit_tests/rng/device/moments/moments.cpp +++ b/tests/unit_tests/rng/device/moments/moments.cpp @@ -109,6 +109,52 @@ TEST_P(Philox4x32x10UniformStdDeviceMomentsTests, UnsignedIntegerPrecision) { EXPECT_TRUEORSKIP((test3(GetParam()))); } +TEST_P(Philox4x32x10UniformStdDeviceMomentsTests, Integer64Precision) { + CHECK_DOUBLE_ON_DEVICE(GetParam()); + + rng_device_test< + moments_test, + oneapi::mkl::rng::device::uniform< + std::int64_t, oneapi::mkl::rng::device::uniform_method::standard>>> + test1; + EXPECT_TRUEORSKIP((test1(GetParam()))); + rng_device_test< + moments_test, + oneapi::mkl::rng::device::uniform< + std::int64_t, oneapi::mkl::rng::device::uniform_method::standard>>> + test2; + EXPECT_TRUEORSKIP((test2(GetParam()))); + rng_device_test< + moments_test, + oneapi::mkl::rng::device::uniform< + std::int64_t, oneapi::mkl::rng::device::uniform_method::standard>>> + test3; + EXPECT_TRUEORSKIP((test3(GetParam()))); +} + +TEST_P(Philox4x32x10UniformStdDeviceMomentsTests, UnsignedInteger64Precision) { + CHECK_DOUBLE_ON_DEVICE(GetParam()); + + rng_device_test< + moments_test, + oneapi::mkl::rng::device::uniform< + std::uint64_t, oneapi::mkl::rng::device::uniform_method::standard>>> + test1; + EXPECT_TRUEORSKIP((test1(GetParam()))); + rng_device_test< + moments_test, + oneapi::mkl::rng::device::uniform< + std::uint64_t, oneapi::mkl::rng::device::uniform_method::standard>>> + test2; + EXPECT_TRUEORSKIP((test2(GetParam()))); + rng_device_test< + moments_test, + oneapi::mkl::rng::device::uniform< + std::uint64_t, oneapi::mkl::rng::device::uniform_method::standard>>> + test3; + EXPECT_TRUEORSKIP((test3(GetParam()))); +} + TEST_P(Philox4x32x10UniformAccDeviceMomentsTests, RealSinglePrecision) { rng_device_test, oneapi::mkl::rng::device::uniform< @@ -189,6 +235,52 @@ TEST_P(Philox4x32x10UniformAccDeviceMomentsTests, UnsignedIntegerPrecision) { EXPECT_TRUEORSKIP((test3(GetParam()))); } +TEST_P(Philox4x32x10UniformAccDeviceMomentsTests, Integer64Precision) { + CHECK_DOUBLE_ON_DEVICE(GetParam()); + + rng_device_test< + moments_test, + oneapi::mkl::rng::device::uniform< + std::int64_t, oneapi::mkl::rng::device::uniform_method::accurate>>> + test1; + EXPECT_TRUEORSKIP((test1(GetParam()))); + rng_device_test< + moments_test, + oneapi::mkl::rng::device::uniform< + std::int64_t, oneapi::mkl::rng::device::uniform_method::accurate>>> + test2; + EXPECT_TRUEORSKIP((test2(GetParam()))); + rng_device_test< + moments_test, + oneapi::mkl::rng::device::uniform< + std::int64_t, oneapi::mkl::rng::device::uniform_method::accurate>>> + test3; + EXPECT_TRUEORSKIP((test3(GetParam()))); +} + +TEST_P(Philox4x32x10UniformAccDeviceMomentsTests, UnsignedInteger64Precision) { + CHECK_DOUBLE_ON_DEVICE(GetParam()); + + rng_device_test< + moments_test, + oneapi::mkl::rng::device::uniform< + std::uint64_t, oneapi::mkl::rng::device::uniform_method::accurate>>> + test1; + EXPECT_TRUEORSKIP((test1(GetParam()))); + rng_device_test< + moments_test, + oneapi::mkl::rng::device::uniform< + std::uint64_t, oneapi::mkl::rng::device::uniform_method::accurate>>> + test2; + EXPECT_TRUEORSKIP((test2(GetParam()))); + rng_device_test< + moments_test, + oneapi::mkl::rng::device::uniform< + std::uint64_t, oneapi::mkl::rng::device::uniform_method::accurate>>> + test3; + EXPECT_TRUEORSKIP((test3(GetParam()))); +} + INSTANTIATE_TEST_SUITE_P(Philox4x32x10UniformStdDeviceMomentsTestsSuite, Philox4x32x10UniformStdDeviceMomentsTests, ::testing::ValuesIn(devices), ::DeviceNamePrint()); @@ -961,6 +1053,190 @@ INSTANTIATE_TEST_SUITE_P(Philox4x32x10ExponentialIcdfAccDeviceMomentsTestsSuite, Philox4x32x10ExponentialIcdfAccDeviceMomentsTests, ::testing::ValuesIn(devices), ::DeviceNamePrint()); +class Philox4x32x10BetaCjaDeviceMomentsTests : public ::testing::TestWithParam {}; + +class Philox4x32x10BetaCjaAccDeviceMomentsTests : public ::testing::TestWithParam {}; + +TEST_P(Philox4x32x10BetaCjaDeviceMomentsTests, RealSinglePrecision) { + rng_device_test, + oneapi::mkl::rng::device::beta>> + test1; + EXPECT_TRUEORSKIP((test1(GetParam()))); + rng_device_test, + oneapi::mkl::rng::device::beta>> + test2; + EXPECT_TRUEORSKIP((test2(GetParam()))); + rng_device_test, + oneapi::mkl::rng::device::beta>> + test3; + EXPECT_TRUEORSKIP((test3(GetParam()))); +} + +TEST_P(Philox4x32x10BetaCjaDeviceMomentsTests, RealDoublePrecision) { + CHECK_DOUBLE_ON_DEVICE(GetParam()); + + rng_device_test, + oneapi::mkl::rng::device::beta>> + test1; + EXPECT_TRUEORSKIP((test1(GetParam()))); + rng_device_test, + oneapi::mkl::rng::device::beta>> + test2; + EXPECT_TRUEORSKIP((test2(GetParam()))); + rng_device_test, + oneapi::mkl::rng::device::beta>> + test3; + EXPECT_TRUEORSKIP((test3(GetParam()))); +} + +TEST_P(Philox4x32x10BetaCjaAccDeviceMomentsTests, RealSinglePrecision) { + rng_device_test, + oneapi::mkl::rng::device::beta>> + test1; + EXPECT_TRUEORSKIP((test1(GetParam()))); + rng_device_test, + oneapi::mkl::rng::device::beta>> + test2; + EXPECT_TRUEORSKIP((test2(GetParam()))); + rng_device_test, + oneapi::mkl::rng::device::beta>> + test3; + EXPECT_TRUEORSKIP((test3(GetParam()))); +} + +TEST_P(Philox4x32x10BetaCjaAccDeviceMomentsTests, RealDoublePrecision) { + CHECK_DOUBLE_ON_DEVICE(GetParam()); + + rng_device_test, + oneapi::mkl::rng::device::beta< + double, oneapi::mkl::rng::device::beta_method::cja_accurate>>> + test1; + EXPECT_TRUEORSKIP((test1(GetParam()))); + rng_device_test, + oneapi::mkl::rng::device::beta< + double, oneapi::mkl::rng::device::beta_method::cja_accurate>>> + test2; + EXPECT_TRUEORSKIP((test2(GetParam()))); + rng_device_test, + oneapi::mkl::rng::device::beta< + double, oneapi::mkl::rng::device::beta_method::cja_accurate>>> + test3; + EXPECT_TRUEORSKIP((test3(GetParam()))); +} + +INSTANTIATE_TEST_SUITE_P(Philox4x32x10BetaCjaDeviceMomentsTestsSuite, + Philox4x32x10BetaCjaDeviceMomentsTests, ::testing::ValuesIn(devices), + ::DeviceNamePrint()); + +INSTANTIATE_TEST_SUITE_P(Philox4x32x10BetaCjaAccDeviceMomentsTestsSuite, + Philox4x32x10BetaCjaAccDeviceMomentsTests, ::testing::ValuesIn(devices), + ::DeviceNamePrint()); + +class Philox4x32x10GammaMarsagliaDeviceMomentsTests + : public ::testing::TestWithParam {}; + +class Philox4x32x10GammaMarsagliaAccDeviceMomentsTests + : public ::testing::TestWithParam {}; + +TEST_P(Philox4x32x10GammaMarsagliaDeviceMomentsTests, RealSinglePrecision) { + rng_device_test, + oneapi::mkl::rng::device::gamma>> + test1; + EXPECT_TRUEORSKIP((test1(GetParam()))); + rng_device_test, + oneapi::mkl::rng::device::gamma>> + test2; + EXPECT_TRUEORSKIP((test2(GetParam()))); + rng_device_test, + oneapi::mkl::rng::device::gamma>> + test3; + EXPECT_TRUEORSKIP((test3(GetParam()))); +} + +TEST_P(Philox4x32x10GammaMarsagliaDeviceMomentsTests, RealDoublePrecision) { + CHECK_DOUBLE_ON_DEVICE(GetParam()); + + rng_device_test, + oneapi::mkl::rng::device::gamma>> + test1; + EXPECT_TRUEORSKIP((test1(GetParam()))); + rng_device_test, + oneapi::mkl::rng::device::gamma>> + test2; + EXPECT_TRUEORSKIP((test2(GetParam()))); + rng_device_test, + oneapi::mkl::rng::device::gamma>> + test3; + EXPECT_TRUEORSKIP((test3(GetParam()))); +} + +TEST_P(Philox4x32x10GammaMarsagliaAccDeviceMomentsTests, RealSinglePrecision) { + rng_device_test< + moments_test, + oneapi::mkl::rng::device::gamma< + float, oneapi::mkl::rng::device::gamma_method::marsaglia_accurate>>> + test1; + EXPECT_TRUEORSKIP((test1(GetParam()))); + rng_device_test< + moments_test, + oneapi::mkl::rng::device::gamma< + float, oneapi::mkl::rng::device::gamma_method::marsaglia_accurate>>> + test2; + EXPECT_TRUEORSKIP((test2(GetParam()))); + rng_device_test< + moments_test, + oneapi::mkl::rng::device::gamma< + float, oneapi::mkl::rng::device::gamma_method::marsaglia_accurate>>> + test3; + EXPECT_TRUEORSKIP((test3(GetParam()))); +} + +TEST_P(Philox4x32x10GammaMarsagliaAccDeviceMomentsTests, RealDoublePrecision) { + CHECK_DOUBLE_ON_DEVICE(GetParam()); + + rng_device_test< + moments_test, + oneapi::mkl::rng::device::gamma< + double, oneapi::mkl::rng::device::gamma_method::marsaglia_accurate>>> + test1; + EXPECT_TRUEORSKIP((test1(GetParam()))); + rng_device_test< + moments_test, + oneapi::mkl::rng::device::gamma< + double, oneapi::mkl::rng::device::gamma_method::marsaglia_accurate>>> + test2; + EXPECT_TRUEORSKIP((test2(GetParam()))); + rng_device_test< + moments_test, + oneapi::mkl::rng::device::gamma< + double, oneapi::mkl::rng::device::gamma_method::marsaglia_accurate>>> + test3; + EXPECT_TRUEORSKIP((test3(GetParam()))); +} + +INSTANTIATE_TEST_SUITE_P(Philox4x32x10GammaMarsagliaDeviceMomentsTestsSuite, + Philox4x32x10GammaMarsagliaDeviceMomentsTests, + ::testing::ValuesIn(devices), ::DeviceNamePrint()); + +INSTANTIATE_TEST_SUITE_P(Philox4x32x10GammaMarsagliaAccDeviceMomentsTestsSuite, + Philox4x32x10GammaMarsagliaAccDeviceMomentsTests, + ::testing::ValuesIn(devices), ::DeviceNamePrint()); + class Philox4x32x10PoissonDevroyeDeviceMomentsTests : public ::testing::TestWithParam {}; @@ -983,19 +1259,22 @@ TEST_P(Philox4x32x10PoissonDevroyeDeviceMomentsTests, IntegerPrecision) { } TEST_P(Philox4x32x10PoissonDevroyeDeviceMomentsTests, UnsignedIntegerPrecision) { - rng_device_test, - oneapi::mkl::rng::device::poisson< - uint32_t, oneapi::mkl::rng::device::poisson_method::devroye>>> + rng_device_test< + moments_test, + oneapi::mkl::rng::device::poisson< + std::uint32_t, oneapi::mkl::rng::device::poisson_method::devroye>>> test1; EXPECT_TRUEORSKIP((test1(GetParam()))); - rng_device_test, - oneapi::mkl::rng::device::poisson< - uint32_t, oneapi::mkl::rng::device::poisson_method::devroye>>> + rng_device_test< + moments_test, + oneapi::mkl::rng::device::poisson< + std::uint32_t, oneapi::mkl::rng::device::poisson_method::devroye>>> test2; EXPECT_TRUEORSKIP((test2(GetParam()))); - rng_device_test, - oneapi::mkl::rng::device::poisson< - uint32_t, oneapi::mkl::rng::device::poisson_method::devroye>>> + rng_device_test< + moments_test, + oneapi::mkl::rng::device::poisson< + std::uint32_t, oneapi::mkl::rng::device::poisson_method::devroye>>> test3; EXPECT_TRUEORSKIP((test3(GetParam()))); } @@ -1008,37 +1287,127 @@ class Philox4x32x10BernoulliIcdfDeviceMomentsTests : public ::testing::TestWithParam {}; TEST_P(Philox4x32x10BernoulliIcdfDeviceMomentsTests, IntegerPrecision) { - rng_device_test, - oneapi::mkl::rng::device::bernoulli< - int32_t, oneapi::mkl::rng::device::bernoulli_method::icdf>>> + rng_device_test< + moments_test, + oneapi::mkl::rng::device::bernoulli< + std::int32_t, oneapi::mkl::rng::device::bernoulli_method::icdf>>> test1; EXPECT_TRUEORSKIP((test1(GetParam()))); - rng_device_test, - oneapi::mkl::rng::device::bernoulli< - int32_t, oneapi::mkl::rng::device::bernoulli_method::icdf>>> + rng_device_test< + moments_test, + oneapi::mkl::rng::device::bernoulli< + std::int32_t, oneapi::mkl::rng::device::bernoulli_method::icdf>>> test2; EXPECT_TRUEORSKIP((test2(GetParam()))); - rng_device_test, - oneapi::mkl::rng::device::bernoulli< - int32_t, oneapi::mkl::rng::device::bernoulli_method::icdf>>> + rng_device_test< + moments_test, + oneapi::mkl::rng::device::bernoulli< + std::int32_t, oneapi::mkl::rng::device::bernoulli_method::icdf>>> test3; EXPECT_TRUEORSKIP((test3(GetParam()))); } TEST_P(Philox4x32x10BernoulliIcdfDeviceMomentsTests, UnsignedIntegerPrecision) { - rng_device_test, - oneapi::mkl::rng::device::bernoulli< - uint32_t, oneapi::mkl::rng::device::bernoulli_method::icdf>>> + rng_device_test< + moments_test, + oneapi::mkl::rng::device::bernoulli< + std::uint32_t, oneapi::mkl::rng::device::bernoulli_method::icdf>>> test1; EXPECT_TRUEORSKIP((test1(GetParam()))); - rng_device_test, - oneapi::mkl::rng::device::bernoulli< - uint32_t, oneapi::mkl::rng::device::bernoulli_method::icdf>>> + rng_device_test< + moments_test, + oneapi::mkl::rng::device::bernoulli< + std::uint32_t, oneapi::mkl::rng::device::bernoulli_method::icdf>>> test2; EXPECT_TRUEORSKIP((test2(GetParam()))); - rng_device_test, - oneapi::mkl::rng::device::bernoulli< - uint32_t, oneapi::mkl::rng::device::bernoulli_method::icdf>>> + rng_device_test< + moments_test, + oneapi::mkl::rng::device::bernoulli< + std::uint32_t, oneapi::mkl::rng::device::bernoulli_method::icdf>>> + test3; + EXPECT_TRUEORSKIP((test3(GetParam()))); +} + +TEST_P(Philox4x32x10BernoulliIcdfDeviceMomentsTests, Integer8Precision) { + rng_device_test< + moments_test, + oneapi::mkl::rng::device::bernoulli< + std::int8_t, oneapi::mkl::rng::device::bernoulli_method::icdf>>> + test1; + EXPECT_TRUEORSKIP((test1(GetParam()))); + rng_device_test< + moments_test, + oneapi::mkl::rng::device::bernoulli< + std::int8_t, oneapi::mkl::rng::device::bernoulli_method::icdf>>> + test2; + EXPECT_TRUEORSKIP((test2(GetParam()))); + rng_device_test< + moments_test, + oneapi::mkl::rng::device::bernoulli< + std::int8_t, oneapi::mkl::rng::device::bernoulli_method::icdf>>> + test3; + EXPECT_TRUEORSKIP((test3(GetParam()))); +} + +TEST_P(Philox4x32x10BernoulliIcdfDeviceMomentsTests, UnsignedInteger8Precision) { + rng_device_test< + moments_test, + oneapi::mkl::rng::device::bernoulli< + std::uint8_t, oneapi::mkl::rng::device::bernoulli_method::icdf>>> + test1; + EXPECT_TRUEORSKIP((test1(GetParam()))); + rng_device_test< + moments_test, + oneapi::mkl::rng::device::bernoulli< + std::uint8_t, oneapi::mkl::rng::device::bernoulli_method::icdf>>> + test2; + EXPECT_TRUEORSKIP((test2(GetParam()))); + rng_device_test< + moments_test, + oneapi::mkl::rng::device::bernoulli< + std::uint8_t, oneapi::mkl::rng::device::bernoulli_method::icdf>>> + test3; + EXPECT_TRUEORSKIP((test3(GetParam()))); +} + +TEST_P(Philox4x32x10BernoulliIcdfDeviceMomentsTests, Integer16Precision) { + rng_device_test< + moments_test, + oneapi::mkl::rng::device::bernoulli< + std::int16_t, oneapi::mkl::rng::device::bernoulli_method::icdf>>> + test1; + EXPECT_TRUEORSKIP((test1(GetParam()))); + rng_device_test< + moments_test, + oneapi::mkl::rng::device::bernoulli< + std::int16_t, oneapi::mkl::rng::device::bernoulli_method::icdf>>> + test2; + EXPECT_TRUEORSKIP((test2(GetParam()))); + rng_device_test< + moments_test, + oneapi::mkl::rng::device::bernoulli< + std::int16_t, oneapi::mkl::rng::device::bernoulli_method::icdf>>> + test3; + EXPECT_TRUEORSKIP((test3(GetParam()))); +} + +TEST_P(Philox4x32x10BernoulliIcdfDeviceMomentsTests, UnsignedInteger16Precision) { + rng_device_test< + moments_test, + oneapi::mkl::rng::device::bernoulli< + std::uint16_t, oneapi::mkl::rng::device::bernoulli_method::icdf>>> + test1; + EXPECT_TRUEORSKIP((test1(GetParam()))); + rng_device_test< + moments_test, + oneapi::mkl::rng::device::bernoulli< + std::uint16_t, oneapi::mkl::rng::device::bernoulli_method::icdf>>> + test2; + EXPECT_TRUEORSKIP((test2(GetParam()))); + rng_device_test< + moments_test, + oneapi::mkl::rng::device::bernoulli< + std::uint16_t, oneapi::mkl::rng::device::bernoulli_method::icdf>>> test3; EXPECT_TRUEORSKIP((test3(GetParam()))); }