Skip to content

Commit

Permalink
[RNG] Updated Device API (#574)
Browse files Browse the repository at this point in the history
  • Loading branch information
andreyfe1 authored Oct 1, 2024
1 parent afb9d5c commit b2324f1
Show file tree
Hide file tree
Showing 12 changed files with 1,692 additions and 73 deletions.
468 changes: 468 additions & 0 deletions include/oneapi/mkl/rng/device/detail/beta_impl.hpp

Large diffs are not rendered by default.

8 changes: 8 additions & 0 deletions include/oneapi/mkl/rng/device/detail/distribution_base.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,12 @@ class bits;
template <typename RealType = float, typename Method = exponential_method::by_default>
class exponential;

template <typename RealType = float, typename Method = beta_method::by_default>
class beta;

template <typename RealType = float, typename Method = gamma_method::by_default>
class gamma;

template <typename IntType = std::int32_t, typename Method = poisson_method::by_default>
class poisson;

Expand All @@ -69,5 +75,7 @@ class bernoulli;
#include "oneapi/mkl/rng/device/detail/exponential_impl.hpp"
#include "oneapi/mkl/rng/device/detail/poisson_impl.hpp"
#include "oneapi/mkl/rng/device/detail/bernoulli_impl.hpp"
#include "oneapi/mkl/rng/device/detail/beta_impl.hpp"
#include "oneapi/mkl/rng/device/detail/gamma_impl.hpp"

#endif // _MKL_RNG_DISTRIBUTION_BASE_HPP_
38 changes: 33 additions & 5 deletions include/oneapi/mkl/rng/device/detail/exponential_impl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -68,10 +68,7 @@ class distribution_base<oneapi::mkl::rng::device::exponential<RealType, Method>>
auto generate(EngineType& engine) ->
typename std::conditional<EngineType::vec_size == 1, RealType,
sycl::vec<RealType, EngineType::vec_size>>::type {
using OutType = typename std::conditional<EngineType::vec_size == 1, RealType,
sycl::vec<RealType, EngineType::vec_size>>::type;

OutType res = engine.generate(RealType(0), RealType(1));
auto res = engine.generate(RealType(0), RealType(1));
if constexpr (EngineType::vec_size == 1) {
res = ln_wrapper(res);
}
Expand All @@ -82,7 +79,19 @@ class distribution_base<oneapi::mkl::rng::device::exponential<RealType, Method>>
}
res = a_ - res * beta_;
if constexpr (std::is_same<Method, exponential_method::icdf_accurate>::value) {
res = sycl::fmax(res, OutType{ a_ });
#ifndef __HIPSYCL__
res = sycl::fmax(res, a_);
#else
// a workaround for hipSYCL (AdaptiveCpp)
if constexpr (EngineType::vec_size == 1) {
res = std::fmax(res, a_);
}
else {
for (int i = 0; i < EngineType::vec_size; i++) {
res[i] = std::fmax(res[i], a_);
}
}
#endif
}
return res;
}
Expand All @@ -93,7 +102,19 @@ class distribution_base<oneapi::mkl::rng::device::exponential<RealType, Method>>
res = ln_wrapper(res);
res = a_ - res * beta_;
if constexpr (std::is_same<Method, exponential_method::icdf_accurate>::value) {
#ifndef __HIPSYCL__
res = sycl::fmax(res, a_);
#else
// a workaround for hipSYCL (AdaptiveCpp)
if constexpr (EngineType::vec_size == 1) {
res = std::fmax(res, a_);
}
else {
for (int i = 0; i < EngineType::vec_size; i++) {
res[i] = std::fmax(res[i], a_);
}
}
#endif
}
return res;
}
Expand All @@ -105,6 +126,13 @@ class distribution_base<oneapi::mkl::rng::device::exponential<RealType, Method>>
oneapi::mkl::rng::device::poisson<std::int32_t, poisson_method::devroye>>;
friend class distribution_base<
oneapi::mkl::rng::device::poisson<std::uint32_t, poisson_method::devroye>>;
friend class distribution_base<oneapi::mkl::rng::device::gamma<float, gamma_method::marsaglia>>;
friend class distribution_base<
oneapi::mkl::rng::device::gamma<double, gamma_method::marsaglia>>;
friend class distribution_base<
oneapi::mkl::rng::device::gamma<float, gamma_method::marsaglia_accurate>>;
friend class distribution_base<
oneapi::mkl::rng::device::gamma<double, gamma_method::marsaglia_accurate>>;
};

} // namespace oneapi::mkl::rng::device::detail
Expand Down
287 changes: 287 additions & 0 deletions include/oneapi/mkl/rng/device/detail/gamma_impl.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,287 @@
/*******************************************************************************
* Copyright 2024 Intel Corporation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions
* and limitations under the License.
*
*
* SPDX-License-Identifier: Apache-2.0
*******************************************************************************/

#ifndef _MKL_RNG_DEVICE_GAMMA_IMPL_HPP_
#define _MKL_RNG_DEVICE_GAMMA_IMPL_HPP_

#include "vm_wrappers.hpp"

namespace oneapi::mkl::rng::device::detail {

enum class gamma_algorithm { Exponential = 0, Vaduva, EPD_Transform, Marsaglia };

// 1/3
template <typename DataType>
inline DataType gamma_c1() {
if constexpr (std::is_same_v<DataType, double>)
return 0x1.5555555555555p-2;
else
return 0x1.555556p-2f;
}

// 0.0331
template <typename DataType>
inline DataType gamma_c2() {
if constexpr (std::is_same_v<DataType, double>)
return 0x1.0f27bb2fec56dp-5;
else
return 0x1.0f27bcp-5f;
}

// 0.6
template <typename DataType>
inline DataType gamma_c06() {
if constexpr (std::is_same_v<DataType, double>)
return 0x1.3333333333333p-1;
else
return 0x1.333334p-1f;
}

template <typename RealType, typename Method>
class distribution_base<oneapi::mkl::rng::device::gamma<RealType, Method>> {
public:
struct param_type {
param_type(RealType alpha, RealType a, RealType beta) : alpha_(alpha), a_(a), beta_(beta) {}
RealType alpha_;
RealType a_;
RealType beta_;
};

distribution_base(RealType alpha, RealType a, RealType beta)
: alpha_(alpha),
a_(a),
beta_(beta),
count_(0) {
set_algorithm();
#ifndef __SYCL_DEVICE_ONLY__
if (alpha <= RealType(0.0)) {
throw oneapi::mkl::invalid_argument("rng", "gamma", "alpha <= 0");
}
else if (beta <= RealType(0.0)) {
throw oneapi::mkl::invalid_argument("rng", "gamma", "beta <= 0");
}
#endif
}

RealType alpha() const {
return alpha_;
}

RealType a() const {
return a_;
}

RealType beta() const {
return beta_;
}

std::size_t count_rejected_numbers() const {
return count_;
}

param_type param() const {
return param_type(alpha_, a_, beta_);
}

void param(const param_type& pt) {
#ifndef __SYCL_DEVICE_ONLY__
if (pt.alpha_ <= RealType(0.0)) {
throw oneapi::mkl::invalid_argument("rng", "gamma", "alpha <= 0");
}
else if (pt.beta_ <= RealType(0.0)) {
throw oneapi::mkl::invalid_argument("rng", "gamma", "beta <= 0");
}
#endif
alpha_ = pt.alpha_;
a_ = pt.a_;
beta_ = pt.beta_;
set_algorithm();
}

protected:
void set_algorithm() {
if (alpha_ <= RealType(1.0)) {
if (alpha_ == RealType(1.0)) {
algorithm_ = gamma_algorithm::Exponential;
}
else if (alpha_ > gamma_c06<RealType>()) {
algorithm_ = gamma_algorithm::Vaduva;
}
else {
algorithm_ = gamma_algorithm::EPD_Transform;
}
}
else {
algorithm_ = gamma_algorithm::Marsaglia;
}
}

template <typename T, int vecSize>
inline std::pair<T, T> gauss_BM2_for_Marsaglia(const sycl::vec<T, vecSize>& vec) {
T tmp, sin, cos, gauss_1, gauss_2;
tmp = ln_wrapper(vec[0]);
tmp = sqrt_wrapper(T(-2.0) * tmp);
sin = sincospi_wrapper(T(2) * vec[2], cos);
gauss_1 = (tmp * sin);
gauss_2 = (tmp * cos);
return { gauss_1, gauss_2 };
}

template <std::int32_t n, typename T, typename EngineType>
T acc_rej_kernel(T& z, EngineType& engine) {
RealType flC, flD;
if (algorithm_ == gamma_algorithm::Vaduva) {
flC = RealType(1.0) / alpha_;
flD = (RealType(1.0) - alpha_) *
exp_wrapper(ln_wrapper(alpha_) * alpha_ / (RealType(1.0) - alpha_));
}
else if (algorithm_ == gamma_algorithm::EPD_Transform) {
flC = RealType(1.0) / alpha_;
flD = (RealType(1.0) - alpha_);
}
else if (algorithm_ == gamma_algorithm::Marsaglia) {
flD = alpha_ - gamma_c1<RealType>();
flC = sqrt_wrapper(RealType(1.0) / (RealType(9.0) * alpha_ - RealType(3.0)));
}

count_ = 0;
RealType z1, z2, z3, z4;
for (int i = 0; i < n; i++) {
while (1) { // looping until satisfied
if (!flag_) {
z1 = engine.generate_single(RealType(0), RealType(1));
z2 = engine.generate_single(RealType(0), RealType(1));
}

if (algorithm_ == gamma_algorithm::Vaduva) {
z1 = -ln_wrapper(z1);
z2 = -ln_wrapper(z2);
z[i] = powr_wrapper(z1, flC);
if (z1 + z2 >= z[i] + flD) {
break;
}
}
if (algorithm_ == gamma_algorithm::EPD_Transform) {
z2 = -ln_wrapper(z2);
if (z1 <= flD) {
z[i] = powr_wrapper(z1, flC);
if (z[i] <= z2) {
break;
}
}
else {
z1 = -ln_wrapper((RealType(1.0) - z1) * flC);
z[i] = powr_wrapper(flD + alpha_ * z1, flC);
if (z[i] <= z2 + z1) {
break;
}
}
}
if (algorithm_ == gamma_algorithm::Marsaglia) {
RealType local_uniform_2, local_gauss;
if (!flag_) {
z3 = engine.generate_single(RealType(0), RealType(1));
z4 = engine.generate_single(RealType(0), RealType(1));
auto gauss =
gauss_BM2_for_Marsaglia(sycl::vec<RealType, 4>{ z1, z2, z3, z4 });
local_uniform_2 = z2;
local_gauss = gauss.first;

saved_uniform_2_ = z4;
saved_gauss_ = gauss.second;
}
else {
local_uniform_2 = saved_uniform_2_;
local_gauss = saved_gauss_;
}
flag_ = !flag_;
z[i] = RealType(1.0) + flC * local_gauss;
if (z[i] > RealType(0.0)) {
z[i] = z[i] * z[i] * z[i];
local_gauss = local_gauss * local_gauss;
if (local_uniform_2 <
RealType(1.0) - gamma_c2<RealType>() * local_gauss * local_gauss) {
z[i] = flD * z[i];
break;
}
else {
RealType local_uniform_1 = ln_wrapper(z[i]);
local_uniform_2 = ln_wrapper(local_uniform_2);
if (local_uniform_2 <
RealType(0.5) * local_gauss +
flD * (RealType(1.0) - z[i] + local_uniform_1)) {
z[i] = flD * z[i];
break;
}
}
}
}
++count_;
}
}
auto res = a_ + beta_ * z;
if constexpr (std::is_same_v<Method, gamma_method::marsaglia_accurate>) {
for (std::int32_t i = 0; i < EngineType::vec_size; i++) {
if (res[i] < a_)
res[i] = a_;
}
}
return res;
}

template <typename EngineType>
auto generate(EngineType& engine) ->
typename std::conditional<EngineType::vec_size == 1, RealType,
sycl::vec<RealType, EngineType::vec_size>>::type {
if (algorithm_ == gamma_algorithm::Exponential) {
distribution_base<oneapi::mkl::rng::device::exponential<RealType>> distr_exp(a_, beta_);
return distr_exp.generate(engine);
}
sycl::vec<RealType, EngineType::vec_size> res{};
res = acc_rej_kernel<EngineType::vec_size>(res, engine);

return res;
}

template <typename EngineType>
RealType generate_single(EngineType& engine) {
if (algorithm_ == gamma_algorithm::Exponential) {
distribution_base<oneapi::mkl::rng::device::exponential<RealType>> distr_exp(a_, beta_);
RealType z = distr_exp.generate_single(engine);
return z;
}
sycl::vec<RealType, 1> res{};
res = acc_rej_kernel<1>(res, engine);

return res[0];
}

RealType alpha_;
RealType a_;
RealType beta_;
RealType saved_gauss_;
RealType saved_uniform_2_;
bool flag_ = false;
std::size_t count_;
gamma_algorithm algorithm_;
};

} // namespace oneapi::mkl::rng::device::detail

#endif // _MKL_RNG_DEVICE_GAMMA_IMPL_HPP_
4 changes: 2 additions & 2 deletions include/oneapi/mkl/rng/device/detail/mcg31m1_impl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ class mcg31m1;

namespace detail {

template <std::int32_t VecSize>
template <std::uint64_t VecSize>
constexpr sycl::vec<std::uint64_t, VecSize> select_vector_a_mcg31m1() {
if constexpr (VecSize == 1)
return sycl::vec<std::uint64_t, 1>(UINT64_C(1));
Expand Down Expand Up @@ -56,7 +56,7 @@ constexpr sycl::vec<std::uint64_t, VecSize> select_vector_a_mcg31m1() {
// hipSYCL (AdaptiveCpp) doesn't support constexpr sycl::vec constructor
// that's why in case of hipSYCL backend sycl::vec is created as a local variable
#ifndef __HIPSYCL__
template <std::int32_t VecSize>
template <std::uint64_t VecSize>
struct mcg31m1_vector_a {
static constexpr sycl::vec<std::uint64_t, VecSize> vector_a =
select_vector_a_mcg31m1<VecSize>(); // powers of a
Expand Down
Loading

0 comments on commit b2324f1

Please sign in to comment.