From bee0243676e4dccc7ca06be320ed1ef47d585bed Mon Sep 17 00:00:00 2001 From: "Kruglov, Oleg" Date: Tue, 31 Oct 2023 14:41:17 -0700 Subject: [PATCH] Add user_cpu_context and ability to provide host_policy via it --- cpp/oneapi/dal/compute.hpp | 6 ++++ cpp/oneapi/dal/detail/policy.hpp | 1 + cpp/oneapi/dal/detail/user_policy.cpp | 47 +++++++++++++++++++++++++++ cpp/oneapi/dal/detail/user_policy.hpp | 32 ++++++++++++++++++ cpp/oneapi/dal/train.hpp | 6 ++++ 5 files changed, 92 insertions(+) create mode 100644 cpp/oneapi/dal/detail/user_policy.cpp create mode 100644 cpp/oneapi/dal/detail/user_policy.hpp diff --git a/cpp/oneapi/dal/compute.hpp b/cpp/oneapi/dal/compute.hpp index f793fa939a0..a5e3618471d 100644 --- a/cpp/oneapi/dal/compute.hpp +++ b/cpp/oneapi/dal/compute.hpp @@ -17,6 +17,7 @@ #pragma once #include "oneapi/dal/detail/compute_ops.hpp" +#include "oneapi/dal/detail/user_policy.hpp" #include "oneapi/dal/detail/spmd_policy.hpp" #include "oneapi/dal/spmd/communicator.hpp" @@ -28,6 +29,11 @@ auto compute(Args&&... args) { return dal::detail::compute_dispatch(std::forward(args)...); } +template +auto compute(detail::user_cpu_context uctx, Args&&... args) { + return dal::detail::compute_dispatch(uctx.get_host_policy(), std::forward(args)...); +} + #ifdef ONEDAL_DATA_PARALLEL template auto compute(sycl::queue& queue, Args&&... args) { diff --git a/cpp/oneapi/dal/detail/policy.hpp b/cpp/oneapi/dal/detail/policy.hpp index aacce445263..e9fa0fa4e62 100644 --- a/cpp/oneapi/dal/detail/policy.hpp +++ b/cpp/oneapi/dal/detail/policy.hpp @@ -103,6 +103,7 @@ class ONEDAL_EXPORT host_policy : public base { } host_policy(const host_policy&) = default; host_policy(host_policy&&) = default; + host_policy& operator=(const host_policy&) = default; static host_policy get_default() { return host_policy(make_default_impl()); diff --git a/cpp/oneapi/dal/detail/user_policy.cpp b/cpp/oneapi/dal/detail/user_policy.cpp new file mode 100644 index 00000000000..f8d3ee819dc --- /dev/null +++ b/cpp/oneapi/dal/detail/user_policy.cpp @@ -0,0 +1,47 @@ +/******************************************************************************* +* Copyright 2023 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ +#include "oneapi/dal/detail/user_policy.hpp" + +namespace oneapi::dal::detail { +class user_cpu_context_impl { +public: + user_cpu_context_impl() : policy_(host_policy::get_default()) {} + user_cpu_context_impl(const host_policy& policy) : policy_(policy) {} + void set_host_policy(const host_policy& policy) { + policy_ = policy; + } + host_policy get_host_policy() { + return policy_; + } + +private: + detail::host_policy policy_; +}; + +user_cpu_context::user_cpu_context() : impl_(new user_cpu_context_impl()) {} + +user_cpu_context::user_cpu_context(const host_policy& policy) + : impl_(new user_cpu_context_impl(policy)) {} + +void user_cpu_context::set_host_policy(const host_policy& policy) { + impl_->set_host_policy(policy); +} + +host_policy user_cpu_context::get_host_policy() { + return impl_->get_host_policy(); +} + +} // namespace oneapi::dal::detail \ No newline at end of file diff --git a/cpp/oneapi/dal/detail/user_policy.hpp b/cpp/oneapi/dal/detail/user_policy.hpp new file mode 100644 index 00000000000..e8cb09e31eb --- /dev/null +++ b/cpp/oneapi/dal/detail/user_policy.hpp @@ -0,0 +1,32 @@ +/******************************************************************************* +* Copyright 2023 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ +#include "oneapi/dal/detail/policy.hpp" + +namespace oneapi::dal::detail { +class user_cpu_context_impl; + +class user_cpu_context { +public: + user_cpu_context(); + user_cpu_context(const host_policy& policy); + host_policy get_host_policy(); + void set_host_policy(const host_policy& policy); + +private: + pimpl impl_; +}; + +} //namespace oneapi::dal::detail \ No newline at end of file diff --git a/cpp/oneapi/dal/train.hpp b/cpp/oneapi/dal/train.hpp index 8d33c25af84..2e4731387fc 100644 --- a/cpp/oneapi/dal/train.hpp +++ b/cpp/oneapi/dal/train.hpp @@ -17,6 +17,7 @@ #pragma once #include "oneapi/dal/detail/train_ops.hpp" +#include "oneapi/dal/detail/user_policy.hpp" #include "oneapi/dal/detail/spmd_policy.hpp" #include "oneapi/dal/spmd/communicator.hpp" @@ -28,6 +29,11 @@ auto train(Args&&... args) { return dal::detail::train_dispatch(std::forward(args)...); } +template +auto train(detail::user_cpu_context uctx, Args&&... args) { + return dal::detail::train_dispatch(uctx.get_host_policy(), std::forward(args)...); +} + #ifdef ONEDAL_DATA_PARALLEL template auto train(sycl::queue& queue, Args&&... args) {