Skip to content

Commit

Permalink
changed tests and make changes for sources
Browse files Browse the repository at this point in the history
  • Loading branch information
andreyfe1 committed Sep 25, 2024
1 parent 731bfb1 commit 71dddf3
Show file tree
Hide file tree
Showing 6 changed files with 270 additions and 54 deletions.
20 changes: 12 additions & 8 deletions include/oneapi/mkl/rng/device/detail/beta_impl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -393,10 +393,12 @@ class distribution_base<oneapi::mkl::rng::device::beta<RealType, Method>> {
res = acc_rej_kernel<EngineType::vec_size>(res, engine);
}
if constexpr (std::is_same_v<Method, beta_method::cja_accurate>) {
if (res < a_)
res = a_;
if (res > a_ + b_)
res = a_ + b_;
for(std::int32_t i = 0; i < EngineType::vec_size; i++) {
if (res[i] < a_)
res[i] = a_;
if (res[i] > a_ + b_)
res[i] = a_ + b_;
}
}
return res;
}
Expand All @@ -414,10 +416,12 @@ class distribution_base<oneapi::mkl::rng::device::beta<RealType, Method>> {
res = acc_rej_kernel<1>(z, engine);
}
if constexpr (std::is_same_v<Method, beta_method::cja_accurate>) {
if (res < a_)
res = a_;
if (res > a_ + b_)
res = a_ + b_;
for(std::int32_t i = 0; i < EngineType::vec_size; i++) {
if (res[i] < a_)
res[i] = a_;
if (res[i] > a_ + b_)
res[i] = a_ + b_;
}
}
return res;
}
Expand Down
6 changes: 4 additions & 2 deletions include/oneapi/mkl/rng/device/detail/gamma_impl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -237,8 +237,10 @@ class distribution_base<oneapi::mkl::rng::device::gamma<RealType, Method>> {
}
auto res = a_ + beta_ * z;
if constexpr (std::is_same_v<Method, gamma_method::marsaglia_accurate>) {
if (res < a_)
res = a_;
for(std::int32_t i = 0; i < EngineType::vec_size; i++) {
if (res[i] < a_)
res[i] = a_;
}
}
return res;
}
Expand Down
25 changes: 13 additions & 12 deletions include/oneapi/mkl/rng/device/detail/uniform_impl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
#define _MKL_RNG_DEVICE_UNIFORM_IMPL_HPP_

#include <limits>
#include "engine_base.hpp"

namespace oneapi::mkl::rng::device::detail {

Expand All @@ -41,13 +42,13 @@ static inline std::uint64_t umul_hi_64(const std::uint64_t a, const std::uint64_
}

template <typename EngineType, typename Generator>
static inline void generate_leftover(std::uint64_t range, Generator generate,
static inline void generate_leftover(std::uint64_t range, Generator generate,
std::uint64_t& res_64, std::uint64_t& leftover) {
if constexpr (std::is_same_v<EngineType, mcg31m1<EngineType::vec_size>>) {
std::uint32_t res_1 = generate();
std::uint32_t res_2 = generate();
std::uint32_t res_3 = generate();
res_64 = (static_cast<std::uint64_t>(res_3) << 62) +
res_64 = (static_cast<std::uint64_t>(res_3) << 62) +
(static_cast<std::uint64_t>(res_2) << 31) + res_1;
}
else {
Expand Down Expand Up @@ -125,7 +126,7 @@ class distribution_base<oneapi::mkl::rng::device::uniform<Type, Method>> {
else {
// Lemire's sample rejection method to exclude bias for uniform numbers
// https://arxiv.org/abs/1805.10941

constexpr std::uint64_t uint_max64 = std::numeric_limits<std::uint64_t>::max();
constexpr std::uint64_t uint_max32 = std::numeric_limits<std::uint32_t>::max();

Expand All @@ -139,14 +140,14 @@ class distribution_base<oneapi::mkl::rng::device::uniform<Type, Method>> {
std::uint32_t res_1, res_2;
std::uint64_t res_64, leftover;

generate_leftover<EngineType>(range, [&engine](){return engine.generate();},
generate_leftover<EngineType>(range, [&engine](){return engine.generate();},
res_64, leftover);

if (range == uint_max64)
return res_64;

while (leftover < threshold) {
generate_leftover<EngineType>(range, [&engine](){return engine.generate();},
generate_leftover<EngineType>(range, [&engine](){return engine.generate();},
res_64, leftover);
}

Expand All @@ -160,12 +161,12 @@ class distribution_base<oneapi::mkl::rng::device::uniform<Type, Method>> {
sycl::vec<std::uint32_t, EngineType::vec_size> res_1 = engine.generate();
sycl::vec<std::uint32_t, EngineType::vec_size> res_2 = engine.generate();
sycl::vec<std::uint64_t, EngineType::vec_size> res_64;

if constexpr (std::is_same_v<EngineType, mcg31m1<EngineType::vec_size>>) {
sycl::vec<std::uint32_t, EngineType::vec_size> res_3 = engine.generate();

for (int i = 0; i < EngineType::vec_size; i++) {
res_64[i] = (static_cast<std::uint64_t>(res_3[i]) << 62) +
res_64[i] = (static_cast<std::uint64_t>(res_3[i]) << 62) +
(static_cast<std::uint64_t>(res_2[i]) << 31) + res_1[i];
}
}
Expand All @@ -186,15 +187,15 @@ class distribution_base<oneapi::mkl::rng::device::uniform<Type, Method>> {
}
}
}

if (range == uint_max64)
return res_64.template convert<Type>();

for (int i = 0; i < EngineType::vec_size; i++) {
leftover = res_64[i] * range;

while (leftover < threshold) {
generate_leftover<EngineType>(range, [&engine](){return engine.generate_single();},
generate_leftover<EngineType>(range, [&engine](){return engine.generate_single();},
res_64[i], leftover);
}

Expand Down Expand Up @@ -233,7 +234,7 @@ class distribution_base<oneapi::mkl::rng::device::uniform<Type, Method>> {
else {
// Lemire's sample rejection method to exclude bias for uniform numbers
// https://arxiv.org/abs/1805.10941

constexpr std::uint64_t uint_max64 = std::numeric_limits<std::uint64_t>::max();
constexpr std::uint64_t uint_max32 = std::numeric_limits<std::uint32_t>::max();

Expand All @@ -251,14 +252,14 @@ class distribution_base<oneapi::mkl::rng::device::uniform<Type, Method>> {
std::uint32_t res_1, res_2;
std::uint64_t res_64, leftover;

generate_leftover<EngineType>(range, [&engine](){return engine.generate_single();},
generate_leftover<EngineType>(range, [&engine](){return engine.generate_single();},
res_64, leftover);

if (range == uint_max64)
return res_64;

while (leftover < threshold) {
generate_leftover<EngineType>(range, [&engine](){return engine.generate_single();},
generate_leftover<EngineType>(range, [&engine](){return engine.generate_single();},
res_64, leftover);
}

Expand Down
43 changes: 11 additions & 32 deletions include/oneapi/mkl/rng/distributions.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -61,17 +61,24 @@ template <typename Type = float, typename Method = uniform_method::by_default>
class uniform {
public:
static_assert(std::is_same<Method, uniform_method::standard>::value ||
(std::is_same<Method, uniform_method::accurate>::value &&
!std::is_same<Type, std::int32_t>::value),
std::is_same<Method, uniform_method::accurate>::value,
"rng uniform distribution method is incorrect");

static_assert(std::is_same<Type, float>::value || std::is_same<Type, double>::value,
static_assert(std::is_same<Type, float>::value || std::is_same<Type, double>::value ||
std::is_same<Type, std::int32_t>::value ||
std::is_same<Type, std::uint32_t>::value,
"rng uniform distribution type is not supported");

using method_type = Method;
using result_type = Type;

uniform() : uniform(static_cast<Type>(0.0f), static_cast<Type>(1.0f)) {}
uniform()
: uniform(static_cast<Type>(0.0f),
std::is_integral<Type>::value
? (std::is_same<Method, uniform_method::standard>::value
? (1 << 23)
: (std::numeric_limits<Type>::max)())
: static_cast<Type>(1.0f)) {}

explicit uniform(Type a, Type b) : a_(a), b_(b) {
if (a >= b) {
Expand All @@ -93,34 +100,6 @@ class uniform {
Type b_;
};

template <typename Method>
class uniform<std::int32_t, Method> {
public:
using method_type = Method;
using result_type = std::int32_t;

uniform() : uniform(0, std::numeric_limits<std::int32_t>::max()) {}

explicit uniform(std::int32_t a, std::int32_t b) : a_(a), b_(b) {
if (a >= b) {
throw oneapi::mkl::invalid_argument("rng", "uniform",
"parameters are incorrect, a >= b");
}
}

std::int32_t a() const {
return a_;
}

std::int32_t b() const {
return b_;
}

private:
std::int32_t a_;
std::int32_t b_;
};

// Class template oneapi::mkl::rng::gaussian
//
// Represents continuous normal random number distribution
Expand Down
46 changes: 46 additions & 0 deletions tests/unit_tests/rng/device/include/rng_device_test_common.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -315,6 +315,52 @@ struct statistics_device<oneapi::mkl::rng::device::bernoulli<Fp, Method>> {
}
};

template <typename Fp, typename Method>
struct statistics_device<oneapi::mkl::rng::device::beta<Fp, Method>> {
template <typename AllocType>
bool check(const std::vector<Fp, AllocType>& r,
const oneapi::mkl::rng::device::beta<Fp, Method>& distr) {
double tM, tD, tQ;
double b, c, d, e, e2, b2, sum_pq;
Fp p = distr.p();
Fp q = distr.q();
Fp a = distr.a();
Fp beta = distr.b();

b2 = beta * beta;
sum_pq = p + q;
b = (p + 1.0) / (sum_pq + 1.0);
c = (p + 2.0) / (sum_pq + 2.0);
d = (p + 3.0) / (sum_pq + 3.0);
e = p / sum_pq;
e2 = e * e;

tM = a + e * beta;
tD = b2 * p * q / (sum_pq * sum_pq * (sum_pq + 1.0));
tQ = b2 * b2 * (e * b * c * d - 4.0 * e2 * b * c + 6.0 * e2 * e * b - 3.0 * e2 * e2);

return compare_moments(r, tM, tD, tQ);
}
};

template <typename Fp, typename Method>
struct statistics_device<oneapi::mkl::rng::device::gamma<Fp, Method>> {
template <typename AllocType>
bool check(const std::vector<Fp, AllocType>& r,
const oneapi::mkl::rng::device::gamma<Fp, Method>& distr) {
double tM, tD, tQ;
Fp a = distr.a();
Fp alpha = distr.alpha();
Fp beta = distr.beta();

tM = a + beta * alpha;
tD = beta * beta * alpha;
tQ = beta * beta * beta * beta * 3 * alpha * (alpha + 2);

return compare_moments(r, tM, tD, tQ);
}
};

template <typename Fp>
struct statistics_device<oneapi::mkl::rng::device::bits<Fp>> {
template <typename AllocType>
Expand Down
Loading

0 comments on commit 71dddf3

Please sign in to comment.