Skip to content

Commit

Permalink
Feature: Data type support in read function (#2731)
Browse files Browse the repository at this point in the history
  • Loading branch information
DDJHB authored Apr 22, 2024
1 parent 01060a3 commit a6e75ea
Show file tree
Hide file tree
Showing 12 changed files with 253 additions and 84 deletions.
48 changes: 27 additions & 21 deletions cpp/oneapi/dal/io/csv/backend/cpu/read_kernel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -34,26 +34,32 @@ namespace oneapi::dal::csv::backend {
namespace interop = dal::backend::interop;
namespace daal_dm = daal::data_management;

template <>
table read_kernel_cpu<table>::operator()(const dal::backend::context_cpu& ctx,
const detail::data_source_base& ds,
const read_args<table>& args) const {
daal_dm::CsvDataSourceOptions csv_options(daal_dm::operator|(
daal_dm::operator|(daal_dm::CsvDataSourceOptions::allocateNumericTable,
daal_dm::CsvDataSourceOptions::createDictionaryFromContext),
(ds.get_parse_header() ? daal_dm::CsvDataSourceOptions::parseHeader
: daal_dm::CsvDataSourceOptions::byDefault)));

daal_dm::FileDataSource<daal_dm::CSVFeatureManager> daal_data_source(ds.get_file_name().c_str(),
csv_options);
interop::status_to_exception(daal_data_source.status());

daal_data_source.getFeatureManager().setDelimiter(ds.get_delimiter());
daal_data_source.loadDataBlock();
interop::status_to_exception(daal_data_source.status());

return oneapi::dal::backend::interop::convert_from_daal_homogen_table<DAAL_DATA_TYPE>(
daal_data_source.getNumericTable());
}
template <typename Float>
struct read_kernel_cpu<table, Float> {
table operator()(const dal::backend::context_cpu& ctx,
const detail::data_source_base& ds,
const read_args<table>& args) const {
daal_dm::CsvDataSourceOptions csv_options(daal_dm::operator|(
daal_dm::operator|(daal_dm::CsvDataSourceOptions::allocateNumericTable,
daal_dm::CsvDataSourceOptions::createDictionaryFromContext),
(ds.get_parse_header() ? daal_dm::CsvDataSourceOptions::parseHeader
: daal_dm::CsvDataSourceOptions::byDefault)));

daal_dm::FileDataSource<daal_dm::CSVFeatureManager> daal_data_source(
ds.get_file_name().c_str(),
csv_options);
interop::status_to_exception(daal_data_source.status());

daal_data_source.getFeatureManager().setDelimiter(ds.get_delimiter());
daal_data_source.loadDataBlock();
interop::status_to_exception(daal_data_source.status());

return oneapi::dal::backend::interop::convert_from_daal_homogen_table<Float>(
daal_data_source.getNumericTable());
}
};

template struct read_kernel_cpu<table, float>;
template struct read_kernel_cpu<table, double>;

} // namespace oneapi::dal::csv::backend
2 changes: 1 addition & 1 deletion cpp/oneapi/dal/io/csv/backend/cpu/read_kernel.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@

namespace oneapi::dal::csv::backend {

template <typename Object>
template <typename Object, typename Float>
struct read_kernel_cpu {
Object operator()(const dal::backend::context_cpu& ctx,
const detail::data_source_base& ds,
Expand Down
2 changes: 1 addition & 1 deletion cpp/oneapi/dal/io/csv/backend/gpu/read_kernel.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@

namespace oneapi::dal::csv::backend {

template <typename Object>
template <typename Object, typename Float>
struct read_kernel_gpu {
Object operator()(const dal::backend::context_gpu& ctx,
const detail::data_source_base& ds,
Expand Down
85 changes: 45 additions & 40 deletions cpp/oneapi/dal/io/csv/backend/gpu/read_kernel_dpc.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -36,45 +36,50 @@ namespace oneapi::dal::csv::backend {
namespace interop = dal::backend::interop;
namespace daal_dm = daal::data_management;

template <>
table read_kernel_gpu<table>::operator()(const dal::backend::context_gpu& ctx,
const detail::data_source_base& ds,
const read_args<table>& args) const {
auto& queue = ctx.get_queue();

daal_dm::CsvDataSourceOptions csv_options(daal_dm::operator|(
daal_dm::operator|(daal_dm::CsvDataSourceOptions::allocateNumericTable,
daal_dm::CsvDataSourceOptions::createDictionaryFromContext),
(ds.get_parse_header() ? daal_dm::CsvDataSourceOptions::parseHeader
: daal_dm::CsvDataSourceOptions::byDefault)));

daal_dm::FileDataSource<daal_dm::CSVFeatureManager> daal_data_source(ds.get_file_name().c_str(),
csv_options);
interop::status_to_exception(daal_data_source.status());

daal_data_source.getFeatureManager().setDelimiter(ds.get_delimiter());
daal_data_source.loadDataBlock();
interop::status_to_exception(daal_data_source.status());

auto nt = daal_data_source.getNumericTable();

daal_dm::BlockDescriptor<DAAL_DATA_TYPE> block;
const std::int64_t row_count = nt->getNumberOfRows();
const std::int64_t column_count = nt->getNumberOfColumns();

interop::status_to_exception(nt->getBlockOfRows(0, row_count, daal_dm::readOnly, block));
DAAL_DATA_TYPE* data = block.getBlockPtr();

auto arr =
array<DAAL_DATA_TYPE>::empty(queue, row_count * column_count, sycl::usm::alloc::device);
dal::detail::memcpy_host2usm(queue,
arr.get_mutable_data(),
data,
sizeof(DAAL_DATA_TYPE) * row_count * column_count);

interop::status_to_exception(nt->releaseBlockOfRows(block));

return dal::detail::homogen_table_builder{}.reset(arr, row_count, column_count).build();
}
template <typename Float>
struct read_kernel_gpu<table, Float> {
table operator()(const dal::backend::context_gpu& ctx,
const detail::data_source_base& ds,
const read_args<table>& args) const {
auto& queue = ctx.get_queue();

daal_dm::CsvDataSourceOptions csv_options(daal_dm::operator|(
daal_dm::operator|(daal_dm::CsvDataSourceOptions::allocateNumericTable,
daal_dm::CsvDataSourceOptions::createDictionaryFromContext),
(ds.get_parse_header() ? daal_dm::CsvDataSourceOptions::parseHeader
: daal_dm::CsvDataSourceOptions::byDefault)));

daal_dm::FileDataSource<daal_dm::CSVFeatureManager> daal_data_source(
ds.get_file_name().c_str(),
csv_options);
interop::status_to_exception(daal_data_source.status());

daal_data_source.getFeatureManager().setDelimiter(ds.get_delimiter());
daal_data_source.loadDataBlock();
interop::status_to_exception(daal_data_source.status());

auto nt = daal_data_source.getNumericTable();

daal_dm::BlockDescriptor<Float> block;
const std::int64_t row_count = nt->getNumberOfRows();
const std::int64_t column_count = nt->getNumberOfColumns();

interop::status_to_exception(nt->getBlockOfRows(0, row_count, daal_dm::readOnly, block));
Float* data = block.getBlockPtr();

auto arr = array<Float>::empty(queue, row_count * column_count, sycl::usm::alloc::device);
dal::detail::memcpy_host2usm(queue,
arr.get_mutable_data(),
data,
sizeof(Float) * row_count * column_count);

interop::status_to_exception(nt->releaseBlockOfRows(block));

return dal::detail::homogen_table_builder{}.reset(arr, row_count, column_count).build();
}
};

template struct read_kernel_gpu<table, float>;
template struct read_kernel_gpu<table, double>;

} // namespace oneapi::dal::csv::backend
19 changes: 19 additions & 0 deletions cpp/oneapi/dal/io/csv/common.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,9 @@ namespace oneapi::dal::csv {
namespace detail {
namespace v1 {

template <typename Float>
constexpr bool is_valid_float_v = dal::detail::is_one_of_v<Float, float, double>;

struct data_source_tag {};
class data_source_impl;

Expand Down Expand Up @@ -65,32 +68,48 @@ class ONEDAL_EXPORT data_source_base : public base {
using v1::data_source_tag;
using v1::data_source_impl;
using v1::data_source_base;
using v1::is_valid_float_v;

} // namespace detail

namespace v1 {

/// Used for the specification of data source configuration.
///
/// @tparam Float The type of the floating-point that the data source will operate with.
/// Must be a floating-point type.
template <typename Float = float>
class data_source : public detail::data_source_base {
static_assert(detail::is_valid_float_v<Float>);

public:
using float_t = Float;

/// Constructs a data_source object from a C-style string file name.
explicit data_source(const char* file_name) : data_source_base(file_name) {}

/// Constructs a data_source from C++-style std::string file name.
explicit data_source(const std::string& file_name) : data_source_base(file_name.c_str()) {}

/// Sets the delimiter character for parsing the data source file.
auto& set_delimiter(char value) {
set_delimiter_impl(value);
return *this;
}

/// Specifies whether to parse the header of the data source file.
auto& set_parse_header(bool value) {
set_parse_header_impl(value);
return *this;
}

/// Sets the file name for the data source via the C-style string.
auto& set_file_name(const char* value) {
set_file_name_impl(value);
return *this;
}

/// Sets the file name for the data source via the C++-style std::string.
auto& set_file_name(const std::string& value) {
set_file_name_impl(value.c_str());
return *this;
Expand Down
15 changes: 11 additions & 4 deletions cpp/oneapi/dal/io/csv/detail/read_ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,13 +22,20 @@ namespace oneapi::dal::csv::detail {
namespace v1 {

using dal::detail::host_policy;
table read_ops_dispatcher<table, host_policy>::operator()(const host_policy& policy,
const data_source_base& ds,
const read_args<table>& args) const {

template <typename Float>
table read_ops_dispatcher<table, Float, host_policy>::operator()(
const host_policy& policy,
const data_source_base& ds,
const read_args<table>& args) const {
using kernel_dispatcher_t = dal::backend::kernel_dispatcher< //
KERNEL_SINGLE_NODE_CPU(backend::read_kernel_cpu<table>)>;
KERNEL_SINGLE_NODE_CPU(backend::read_kernel_cpu<table, Float>)>;
return kernel_dispatcher_t()(policy, ds, args);
}

#define INSTANTIATE(F) template struct ONEDAL_EXPORT read_ops_dispatcher<table, F, host_policy>;
INSTANTIATE(float)
INSTANTIATE(double)

} // namespace v1
} // namespace oneapi::dal::csv::detail
28 changes: 15 additions & 13 deletions cpp/oneapi/dal/io/csv/detail/read_ops.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,11 +24,11 @@
namespace oneapi::dal::csv::detail {
namespace v1 {

template <typename Object, typename Policy, typename... Options>
template <typename Object, typename Float, typename Policy, typename... Options>
struct read_ops_dispatcher;

template <typename Object>
struct read_ops_dispatcher<Object, dal::detail::host_policy> {
template <typename Object, typename Float>
struct read_ops_dispatcher<Object, Float, dal::detail::host_policy> {
Object operator()(const dal::detail::host_policy& policy,
const data_source_base& ds,
const dal::preview::csv::read_args<Object>& args) const {
Expand All @@ -40,17 +40,17 @@ struct read_ops_dispatcher<Object, dal::detail::host_policy> {
}
};

template <>
struct ONEDAL_EXPORT read_ops_dispatcher<table, dal::detail::host_policy> {
template <typename Float>
struct read_ops_dispatcher<table, Float, dal::detail::host_policy> {
table operator()(const dal::detail::host_policy& policy,
const data_source_base& ds,
const dal::csv::read_args<table>& args) const;
};

#ifdef ONEDAL_DATA_PARALLEL

template <>
struct ONEDAL_EXPORT read_ops_dispatcher<table, dal::detail::data_parallel_policy> {
template <typename Float>
struct read_ops_dispatcher<table, Float, dal::detail::data_parallel_policy> {
table operator()(const dal::detail::data_parallel_policy& ctx,
const data_source_base& ds,
const dal::csv::read_args<table>& args) const;
Expand All @@ -61,8 +61,9 @@ struct ONEDAL_EXPORT read_ops_dispatcher<table, dal::detail::data_parallel_polic
template <typename Object, typename DataSource>
struct read_ops;

template <typename Object>
struct read_ops<Object, data_source> {
template <typename Object, typename DataSource>
struct read_ops {
using float_t = typename DataSource::float_t;
using input_t = dal::preview::csv::read_args<Object>;
using result_t = Object;

Expand All @@ -75,14 +76,15 @@ struct read_ops<Object, data_source> {
template <typename Policy>
auto operator()(const Policy& ctx, const data_source_base& ds, const input_t& args) const {
check_preconditions(ds, args);
auto result = read_ops_dispatcher<Object, Policy>()(ctx, ds, args);
auto result = read_ops_dispatcher<Object, float_t, Policy>()(ctx, ds, args);
check_postconditions(ds, args, result);
return result;
}
};

template <>
struct read_ops<table, data_source> {
template <typename DataSource>
struct read_ops<table, DataSource> {
using float_t = typename DataSource::float_t;
using input_t = read_args<table>;
using result_t = table;

Expand All @@ -95,7 +97,7 @@ struct read_ops<table, data_source> {
template <typename Policy>
auto operator()(const Policy& ctx, const data_source_base& ds, const input_t& args) const {
check_preconditions(ds, args);
const auto result = read_ops_dispatcher<table, Policy>()(ctx, ds, args);
const auto result = read_ops_dispatcher<table, float_t, Policy>()(ctx, ds, args);
check_postconditions(ds, args, result);
return result;
}
Expand Down
14 changes: 10 additions & 4 deletions cpp/oneapi/dal/io/csv/detail/read_ops_dpc.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -23,15 +23,21 @@ namespace oneapi::dal::csv::detail {
namespace v1 {

using dal::detail::data_parallel_policy;
table read_ops_dispatcher<table, data_parallel_policy>::operator()(
template <typename Float>
table read_ops_dispatcher<table, Float, data_parallel_policy>::operator()(
const data_parallel_policy& ctx,
const data_source_base& ds,
const read_args<table>& args) const {
using kernel_dispatcher_t =
dal::backend::kernel_dispatcher<KERNEL_SINGLE_NODE_CPU(backend::read_kernel_cpu<table>),
KERNEL_SINGLE_NODE_GPU(backend::read_kernel_gpu<table>)>;
using kernel_dispatcher_t = dal::backend::kernel_dispatcher<
KERNEL_SINGLE_NODE_CPU(backend::read_kernel_cpu<table, Float>),
KERNEL_SINGLE_NODE_GPU(backend::read_kernel_gpu<table, Float>)>;
return kernel_dispatcher_t{}(ctx, ds, args);
}

#define INSTANTIATE(F) \
template struct ONEDAL_EXPORT read_ops_dispatcher<table, F, data_parallel_policy>;
INSTANTIATE(float)
INSTANTIATE(double)

} // namespace v1
} // namespace oneapi::dal::csv::detail
11 changes: 11 additions & 0 deletions examples/oneapi/cpp/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,17 @@ dal_example_suite(
extra_deps = _TEST_DEPS,
)

dal_example_suite(
name = "misc",
compile_as = [ "c++" ],
srcs = glob(["source/misc/*.cpp"]),
dal_deps = [
"@onedal//cpp/oneapi/dal:io",
],
data = _DATA_DEPS,
extra_deps = _TEST_DEPS,
)

dal_example_suite(
name = "jaccard",
compile_as = [ "c++" ],
Expand Down
Loading

0 comments on commit a6e75ea

Please sign in to comment.