From 4efe375e1435eaeb004c7b8da5682f31c81a32a2 Mon Sep 17 00:00:00 2001 From: Samir Nasibli Date: Thu, 28 Mar 2024 08:45:21 +0100 Subject: [PATCH] ENH: dtype dispatcher for Dataframe Intercahnge Protocol in Python (#2696) * ENH: dtype dispatcher for Dataframe Intercahnge Protocol in Python --------- Co-authored-by: Nikita Kulikov --- cpp/oneapi/dal/backend/dispatcher.hpp | 99 +--------------- cpp/oneapi/dal/detail/dtype_dispatcher.hpp | 128 +++++++++++++++++++++ 2 files changed, 131 insertions(+), 96 deletions(-) create mode 100644 cpp/oneapi/dal/detail/dtype_dispatcher.hpp diff --git a/cpp/oneapi/dal/backend/dispatcher.hpp b/cpp/oneapi/dal/backend/dispatcher.hpp index d9048c0811a..5ee667958e9 100644 --- a/cpp/oneapi/dal/backend/dispatcher.hpp +++ b/cpp/oneapi/dal/backend/dispatcher.hpp @@ -20,6 +20,7 @@ #include "oneapi/dal/detail/global_context.hpp" #include "oneapi/dal/detail/policy.hpp" #include "oneapi/dal/detail/spmd_policy.hpp" +#include "oneapi/dal/detail/dtype_dispatcher.hpp" #include "oneapi/dal/backend/common.hpp" #include "oneapi/dal/backend/communicator.hpp" @@ -308,101 +309,7 @@ inline constexpr auto dispatch_by_cpu(const context_cpu& ctx, Op&& op) { return op(cpu_dispatch_default{}); } -template -inline constexpr auto dispatch_by_data_type(data_type dtype, Op&& op, OnUnknown&& on_unknown) { - switch (dtype) { - case data_type::int8: return op(std::int8_t{}); - case data_type::uint8: return op(std::uint8_t{}); - case data_type::int16: return op(std::int16_t{}); - case data_type::uint16: return op(std::uint16_t{}); - case data_type::int32: return op(std::int32_t{}); - case data_type::uint32: return op(std::uint32_t{}); - case data_type::int64: return op(std::int64_t{}); - case data_type::uint64: return op(std::uint64_t{}); - case data_type::float32: return op(float{}); - case data_type::float64: return op(double{}); - default: return on_unknown(dtype); - } -} - -template > -inline constexpr ResultType dispatch_by_data_type(data_type dtype, Op&& op) { - // Necessary to make the return type conformant with - // other dispatch branches - const auto on_unknown = [](data_type) -> ResultType { - using msg = dal::detail::error_messages; - throw unimplemented{ msg::unsupported_conversion_types() }; - }; - - return dispatch_by_data_type(dtype, std::forward(op), on_unknown); -} - -namespace impl { - -template -struct type_holder { - using result_t = Result; - - template - using add_tail = type_holder; - - template - constexpr static inline Result evaluate(Op&& op) { - return op(Types{}...); - } -}; - -template -inline constexpr auto multi_dispatch_by_data_type(Op&& op) { - return TypeHolder::evaluate(std::forward(op)); -} - -template -inline constexpr auto multi_dispatch_by_data_type(Op&& op, Head&& head, Tail&&... tail) { - using result_t = typename TypeHolder::result_t; - const auto functor = [&](auto arg) -> result_t { - using type_t = std::decay_t; - using holder_t = typename TypeHolder::template add_tail; - return multi_dispatch_by_data_type( // - std::forward(op), - std::forward(tail)...); - }; - return dispatch_by_data_type(head, functor); -} - -template -struct invoke_result_multiple_impl { - using next_t = invoke_result_multiple_impl; - using type = typename next_t::type; -}; - -template -struct invoke_result_multiple_impl<0ul, DefaultType, Op, Types...> { - using type = std::invoke_result_t; -}; - -template -using invoke_result_multiple_t = typename invoke_result_multiple_impl::type; - -} // namespace impl - -// Signature of this function is slightly different from -// a simple `dispatch_by_data_type` due to inconsistency -// with a `std::visit` which it heavily resembles -template -inline constexpr ResultType multi_dispatch_by_data_type(Op&& op, Types&&... types) { - using holder_t = impl::type_holder; - return impl::multi_dispatch_by_data_type( // - std::forward(op), - std::forward(types)...); -} - -template -inline constexpr auto multi_dispatch_by_data_type(Op&& op, Types&&... types) { - using result_t = impl::invoke_result_multiple_t; - return multi_dispatch_by_data_type( // - std::forward(op), - std::forward(types)...); -} +using detail::v1::dispatch_by_data_type; +using detail::v1::multi_dispatch_by_data_type; } // namespace oneapi::dal::backend diff --git a/cpp/oneapi/dal/detail/dtype_dispatcher.hpp b/cpp/oneapi/dal/detail/dtype_dispatcher.hpp new file mode 100644 index 00000000000..6141b68df14 --- /dev/null +++ b/cpp/oneapi/dal/detail/dtype_dispatcher.hpp @@ -0,0 +1,128 @@ +/******************************************************************************* +* Copyright contributors to the oneDAL project +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#pragma once + +#include + +#include "oneapi/dal/common.hpp" + +namespace oneapi::dal::detail { +namespace v1 { + +template +inline constexpr auto dispatch_by_data_type(data_type dtype, Op&& op, OnUnknown&& on_unknown) { + switch (dtype) { + case data_type::int8: return op(std::int8_t{}); + case data_type::uint8: return op(std::uint8_t{}); + case data_type::int16: return op(std::int16_t{}); + case data_type::uint16: return op(std::uint16_t{}); + case data_type::int32: return op(std::int32_t{}); + case data_type::uint32: return op(std::uint32_t{}); + case data_type::int64: return op(std::int64_t{}); + case data_type::uint64: return op(std::uint64_t{}); + case data_type::float32: return op(float{}); + case data_type::float64: return op(double{}); + default: return on_unknown(dtype); + } +} + +template > +inline constexpr ResultType dispatch_by_data_type(data_type dtype, Op&& op) { + // Necessary to make the return type conformant with + // other dispatch branches + const auto on_unknown = [](data_type) -> ResultType { + using msg = dal::detail::error_messages; + throw unimplemented{ msg::unsupported_conversion_types() }; + }; + + return dispatch_by_data_type(dtype, std::forward(op), on_unknown); +} + +namespace impl { + +template +struct type_holder { + using result_t = Result; + + template + using add_tail = type_holder; + + template + constexpr static inline Result evaluate(Op&& op) { + return op(Types{}...); + } +}; + +template +inline constexpr auto multi_dispatch_by_data_type(Op&& op) { + return TypeHolder::evaluate(std::forward(op)); +} + +template +inline constexpr auto multi_dispatch_by_data_type(Op&& op, Head&& head, Tail&&... tail) { + using result_t = typename TypeHolder::result_t; + const auto functor = [&](auto arg) -> result_t { + using type_t = std::decay_t; + using holder_t = typename TypeHolder::template add_tail; + return multi_dispatch_by_data_type( // + std::forward(op), + std::forward(tail)...); + }; + return dispatch_by_data_type(head, functor); +} + +template +struct invoke_result_multiple_impl { + using next_t = invoke_result_multiple_impl; + using type = typename next_t::type; +}; + +template +struct invoke_result_multiple_impl<0ul, DefaultType, Op, Types...> { + using type = std::invoke_result_t; +}; + +template +using invoke_result_multiple_t = typename invoke_result_multiple_impl::type; + +} // namespace impl + +// Signature of this function is slightly different from +// a simple `dispatch_by_data_type` due to inconsistency +// with a `std::visit` which it heavily resembles +template +inline constexpr ResultType multi_dispatch_by_data_type(Op&& op, Types&&... types) { + using holder_t = impl::type_holder; + return impl::multi_dispatch_by_data_type( // + std::forward(op), + std::forward(types)...); +} + +template +inline constexpr auto multi_dispatch_by_data_type(Op&& op, Types&&... types) { + using result_t = impl::invoke_result_multiple_t; + return multi_dispatch_by_data_type( // + std::forward(op), + std::forward(types)...); +} + +} // namespace v1 + +using v1::dispatch_by_data_type; +using v1::multi_dispatch_by_data_type; + +} // namespace oneapi::dal::detail