From 731bfb1b6be64eb668c2b58b34739ada2815433f Mon Sep 17 00:00:00 2001 From: "Fedorov, Andrey" Date: Wed, 25 Sep 2024 05:47:43 -0700 Subject: [PATCH 01/14] modified headers --- .../mkl/rng/device/detail/beta_impl.hpp | 464 ++++++++++++++++++ .../rng/device/detail/distribution_base.hpp | 8 + .../rng/device/detail/exponential_impl.hpp | 14 +- .../mkl/rng/device/detail/gamma_impl.hpp | 285 +++++++++++ .../mkl/rng/device/detail/mcg31m1_impl.hpp | 4 +- .../mkl/rng/device/detail/mcg59_impl.hpp | 6 +- .../mkl/rng/device/detail/uniform_impl.hpp | 190 ++++++- .../mkl/rng/device/detail/vm_wrappers.hpp | 14 + .../oneapi/mkl/rng/device/distributions.hpp | 187 ++++++- include/oneapi/mkl/rng/device/types.hpp | 12 + 10 files changed, 1142 insertions(+), 42 deletions(-) create mode 100755 include/oneapi/mkl/rng/device/detail/beta_impl.hpp create mode 100755 include/oneapi/mkl/rng/device/detail/gamma_impl.hpp diff --git a/include/oneapi/mkl/rng/device/detail/beta_impl.hpp b/include/oneapi/mkl/rng/device/detail/beta_impl.hpp new file mode 100755 index 000000000..9e64c5cde --- /dev/null +++ b/include/oneapi/mkl/rng/device/detail/beta_impl.hpp @@ -0,0 +1,464 @@ +/******************************************************************************* +* Copyright 2024 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, +* software distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions +* and limitations under the License. +* +* +* SPDX-License-Identifier: Apache-2.0 +*******************************************************************************/ + +#ifndef _MKL_RNG_DEVICE_BETA_IMPL_HPP_ +#define _MKL_RNG_DEVICE_BETA_IMPL_HPP_ + +#include "vm_wrappers.hpp" + +namespace oneapi::mkl::rng::device::detail { + +enum class beta_algorithm { Johnk = 0, Atkinson1, Atkinson2, Atkinson3, Cheng, p1, q1, p1q1 }; + +// log(4)=1.3862944.. +template +inline DataType log4() { + if constexpr (std::is_same_v) + return 0x1.62e42fefa39efp+0; + else + return 0x1.62e43p+0f; +} + +// K=0.85225521765372429631847 +template +inline DataType beta_k() { + if constexpr (std::is_same_v) + return 0x1.b45acbbf56123p-1; + else + return 0x1.b45accp-1f; +} + +// C=-0.956240971340815081432202 +template +inline DataType beta_c() { + if constexpr (std::is_same_v) + return -0x1.e9986aa60216p-1; + else + return -0x1.e9986ap-1f; +} + +template +class distribution_base> { +public: + struct param_type { + param_type(RealType p, RealType q, RealType a, RealType b) : p_(p), q_(q), a_(a), b_(b) {} + RealType p_; + RealType q_; + RealType a_; + RealType b_; + }; + + distribution_base(RealType p, RealType q, RealType a, RealType b) + : p_(p), + q_(q), + a_(a), + b_(b), + count_(0) { + set_algorithm(); +#ifndef __SYCL_DEVICE_ONLY__ + if (p <= RealType(0.0)) { + throw oneapi::mkl::invalid_argument("rng", "beta", "p <= 0"); + } + else if (q <= RealType(0.0)) { + throw oneapi::mkl::invalid_argument("rng", "beta", "q <= 0"); + } + else if (b <= RealType(0.0)) { + throw oneapi::mkl::invalid_argument("rng", "beta", "b <= 0"); + } +#endif + } + + RealType p() const { + return p_; + } + + RealType q() const { + return q_; + } + + RealType a() const { + return a_; + } + + RealType b() const { + return b_; + } + + std::size_t count_rejected_numbers() const { + return count_; + } + + param_type param() const { + return param_type(p_, q_, a_, b_); + } + + void param(const param_type& pt) { +#ifndef __SYCL_DEVICE_ONLY__ + if (pt.p_ <= RealType(0.0)) { + throw oneapi::mkl::invalid_argument("rng", "beta", "p <= 0"); + } + else if (pt.q_ <= RealType(0.0)) { + throw oneapi::mkl::invalid_argument("rng", "beta", "q <= 0"); + } + else if (pt.b_ <= RealType(0.0)) { + throw oneapi::mkl::invalid_argument("rng", "beta", "b <= 0"); + } +#endif + p_ = pt.p_; + q_ = pt.q_; + a_ = pt.a_; + b_ = pt.b_; + set_algorithm(); + } + +protected: + template + T pq_kernel(T& z) { + for (std::int32_t i = 0; i < n; i++) { + if (p_ == RealType(1.0)) { + z[i] = pow_wrapper(z[i], RealType(1) / q_); + z[i] = RealType(1.0) - z[i]; + } + if (q_ == RealType(1.0)) { + z[i] = pow_wrapper(z[i], RealType(1) / p_); + } + } + count_ = 0; + + // p1q1 + return a_ + b_ * z; + } + + template + T acc_rej_kernel(T& z, EngineType& engine) { + RealType s, t; + + RealType flKoef1, flKoef2, flKoef3, flKoef4, flKoef5, flKoef6; + RealType flDeg[2]; + + if (algorithm_ == beta_algorithm::Atkinson1) { + RealType flInv_s[2], flTmp[2]; + flTmp[0] = p_ * (RealType(1.0) - p_); + flTmp[1] = q_ * (RealType(1.0) - q_); + + flTmp[0] = sqrt_wrapper(flTmp[0]); + flTmp[1] = sqrt_wrapper(flTmp[1]); + + t = flTmp[0] / (flTmp[0] + flTmp[1]); + + s = q_ * t; + s = s / (s + p_ * (RealType(1.0) - t)); + + flInv_s[0] = RealType(1.0) / s; + flInv_s[1] = RealType(1.0) / (RealType(1.0) - s); + flDeg[0] = RealType(1.0) / p_; + flDeg[1] = RealType(1.0) / q_; + + flInv_s[0] = pow_wrapper(flInv_s[0], flDeg[0]); + flInv_s[1] = pow_wrapper(flInv_s[1], flDeg[1]); + + flKoef1 = t * flInv_s[0]; + flKoef2 = (RealType(1.0) - t) * flInv_s[1]; + flKoef3 = RealType(1.0) - q_; + flKoef4 = RealType(1.0) - p_; + flKoef5 = RealType(1.0) / (RealType(1.0) - t); + flKoef6 = RealType(1.0) / t; + } + else if (algorithm_ == beta_algorithm::Atkinson2) { + RealType flInv_s[2], flTmp; + + t = RealType(1.0) - p_; + t /= (t + q_); + + flTmp = RealType(1.0) - t; + flTmp = pow_wrapper(flTmp, q_); + s = q_ * t; + s /= (s + p_ * flTmp); + + flInv_s[0] = RealType(1.0) / s; + flInv_s[1] = RealType(1.0) / (RealType(1.0) - s); + flDeg[0] = RealType(1.0) / p_; + flDeg[1] = RealType(1.0) / q_; + + flInv_s[0] = pow_wrapper(flInv_s[0], flDeg[0]); + flInv_s[1] = pow_wrapper(flInv_s[1], flDeg[1]); + + flKoef1 = t * flInv_s[0]; + flKoef2 = (RealType(1.0) - t) * flInv_s[1]; + flKoef3 = RealType(1.0) - q_; + flKoef4 = RealType(1.0) - p_; + } + else if (algorithm_ == beta_algorithm::Atkinson3) { + RealType flInv_s[2], flTmp; + + t = RealType(1.0) - q_; + t /= (t + p_); + + flTmp = RealType(1.0) - t; + flTmp = pow_wrapper(flTmp, p_); + s = p_ * t; + s /= (s + q_ * flTmp); + + flInv_s[0] = RealType(1.0) / s; + flInv_s[1] = RealType(1.0) / (RealType(1.0) - s); + flDeg[0] = RealType(1.0) / q_; + flDeg[1] = RealType(1.0) / p_; + + flInv_s[0] = pow_wrapper(flInv_s[0], flDeg[0]); + flInv_s[1] = pow_wrapper(flInv_s[1], flDeg[1]); + + flKoef1 = t * flInv_s[0]; + flKoef2 = (RealType(1.0) - t) * flInv_s[1]; + flKoef3 = RealType(1.0) - p_; + flKoef4 = RealType(1.0) - q_; + } + else if (algorithm_ == beta_algorithm::Cheng) { + flKoef1 = p_ + q_; + flKoef2 = (flKoef1 - RealType(2.0)) / (RealType(2.0) * p_ * q_ - flKoef1); + flKoef2 = sqrt_wrapper(flKoef2); + flKoef3 = p_ + RealType(1.0) / flKoef2; + } + + RealType z1, z2; + + count_ = 0; + for (int i = 0; i < n; i++) { + while (1) { // looping until satisfied + z1 = engine.generate_single(RealType(0), RealType(1)); + z2 = engine.generate_single(RealType(0), RealType(1)); + + if (algorithm_ == beta_algorithm::Johnk) { + RealType flU1, flU2, flSum; + z1 = ln_wrapper(z1) / p_; + z2 = ln_wrapper(z2) / q_; + + z1 = exp_wrapper(z1); + z2 = exp_wrapper(z2); + + flU1 = z1; + flU2 = z2; + flSum = flU1 + flU2; + if (flSum > RealType(0.0) && flSum <= RealType(1.0)) { + z[i] = flU1 / flSum; + break; + } + } + if (algorithm_ == beta_algorithm::Atkinson1) { + RealType flU, flExp, flX, flLn; + z2 = ln_wrapper(z2); + + flU = z1; + flExp = z2; + if (flU <= s) { + flU = pow_wrapper(flU, flDeg[0]); + flX = flKoef1 * flU; + flLn = (RealType(1.0) - flX) * flKoef5; + flLn = ln_wrapper(flLn); + if (flKoef3 * flLn + flExp <= RealType(0.0)) { + z[i] = flX; + break; + } + } + else { + flU = RealType(1.0) - flU; + flU = pow_wrapper(flU, flDeg[1]); + flX = RealType(1.0) - flKoef2 * flU; + + flLn = flX * flKoef6; + flLn = ln_wrapper(flLn); + if (flKoef4 * flLn + flExp <= RealType(0.0)) { + z[i] = flX; + break; + } + } + } + if (algorithm_ == beta_algorithm::Atkinson2) { + RealType flU, flExp, flX, flLn; + z2 = ln_wrapper(z2); + + flU = z1; + flExp = z2; + if (flU <= s) { + flU = pow_wrapper(flU, flDeg[0]); + flX = flKoef1 * flU; + flLn = (RealType(1.0) - flX); + flLn = ln_wrapper(flLn); + if (flKoef3 * flLn + flExp <= RealType(0.0)) { + z[i] = flX; + break; + } + } + else { + flU = RealType(1.0) - flU; + flU = pow_wrapper(flU, flDeg[1]); + flX = RealType(1.0) - flKoef2 * flU; + + flLn = flX / t; + flLn = ln_wrapper(flLn); + if (flKoef4 * flLn + flExp <= RealType(0.0)) { + z[i] = flX; + break; + } + } + } + if (algorithm_ == beta_algorithm::Atkinson3) { + RealType flU, flExp, flX, flLn; + z2 = ln_wrapper(z2); + + flU = z1; + flExp = z2; + if (flU <= s) { + flU = pow_wrapper(flU, flDeg[0]); + flX = flKoef1 * flU; + flLn = (RealType(1.0) - flX); + flLn = ln_wrapper(flLn); + if (flKoef3 * flLn + flExp <= RealType(0.0)) { + z[i] = RealType(1.0) - flX; + break; + } + } + else { + flU = RealType(1.0) - flU; + flU = pow_wrapper(flU, flDeg[1]); + flX = RealType(1.0) - flKoef2 * flU; + + flLn = flX / t; + flLn = ln_wrapper(flLn); + if (flKoef4 * flLn + flExp <= RealType(0.0)) { + z[i] = RealType(1.0) - flX; + break; + } + } + } + if (algorithm_ == beta_algorithm::Cheng) { + RealType flU1, flU2, flV, flW, flInv; + RealType flTmp[2]; + flU1 = z1; + flU2 = z2; + + flV = flU1 / (RealType(1.0) - flU1); + + flV = ln_wrapper(flV); + + flV = flKoef2 * flV; + + flW = flV; + flW = exp_wrapper(flW); + flW = p_ * flW; + flInv = RealType(1.0) / (q_ + flW); + flTmp[0] = flKoef1 * flInv; + flTmp[1] = flU1 * flU1 * flU2; + for (int i = 0; i < 2; i++) { + flTmp[i] = ln_wrapper(flTmp[i]); + } + + if (flKoef1 * flTmp[0] + flKoef3 * flV - log4() >= flTmp[1]) { + z[i] = flW * flInv; + break; + } + } + ++count_; + } + } + return a_ + b_ * z; + } + + template + auto generate(EngineType& engine) -> + typename std::conditional>::type { + sycl::vec res{}; + if (algorithm_ == beta_algorithm::p1 || algorithm_ == beta_algorithm::q1 || + algorithm_ == beta_algorithm::p1q1) { + res = engine.generate(RealType(0), RealType(1)); + res = pq_kernel(res); + } + else { + res = acc_rej_kernel(res, engine); + } + if constexpr (std::is_same_v) { + if (res < a_) + res = a_; + if (res > a_ + b_) + res = a_ + b_; + } + return res; + } + + template + RealType generate_single(EngineType& engine) { + RealType res{}; + sycl::vec z{ res }; + if (algorithm_ == beta_algorithm::p1 || algorithm_ == beta_algorithm::q1 || + algorithm_ == beta_algorithm::p1q1) { + z[0] = engine.generate_single(RealType(0), RealType(1)); + res = pq_kernel<1>(z); + } + else { + res = acc_rej_kernel<1>(z, engine); + } + if constexpr (std::is_same_v) { + if (res < a_) + res = a_; + if (res > a_ + b_) + res = a_ + b_; + } + return res; + } + + void set_algorithm() { + if (p_ < RealType(1.0) && q_ < RealType(1.0)) { + if (q_ + beta_k() * p_ * p_ + beta_c() <= RealType(0.0)) { + algorithm_ = beta_algorithm::Johnk; + } + else { + algorithm_ = beta_algorithm::Atkinson1; + } + } + else if (p_ < RealType(1.0) && q_ > RealType(1.0)) { + algorithm_ = beta_algorithm::Atkinson2; + } + else if (p_ > RealType(1.0) && q_ < RealType(1.0)) { + algorithm_ = beta_algorithm::Atkinson3; + } + else if (p_ > RealType(1.0) && q_ > RealType(1.0)) { + algorithm_ = beta_algorithm::Cheng; + } + else if (p_ == RealType(1.0) && q_ != RealType(1.0)) { + algorithm_ = beta_algorithm::p1; + } + else if (q_ == RealType(1.0) && p_ != RealType(1.0)) { + algorithm_ = beta_algorithm::q1; + } + else { + algorithm_ = beta_algorithm::p1q1; + } + } + + RealType p_; + RealType q_; + RealType a_; + RealType b_; + std::size_t count_; + beta_algorithm algorithm_; +}; + +} // namespace oneapi::mkl::rng::device::detail + +#endif // _MKL_RNG_DEVICE_BETA_IMPL_HPP_ diff --git a/include/oneapi/mkl/rng/device/detail/distribution_base.hpp b/include/oneapi/mkl/rng/device/detail/distribution_base.hpp index e728a564c..575ea27f7 100644 --- a/include/oneapi/mkl/rng/device/detail/distribution_base.hpp +++ b/include/oneapi/mkl/rng/device/detail/distribution_base.hpp @@ -53,6 +53,12 @@ class bits; template class exponential; +template +class beta; + +template +class gamma; + template class poisson; @@ -69,5 +75,7 @@ class bernoulli; #include "oneapi/mkl/rng/device/detail/exponential_impl.hpp" #include "oneapi/mkl/rng/device/detail/poisson_impl.hpp" #include "oneapi/mkl/rng/device/detail/bernoulli_impl.hpp" +#include "oneapi/mkl/rng/device/detail/beta_impl.hpp" +#include "oneapi/mkl/rng/device/detail/gamma_impl.hpp" #endif // _MKL_RNG_DISTRIBUTION_BASE_HPP_ diff --git a/include/oneapi/mkl/rng/device/detail/exponential_impl.hpp b/include/oneapi/mkl/rng/device/detail/exponential_impl.hpp index cf712f0e5..2550dba45 100644 --- a/include/oneapi/mkl/rng/device/detail/exponential_impl.hpp +++ b/include/oneapi/mkl/rng/device/detail/exponential_impl.hpp @@ -68,10 +68,7 @@ class distribution_base> auto generate(EngineType& engine) -> typename std::conditional>::type { - using OutType = typename std::conditional>::type; - - OutType res = engine.generate(RealType(0), RealType(1)); + auto res = engine.generate(RealType(0), RealType(1)); if constexpr (EngineType::vec_size == 1) { res = ln_wrapper(res); } @@ -82,7 +79,7 @@ class distribution_base> } res = a_ - res * beta_; if constexpr (std::is_same::value) { - res = sycl::fmax(res, OutType{ a_ }); + res = sycl::fmax(res, a_); } return res; } @@ -105,6 +102,13 @@ class distribution_base> oneapi::mkl::rng::device::poisson>; friend class distribution_base< oneapi::mkl::rng::device::poisson>; + friend class distribution_base>; + friend class distribution_base< + oneapi::mkl::rng::device::gamma>; + friend class distribution_base< + oneapi::mkl::rng::device::gamma>; + friend class distribution_base< + oneapi::mkl::rng::device::gamma>; }; } // namespace oneapi::mkl::rng::device::detail diff --git a/include/oneapi/mkl/rng/device/detail/gamma_impl.hpp b/include/oneapi/mkl/rng/device/detail/gamma_impl.hpp new file mode 100755 index 000000000..17e9ef28b --- /dev/null +++ b/include/oneapi/mkl/rng/device/detail/gamma_impl.hpp @@ -0,0 +1,285 @@ +/******************************************************************************* +* Copyright 2024 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, +* software distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions +* and limitations under the License. +* +* +* SPDX-License-Identifier: Apache-2.0 +*******************************************************************************/ + +#ifndef _MKL_RNG_DEVICE_GAMMA_IMPL_HPP_ +#define _MKL_RNG_DEVICE_GAMMA_IMPL_HPP_ + +#include "vm_wrappers.hpp" + +namespace oneapi::mkl::rng::device::detail { + +enum class gamma_algorithm { Exponential = 0, Vaduva, EPD_Transform, Marsaglia }; + +// 1/3 +template +inline DataType gamma_c1() { + if constexpr (std::is_same_v) + return 0x1.5555555555555p-2; + else + return 0x1.555556p-2f; +} + +// 0.0331 +template +inline DataType gamma_c2() { + if constexpr (std::is_same_v) + return 0x1.0f27bb2fec56dp-5; + else + return 0x1.0f27bcp-5f; +} + +// 0.6 +template +inline DataType gamma_c06() { + if constexpr (std::is_same_v) + return 0x1.3333333333333p-1; + else + return 0x1.333334p-1f; +} + +template +class distribution_base> { +public: + struct param_type { + param_type(RealType alpha, RealType a, RealType beta) : alpha_(alpha), a_(a), beta_(beta) {} + RealType alpha_; + RealType a_; + RealType beta_; + }; + + distribution_base(RealType alpha, RealType a, RealType beta) + : alpha_(alpha), + a_(a), + beta_(beta), + count_(0) { + set_algorithm(); +#ifndef __SYCL_DEVICE_ONLY__ + if (alpha <= RealType(0.0)) { + throw oneapi::mkl::invalid_argument("rng", "gamma", "alpha <= 0"); + } + else if (beta <= RealType(0.0)) { + throw oneapi::mkl::invalid_argument("rng", "gamma", "beta <= 0"); + } +#endif + } + + RealType alpha() const { + return alpha_; + } + + RealType a() const { + return a_; + } + + RealType beta() const { + return beta_; + } + + std::size_t count_rejected_numbers() const { + return count_; + } + + param_type param() const { + return param_type(alpha_, a_, beta_); + } + + void param(const param_type& pt) { +#ifndef __SYCL_DEVICE_ONLY__ + if (pt.alpha_ <= RealType(0.0)) { + throw oneapi::mkl::invalid_argument("rng", "gamma", "alpha <= 0"); + } + else if (pt.beta_ <= RealType(0.0)) { + throw oneapi::mkl::invalid_argument("rng", "gamma", "beta <= 0"); + } +#endif + alpha_ = pt.alpha_; + a_ = pt.a_; + beta_ = pt.beta_; + set_algorithm(); + } + +protected: + void set_algorithm() { + if (alpha_ <= RealType(1.0)) { + if (alpha_ == RealType(1.0)) { + algorithm_ = gamma_algorithm::Exponential; + } + else if (alpha_ > gamma_c06()) { + algorithm_ = gamma_algorithm::Vaduva; + } + else { + algorithm_ = gamma_algorithm::EPD_Transform; + } + } + else { + algorithm_ = gamma_algorithm::Marsaglia; + } + } + + template + inline std::pair gauss_BM2_for_Marsaglia(const sycl::vec& vec) { + T tmp, sin, cos, gauss_1, gauss_2; + tmp = ln_wrapper(vec[0]); + tmp = sqrt_wrapper(T(-2.0) * tmp); + sin = sincospi_wrapper(T(2) * vec[2], cos); + gauss_1 = (tmp * sin); + gauss_2 = (tmp * cos); + return { gauss_1, gauss_2 }; + } + + template + T acc_rej_kernel(T& z, EngineType& engine) { + RealType flC, flD; + if (algorithm_ == gamma_algorithm::Vaduva) { + flC = RealType(1.0) / alpha_; + flD = (RealType(1.0) - alpha_) * + exp_wrapper(ln_wrapper(alpha_) * alpha_ / (RealType(1.0) - alpha_)); + } + else if (algorithm_ == gamma_algorithm::EPD_Transform) { + flC = RealType(1.0) / alpha_; + flD = (RealType(1.0) - alpha_); + } + else if (algorithm_ == gamma_algorithm::Marsaglia) { + flD = alpha_ - gamma_c1(); + flC = sqrt_wrapper(RealType(1.0) / (RealType(9.0) * alpha_ - RealType(3.0))); + } + + count_ = 0; + RealType z1, z2, z3, z4; + for (int i = 0; i < n; i++) { + while (1) { // looping until satisfied + if (!flag_) { + z1 = engine.generate_single(RealType(0), RealType(1)); + z2 = engine.generate_single(RealType(0), RealType(1)); + } + + if (algorithm_ == gamma_algorithm::Vaduva) { + z1 = -ln_wrapper(z1); + z2 = -ln_wrapper(z2); + z[i] = powr_wrapper(z1, flC); + if (z1 + z2 >= z[i] + flD) { + break; + } + } + if (algorithm_ == gamma_algorithm::EPD_Transform) { + z2 = -ln_wrapper(z2); + if (z1 <= flD) { + z[i] = powr_wrapper(z1, flC); + if (z[i] <= z2) { + break; + } + } + else { + z1 = -ln_wrapper((RealType(1.0) - z1) * flC); + z[i] = powr_wrapper(flD + alpha_ * z1, flC); + if (z[i] <= z2 + z1) { + break; + } + } + } + if (algorithm_ == gamma_algorithm::Marsaglia) { + RealType local_uniform_2, local_gauss; + if (!flag_) { + z3 = engine.generate_single(RealType(0), RealType(1)); + z4 = engine.generate_single(RealType(0), RealType(1)); + auto gauss = + gauss_BM2_for_Marsaglia(sycl::vec{ z1, z2, z3, z4 }); + local_uniform_2 = z2; + local_gauss = gauss.first; + + saved_uniform_2_ = z4; + saved_gauss_ = gauss.second; + } + else { + local_uniform_2 = saved_uniform_2_; + local_gauss = saved_gauss_; + } + flag_ = !flag_; + z[i] = RealType(1.0) + flC * local_gauss; + if (z[i] > RealType(0.0)) { + z[i] = z[i] * z[i] * z[i]; + local_gauss = local_gauss * local_gauss; + if (local_uniform_2 < + RealType(1.0) - gamma_c2() * local_gauss * local_gauss) { + z[i] = flD * z[i]; + break; + } + else { + RealType local_uniform_1 = ln_wrapper(z[i]); + local_uniform_2 = ln_wrapper(local_uniform_2); + if (local_uniform_2 < + RealType(0.5) * local_gauss + + flD * (RealType(1.0) - z[i] + local_uniform_1)) { + z[i] = flD * z[i]; + break; + } + } + } + } + ++count_; + } + } + auto res = a_ + beta_ * z; + if constexpr (std::is_same_v) { + if (res < a_) + res = a_; + } + return res; + } + + template + auto generate(EngineType& engine) -> + typename std::conditional>::type { + if (algorithm_ == gamma_algorithm::Exponential) { + distribution_base> distr_exp(a_, beta_); + return distr_exp.generate(engine); + } + sycl::vec res{}; + res = acc_rej_kernel(res, engine); + + return res; + } + + template + RealType generate_single(EngineType& engine) { + if (algorithm_ == gamma_algorithm::Exponential) { + distribution_base> distr_exp(a_, beta_); + RealType z = distr_exp.generate_single(engine); + return z; + } + sycl::vec res{}; + res = acc_rej_kernel<1>(res, engine); + + return res[0]; + } + + RealType alpha_; + RealType a_; + RealType beta_; + RealType saved_gauss_; + RealType saved_uniform_2_; + bool flag_ = false; + std::size_t count_; + gamma_algorithm algorithm_; +}; + +} // namespace oneapi::mkl::rng::device::detail + +#endif // _MKL_RNG_DEVICE_GAMMA_IMPL_HPP_ diff --git a/include/oneapi/mkl/rng/device/detail/mcg31m1_impl.hpp b/include/oneapi/mkl/rng/device/detail/mcg31m1_impl.hpp index 8f1294ac2..72447bc5d 100644 --- a/include/oneapi/mkl/rng/device/detail/mcg31m1_impl.hpp +++ b/include/oneapi/mkl/rng/device/detail/mcg31m1_impl.hpp @@ -27,7 +27,7 @@ class mcg31m1; namespace detail { -template +template constexpr sycl::vec select_vector_a_mcg31m1() { if constexpr (VecSize == 1) return sycl::vec(UINT64_C(1)); @@ -56,7 +56,7 @@ constexpr sycl::vec select_vector_a_mcg31m1() { // hipSYCL (AdaptiveCpp) doesn't support constexpr sycl::vec constructor // that's why in case of hipSYCL backend sycl::vec is created as a local variable #ifndef __HIPSYCL__ -template +template struct mcg31m1_vector_a { static constexpr sycl::vec vector_a = select_vector_a_mcg31m1(); // powers of a diff --git a/include/oneapi/mkl/rng/device/detail/mcg59_impl.hpp b/include/oneapi/mkl/rng/device/detail/mcg59_impl.hpp index bc21eb607..a70bb323d 100644 --- a/include/oneapi/mkl/rng/device/detail/mcg59_impl.hpp +++ b/include/oneapi/mkl/rng/device/detail/mcg59_impl.hpp @@ -27,7 +27,7 @@ class mcg59; namespace detail { -template +template constexpr sycl::vec select_vector_a_mcg59() { if constexpr (VecSize == 1) return sycl::vec(UINT64_C(1)); @@ -57,7 +57,7 @@ constexpr sycl::vec select_vector_a_mcg59() { // hipSYCL (AdaptiveCpp) doesn't support constexpr sycl::vec constructor // that's why in case of hipSYCL backend sycl::vec is created as a local variable #ifndef __HIPSYCL__ -template +template struct mcg59_vector_a { static constexpr sycl::vec vector_a = select_vector_a_mcg59(); // powers of a @@ -165,7 +165,7 @@ class engine_base> { auto generate() -> typename std::conditional>::type { - return mcg59_impl::generate(this->state_); + return mcg59_impl::generate(this->state_).template convert(); } auto generate_bits() -> typename std::conditional + namespace oneapi::mkl::rng::device::detail { +static inline std::uint64_t umul_hi_64(const std::uint64_t a, const std::uint64_t b) { + const std::uint64_t a_lo = a & 0xFFFFFFFFULL; + const std::uint64_t a_hi = a >> 32; + const std::uint64_t b_lo = b & 0xFFFFFFFFULL; + const std::uint64_t b_hi = b >> 32; + + const std::uint64_t ab_hi = a_hi * b_hi; + const std::uint64_t ab_lo = a_lo * b_lo; + const std::uint64_t ab_md = a_hi * b_lo; + const std::uint64_t ba_md = b_hi * a_lo; + + const std::uint64_t bias = ((ab_md & 0xFFFFFFFFULL) + (ba_md & 0xFFFFFFFFULL) + (ab_lo >> 32)) >> 32; + + return ab_hi + (ab_md >> 32) + (ba_md >> 32) + bias; +} + +template +static inline void generate_leftover(std::uint64_t range, Generator generate, + std::uint64_t& res_64, std::uint64_t& leftover) { + if constexpr (std::is_same_v>) { + std::uint32_t res_1 = generate(); + std::uint32_t res_2 = generate(); + std::uint32_t res_3 = generate(); + res_64 = (static_cast(res_3) << 62) + + (static_cast(res_2) << 31) + res_1; + } + else { + std::uint32_t res_1 = generate(); + std::uint32_t res_2 = generate(); + res_64 = (static_cast(res_2) << 32) + res_1; + } + + leftover = res_64 * range; +} + template class distribution_base> { public: @@ -62,6 +99,15 @@ class distribution_base> { } protected: + template + OutType generate_single_int(EngineType& engine) { + sycl::vec res_fp; + res_fp = engine.generate(static_cast(a_), static_cast(b_)); + res_fp = sycl::floor(res_fp); + OutType res = res_fp.template convert(); + return res; + } + template auto generate(EngineType& engine) -> typename std::conditional> { float>::type; OutType res; if constexpr (std::is_integral::value) { - if constexpr (EngineType::vec_size == 1) { - FpType res_fp = engine.generate(static_cast(a_), static_cast(b_)); - res_fp = sycl::floor(res_fp); - res = static_cast(res_fp); - return res; + if constexpr (std::is_same_v || std::is_same_v) { + return generate_single_int(engine); } else { - sycl::vec res_fp; - res_fp = engine.generate(static_cast(a_), static_cast(b_)); - res_fp = sycl::floor(res_fp); - res = res_fp.template convert(); - return res; + // Lemire's sample rejection method to exclude bias for uniform numbers + // https://arxiv.org/abs/1805.10941 + + constexpr std::uint64_t uint_max64 = std::numeric_limits::max(); + constexpr std::uint64_t uint_max32 = std::numeric_limits::max(); + + std::uint64_t range = b_ - a_; + std::uint64_t threshold = (uint_max64 - range) % range; + + if (range <= uint_max32) + return generate_single_int(engine); + + if constexpr (EngineType::vec_size == 1) { + std::uint32_t res_1, res_2; + std::uint64_t res_64, leftover; + + generate_leftover(range, [&engine](){return engine.generate();}, + res_64, leftover); + + if (range == uint_max64) + return res_64; + + while (leftover < threshold) { + generate_leftover(range, [&engine](){return engine.generate();}, + res_64, leftover); + } + + res = a_ + umul_hi_64(res_64, range); + + return res; + } + else { + std::uint64_t leftover; + + sycl::vec res_1 = engine.generate(); + sycl::vec res_2 = engine.generate(); + sycl::vec res_64; + + if constexpr (std::is_same_v>) { + sycl::vec res_3 = engine.generate(); + + for (int i = 0; i < EngineType::vec_size; i++) { + res_64[i] = (static_cast(res_3[i]) << 62) + + (static_cast(res_2[i]) << 31) + res_1[i]; + } + } + else { + if constexpr (EngineType::vec_size == 3) { + res_64[0] = (static_cast(res_1[1]) << 32) + + static_cast(res_1[0]); + res_64[1] = (static_cast(res_2[0]) << 32) + + static_cast(res_1[2]); + res_64[2] = (static_cast(res_2[2]) << 32) + + static_cast(res_2[1]); + } else { + for (int i = 0; i < EngineType::vec_size / 2; i++) { + res_64[i] = (static_cast(res_1[2 * i + 1]) << 32) + + static_cast(res_1[2 * i]); + res_64[i + EngineType::vec_size / 2] = (static_cast(res_2[2 * i + 1]) << 32) + + static_cast(res_2[2 * i]); + } + } + } + + if (range == uint_max64) + return res_64.template convert(); + + for (int i = 0; i < EngineType::vec_size; i++) { + leftover = res_64[i] * range; + + while (leftover < threshold) { + generate_leftover(range, [&engine](){return engine.generate_single();}, + res_64[i], leftover); + } + + res[i] = a_ + umul_hi_64(res_64[i], range); + } + + return res; + } } } else { res = engine.generate(a_, b_); if constexpr (std::is_same::value) { - res = sycl::fmax(res, OutType{ a_ }); - res = sycl::fmin(res, OutType{ b_ }); + res = sycl::fmax(res, a_); + res = sycl::fmin(res, b_); } } @@ -105,11 +223,49 @@ class distribution_base> { float>::type; Type res; if constexpr (std::is_integral::value) { - FpType res_fp = - engine.generate_single(static_cast(a_), static_cast(b_)); - res_fp = sycl::floor(res_fp); - res = static_cast(res_fp); - return res; + if constexpr (std::is_same_v || std::is_same_v) { + FpType res_fp = + engine.generate_single(static_cast(a_), static_cast(b_)); + res_fp = sycl::floor(res_fp); + res = static_cast(res_fp); + return res; + } + else { + // Lemire's sample rejection method to exclude bias for uniform numbers + // https://arxiv.org/abs/1805.10941 + + constexpr std::uint64_t uint_max64 = std::numeric_limits::max(); + constexpr std::uint64_t uint_max32 = std::numeric_limits::max(); + + std::uint64_t range = b_ - a_; + std::uint64_t threshold = (uint_max64 - range) % range; + + if (range <= uint_max32) { + FpType res_fp = + engine.generate_single(static_cast(a_), static_cast(b_)); + res_fp = sycl::floor(res_fp); + res = static_cast(res_fp); + return res; + } + + std::uint32_t res_1, res_2; + std::uint64_t res_64, leftover; + + generate_leftover(range, [&engine](){return engine.generate_single();}, + res_64, leftover); + + if (range == uint_max64) + return res_64; + + while (leftover < threshold) { + generate_leftover(range, [&engine](){return engine.generate_single();}, + res_64, leftover); + } + + res = a_ + umul_hi_64(res_64, range); + + return res; + } } else { res = engine.generate_single(a_, b_); diff --git a/include/oneapi/mkl/rng/device/detail/vm_wrappers.hpp b/include/oneapi/mkl/rng/device/detail/vm_wrappers.hpp index ec070c92c..850945a4c 100644 --- a/include/oneapi/mkl/rng/device/detail/vm_wrappers.hpp +++ b/include/oneapi/mkl/rng/device/detail/vm_wrappers.hpp @@ -56,6 +56,20 @@ static inline DataType ln_wrapper(DataType a) { return sycl::log(a); } +template +static inline DataType pow_wrapper(DataType a, DataType b) { + return sycl::pow(a, b); +} + +template +static inline DataType powr_wrapper(DataType a, DataType b) { + return sycl::powr(a, b); +} + +template +static inline DataType exp_wrapper(DataType a) { + return sycl::exp(a); +} } // namespace oneapi::mkl::rng::device::detail #endif // _MKL_RNG_DEVICE_VM_WRAPPERS_HPP_ diff --git a/include/oneapi/mkl/rng/device/distributions.hpp b/include/oneapi/mkl/rng/device/distributions.hpp index 21739f7f2..5051a670a 100644 --- a/include/oneapi/mkl/rng/device/distributions.hpp +++ b/include/oneapi/mkl/rng/device/distributions.hpp @@ -62,7 +62,9 @@ class uniform : detail::distribution_base> { static_assert(std::is_same::value || std::is_same::value || std::is_same::value || - std::is_same::value, + std::is_same::value || + std::is_same::value || + std::is_same::value, "oneMKL: rng/uniform: type is not supported"); using method_type = Method; @@ -71,12 +73,14 @@ class uniform : detail::distribution_base> { uniform() : detail::distribution_base>( - static_cast(0.0), + Type(0.0), std::is_integral::value - ? (std::is_same::value - ? (1 << 23) - : (std::numeric_limits::max)()) - : static_cast(1.0)) {} + ? ((std::is_same_v || std::is_same_v) + ? (std::numeric_limits::max)() + : (std::is_same::value + ? (1 << 23) + : (std::numeric_limits::max)())) + : Type(1.0)) {} explicit uniform(Type a, Type b) : detail::distribution_base>(a, b) {} explicit uniform(const param_type& pt) @@ -144,8 +148,7 @@ class gaussian : detail::distribution_base> { using param_type = typename detail::distribution_base>::param_type; gaussian() - : detail::distribution_base>(static_cast(0.0), - static_cast(1.0)) {} + : detail::distribution_base>(RealType(0.0), RealType(1.0)) {} explicit gaussian(RealType mean, RealType stddev) : detail::distribution_base>(mean, stddev) {} @@ -208,11 +211,10 @@ class lognormal : detail::distribution_base> { lognormal() : detail::distribution_base>( - static_cast(0.0), static_cast(1.0), - static_cast(0.0), static_cast(1.0)) {} + RealType(0.0), RealType(1.0), RealType(0.0), RealType(1.0)) {} - explicit lognormal(RealType m, RealType s, RealType displ = static_cast(0.0), - RealType scale = static_cast(1.0)) + explicit lognormal(RealType m, RealType s, RealType displ = RealType(0.0), + RealType scale = RealType(1.0)) : detail::distribution_base>(m, s, displ, scale) {} explicit lognormal(const param_type& pt) : detail::distribution_base>(pt.m_, pt.s_, pt.displ_, @@ -250,6 +252,157 @@ class lognormal : detail::distribution_base> { friend typename Distr::result_type generate_single(Distr& distr, Engine& engine); }; +// Class template oneapi::mkl::rng::device::beta +// +// Represents continuous beta random number distribution +// +// Supported types: +// float +// double +// +// Supported methods: +// oneapi::mkl::rng::device::beta_method::cja +// oneapi::mkl::rng::device::beta_method::cja_accurate +// +// Input arguments: +// p - shape. 1.0 by default +// q - shape. 0.0 by default +// a - displacement. 1.0 by default +// b - scalefactor. 1.0 by default +// +template +class beta : detail::distribution_base> { +public: + static_assert(std::is_same::value || + std::is_same::value, + "oneMKL: rng/beta: method is incorrect"); + + static_assert(std::is_same::value || std::is_same::value, + "oneMKL: rng/beta: type is not supported"); + + using method_type = Method; + using result_type = RealType; + using param_type = typename detail::distribution_base>::param_type; + + beta() + : detail::distribution_base>(RealType(1.0), RealType(1.0), + RealType(0.0), RealType(1.0)) {} + + explicit beta(RealType p, RealType q, RealType a, RealType b) + : detail::distribution_base>(p, q, a, b) {} + + explicit beta(const param_type& pt) + : detail::distribution_base>(pt.p_, pt.q_, pt.a_, pt.b_) {} + + RealType p() const { + return detail::distribution_base>::p(); + } + + RealType q() const { + return detail::distribution_base>::q(); + } + + RealType a() const { + return detail::distribution_base>::a(); + } + + RealType b() const { + return detail::distribution_base>::b(); + } + + param_type param() const { + return detail::distribution_base>::param(); + } + + std::size_t count_rejected_numbers() const { + return detail::distribution_base>::count_rejected_numbers(); + } + + void param(const param_type& pt) { + detail::distribution_base>::param(pt); + } + + template + friend auto generate(Distr& distr, Engine& engine) -> + typename std::conditional>::type; + template + friend typename Distr::result_type generate_single(Distr& distr, Engine& engine); +}; + +// Class template oneapi::mkl::rng::device::gamma +// +// Represents continuous gamma random number distribution +// +// Supported types: +// float +// double +// +// Supported methods: +// oneapi::mkl::rng::device::gamma_method::marsaglia +// oneapi::mkl::rng::device::gamma_method::marsaglia_accurate +// +// Input arguments: +// alpha - shape. 1.0 by default +// a - displacement. 0.0 by default +// beta - scalefactor. 1.0 by default +// +template +class gamma : detail::distribution_base> { +public: + static_assert(std::is_same::value || + std::is_same::value, + "oneMKL: rng/gamma: method is incorrect"); + + static_assert(std::is_same::value || std::is_same::value, + "oneMKL: rng/gamma: type is not supported"); + + using method_type = Method; + using result_type = RealType; + using param_type = typename detail::distribution_base>::param_type; + + gamma() + : detail::distribution_base>(RealType(1.0), RealType(0.0), + RealType(1.0)) {} + + explicit gamma(RealType alpha, RealType a, RealType beta) + : detail::distribution_base>(alpha, a, beta) {} + + explicit gamma(const param_type& pt) + : detail::distribution_base>(pt.alpha_, pt.a_, pt.beta_) {} + + RealType alpha() const { + return detail::distribution_base>::alpha(); + } + + RealType a() const { + return detail::distribution_base>::a(); + } + + RealType beta() const { + return detail::distribution_base>::beta(); + } + + std::size_t count_rejected_numbers() const { + return detail::distribution_base>::count_rejected_numbers(); + } + + param_type param() const { + return detail::distribution_base>::param(); + } + + void param(const param_type& pt) { + detail::distribution_base>::param(pt); + } + + template + friend auto generate(Distr& distr, Engine& engine) -> + typename std::conditional>::type; + template + friend typename Distr::result_type generate_single(Distr& distr, Engine& engine); +}; + // Class template oneapi::mkl::rng::device::uniform_bits // // Represents discrete uniform bits random number distribution @@ -334,8 +487,8 @@ class exponential : detail::distribution_base> { typename detail::distribution_base>::param_type; exponential() - : detail::distribution_base>( - static_cast(0.0), static_cast(1.0)) {} + : detail::distribution_base>(RealType(0.0), + RealType(1.0)) {} explicit exponential(RealType a, RealType beta) : detail::distribution_base>(a, beta) {} @@ -442,7 +595,11 @@ class bernoulli : detail::distribution_base> { "oneMKL: rng/bernoulli: method is incorrect"); static_assert(std::is_same::value || - std::is_same::value, + std::is_same::value || + std::is_same::value || + std::is_same::value || + std::is_same::value || + std::is_same::value, "oneMKL: rng/bernoulli: type is not supported"); using method_type = Method; diff --git a/include/oneapi/mkl/rng/device/types.hpp b/include/oneapi/mkl/rng/device/types.hpp index e5f74e25b..6f87917f8 100644 --- a/include/oneapi/mkl/rng/device/types.hpp +++ b/include/oneapi/mkl/rng/device/types.hpp @@ -57,6 +57,18 @@ struct icdf {}; using by_default = icdf; } // namespace bernoulli_method +namespace beta_method { +struct cja {}; +struct cja_accurate {}; +using by_default = cja; +} // namespace beta_method + +namespace gamma_method { +struct marsaglia {}; +struct marsaglia_accurate {}; +using by_default = marsaglia; +} // namespace gamma_method + } // namespace oneapi::mkl::rng::device #endif // _MKL_RNG_DEVICE_TYPES_HPP_ From 71dddf31bc57f15f0c938976c53e19ad08e4e283 Mon Sep 17 00:00:00 2001 From: "Fedorov, Andrey" Date: Wed, 25 Sep 2024 06:54:44 -0700 Subject: [PATCH 02/14] changed tests and make changes for sources --- .../mkl/rng/device/detail/beta_impl.hpp | 20 +- .../mkl/rng/device/detail/gamma_impl.hpp | 6 +- .../mkl/rng/device/detail/uniform_impl.hpp | 25 +-- include/oneapi/mkl/rng/distributions.hpp | 43 ++-- .../device/include/rng_device_test_common.hpp | 46 +++++ .../unit_tests/rng/device/moments/moments.cpp | 184 ++++++++++++++++++ 6 files changed, 270 insertions(+), 54 deletions(-) diff --git a/include/oneapi/mkl/rng/device/detail/beta_impl.hpp b/include/oneapi/mkl/rng/device/detail/beta_impl.hpp index 9e64c5cde..1336279b9 100755 --- a/include/oneapi/mkl/rng/device/detail/beta_impl.hpp +++ b/include/oneapi/mkl/rng/device/detail/beta_impl.hpp @@ -393,10 +393,12 @@ class distribution_base> { res = acc_rej_kernel(res, engine); } if constexpr (std::is_same_v) { - if (res < a_) - res = a_; - if (res > a_ + b_) - res = a_ + b_; + for(std::int32_t i = 0; i < EngineType::vec_size; i++) { + if (res[i] < a_) + res[i] = a_; + if (res[i] > a_ + b_) + res[i] = a_ + b_; + } } return res; } @@ -414,10 +416,12 @@ class distribution_base> { res = acc_rej_kernel<1>(z, engine); } if constexpr (std::is_same_v) { - if (res < a_) - res = a_; - if (res > a_ + b_) - res = a_ + b_; + for(std::int32_t i = 0; i < EngineType::vec_size; i++) { + if (res[i] < a_) + res[i] = a_; + if (res[i] > a_ + b_) + res[i] = a_ + b_; + } } return res; } diff --git a/include/oneapi/mkl/rng/device/detail/gamma_impl.hpp b/include/oneapi/mkl/rng/device/detail/gamma_impl.hpp index 17e9ef28b..fa46b2a88 100755 --- a/include/oneapi/mkl/rng/device/detail/gamma_impl.hpp +++ b/include/oneapi/mkl/rng/device/detail/gamma_impl.hpp @@ -237,8 +237,10 @@ class distribution_base> { } auto res = a_ + beta_ * z; if constexpr (std::is_same_v) { - if (res < a_) - res = a_; + for(std::int32_t i = 0; i < EngineType::vec_size; i++) { + if (res[i] < a_) + res[i] = a_; + } } return res; } diff --git a/include/oneapi/mkl/rng/device/detail/uniform_impl.hpp b/include/oneapi/mkl/rng/device/detail/uniform_impl.hpp index a4d66b62b..3eadb2a66 100644 --- a/include/oneapi/mkl/rng/device/detail/uniform_impl.hpp +++ b/include/oneapi/mkl/rng/device/detail/uniform_impl.hpp @@ -21,6 +21,7 @@ #define _MKL_RNG_DEVICE_UNIFORM_IMPL_HPP_ #include +#include "engine_base.hpp" namespace oneapi::mkl::rng::device::detail { @@ -41,13 +42,13 @@ static inline std::uint64_t umul_hi_64(const std::uint64_t a, const std::uint64_ } template -static inline void generate_leftover(std::uint64_t range, Generator generate, +static inline void generate_leftover(std::uint64_t range, Generator generate, std::uint64_t& res_64, std::uint64_t& leftover) { if constexpr (std::is_same_v>) { std::uint32_t res_1 = generate(); std::uint32_t res_2 = generate(); std::uint32_t res_3 = generate(); - res_64 = (static_cast(res_3) << 62) + + res_64 = (static_cast(res_3) << 62) + (static_cast(res_2) << 31) + res_1; } else { @@ -125,7 +126,7 @@ class distribution_base> { else { // Lemire's sample rejection method to exclude bias for uniform numbers // https://arxiv.org/abs/1805.10941 - + constexpr std::uint64_t uint_max64 = std::numeric_limits::max(); constexpr std::uint64_t uint_max32 = std::numeric_limits::max(); @@ -139,14 +140,14 @@ class distribution_base> { std::uint32_t res_1, res_2; std::uint64_t res_64, leftover; - generate_leftover(range, [&engine](){return engine.generate();}, + generate_leftover(range, [&engine](){return engine.generate();}, res_64, leftover); if (range == uint_max64) return res_64; while (leftover < threshold) { - generate_leftover(range, [&engine](){return engine.generate();}, + generate_leftover(range, [&engine](){return engine.generate();}, res_64, leftover); } @@ -160,12 +161,12 @@ class distribution_base> { sycl::vec res_1 = engine.generate(); sycl::vec res_2 = engine.generate(); sycl::vec res_64; - + if constexpr (std::is_same_v>) { sycl::vec res_3 = engine.generate(); for (int i = 0; i < EngineType::vec_size; i++) { - res_64[i] = (static_cast(res_3[i]) << 62) + + res_64[i] = (static_cast(res_3[i]) << 62) + (static_cast(res_2[i]) << 31) + res_1[i]; } } @@ -186,7 +187,7 @@ class distribution_base> { } } } - + if (range == uint_max64) return res_64.template convert(); @@ -194,7 +195,7 @@ class distribution_base> { leftover = res_64[i] * range; while (leftover < threshold) { - generate_leftover(range, [&engine](){return engine.generate_single();}, + generate_leftover(range, [&engine](){return engine.generate_single();}, res_64[i], leftover); } @@ -233,7 +234,7 @@ class distribution_base> { else { // Lemire's sample rejection method to exclude bias for uniform numbers // https://arxiv.org/abs/1805.10941 - + constexpr std::uint64_t uint_max64 = std::numeric_limits::max(); constexpr std::uint64_t uint_max32 = std::numeric_limits::max(); @@ -251,14 +252,14 @@ class distribution_base> { std::uint32_t res_1, res_2; std::uint64_t res_64, leftover; - generate_leftover(range, [&engine](){return engine.generate_single();}, + generate_leftover(range, [&engine](){return engine.generate_single();}, res_64, leftover); if (range == uint_max64) return res_64; while (leftover < threshold) { - generate_leftover(range, [&engine](){return engine.generate_single();}, + generate_leftover(range, [&engine](){return engine.generate_single();}, res_64, leftover); } diff --git a/include/oneapi/mkl/rng/distributions.hpp b/include/oneapi/mkl/rng/distributions.hpp index 88d1e46e7..526fd4ab4 100644 --- a/include/oneapi/mkl/rng/distributions.hpp +++ b/include/oneapi/mkl/rng/distributions.hpp @@ -61,17 +61,24 @@ template class uniform { public: static_assert(std::is_same::value || - (std::is_same::value && - !std::is_same::value), + std::is_same::value, "rng uniform distribution method is incorrect"); - static_assert(std::is_same::value || std::is_same::value, + static_assert(std::is_same::value || std::is_same::value || + std::is_same::value || + std::is_same::value, "rng uniform distribution type is not supported"); using method_type = Method; using result_type = Type; - uniform() : uniform(static_cast(0.0f), static_cast(1.0f)) {} + uniform() + : uniform(static_cast(0.0f), + std::is_integral::value + ? (std::is_same::value + ? (1 << 23) + : (std::numeric_limits::max)()) + : static_cast(1.0f)) {} explicit uniform(Type a, Type b) : a_(a), b_(b) { if (a >= b) { @@ -93,34 +100,6 @@ class uniform { Type b_; }; -template -class uniform { -public: - using method_type = Method; - using result_type = std::int32_t; - - uniform() : uniform(0, std::numeric_limits::max()) {} - - explicit uniform(std::int32_t a, std::int32_t b) : a_(a), b_(b) { - if (a >= b) { - throw oneapi::mkl::invalid_argument("rng", "uniform", - "parameters are incorrect, a >= b"); - } - } - - std::int32_t a() const { - return a_; - } - - std::int32_t b() const { - return b_; - } - -private: - std::int32_t a_; - std::int32_t b_; -}; - // Class template oneapi::mkl::rng::gaussian // // Represents continuous normal random number distribution diff --git a/tests/unit_tests/rng/device/include/rng_device_test_common.hpp b/tests/unit_tests/rng/device/include/rng_device_test_common.hpp index 6b014f0ec..69846eb4d 100644 --- a/tests/unit_tests/rng/device/include/rng_device_test_common.hpp +++ b/tests/unit_tests/rng/device/include/rng_device_test_common.hpp @@ -315,6 +315,52 @@ struct statistics_device> { } }; +template +struct statistics_device> { + template + bool check(const std::vector& r, + const oneapi::mkl::rng::device::beta& distr) { + double tM, tD, tQ; + double b, c, d, e, e2, b2, sum_pq; + Fp p = distr.p(); + Fp q = distr.q(); + Fp a = distr.a(); + Fp beta = distr.b(); + + b2 = beta * beta; + sum_pq = p + q; + b = (p + 1.0) / (sum_pq + 1.0); + c = (p + 2.0) / (sum_pq + 2.0); + d = (p + 3.0) / (sum_pq + 3.0); + e = p / sum_pq; + e2 = e * e; + + tM = a + e * beta; + tD = b2 * p * q / (sum_pq * sum_pq * (sum_pq + 1.0)); + tQ = b2 * b2 * (e * b * c * d - 4.0 * e2 * b * c + 6.0 * e2 * e * b - 3.0 * e2 * e2); + + return compare_moments(r, tM, tD, tQ); + } +}; + +template +struct statistics_device> { + template + bool check(const std::vector& r, + const oneapi::mkl::rng::device::gamma& distr) { + double tM, tD, tQ; + Fp a = distr.a(); + Fp alpha = distr.alpha(); + Fp beta = distr.beta(); + + tM = a + beta * alpha; + tD = beta * beta * alpha; + tQ = beta * beta * beta * beta * 3 * alpha * (alpha + 2); + + return compare_moments(r, tM, tD, tQ); + } +}; + template struct statistics_device> { template diff --git a/tests/unit_tests/rng/device/moments/moments.cpp b/tests/unit_tests/rng/device/moments/moments.cpp index 36ce38ee8..21810817a 100644 --- a/tests/unit_tests/rng/device/moments/moments.cpp +++ b/tests/unit_tests/rng/device/moments/moments.cpp @@ -961,6 +961,190 @@ INSTANTIATE_TEST_SUITE_P(Philox4x32x10ExponentialIcdfAccDeviceMomentsTestsSuite, Philox4x32x10ExponentialIcdfAccDeviceMomentsTests, ::testing::ValuesIn(devices), ::DeviceNamePrint()); +class Philox4x32x10BetaCjaDeviceMomentsTests + : public ::testing::TestWithParam {}; + +class Philox4x32x10BetaCjaAccDeviceMomentsTests + : public ::testing::TestWithParam {}; + +// implementation uses double precision for accuracy +TEST_P(Philox4x32x10BetaCjaDeviceMomentsTests, RealDoublePrecision) { + CHECK_DOUBLE_ON_DEVICE(GetParam()); + + rng_device_test, + oneapi::mkl::rng::device::beta< + float, oneapi::mkl::rng::device::beta_method::cja>>> + test1; + EXPECT_TRUEORSKIP((test1(GetParam()))); + rng_device_test, + oneapi::mkl::rng::device::beta< + float, oneapi::mkl::rng::device::beta_method::cja>>> + test2; + EXPECT_TRUEORSKIP((test2(GetParam()))); + rng_device_test, + oneapi::mkl::rng::device::beta< + float, oneapi::mkl::rng::device::beta_method::cja>>> + test3; + EXPECT_TRUEORSKIP((test3(GetParam()))); + rng_device_test, + oneapi::mkl::rng::device::beta< + double, oneapi::mkl::rng::device::beta_method::cja>>> + test4; + EXPECT_TRUEORSKIP((test4(GetParam()))); + rng_device_test, + oneapi::mkl::rng::device::beta< + double, oneapi::mkl::rng::device::beta_method::cja>>> + test5; + EXPECT_TRUEORSKIP((test5(GetParam()))); + rng_device_test, + oneapi::mkl::rng::device::beta< + double, oneapi::mkl::rng::device::beta_method::cja>>> + test6; + EXPECT_TRUEORSKIP((test6(GetParam()))); +} + +// implementation uses double precision for accuracy +TEST_P(Philox4x32x10BetaCjaAccDeviceMomentsTests, RealDoublePrecision) { + CHECK_DOUBLE_ON_DEVICE(GetParam()); + + rng_device_test< + moments_test, + oneapi::mkl::rng::device::beta< + float, oneapi::mkl::rng::device::beta_method::cja_accurate>>> + test1; + EXPECT_TRUEORSKIP((test1(GetParam()))); + rng_device_test< + moments_test, + oneapi::mkl::rng::device::beta< + float, oneapi::mkl::rng::device::beta_method::cja_accurate>>> + test2; + EXPECT_TRUEORSKIP((test2(GetParam()))); + rng_device_test< + moments_test, + oneapi::mkl::rng::device::beta< + float, oneapi::mkl::rng::device::beta_method::cja_accurate>>> + test3; + EXPECT_TRUEORSKIP((test3(GetParam()))); + rng_device_test< + moments_test, + oneapi::mkl::rng::device::beta< + double, oneapi::mkl::rng::device::beta_method::cja_accurate>>> + test4; + EXPECT_TRUEORSKIP((test4(GetParam()))); + rng_device_test< + moments_test, + oneapi::mkl::rng::device::beta< + double, oneapi::mkl::rng::device::beta_method::cja_accurate>>> + test5; + EXPECT_TRUEORSKIP((test5(GetParam()))); + rng_device_test< + moments_test, + oneapi::mkl::rng::device::beta< + double, oneapi::mkl::rng::device::beta_method::cja_accurate>>> + test6; + EXPECT_TRUEORSKIP((test6(GetParam()))); +} + +INSTANTIATE_TEST_SUITE_P(Philox4x32x10BetaCjaDeviceMomentsTestsSuite, + Philox4x32x10BetaCjaDeviceMomentsTests, + ::testing::ValuesIn(devices), ::DeviceNamePrint()); + +INSTANTIATE_TEST_SUITE_P(Philox4x32x10BetaCjaAccDeviceMomentsTestsSuite, + Philox4x32x10BetaCjaAccDeviceMomentsTests, + ::testing::ValuesIn(devices), ::DeviceNamePrint()); + +class Philox4x32x10GammaMarsagliaDeviceMomentsTests + : public ::testing::TestWithParam {}; + +class Philox4x32x10GammaMarsagliaAccDeviceMomentsTests + : public ::testing::TestWithParam {}; + +// implementation uses double precision for accuracy +TEST_P(Philox4x32x10GammaMarsagliaDeviceMomentsTests, RealDoublePrecision) { + CHECK_DOUBLE_ON_DEVICE(GetParam()); + + rng_device_test, + oneapi::mkl::rng::device::gamma< + float, oneapi::mkl::rng::device::gamma_method::marsaglia>>> + test1; + EXPECT_TRUEORSKIP((test1(GetParam()))); + rng_device_test, + oneapi::mkl::rng::device::gamma< + float, oneapi::mkl::rng::device::gamma_method::marsaglia>>> + test2; + EXPECT_TRUEORSKIP((test2(GetParam()))); + rng_device_test, + oneapi::mkl::rng::device::gamma< + float, oneapi::mkl::rng::device::gamma_method::marsaglia>>> + test3; + EXPECT_TRUEORSKIP((test3(GetParam()))); + rng_device_test, + oneapi::mkl::rng::device::gamma< + double, oneapi::mkl::rng::device::gamma_method::marsaglia>>> + test4; + EXPECT_TRUEORSKIP((test4(GetParam()))); + rng_device_test, + oneapi::mkl::rng::device::gamma< + double, oneapi::mkl::rng::device::gamma_method::marsaglia>>> + test5; + EXPECT_TRUEORSKIP((test5(GetParam()))); + rng_device_test, + oneapi::mkl::rng::device::gamma< + double, oneapi::mkl::rng::device::gamma_method::marsaglia>>> + test6; + EXPECT_TRUEORSKIP((test6(GetParam()))); +} + +// implementation uses double precision for accuracy +TEST_P(Philox4x32x10GammaMarsagliaAccDeviceMomentsTests, RealDoublePrecision) { + CHECK_DOUBLE_ON_DEVICE(GetParam()); + + rng_device_test< + moments_test, + oneapi::mkl::rng::device::gamma< + float, oneapi::mkl::rng::device::gamma_method::marsaglia_accurate>>> + test1; + EXPECT_TRUEORSKIP((test1(GetParam()))); + rng_device_test< + moments_test, + oneapi::mkl::rng::device::gamma< + float, oneapi::mkl::rng::device::gamma_method::marsaglia_accurate>>> + test2; + EXPECT_TRUEORSKIP((test2(GetParam()))); + rng_device_test< + moments_test, + oneapi::mkl::rng::device::gamma< + float, oneapi::mkl::rng::device::gamma_method::marsaglia_accurate>>> + test3; + EXPECT_TRUEORSKIP((test3(GetParam()))); + rng_device_test< + moments_test, + oneapi::mkl::rng::device::gamma< + double, oneapi::mkl::rng::device::gamma_method::marsaglia_accurate>>> + test4; + EXPECT_TRUEORSKIP((test4(GetParam()))); + rng_device_test< + moments_test, + oneapi::mkl::rng::device::gamma< + double, oneapi::mkl::rng::device::gamma_method::marsaglia_accurate>>> + test5; + EXPECT_TRUEORSKIP((test5(GetParam()))); + rng_device_test< + moments_test, + oneapi::mkl::rng::device::gamma< + double, oneapi::mkl::rng::device::gamma_method::marsaglia_accurate>>> + test6; + EXPECT_TRUEORSKIP((test6(GetParam()))); +} + +INSTANTIATE_TEST_SUITE_P(Philox4x32x10GammaMarsagliaDeviceMomentsTestsSuite, + Philox4x32x10GammaMarsagliaDeviceMomentsTests, + ::testing::ValuesIn(devices), ::DeviceNamePrint()); + +INSTANTIATE_TEST_SUITE_P(Philox4x32x10GammaMarsagliaAccDeviceMomentsTestsSuite, + Philox4x32x10GammaMarsagliaAccDeviceMomentsTests, + ::testing::ValuesIn(devices), ::DeviceNamePrint()); + class Philox4x32x10PoissonDevroyeDeviceMomentsTests : public ::testing::TestWithParam {}; From ecd826e89ed472431e24030660be45416705499d Mon Sep 17 00:00:00 2001 From: "Fedorov, Andrey" Date: Wed, 25 Sep 2024 07:55:41 -0700 Subject: [PATCH 03/14] replace sycl fmin, fmax with std --- include/oneapi/mkl/rng/device/detail/uniform_impl.hpp | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/include/oneapi/mkl/rng/device/detail/uniform_impl.hpp b/include/oneapi/mkl/rng/device/detail/uniform_impl.hpp index 3eadb2a66..b6e2b5806 100644 --- a/include/oneapi/mkl/rng/device/detail/uniform_impl.hpp +++ b/include/oneapi/mkl/rng/device/detail/uniform_impl.hpp @@ -209,8 +209,8 @@ class distribution_base> { else { res = engine.generate(a_, b_); if constexpr (std::is_same::value) { - res = sycl::fmax(res, a_); - res = sycl::fmin(res, b_); + res = std::fmax(res, a_); + res = std::fmin(res, b_); } } @@ -271,8 +271,8 @@ class distribution_base> { else { res = engine.generate_single(a_, b_); if constexpr (std::is_same::value) { - res = sycl::fmax(res, a_); - res = sycl::fmin(res, b_); + res = std::fmax(res, a_); + res = std::fmin(res, b_); } } From d54605e94c00960434bfa3ae86186b97426bb310 Mon Sep 17 00:00:00 2001 From: "Fedorov, Andrey" Date: Wed, 25 Sep 2024 08:02:23 -0700 Subject: [PATCH 04/14] add header --- include/oneapi/mkl/rng/device/detail/uniform_impl.hpp | 1 + 1 file changed, 1 insertion(+) diff --git a/include/oneapi/mkl/rng/device/detail/uniform_impl.hpp b/include/oneapi/mkl/rng/device/detail/uniform_impl.hpp index b6e2b5806..3f79c0c69 100644 --- a/include/oneapi/mkl/rng/device/detail/uniform_impl.hpp +++ b/include/oneapi/mkl/rng/device/detail/uniform_impl.hpp @@ -21,6 +21,7 @@ #define _MKL_RNG_DEVICE_UNIFORM_IMPL_HPP_ #include +#include #include "engine_base.hpp" namespace oneapi::mkl::rng::device::detail { From 69c923bceba2b6165485d22ea7104ff73aca1806 Mon Sep 17 00:00:00 2001 From: "Fedorov, Andrey" Date: Wed, 25 Sep 2024 08:34:26 -0700 Subject: [PATCH 05/14] workaround for hipsycl --- .../mkl/rng/device/detail/uniform_impl.hpp | 36 ++++++++++++++++--- 1 file changed, 32 insertions(+), 4 deletions(-) diff --git a/include/oneapi/mkl/rng/device/detail/uniform_impl.hpp b/include/oneapi/mkl/rng/device/detail/uniform_impl.hpp index 3f79c0c69..632941662 100644 --- a/include/oneapi/mkl/rng/device/detail/uniform_impl.hpp +++ b/include/oneapi/mkl/rng/device/detail/uniform_impl.hpp @@ -210,8 +210,22 @@ class distribution_base> { else { res = engine.generate(a_, b_); if constexpr (std::is_same::value) { - res = std::fmax(res, a_); - res = std::fmin(res, b_); +#ifndef __HIPSYCL__ + res = sycl::fmax(res, a_); + res = sycl::fmin(res, b_); +#else + // a workaround for hipSYCL (AdaptiveCpp) + if constexpr (EngineType::vec_size == 1) { + res = std::fmax(res, a_); + res = std::fmin(res, b_); + } + else{ + for (int i = 0; i < EngineType::vec_size; i++) { + res[i] = std::fmax(res[i], a_); + res[i] = std::fmin(res[i], b_); + } + } +#endif } } @@ -272,8 +286,22 @@ class distribution_base> { else { res = engine.generate_single(a_, b_); if constexpr (std::is_same::value) { - res = std::fmax(res, a_); - res = std::fmin(res, b_); +#ifndef __HIPSYCL__ + res = sycl::fmax(res, a_); + res = sycl::fmin(res, b_); +#else + // a workaround for hipSYCL (AdaptiveCpp) + if constexpr (EngineType::vec_size == 1) { + res = std::fmax(res, a_); + res = std::fmin(res, b_); + } + else{ + for (int i = 0; i < EngineType::vec_size; i++) { + res[i] = std::fmax(res[i], a_); + res[i] = std::fmin(res[i], b_); + } + } +#endif } } From e918b48d84d82a656d0838c43a7778479765d123 Mon Sep 17 00:00:00 2001 From: "Fedorov, Andrey" Date: Thu, 26 Sep 2024 03:00:55 -0700 Subject: [PATCH 06/14] added unifrom and bernoulli test --- .../unit_tests/rng/device/moments/moments.cpp | 168 +++++++++++++++++- 1 file changed, 162 insertions(+), 6 deletions(-) diff --git a/tests/unit_tests/rng/device/moments/moments.cpp b/tests/unit_tests/rng/device/moments/moments.cpp index 21810817a..6093579e4 100644 --- a/tests/unit_tests/rng/device/moments/moments.cpp +++ b/tests/unit_tests/rng/device/moments/moments.cpp @@ -109,6 +109,48 @@ TEST_P(Philox4x32x10UniformStdDeviceMomentsTests, UnsignedIntegerPrecision) { EXPECT_TRUEORSKIP((test3(GetParam()))); } +TEST_P(Philox4x32x10UniformStdDeviceMomentsTests, Integer64Precision) { + rng_device_test< + moments_test, + oneapi::mkl::rng::device::uniform< + std::int64_t, oneapi::mkl::rng::device::uniform_method::standard>>> + test1; + EXPECT_TRUEORSKIP((test1(GetParam()))); + rng_device_test< + moments_test, + oneapi::mkl::rng::device::uniform< + std::int64_t, oneapi::mkl::rng::device::uniform_method::standard>>> + test2; + EXPECT_TRUEORSKIP((test2(GetParam()))); + rng_device_test< + moments_test, + oneapi::mkl::rng::device::uniform< + std::int64_t, oneapi::mkl::rng::device::uniform_method::standard>>> + test3; + EXPECT_TRUEORSKIP((test3(GetParam()))); +} + +TEST_P(Philox4x32x10UniformStdDeviceMomentsTests, UnsignedInteger64Precision) { + rng_device_test< + moments_test, + oneapi::mkl::rng::device::uniform< + std::uint64_t, oneapi::mkl::rng::device::uniform_method::standard>>> + test1; + EXPECT_TRUEORSKIP((test1(GetParam()))); + rng_device_test< + moments_test, + oneapi::mkl::rng::device::uniform< + std::uint64_t, oneapi::mkl::rng::device::uniform_method::standard>>> + test2; + EXPECT_TRUEORSKIP((test2(GetParam()))); + rng_device_test< + moments_test, + oneapi::mkl::rng::device::uniform< + std::uint64_t, oneapi::mkl::rng::device::uniform_method::standard>>> + test3; + EXPECT_TRUEORSKIP((test3(GetParam()))); +} + TEST_P(Philox4x32x10UniformAccDeviceMomentsTests, RealSinglePrecision) { rng_device_test, oneapi::mkl::rng::device::uniform< @@ -189,6 +231,48 @@ TEST_P(Philox4x32x10UniformAccDeviceMomentsTests, UnsignedIntegerPrecision) { EXPECT_TRUEORSKIP((test3(GetParam()))); } +TEST_P(Philox4x32x10UniformAccDeviceMomentsTests, Integer64Precision) { + rng_device_test< + moments_test, + oneapi::mkl::rng::device::uniform< + std::int64_t, oneapi::mkl::rng::device::uniform_method::accurate>>> + test1; + EXPECT_TRUEORSKIP((test1(GetParam()))); + rng_device_test< + moments_test, + oneapi::mkl::rng::device::uniform< + std::int64_t, oneapi::mkl::rng::device::uniform_method::accurate>>> + test2; + EXPECT_TRUEORSKIP((test2(GetParam()))); + rng_device_test< + moments_test, + oneapi::mkl::rng::device::uniform< + std::int64_t, oneapi::mkl::rng::device::uniform_method::accurate>>> + test3; + EXPECT_TRUEORSKIP((test3(GetParam()))); +} + +TEST_P(Philox4x32x10UniformAccDeviceMomentsTests, UnsignedInteger64Precision) { + rng_device_test< + moments_test, + oneapi::mkl::rng::device::uniform< + std::uint64_t, oneapi::mkl::rng::device::uniform_method::accurate>>> + test1; + EXPECT_TRUEORSKIP((test1(GetParam()))); + rng_device_test< + moments_test, + oneapi::mkl::rng::device::uniform< + std::uint64_t, oneapi::mkl::rng::device::uniform_method::accurate>>> + test2; + EXPECT_TRUEORSKIP((test2(GetParam()))); + rng_device_test< + moments_test, + oneapi::mkl::rng::device::uniform< + std::uint64_t, oneapi::mkl::rng::device::uniform_method::accurate>>> + test3; + EXPECT_TRUEORSKIP((test3(GetParam()))); +} + INSTANTIATE_TEST_SUITE_P(Philox4x32x10UniformStdDeviceMomentsTestsSuite, Philox4x32x10UniformStdDeviceMomentsTests, ::testing::ValuesIn(devices), ::DeviceNamePrint()); @@ -1194,17 +1278,17 @@ class Philox4x32x10BernoulliIcdfDeviceMomentsTests TEST_P(Philox4x32x10BernoulliIcdfDeviceMomentsTests, IntegerPrecision) { rng_device_test, oneapi::mkl::rng::device::bernoulli< - int32_t, oneapi::mkl::rng::device::bernoulli_method::icdf>>> + std::int32_t, oneapi::mkl::rng::device::bernoulli_method::icdf>>> test1; EXPECT_TRUEORSKIP((test1(GetParam()))); rng_device_test, oneapi::mkl::rng::device::bernoulli< - int32_t, oneapi::mkl::rng::device::bernoulli_method::icdf>>> + std::int32_t, oneapi::mkl::rng::device::bernoulli_method::icdf>>> test2; EXPECT_TRUEORSKIP((test2(GetParam()))); rng_device_test, oneapi::mkl::rng::device::bernoulli< - int32_t, oneapi::mkl::rng::device::bernoulli_method::icdf>>> + std::int32_t, oneapi::mkl::rng::device::bernoulli_method::icdf>>> test3; EXPECT_TRUEORSKIP((test3(GetParam()))); } @@ -1212,17 +1296,89 @@ TEST_P(Philox4x32x10BernoulliIcdfDeviceMomentsTests, IntegerPrecision) { TEST_P(Philox4x32x10BernoulliIcdfDeviceMomentsTests, UnsignedIntegerPrecision) { rng_device_test, oneapi::mkl::rng::device::bernoulli< - uint32_t, oneapi::mkl::rng::device::bernoulli_method::icdf>>> + std::uint32_t, oneapi::mkl::rng::device::bernoulli_method::icdf>>> + test1; + EXPECT_TRUEORSKIP((test1(GetParam()))); + rng_device_test, + oneapi::mkl::rng::device::bernoulli< + std::uint32_t, oneapi::mkl::rng::device::bernoulli_method::icdf>>> + test2; + EXPECT_TRUEORSKIP((test2(GetParam()))); + rng_device_test, + oneapi::mkl::rng::device::bernoulli< + std::uint32_t, oneapi::mkl::rng::device::bernoulli_method::icdf>>> + test3; + EXPECT_TRUEORSKIP((test3(GetParam()))); +} + +TEST_P(Philox4x32x10BernoulliIcdfDeviceMomentsTests, Integer8Precision) { + rng_device_test, + oneapi::mkl::rng::device::bernoulli< + std::int8_t, oneapi::mkl::rng::device::bernoulli_method::icdf>>> + test1; + EXPECT_TRUEORSKIP((test1(GetParam()))); + rng_device_test, + oneapi::mkl::rng::device::bernoulli< + std::int8_t, oneapi::mkl::rng::device::bernoulli_method::icdf>>> + test2; + EXPECT_TRUEORSKIP((test2(GetParam()))); + rng_device_test, + oneapi::mkl::rng::device::bernoulli< + std::int8_t, oneapi::mkl::rng::device::bernoulli_method::icdf>>> + test3; + EXPECT_TRUEORSKIP((test3(GetParam()))); +} + +TEST_P(Philox4x32x10BernoulliIcdfDeviceMomentsTests, UnsignedInteger8Precision) { + rng_device_test, + oneapi::mkl::rng::device::bernoulli< + std::uint8_t, oneapi::mkl::rng::device::bernoulli_method::icdf>>> + test1; + EXPECT_TRUEORSKIP((test1(GetParam()))); + rng_device_test, + oneapi::mkl::rng::device::bernoulli< + std::uint8_t, oneapi::mkl::rng::device::bernoulli_method::icdf>>> + test2; + EXPECT_TRUEORSKIP((test2(GetParam()))); + rng_device_test, + oneapi::mkl::rng::device::bernoulli< + std::uint8_t, oneapi::mkl::rng::device::bernoulli_method::icdf>>> + test3; + EXPECT_TRUEORSKIP((test3(GetParam()))); +} + +TEST_P(Philox4x32x10BernoulliIcdfDeviceMomentsTests, Integer16Precision) { + rng_device_test, + oneapi::mkl::rng::device::bernoulli< + std::int16_t, oneapi::mkl::rng::device::bernoulli_method::icdf>>> + test1; + EXPECT_TRUEORSKIP((test1(GetParam()))); + rng_device_test, + oneapi::mkl::rng::device::bernoulli< + std::int16_t, oneapi::mkl::rng::device::bernoulli_method::icdf>>> + test2; + EXPECT_TRUEORSKIP((test2(GetParam()))); + rng_device_test, + oneapi::mkl::rng::device::bernoulli< + std::int16_t, oneapi::mkl::rng::device::bernoulli_method::icdf>>> + test3; + EXPECT_TRUEORSKIP((test3(GetParam()))); +} + +TEST_P(Philox4x32x10BernoulliIcdfDeviceMomentsTests, UnsignedInteger16Precision) { + rng_device_test, + oneapi::mkl::rng::device::bernoulli< + std::uint16_t, oneapi::mkl::rng::device::bernoulli_method::icdf>>> test1; EXPECT_TRUEORSKIP((test1(GetParam()))); rng_device_test, oneapi::mkl::rng::device::bernoulli< - uint32_t, oneapi::mkl::rng::device::bernoulli_method::icdf>>> + std::uint16_t, oneapi::mkl::rng::device::bernoulli_method::icdf>>> test2; EXPECT_TRUEORSKIP((test2(GetParam()))); rng_device_test, oneapi::mkl::rng::device::bernoulli< - uint32_t, oneapi::mkl::rng::device::bernoulli_method::icdf>>> + std::uint16_t, oneapi::mkl::rng::device::bernoulli_method::icdf>>> test3; EXPECT_TRUEORSKIP((test3(GetParam()))); } From 6cfaee0c787023c5e9b704002e1330c95ab93cb0 Mon Sep 17 00:00:00 2001 From: "Fedorov, Andrey" Date: Thu, 26 Sep 2024 03:33:25 -0700 Subject: [PATCH 07/14] make changes for exponential using hipsycl --- .../rng/device/detail/exponential_impl.hpp | 24 +++++++++++++++++++ 1 file changed, 24 insertions(+) diff --git a/include/oneapi/mkl/rng/device/detail/exponential_impl.hpp b/include/oneapi/mkl/rng/device/detail/exponential_impl.hpp index 2550dba45..713784ecb 100644 --- a/include/oneapi/mkl/rng/device/detail/exponential_impl.hpp +++ b/include/oneapi/mkl/rng/device/detail/exponential_impl.hpp @@ -79,7 +79,19 @@ class distribution_base> } res = a_ - res * beta_; if constexpr (std::is_same::value) { +#ifndef __HIPSYCL__ res = sycl::fmax(res, a_); +#else + // a workaround for hipSYCL (AdaptiveCpp) + if constexpr (EngineType::vec_size == 1) { + res = std::fmax(res, a_); + } + else{ + for (int i = 0; i < EngineType::vec_size; i++) { + res[i] = std::fmax(res[i], a_); + } + } +#endif } return res; } @@ -90,7 +102,19 @@ class distribution_base> res = ln_wrapper(res); res = a_ - res * beta_; if constexpr (std::is_same::value) { +#ifndef __HIPSYCL__ res = sycl::fmax(res, a_); +#else + // a workaround for hipSYCL (AdaptiveCpp) + if constexpr (EngineType::vec_size == 1) { + res = std::fmax(res, a_); + } + else{ + for (int i = 0; i < EngineType::vec_size; i++) { + res[i] = std::fmax(res[i], a_); + } + } +#endif } return res; } From b59c371a8a5652b7811c2e22e71d3e89d786fedd Mon Sep 17 00:00:00 2001 From: "Fedorov, Andrey" Date: Thu, 26 Sep 2024 08:18:26 -0700 Subject: [PATCH 08/14] added specialization for integer 64 types checks --- .../device/include/rng_device_test_common.hpp | 38 +++++++++++++++++++ 1 file changed, 38 insertions(+) diff --git a/tests/unit_tests/rng/device/include/rng_device_test_common.hpp b/tests/unit_tests/rng/device/include/rng_device_test_common.hpp index 69846eb4d..aa2a54e09 100644 --- a/tests/unit_tests/rng/device/include/rng_device_test_common.hpp +++ b/tests/unit_tests/rng/device/include/rng_device_test_common.hpp @@ -226,6 +226,44 @@ struct statistics_device +struct statistics_device> { + template + bool check(const std::vector& r, + const oneapi::mkl::rng::device::uniform& distr) { + double tM, tD, tQ; + double a = distr.a(); + double b = distr.b(); + + // Theoretical moments + tM = (a + b - 1.0) / 2.0; + tD = ((b - a) * (b - a) - 1.0) / 12.0; + tQ = (((b - a) * (b - a)) * ((1.0 / 80.0) * (b - a) * (b - a) - (1.0 / 24.0))) + + (7.0 / 240.0); + + return compare_moments(r, tM, tD, tQ); + } +}; + +template +struct statistics_device> { + template + bool check(const std::vector& r, + const oneapi::mkl::rng::device::uniform& distr) { + double tM, tD, tQ; + double a = distr.a(); + double b = distr.b(); + + // Theoretical moments + tM = (a + b - 1.0) / 2.0; + tD = ((b - a) * (b - a) - 1.0) / 12.0; + tQ = (((b - a) * (b - a)) * ((1.0 / 80.0) * (b - a) * (b - a) - (1.0 / 24.0))) + + (7.0 / 240.0); + + return compare_moments(r, tM, tD, tQ); + } +}; + template struct statistics_device> { template From b89743669bd5a869c64a09d970bd468fb5862d68 Mon Sep 17 00:00:00 2001 From: "Fedorov, Andrey" Date: Mon, 30 Sep 2024 04:32:39 -0700 Subject: [PATCH 09/14] revert changes for host api --- include/oneapi/mkl/rng/distributions.hpp | 43 ++++++++++++++++++------ 1 file changed, 32 insertions(+), 11 deletions(-) diff --git a/include/oneapi/mkl/rng/distributions.hpp b/include/oneapi/mkl/rng/distributions.hpp index 526fd4ab4..88d1e46e7 100644 --- a/include/oneapi/mkl/rng/distributions.hpp +++ b/include/oneapi/mkl/rng/distributions.hpp @@ -61,24 +61,17 @@ template class uniform { public: static_assert(std::is_same::value || - std::is_same::value, + (std::is_same::value && + !std::is_same::value), "rng uniform distribution method is incorrect"); - static_assert(std::is_same::value || std::is_same::value || - std::is_same::value || - std::is_same::value, + static_assert(std::is_same::value || std::is_same::value, "rng uniform distribution type is not supported"); using method_type = Method; using result_type = Type; - uniform() - : uniform(static_cast(0.0f), - std::is_integral::value - ? (std::is_same::value - ? (1 << 23) - : (std::numeric_limits::max)()) - : static_cast(1.0f)) {} + uniform() : uniform(static_cast(0.0f), static_cast(1.0f)) {} explicit uniform(Type a, Type b) : a_(a), b_(b) { if (a >= b) { @@ -100,6 +93,34 @@ class uniform { Type b_; }; +template +class uniform { +public: + using method_type = Method; + using result_type = std::int32_t; + + uniform() : uniform(0, std::numeric_limits::max()) {} + + explicit uniform(std::int32_t a, std::int32_t b) : a_(a), b_(b) { + if (a >= b) { + throw oneapi::mkl::invalid_argument("rng", "uniform", + "parameters are incorrect, a >= b"); + } + } + + std::int32_t a() const { + return a_; + } + + std::int32_t b() const { + return b_; + } + +private: + std::int32_t a_; + std::int32_t b_; +}; + // Class template oneapi::mkl::rng::gaussian // // Represents continuous normal random number distribution From 21dd603c291b600673019b65ad737778535ebee9 Mon Sep 17 00:00:00 2001 From: "Fedorov, Andrey" Date: Mon, 30 Sep 2024 05:31:14 -0700 Subject: [PATCH 10/14] added check for doubles --- tests/unit_tests/rng/device/moments/moments.cpp | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/tests/unit_tests/rng/device/moments/moments.cpp b/tests/unit_tests/rng/device/moments/moments.cpp index 6093579e4..5fc14b7fd 100644 --- a/tests/unit_tests/rng/device/moments/moments.cpp +++ b/tests/unit_tests/rng/device/moments/moments.cpp @@ -110,6 +110,8 @@ TEST_P(Philox4x32x10UniformStdDeviceMomentsTests, UnsignedIntegerPrecision) { } TEST_P(Philox4x32x10UniformStdDeviceMomentsTests, Integer64Precision) { + CHECK_DOUBLE_ON_DEVICE(GetParam()); + rng_device_test< moments_test, oneapi::mkl::rng::device::uniform< @@ -131,6 +133,8 @@ TEST_P(Philox4x32x10UniformStdDeviceMomentsTests, Integer64Precision) { } TEST_P(Philox4x32x10UniformStdDeviceMomentsTests, UnsignedInteger64Precision) { + CHECK_DOUBLE_ON_DEVICE(GetParam()); + rng_device_test< moments_test, oneapi::mkl::rng::device::uniform< From 2bb30f5ab92b75813ac920dd4d08d9eaec16e59f Mon Sep 17 00:00:00 2001 From: "Fedorov, Andrey" Date: Mon, 30 Sep 2024 07:56:35 -0700 Subject: [PATCH 11/14] applied feedback --- .../device/include/rng_device_test_common.hpp | 16 ++-- .../unit_tests/rng/device/moments/moments.cpp | 90 +++++++++++-------- 2 files changed, 59 insertions(+), 47 deletions(-) diff --git a/tests/unit_tests/rng/device/include/rng_device_test_common.hpp b/tests/unit_tests/rng/device/include/rng_device_test_common.hpp index aa2a54e09..33533255e 100644 --- a/tests/unit_tests/rng/device/include/rng_device_test_common.hpp +++ b/tests/unit_tests/rng/device/include/rng_device_test_common.hpp @@ -191,8 +191,8 @@ struct statistics_device> { template struct statistics_device> { template - bool check(const std::vector& r, - const oneapi::mkl::rng::device::uniform& distr) { + bool check(const std::vector& r, + const oneapi::mkl::rng::device::uniform& distr) { double tM, tD, tQ; double a = distr.a(); double b = distr.b(); @@ -210,8 +210,8 @@ struct statistics_device template struct statistics_device> { template - bool check(const std::vector& r, - const oneapi::mkl::rng::device::uniform& distr) { + bool check(const std::vector& r, + const oneapi::mkl::rng::device::uniform& distr) { double tM, tD, tQ; double a = distr.a(); double b = distr.b(); @@ -229,8 +229,8 @@ struct statistics_device struct statistics_device> { template - bool check(const std::vector& r, - const oneapi::mkl::rng::device::uniform& distr) { + bool check(const std::vector& r, + const oneapi::mkl::rng::device::uniform& distr) { double tM, tD, tQ; double a = distr.a(); double b = distr.b(); @@ -248,8 +248,8 @@ struct statistics_device template struct statistics_device> { template - bool check(const std::vector& r, - const oneapi::mkl::rng::device::uniform& distr) { + bool check(const std::vector& r, + const oneapi::mkl::rng::device::uniform& distr) { double tM, tD, tQ; double a = distr.a(); double b = distr.b(); diff --git a/tests/unit_tests/rng/device/moments/moments.cpp b/tests/unit_tests/rng/device/moments/moments.cpp index 5fc14b7fd..8e5e55239 100644 --- a/tests/unit_tests/rng/device/moments/moments.cpp +++ b/tests/unit_tests/rng/device/moments/moments.cpp @@ -1055,9 +1055,7 @@ class Philox4x32x10BetaCjaDeviceMomentsTests class Philox4x32x10BetaCjaAccDeviceMomentsTests : public ::testing::TestWithParam {}; -// implementation uses double precision for accuracy -TEST_P(Philox4x32x10BetaCjaDeviceMomentsTests, RealDoublePrecision) { - CHECK_DOUBLE_ON_DEVICE(GetParam()); +TEST_P(Philox4x32x10BetaCjaDeviceMomentsTests, RealSinglePrecision) { rng_device_test, oneapi::mkl::rng::device::beta< @@ -1074,26 +1072,29 @@ TEST_P(Philox4x32x10BetaCjaDeviceMomentsTests, RealDoublePrecision) { float, oneapi::mkl::rng::device::beta_method::cja>>> test3; EXPECT_TRUEORSKIP((test3(GetParam()))); +} + +TEST_P(Philox4x32x10BetaCjaDeviceMomentsTests, RealDoublePrecision) { + CHECK_DOUBLE_ON_DEVICE(GetParam()); + rng_device_test, oneapi::mkl::rng::device::beta< double, oneapi::mkl::rng::device::beta_method::cja>>> - test4; - EXPECT_TRUEORSKIP((test4(GetParam()))); + test1; + EXPECT_TRUEORSKIP((test1(GetParam()))); rng_device_test, oneapi::mkl::rng::device::beta< double, oneapi::mkl::rng::device::beta_method::cja>>> - test5; - EXPECT_TRUEORSKIP((test5(GetParam()))); + test2; + EXPECT_TRUEORSKIP((test2(GetParam()))); rng_device_test, oneapi::mkl::rng::device::beta< double, oneapi::mkl::rng::device::beta_method::cja>>> - test6; - EXPECT_TRUEORSKIP((test6(GetParam()))); + test3; + EXPECT_TRUEORSKIP((test3(GetParam()))); } -// implementation uses double precision for accuracy -TEST_P(Philox4x32x10BetaCjaAccDeviceMomentsTests, RealDoublePrecision) { - CHECK_DOUBLE_ON_DEVICE(GetParam()); +TEST_P(Philox4x32x10BetaCjaAccDeviceMomentsTests, RealSinglePrecision) { rng_device_test< moments_test, @@ -1113,24 +1114,29 @@ TEST_P(Philox4x32x10BetaCjaAccDeviceMomentsTests, RealDoublePrecision) { float, oneapi::mkl::rng::device::beta_method::cja_accurate>>> test3; EXPECT_TRUEORSKIP((test3(GetParam()))); +} + +TEST_P(Philox4x32x10BetaCjaAccDeviceMomentsTests, RealDoublePrecision) { + CHECK_DOUBLE_ON_DEVICE(GetParam()); + rng_device_test< moments_test, oneapi::mkl::rng::device::beta< double, oneapi::mkl::rng::device::beta_method::cja_accurate>>> - test4; - EXPECT_TRUEORSKIP((test4(GetParam()))); + test1; + EXPECT_TRUEORSKIP((test1(GetParam()))); rng_device_test< moments_test, oneapi::mkl::rng::device::beta< double, oneapi::mkl::rng::device::beta_method::cja_accurate>>> - test5; - EXPECT_TRUEORSKIP((test5(GetParam()))); + test2; + EXPECT_TRUEORSKIP((test2(GetParam()))); rng_device_test< moments_test, oneapi::mkl::rng::device::beta< double, oneapi::mkl::rng::device::beta_method::cja_accurate>>> - test6; - EXPECT_TRUEORSKIP((test6(GetParam()))); + test3; + EXPECT_TRUEORSKIP((test3(GetParam()))); } INSTANTIATE_TEST_SUITE_P(Philox4x32x10BetaCjaDeviceMomentsTestsSuite, @@ -1147,9 +1153,7 @@ class Philox4x32x10GammaMarsagliaDeviceMomentsTests class Philox4x32x10GammaMarsagliaAccDeviceMomentsTests : public ::testing::TestWithParam {}; -// implementation uses double precision for accuracy -TEST_P(Philox4x32x10GammaMarsagliaDeviceMomentsTests, RealDoublePrecision) { - CHECK_DOUBLE_ON_DEVICE(GetParam()); +TEST_P(Philox4x32x10GammaMarsagliaDeviceMomentsTests, RealSinglePrecision) { rng_device_test, oneapi::mkl::rng::device::gamma< @@ -1166,26 +1170,29 @@ TEST_P(Philox4x32x10GammaMarsagliaDeviceMomentsTests, RealDoublePrecision) { float, oneapi::mkl::rng::device::gamma_method::marsaglia>>> test3; EXPECT_TRUEORSKIP((test3(GetParam()))); +} + +TEST_P(Philox4x32x10GammaMarsagliaDeviceMomentsTests, RealDoublePrecision) { + CHECK_DOUBLE_ON_DEVICE(GetParam()); + rng_device_test, oneapi::mkl::rng::device::gamma< double, oneapi::mkl::rng::device::gamma_method::marsaglia>>> - test4; - EXPECT_TRUEORSKIP((test4(GetParam()))); + test1; + EXPECT_TRUEORSKIP((test1(GetParam()))); rng_device_test, oneapi::mkl::rng::device::gamma< double, oneapi::mkl::rng::device::gamma_method::marsaglia>>> - test5; - EXPECT_TRUEORSKIP((test5(GetParam()))); + test2; + EXPECT_TRUEORSKIP((test2(GetParam()))); rng_device_test, oneapi::mkl::rng::device::gamma< double, oneapi::mkl::rng::device::gamma_method::marsaglia>>> - test6; - EXPECT_TRUEORSKIP((test6(GetParam()))); + test3; + EXPECT_TRUEORSKIP((test3(GetParam()))); } -// implementation uses double precision for accuracy -TEST_P(Philox4x32x10GammaMarsagliaAccDeviceMomentsTests, RealDoublePrecision) { - CHECK_DOUBLE_ON_DEVICE(GetParam()); +TEST_P(Philox4x32x10GammaMarsagliaAccDeviceMomentsTests, RealSinglePrecision) { rng_device_test< moments_test, @@ -1205,24 +1212,29 @@ TEST_P(Philox4x32x10GammaMarsagliaAccDeviceMomentsTests, RealDoublePrecision) { float, oneapi::mkl::rng::device::gamma_method::marsaglia_accurate>>> test3; EXPECT_TRUEORSKIP((test3(GetParam()))); +} + +TEST_P(Philox4x32x10GammaMarsagliaAccDeviceMomentsTests, RealDoublePrecision) { + CHECK_DOUBLE_ON_DEVICE(GetParam()); + rng_device_test< moments_test, oneapi::mkl::rng::device::gamma< double, oneapi::mkl::rng::device::gamma_method::marsaglia_accurate>>> - test4; - EXPECT_TRUEORSKIP((test4(GetParam()))); + test1; + EXPECT_TRUEORSKIP((test1(GetParam()))); rng_device_test< moments_test, oneapi::mkl::rng::device::gamma< double, oneapi::mkl::rng::device::gamma_method::marsaglia_accurate>>> - test5; - EXPECT_TRUEORSKIP((test5(GetParam()))); + test2; + EXPECT_TRUEORSKIP((test2(GetParam()))); rng_device_test< moments_test, oneapi::mkl::rng::device::gamma< double, oneapi::mkl::rng::device::gamma_method::marsaglia_accurate>>> - test6; - EXPECT_TRUEORSKIP((test6(GetParam()))); + test3; + EXPECT_TRUEORSKIP((test3(GetParam()))); } INSTANTIATE_TEST_SUITE_P(Philox4x32x10GammaMarsagliaDeviceMomentsTestsSuite, @@ -1257,17 +1269,17 @@ TEST_P(Philox4x32x10PoissonDevroyeDeviceMomentsTests, IntegerPrecision) { TEST_P(Philox4x32x10PoissonDevroyeDeviceMomentsTests, UnsignedIntegerPrecision) { rng_device_test, oneapi::mkl::rng::device::poisson< - uint32_t, oneapi::mkl::rng::device::poisson_method::devroye>>> + std::uint32_t, oneapi::mkl::rng::device::poisson_method::devroye>>> test1; EXPECT_TRUEORSKIP((test1(GetParam()))); rng_device_test, oneapi::mkl::rng::device::poisson< - uint32_t, oneapi::mkl::rng::device::poisson_method::devroye>>> + std::uint32_t, oneapi::mkl::rng::device::poisson_method::devroye>>> test2; EXPECT_TRUEORSKIP((test2(GetParam()))); rng_device_test, oneapi::mkl::rng::device::poisson< - uint32_t, oneapi::mkl::rng::device::poisson_method::devroye>>> + std::uint32_t, oneapi::mkl::rng::device::poisson_method::devroye>>> test3; EXPECT_TRUEORSKIP((test3(GetParam()))); } From e4855c497815956c1f7bdffd8699a1d586deda6f Mon Sep 17 00:00:00 2001 From: "Fedorov, Andrey" Date: Mon, 30 Sep 2024 08:15:20 -0700 Subject: [PATCH 12/14] added more double skip --- tests/unit_tests/rng/device/moments/moments.cpp | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/tests/unit_tests/rng/device/moments/moments.cpp b/tests/unit_tests/rng/device/moments/moments.cpp index 8e5e55239..a525dd758 100644 --- a/tests/unit_tests/rng/device/moments/moments.cpp +++ b/tests/unit_tests/rng/device/moments/moments.cpp @@ -236,6 +236,8 @@ TEST_P(Philox4x32x10UniformAccDeviceMomentsTests, UnsignedIntegerPrecision) { } TEST_P(Philox4x32x10UniformAccDeviceMomentsTests, Integer64Precision) { + CHECK_DOUBLE_ON_DEVICE(GetParam()); + rng_device_test< moments_test, oneapi::mkl::rng::device::uniform< @@ -257,6 +259,8 @@ TEST_P(Philox4x32x10UniformAccDeviceMomentsTests, Integer64Precision) { } TEST_P(Philox4x32x10UniformAccDeviceMomentsTests, UnsignedInteger64Precision) { + CHECK_DOUBLE_ON_DEVICE(GetParam()); + rng_device_test< moments_test, oneapi::mkl::rng::device::uniform< From 0247c0d4de1ab3cfc855e5279f4ff7c1436fcd34 Mon Sep 17 00:00:00 2001 From: "Fedorov, Andrey" Date: Mon, 30 Sep 2024 08:51:24 -0700 Subject: [PATCH 13/14] clang format --- .../mkl/rng/device/detail/beta_impl.hpp | 4 +- .../rng/device/detail/exponential_impl.hpp | 4 +- .../mkl/rng/device/detail/gamma_impl.hpp | 2 +- .../mkl/rng/device/detail/uniform_impl.hpp | 58 ++-- .../oneapi/mkl/rng/device/distributions.hpp | 8 +- .../unit_tests/rng/device/moments/moments.cpp | 279 +++++++++--------- 6 files changed, 185 insertions(+), 170 deletions(-) mode change 100755 => 100644 include/oneapi/mkl/rng/device/detail/beta_impl.hpp mode change 100755 => 100644 include/oneapi/mkl/rng/device/detail/gamma_impl.hpp diff --git a/include/oneapi/mkl/rng/device/detail/beta_impl.hpp b/include/oneapi/mkl/rng/device/detail/beta_impl.hpp old mode 100755 new mode 100644 index 1336279b9..e412ee157 --- a/include/oneapi/mkl/rng/device/detail/beta_impl.hpp +++ b/include/oneapi/mkl/rng/device/detail/beta_impl.hpp @@ -393,7 +393,7 @@ class distribution_base> { res = acc_rej_kernel(res, engine); } if constexpr (std::is_same_v) { - for(std::int32_t i = 0; i < EngineType::vec_size; i++) { + for (std::int32_t i = 0; i < EngineType::vec_size; i++) { if (res[i] < a_) res[i] = a_; if (res[i] > a_ + b_) @@ -416,7 +416,7 @@ class distribution_base> { res = acc_rej_kernel<1>(z, engine); } if constexpr (std::is_same_v) { - for(std::int32_t i = 0; i < EngineType::vec_size; i++) { + for (std::int32_t i = 0; i < EngineType::vec_size; i++) { if (res[i] < a_) res[i] = a_; if (res[i] > a_ + b_) diff --git a/include/oneapi/mkl/rng/device/detail/exponential_impl.hpp b/include/oneapi/mkl/rng/device/detail/exponential_impl.hpp index 713784ecb..9419fc154 100644 --- a/include/oneapi/mkl/rng/device/detail/exponential_impl.hpp +++ b/include/oneapi/mkl/rng/device/detail/exponential_impl.hpp @@ -86,7 +86,7 @@ class distribution_base> if constexpr (EngineType::vec_size == 1) { res = std::fmax(res, a_); } - else{ + else { for (int i = 0; i < EngineType::vec_size; i++) { res[i] = std::fmax(res[i], a_); } @@ -109,7 +109,7 @@ class distribution_base> if constexpr (EngineType::vec_size == 1) { res = std::fmax(res, a_); } - else{ + else { for (int i = 0; i < EngineType::vec_size; i++) { res[i] = std::fmax(res[i], a_); } diff --git a/include/oneapi/mkl/rng/device/detail/gamma_impl.hpp b/include/oneapi/mkl/rng/device/detail/gamma_impl.hpp old mode 100755 new mode 100644 index fa46b2a88..11397a69d --- a/include/oneapi/mkl/rng/device/detail/gamma_impl.hpp +++ b/include/oneapi/mkl/rng/device/detail/gamma_impl.hpp @@ -237,7 +237,7 @@ class distribution_base> { } auto res = a_ + beta_ * z; if constexpr (std::is_same_v) { - for(std::int32_t i = 0; i < EngineType::vec_size; i++) { + for (std::int32_t i = 0; i < EngineType::vec_size; i++) { if (res[i] < a_) res[i] = a_; } diff --git a/include/oneapi/mkl/rng/device/detail/uniform_impl.hpp b/include/oneapi/mkl/rng/device/detail/uniform_impl.hpp index 632941662..2427a6866 100644 --- a/include/oneapi/mkl/rng/device/detail/uniform_impl.hpp +++ b/include/oneapi/mkl/rng/device/detail/uniform_impl.hpp @@ -37,20 +37,21 @@ static inline std::uint64_t umul_hi_64(const std::uint64_t a, const std::uint64_ const std::uint64_t ab_md = a_hi * b_lo; const std::uint64_t ba_md = b_hi * a_lo; - const std::uint64_t bias = ((ab_md & 0xFFFFFFFFULL) + (ba_md & 0xFFFFFFFFULL) + (ab_lo >> 32)) >> 32; + const std::uint64_t bias = + ((ab_md & 0xFFFFFFFFULL) + (ba_md & 0xFFFFFFFFULL) + (ab_lo >> 32)) >> 32; return ab_hi + (ab_md >> 32) + (ba_md >> 32) + bias; } template -static inline void generate_leftover(std::uint64_t range, Generator generate, - std::uint64_t& res_64, std::uint64_t& leftover) { +static inline void generate_leftover(std::uint64_t range, Generator generate, std::uint64_t& res_64, + std::uint64_t& leftover) { if constexpr (std::is_same_v>) { std::uint32_t res_1 = generate(); std::uint32_t res_2 = generate(); std::uint32_t res_3 = generate(); res_64 = (static_cast(res_3) << 62) + - (static_cast(res_2) << 31) + res_1; + (static_cast(res_2) << 31) + res_1; } else { std::uint32_t res_1 = generate(); @@ -121,7 +122,8 @@ class distribution_base> { float>::type; OutType res; if constexpr (std::is_integral::value) { - if constexpr (std::is_same_v || std::is_same_v) { + if constexpr (std::is_same_v || + std::is_same_v) { return generate_single_int(engine); } else { @@ -141,15 +143,15 @@ class distribution_base> { std::uint32_t res_1, res_2; std::uint64_t res_64, leftover; - generate_leftover(range, [&engine](){return engine.generate();}, - res_64, leftover); + generate_leftover( + range, [&engine]() { return engine.generate(); }, res_64, leftover); if (range == uint_max64) return res_64; while (leftover < threshold) { - generate_leftover(range, [&engine](){return engine.generate();}, - res_64, leftover); + generate_leftover( + range, [&engine]() { return engine.generate(); }, res_64, leftover); } res = a_ + umul_hi_64(res_64, range); @@ -168,23 +170,25 @@ class distribution_base> { for (int i = 0; i < EngineType::vec_size; i++) { res_64[i] = (static_cast(res_3[i]) << 62) + - (static_cast(res_2[i]) << 31) + res_1[i]; + (static_cast(res_2[i]) << 31) + res_1[i]; } } else { if constexpr (EngineType::vec_size == 3) { res_64[0] = (static_cast(res_1[1]) << 32) + - static_cast(res_1[0]); + static_cast(res_1[0]); res_64[1] = (static_cast(res_2[0]) << 32) + - static_cast(res_1[2]); + static_cast(res_1[2]); res_64[2] = (static_cast(res_2[2]) << 32) + - static_cast(res_2[1]); - } else { + static_cast(res_2[1]); + } + else { for (int i = 0; i < EngineType::vec_size / 2; i++) { res_64[i] = (static_cast(res_1[2 * i + 1]) << 32) + - static_cast(res_1[2 * i]); - res_64[i + EngineType::vec_size / 2] = (static_cast(res_2[2 * i + 1]) << 32) + - static_cast(res_2[2 * i]); + static_cast(res_1[2 * i]); + res_64[i + EngineType::vec_size / 2] = + (static_cast(res_2[2 * i + 1]) << 32) + + static_cast(res_2[2 * i]); } } } @@ -196,8 +200,9 @@ class distribution_base> { leftover = res_64[i] * range; while (leftover < threshold) { - generate_leftover(range, [&engine](){return engine.generate_single();}, - res_64[i], leftover); + generate_leftover( + range, [&engine]() { return engine.generate_single(); }, res_64[i], + leftover); } res[i] = a_ + umul_hi_64(res_64[i], range); @@ -219,7 +224,7 @@ class distribution_base> { res = std::fmax(res, a_); res = std::fmin(res, b_); } - else{ + else { for (int i = 0; i < EngineType::vec_size; i++) { res[i] = std::fmax(res[i], a_); res[i] = std::fmin(res[i], b_); @@ -239,7 +244,8 @@ class distribution_base> { float>::type; Type res; if constexpr (std::is_integral::value) { - if constexpr (std::is_same_v || std::is_same_v) { + if constexpr (std::is_same_v || + std::is_same_v) { FpType res_fp = engine.generate_single(static_cast(a_), static_cast(b_)); res_fp = sycl::floor(res_fp); @@ -267,15 +273,15 @@ class distribution_base> { std::uint32_t res_1, res_2; std::uint64_t res_64, leftover; - generate_leftover(range, [&engine](){return engine.generate_single();}, - res_64, leftover); + generate_leftover( + range, [&engine]() { return engine.generate_single(); }, res_64, leftover); if (range == uint_max64) return res_64; while (leftover < threshold) { - generate_leftover(range, [&engine](){return engine.generate_single();}, - res_64, leftover); + generate_leftover( + range, [&engine]() { return engine.generate_single(); }, res_64, leftover); } res = a_ + umul_hi_64(res_64, range); @@ -295,7 +301,7 @@ class distribution_base> { res = std::fmax(res, a_); res = std::fmin(res, b_); } - else{ + else { for (int i = 0; i < EngineType::vec_size; i++) { res[i] = std::fmax(res[i], a_); res[i] = std::fmin(res[i], b_); diff --git a/include/oneapi/mkl/rng/device/distributions.hpp b/include/oneapi/mkl/rng/device/distributions.hpp index 5051a670a..121e81aa3 100644 --- a/include/oneapi/mkl/rng/device/distributions.hpp +++ b/include/oneapi/mkl/rng/device/distributions.hpp @@ -76,10 +76,10 @@ class uniform : detail::distribution_base> { Type(0.0), std::is_integral::value ? ((std::is_same_v || std::is_same_v) - ? (std::numeric_limits::max)() - : (std::is_same::value - ? (1 << 23) - : (std::numeric_limits::max)())) + ? (std::numeric_limits::max)() + : (std::is_same::value + ? (1 << 23) + : (std::numeric_limits::max)())) : Type(1.0)) {} explicit uniform(Type a, Type b) : detail::distribution_base>(a, b) {} diff --git a/tests/unit_tests/rng/device/moments/moments.cpp b/tests/unit_tests/rng/device/moments/moments.cpp index a525dd758..3ae45d657 100644 --- a/tests/unit_tests/rng/device/moments/moments.cpp +++ b/tests/unit_tests/rng/device/moments/moments.cpp @@ -1053,27 +1053,24 @@ INSTANTIATE_TEST_SUITE_P(Philox4x32x10ExponentialIcdfAccDeviceMomentsTestsSuite, Philox4x32x10ExponentialIcdfAccDeviceMomentsTests, ::testing::ValuesIn(devices), ::DeviceNamePrint()); -class Philox4x32x10BetaCjaDeviceMomentsTests - : public ::testing::TestWithParam {}; +class Philox4x32x10BetaCjaDeviceMomentsTests : public ::testing::TestWithParam {}; -class Philox4x32x10BetaCjaAccDeviceMomentsTests - : public ::testing::TestWithParam {}; +class Philox4x32x10BetaCjaAccDeviceMomentsTests : public ::testing::TestWithParam {}; TEST_P(Philox4x32x10BetaCjaDeviceMomentsTests, RealSinglePrecision) { - - rng_device_test, - oneapi::mkl::rng::device::beta< - float, oneapi::mkl::rng::device::beta_method::cja>>> + rng_device_test, + oneapi::mkl::rng::device::beta>> test1; EXPECT_TRUEORSKIP((test1(GetParam()))); - rng_device_test, - oneapi::mkl::rng::device::beta< - float, oneapi::mkl::rng::device::beta_method::cja>>> + rng_device_test, + oneapi::mkl::rng::device::beta>> test2; EXPECT_TRUEORSKIP((test2(GetParam()))); - rng_device_test, - oneapi::mkl::rng::device::beta< - float, oneapi::mkl::rng::device::beta_method::cja>>> + rng_device_test, + oneapi::mkl::rng::device::beta>> test3; EXPECT_TRUEORSKIP((test3(GetParam()))); } @@ -1081,41 +1078,37 @@ TEST_P(Philox4x32x10BetaCjaDeviceMomentsTests, RealSinglePrecision) { TEST_P(Philox4x32x10BetaCjaDeviceMomentsTests, RealDoublePrecision) { CHECK_DOUBLE_ON_DEVICE(GetParam()); - rng_device_test, - oneapi::mkl::rng::device::beta< - double, oneapi::mkl::rng::device::beta_method::cja>>> + rng_device_test, + oneapi::mkl::rng::device::beta>> test1; EXPECT_TRUEORSKIP((test1(GetParam()))); - rng_device_test, - oneapi::mkl::rng::device::beta< - double, oneapi::mkl::rng::device::beta_method::cja>>> + rng_device_test, + oneapi::mkl::rng::device::beta>> test2; EXPECT_TRUEORSKIP((test2(GetParam()))); - rng_device_test, - oneapi::mkl::rng::device::beta< - double, oneapi::mkl::rng::device::beta_method::cja>>> + rng_device_test, + oneapi::mkl::rng::device::beta>> test3; EXPECT_TRUEORSKIP((test3(GetParam()))); } TEST_P(Philox4x32x10BetaCjaAccDeviceMomentsTests, RealSinglePrecision) { - - rng_device_test< - moments_test, - oneapi::mkl::rng::device::beta< - float, oneapi::mkl::rng::device::beta_method::cja_accurate>>> + rng_device_test, + oneapi::mkl::rng::device::beta>> test1; EXPECT_TRUEORSKIP((test1(GetParam()))); - rng_device_test< - moments_test, - oneapi::mkl::rng::device::beta< - float, oneapi::mkl::rng::device::beta_method::cja_accurate>>> + rng_device_test, + oneapi::mkl::rng::device::beta>> test2; EXPECT_TRUEORSKIP((test2(GetParam()))); - rng_device_test< - moments_test, - oneapi::mkl::rng::device::beta< - float, oneapi::mkl::rng::device::beta_method::cja_accurate>>> + rng_device_test, + oneapi::mkl::rng::device::beta>> test3; EXPECT_TRUEORSKIP((test3(GetParam()))); } @@ -1123,33 +1116,30 @@ TEST_P(Philox4x32x10BetaCjaAccDeviceMomentsTests, RealSinglePrecision) { TEST_P(Philox4x32x10BetaCjaAccDeviceMomentsTests, RealDoublePrecision) { CHECK_DOUBLE_ON_DEVICE(GetParam()); - rng_device_test< - moments_test, - oneapi::mkl::rng::device::beta< - double, oneapi::mkl::rng::device::beta_method::cja_accurate>>> + rng_device_test, + oneapi::mkl::rng::device::beta< + double, oneapi::mkl::rng::device::beta_method::cja_accurate>>> test1; EXPECT_TRUEORSKIP((test1(GetParam()))); - rng_device_test< - moments_test, - oneapi::mkl::rng::device::beta< - double, oneapi::mkl::rng::device::beta_method::cja_accurate>>> + rng_device_test, + oneapi::mkl::rng::device::beta< + double, oneapi::mkl::rng::device::beta_method::cja_accurate>>> test2; EXPECT_TRUEORSKIP((test2(GetParam()))); - rng_device_test< - moments_test, - oneapi::mkl::rng::device::beta< - double, oneapi::mkl::rng::device::beta_method::cja_accurate>>> + rng_device_test, + oneapi::mkl::rng::device::beta< + double, oneapi::mkl::rng::device::beta_method::cja_accurate>>> test3; EXPECT_TRUEORSKIP((test3(GetParam()))); } INSTANTIATE_TEST_SUITE_P(Philox4x32x10BetaCjaDeviceMomentsTestsSuite, - Philox4x32x10BetaCjaDeviceMomentsTests, - ::testing::ValuesIn(devices), ::DeviceNamePrint()); + Philox4x32x10BetaCjaDeviceMomentsTests, ::testing::ValuesIn(devices), + ::DeviceNamePrint()); INSTANTIATE_TEST_SUITE_P(Philox4x32x10BetaCjaAccDeviceMomentsTestsSuite, - Philox4x32x10BetaCjaAccDeviceMomentsTests, - ::testing::ValuesIn(devices), ::DeviceNamePrint()); + Philox4x32x10BetaCjaAccDeviceMomentsTests, ::testing::ValuesIn(devices), + ::DeviceNamePrint()); class Philox4x32x10GammaMarsagliaDeviceMomentsTests : public ::testing::TestWithParam {}; @@ -1158,20 +1148,19 @@ class Philox4x32x10GammaMarsagliaAccDeviceMomentsTests : public ::testing::TestWithParam {}; TEST_P(Philox4x32x10GammaMarsagliaDeviceMomentsTests, RealSinglePrecision) { - - rng_device_test, - oneapi::mkl::rng::device::gamma< - float, oneapi::mkl::rng::device::gamma_method::marsaglia>>> + rng_device_test, + oneapi::mkl::rng::device::gamma>> test1; EXPECT_TRUEORSKIP((test1(GetParam()))); - rng_device_test, - oneapi::mkl::rng::device::gamma< - float, oneapi::mkl::rng::device::gamma_method::marsaglia>>> + rng_device_test, + oneapi::mkl::rng::device::gamma>> test2; EXPECT_TRUEORSKIP((test2(GetParam()))); - rng_device_test, - oneapi::mkl::rng::device::gamma< - float, oneapi::mkl::rng::device::gamma_method::marsaglia>>> + rng_device_test, + oneapi::mkl::rng::device::gamma>> test3; EXPECT_TRUEORSKIP((test3(GetParam()))); } @@ -1179,25 +1168,24 @@ TEST_P(Philox4x32x10GammaMarsagliaDeviceMomentsTests, RealSinglePrecision) { TEST_P(Philox4x32x10GammaMarsagliaDeviceMomentsTests, RealDoublePrecision) { CHECK_DOUBLE_ON_DEVICE(GetParam()); - rng_device_test, - oneapi::mkl::rng::device::gamma< - double, oneapi::mkl::rng::device::gamma_method::marsaglia>>> + rng_device_test, + oneapi::mkl::rng::device::gamma>> test1; EXPECT_TRUEORSKIP((test1(GetParam()))); - rng_device_test, - oneapi::mkl::rng::device::gamma< - double, oneapi::mkl::rng::device::gamma_method::marsaglia>>> + rng_device_test, + oneapi::mkl::rng::device::gamma>> test2; EXPECT_TRUEORSKIP((test2(GetParam()))); - rng_device_test, - oneapi::mkl::rng::device::gamma< - double, oneapi::mkl::rng::device::gamma_method::marsaglia>>> + rng_device_test, + oneapi::mkl::rng::device::gamma>> test3; EXPECT_TRUEORSKIP((test3(GetParam()))); } TEST_P(Philox4x32x10GammaMarsagliaAccDeviceMomentsTests, RealSinglePrecision) { - rng_device_test< moments_test, oneapi::mkl::rng::device::gamma< @@ -1271,19 +1259,22 @@ TEST_P(Philox4x32x10PoissonDevroyeDeviceMomentsTests, IntegerPrecision) { } TEST_P(Philox4x32x10PoissonDevroyeDeviceMomentsTests, UnsignedIntegerPrecision) { - rng_device_test, - oneapi::mkl::rng::device::poisson< - std::uint32_t, oneapi::mkl::rng::device::poisson_method::devroye>>> + rng_device_test< + moments_test, + oneapi::mkl::rng::device::poisson< + std::uint32_t, oneapi::mkl::rng::device::poisson_method::devroye>>> test1; EXPECT_TRUEORSKIP((test1(GetParam()))); - rng_device_test, - oneapi::mkl::rng::device::poisson< - std::uint32_t, oneapi::mkl::rng::device::poisson_method::devroye>>> + rng_device_test< + moments_test, + oneapi::mkl::rng::device::poisson< + std::uint32_t, oneapi::mkl::rng::device::poisson_method::devroye>>> test2; EXPECT_TRUEORSKIP((test2(GetParam()))); - rng_device_test, - oneapi::mkl::rng::device::poisson< - std::uint32_t, oneapi::mkl::rng::device::poisson_method::devroye>>> + rng_device_test< + moments_test, + oneapi::mkl::rng::device::poisson< + std::uint32_t, oneapi::mkl::rng::device::poisson_method::devroye>>> test3; EXPECT_TRUEORSKIP((test3(GetParam()))); } @@ -1296,109 +1287,127 @@ class Philox4x32x10BernoulliIcdfDeviceMomentsTests : public ::testing::TestWithParam {}; TEST_P(Philox4x32x10BernoulliIcdfDeviceMomentsTests, IntegerPrecision) { - rng_device_test, - oneapi::mkl::rng::device::bernoulli< - std::int32_t, oneapi::mkl::rng::device::bernoulli_method::icdf>>> + rng_device_test< + moments_test, + oneapi::mkl::rng::device::bernoulli< + std::int32_t, oneapi::mkl::rng::device::bernoulli_method::icdf>>> test1; EXPECT_TRUEORSKIP((test1(GetParam()))); - rng_device_test, - oneapi::mkl::rng::device::bernoulli< - std::int32_t, oneapi::mkl::rng::device::bernoulli_method::icdf>>> + rng_device_test< + moments_test, + oneapi::mkl::rng::device::bernoulli< + std::int32_t, oneapi::mkl::rng::device::bernoulli_method::icdf>>> test2; EXPECT_TRUEORSKIP((test2(GetParam()))); - rng_device_test, - oneapi::mkl::rng::device::bernoulli< - std::int32_t, oneapi::mkl::rng::device::bernoulli_method::icdf>>> + rng_device_test< + moments_test, + oneapi::mkl::rng::device::bernoulli< + std::int32_t, oneapi::mkl::rng::device::bernoulli_method::icdf>>> test3; EXPECT_TRUEORSKIP((test3(GetParam()))); } TEST_P(Philox4x32x10BernoulliIcdfDeviceMomentsTests, UnsignedIntegerPrecision) { - rng_device_test, - oneapi::mkl::rng::device::bernoulli< - std::uint32_t, oneapi::mkl::rng::device::bernoulli_method::icdf>>> + rng_device_test< + moments_test, + oneapi::mkl::rng::device::bernoulli< + std::uint32_t, oneapi::mkl::rng::device::bernoulli_method::icdf>>> test1; EXPECT_TRUEORSKIP((test1(GetParam()))); - rng_device_test, - oneapi::mkl::rng::device::bernoulli< - std::uint32_t, oneapi::mkl::rng::device::bernoulli_method::icdf>>> + rng_device_test< + moments_test, + oneapi::mkl::rng::device::bernoulli< + std::uint32_t, oneapi::mkl::rng::device::bernoulli_method::icdf>>> test2; EXPECT_TRUEORSKIP((test2(GetParam()))); - rng_device_test, - oneapi::mkl::rng::device::bernoulli< - std::uint32_t, oneapi::mkl::rng::device::bernoulli_method::icdf>>> + rng_device_test< + moments_test, + oneapi::mkl::rng::device::bernoulli< + std::uint32_t, oneapi::mkl::rng::device::bernoulli_method::icdf>>> test3; EXPECT_TRUEORSKIP((test3(GetParam()))); } TEST_P(Philox4x32x10BernoulliIcdfDeviceMomentsTests, Integer8Precision) { - rng_device_test, - oneapi::mkl::rng::device::bernoulli< - std::int8_t, oneapi::mkl::rng::device::bernoulli_method::icdf>>> + rng_device_test< + moments_test, + oneapi::mkl::rng::device::bernoulli< + std::int8_t, oneapi::mkl::rng::device::bernoulli_method::icdf>>> test1; EXPECT_TRUEORSKIP((test1(GetParam()))); - rng_device_test, - oneapi::mkl::rng::device::bernoulli< - std::int8_t, oneapi::mkl::rng::device::bernoulli_method::icdf>>> + rng_device_test< + moments_test, + oneapi::mkl::rng::device::bernoulli< + std::int8_t, oneapi::mkl::rng::device::bernoulli_method::icdf>>> test2; EXPECT_TRUEORSKIP((test2(GetParam()))); - rng_device_test, - oneapi::mkl::rng::device::bernoulli< - std::int8_t, oneapi::mkl::rng::device::bernoulli_method::icdf>>> + rng_device_test< + moments_test, + oneapi::mkl::rng::device::bernoulli< + std::int8_t, oneapi::mkl::rng::device::bernoulli_method::icdf>>> test3; EXPECT_TRUEORSKIP((test3(GetParam()))); } TEST_P(Philox4x32x10BernoulliIcdfDeviceMomentsTests, UnsignedInteger8Precision) { - rng_device_test, - oneapi::mkl::rng::device::bernoulli< - std::uint8_t, oneapi::mkl::rng::device::bernoulli_method::icdf>>> + rng_device_test< + moments_test, + oneapi::mkl::rng::device::bernoulli< + std::uint8_t, oneapi::mkl::rng::device::bernoulli_method::icdf>>> test1; EXPECT_TRUEORSKIP((test1(GetParam()))); - rng_device_test, - oneapi::mkl::rng::device::bernoulli< - std::uint8_t, oneapi::mkl::rng::device::bernoulli_method::icdf>>> + rng_device_test< + moments_test, + oneapi::mkl::rng::device::bernoulli< + std::uint8_t, oneapi::mkl::rng::device::bernoulli_method::icdf>>> test2; EXPECT_TRUEORSKIP((test2(GetParam()))); - rng_device_test, - oneapi::mkl::rng::device::bernoulli< - std::uint8_t, oneapi::mkl::rng::device::bernoulli_method::icdf>>> + rng_device_test< + moments_test, + oneapi::mkl::rng::device::bernoulli< + std::uint8_t, oneapi::mkl::rng::device::bernoulli_method::icdf>>> test3; EXPECT_TRUEORSKIP((test3(GetParam()))); } TEST_P(Philox4x32x10BernoulliIcdfDeviceMomentsTests, Integer16Precision) { - rng_device_test, - oneapi::mkl::rng::device::bernoulli< - std::int16_t, oneapi::mkl::rng::device::bernoulli_method::icdf>>> + rng_device_test< + moments_test, + oneapi::mkl::rng::device::bernoulli< + std::int16_t, oneapi::mkl::rng::device::bernoulli_method::icdf>>> test1; EXPECT_TRUEORSKIP((test1(GetParam()))); - rng_device_test, - oneapi::mkl::rng::device::bernoulli< - std::int16_t, oneapi::mkl::rng::device::bernoulli_method::icdf>>> + rng_device_test< + moments_test, + oneapi::mkl::rng::device::bernoulli< + std::int16_t, oneapi::mkl::rng::device::bernoulli_method::icdf>>> test2; EXPECT_TRUEORSKIP((test2(GetParam()))); - rng_device_test, - oneapi::mkl::rng::device::bernoulli< - std::int16_t, oneapi::mkl::rng::device::bernoulli_method::icdf>>> + rng_device_test< + moments_test, + oneapi::mkl::rng::device::bernoulli< + std::int16_t, oneapi::mkl::rng::device::bernoulli_method::icdf>>> test3; EXPECT_TRUEORSKIP((test3(GetParam()))); } TEST_P(Philox4x32x10BernoulliIcdfDeviceMomentsTests, UnsignedInteger16Precision) { - rng_device_test, - oneapi::mkl::rng::device::bernoulli< - std::uint16_t, oneapi::mkl::rng::device::bernoulli_method::icdf>>> + rng_device_test< + moments_test, + oneapi::mkl::rng::device::bernoulli< + std::uint16_t, oneapi::mkl::rng::device::bernoulli_method::icdf>>> test1; EXPECT_TRUEORSKIP((test1(GetParam()))); - rng_device_test, - oneapi::mkl::rng::device::bernoulli< - std::uint16_t, oneapi::mkl::rng::device::bernoulli_method::icdf>>> + rng_device_test< + moments_test, + oneapi::mkl::rng::device::bernoulli< + std::uint16_t, oneapi::mkl::rng::device::bernoulli_method::icdf>>> test2; EXPECT_TRUEORSKIP((test2(GetParam()))); - rng_device_test, - oneapi::mkl::rng::device::bernoulli< - std::uint16_t, oneapi::mkl::rng::device::bernoulli_method::icdf>>> + rng_device_test< + moments_test, + oneapi::mkl::rng::device::bernoulli< + std::uint16_t, oneapi::mkl::rng::device::bernoulli_method::icdf>>> test3; EXPECT_TRUEORSKIP((test3(GetParam()))); } From 7fbf5969ea6f496bd355670f0085f8bb81bc5ce6 Mon Sep 17 00:00:00 2001 From: "Fedorov, Andrey" Date: Tue, 1 Oct 2024 03:31:05 -0700 Subject: [PATCH 14/14] return sycl for fmax back --- .../mkl/rng/device/detail/uniform_impl.hpp | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/include/oneapi/mkl/rng/device/detail/uniform_impl.hpp b/include/oneapi/mkl/rng/device/detail/uniform_impl.hpp index 2427a6866..ec50eb8fc 100644 --- a/include/oneapi/mkl/rng/device/detail/uniform_impl.hpp +++ b/include/oneapi/mkl/rng/device/detail/uniform_impl.hpp @@ -221,13 +221,13 @@ class distribution_base> { #else // a workaround for hipSYCL (AdaptiveCpp) if constexpr (EngineType::vec_size == 1) { - res = std::fmax(res, a_); - res = std::fmin(res, b_); + res = sycl::fmax(res, a_); + res = sycl::fmin(res, b_); } else { for (int i = 0; i < EngineType::vec_size; i++) { - res[i] = std::fmax(res[i], a_); - res[i] = std::fmin(res[i], b_); + res[i] = sycl::fmax(res[i], a_); + res[i] = sycl::fmin(res[i], b_); } } #endif @@ -298,13 +298,13 @@ class distribution_base> { #else // a workaround for hipSYCL (AdaptiveCpp) if constexpr (EngineType::vec_size == 1) { - res = std::fmax(res, a_); - res = std::fmin(res, b_); + res = sycl::fmax(res, a_); + res = sycl::fmin(res, b_); } else { for (int i = 0; i < EngineType::vec_size; i++) { - res[i] = std::fmax(res[i], a_); - res[i] = std::fmin(res[i], b_); + res[i] = sycl::fmax(res[i], a_); + res[i] = sycl::fmin(res[i], b_); } } #endif