Skip to content

Commit

Permalink
Add user_cpu_context and ability to provide host_policy via it
Browse files Browse the repository at this point in the history
  • Loading branch information
olegkkruglov committed Nov 2, 2023
1 parent 4eb47a5 commit bee0243
Show file tree
Hide file tree
Showing 5 changed files with 92 additions and 0 deletions.
6 changes: 6 additions & 0 deletions cpp/oneapi/dal/compute.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
#pragma once

#include "oneapi/dal/detail/compute_ops.hpp"
#include "oneapi/dal/detail/user_policy.hpp"
#include "oneapi/dal/detail/spmd_policy.hpp"
#include "oneapi/dal/spmd/communicator.hpp"

Expand All @@ -28,6 +29,11 @@ auto compute(Args&&... args) {
return dal::detail::compute_dispatch(std::forward<Args>(args)...);
}

template <typename... Args>
auto compute(detail::user_cpu_context uctx, Args&&... args) {
return dal::detail::compute_dispatch(uctx.get_host_policy(), std::forward<Args>(args)...);
}

#ifdef ONEDAL_DATA_PARALLEL
template <typename... Args>
auto compute(sycl::queue& queue, Args&&... args) {
Expand Down
1 change: 1 addition & 0 deletions cpp/oneapi/dal/detail/policy.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,7 @@ class ONEDAL_EXPORT host_policy : public base {
}
host_policy(const host_policy&) = default;
host_policy(host_policy&&) = default;
host_policy& operator=(const host_policy&) = default;

static host_policy get_default() {
return host_policy(make_default_impl());
Expand Down
47 changes: 47 additions & 0 deletions cpp/oneapi/dal/detail/user_policy.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
/*******************************************************************************
* Copyright 2023 Intel Corporation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*******************************************************************************/
#include "oneapi/dal/detail/user_policy.hpp"

namespace oneapi::dal::detail {
class user_cpu_context_impl {
public:
user_cpu_context_impl() : policy_(host_policy::get_default()) {}
user_cpu_context_impl(const host_policy& policy) : policy_(policy) {}
void set_host_policy(const host_policy& policy) {
policy_ = policy;
}
host_policy get_host_policy() {
return policy_;
}

private:
detail::host_policy policy_;
};

user_cpu_context::user_cpu_context() : impl_(new user_cpu_context_impl()) {}

user_cpu_context::user_cpu_context(const host_policy& policy)
: impl_(new user_cpu_context_impl(policy)) {}

void user_cpu_context::set_host_policy(const host_policy& policy) {
impl_->set_host_policy(policy);
}

host_policy user_cpu_context::get_host_policy() {
return impl_->get_host_policy();
}

} // namespace oneapi::dal::detail
32 changes: 32 additions & 0 deletions cpp/oneapi/dal/detail/user_policy.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
/*******************************************************************************
* Copyright 2023 Intel Corporation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*******************************************************************************/
#include "oneapi/dal/detail/policy.hpp"

namespace oneapi::dal::detail {
class user_cpu_context_impl;

class user_cpu_context {
public:
user_cpu_context();
user_cpu_context(const host_policy& policy);
host_policy get_host_policy();
void set_host_policy(const host_policy& policy);

private:
pimpl<user_cpu_context_impl> impl_;
};

} //namespace oneapi::dal::detail
6 changes: 6 additions & 0 deletions cpp/oneapi/dal/train.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
#pragma once

#include "oneapi/dal/detail/train_ops.hpp"
#include "oneapi/dal/detail/user_policy.hpp"
#include "oneapi/dal/detail/spmd_policy.hpp"
#include "oneapi/dal/spmd/communicator.hpp"

Expand All @@ -28,6 +29,11 @@ auto train(Args&&... args) {
return dal::detail::train_dispatch(std::forward<Args>(args)...);
}

template <typename... Args>
auto train(detail::user_cpu_context uctx, Args&&... args) {
return dal::detail::train_dispatch(uctx.get_host_policy(), std::forward<Args>(args)...);
}

#ifdef ONEDAL_DATA_PARALLEL
template <typename... Args>
auto train(sycl::queue& queue, Args&&... args) {
Expand Down

0 comments on commit bee0243

Please sign in to comment.